Unverified Commit 667632cc authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents d6dd2ddf a874e4e8
# Programming Guides Overview
This section provides a practical guide to writing high‑performance kernels with Tile Language (tile‑lang).
It mirrors the structure of a similar guide in another project and adapts it to tile‑lang concepts and APIs.
- Audience: Developers implementing custom GPU/CPU kernels with tile‑lang
- Prereqs: Basic Python, NumPy/Tensor concepts, and familiarity with GPU programming notions
- Scope: Language basics, control flow, instructions, autotuning, and type system
## What You’ll Learn
- How to structure kernels with TileLang’s core DSL constructs
- How to move data across global/shared/fragment and pipeline compute
- How to apply autotuning to tile sizes and schedules
- How to specify and work with dtypes in kernels
## Suggested Reading Order
1. Language Basics
2. Control Flow
3. Instructions
4. Autotuning
5. Type System
## Related Docs
- Tutorials: see existing guides in `tutorials/`
- Operators: examples in `deeplearning_operators/`
> NOTE: This is a draft scaffold. Fill in code snippets and benchmarks as APIs evolve.
# Type System
This page lists the data types supported by TileLang and how to specify them in
kernels. For full details and the authoritative list, see the API Reference
(`autoapi/tilelang/index`) and `tilelang.language.v2.dtypes`.
How to specify dtypes
- Use any of the following forms; TileLang normalizes them internally:
- String: `'float32'`, `'int8'`, `'bfloat16'`, ...
- TileLang dtype object: `T.float32`, `T.int8`, `T.bfloat16`, ...
- Framework dtype: `torch.float32`, `torch.int8`, `torch.bfloat16`, ...
Common scalar types
- Boolean: `bool`
- Signed integers: `int8`, `int16`, `int32`, `int64`
- Unsigned integers: `uint8`, `uint16`, `uint32`, `uint64`
- Floating‑point: `float16` (half), `bfloat16`, `float32`, `float64`
Float8 and low‑precision families
- Float8: `float8_e3m4`, `float8_e4m3`, `float8_e4m3b11fnuz`, `float8_e4m3fn`,
`float8_e4m3fnuz`, `float8_e5m2`, `float8_e5m2fnuz`, `float8_e8m0fnu`
- Float6: `float6_e2m3fn`, `float6_e3m2fn`
- Float4: `float4_e2m1fn`
Vectorized element types (SIMD packs)
- For many base types, vector‑packed variants are available by lane count:
`x2`, `x4`, `x8`, `x16`, `x32`, `x64`.
- Examples:
- Integers: `int8x2`, `int8x4`, ..., `int32x2`, `int32x4`, ...
- Unsigned: `uint8x2`, `uint8x4`, ...
- Floats: `float16x2`, `float16x4`, `float32x2`, `float32x4`, ...
- Float8/6/4 families also provide `x2/x4/x8/x16/x32/x64` where applicable,
e.g., `float8_e4m3x2`, `float8_e4m3x4`, `float6_e2m3fnx8`, `float4_e2m1fnx16`.
Notes
- Availability of certain low‑precision formats (float8/6/4) depends on target
architecture and backend support.
- Choose accumulation dtypes explicitly for mixed‑precision compute (e.g.,
GEMM with `float16` inputs and `float32` accumulators).
- The complete, up‑to‑date list is exposed in
`tilelang.language.v2.dtypes` and rendered in the API Reference.
......@@ -171,6 +171,32 @@ The output messages will include something like:
msg='hello world' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): 0
```
### Visual Layout Inference For TileLang
The **Visual Layout Inference** tool automatically generates visual diagrams that illustrate the mapping between logical indices, thread IDs, and register file locations.
When TileLang performs layout inference, it determines how fragment buffers are distributed across threads. The visual layout tool captures this information and generates:
1. **Textual output**: A human-readable description of the layout mapping
2. **Visual diagrams**: Color-coded plots showing the thread-to-data mapping
The visual layout inference tool is controlled through the `TL_LAYOUT_VISUALIZATION_ENABLE` and `TL_LAYOUT_VISUALIZATION_FORMATS` pass configuration. By default, `TL_LAYOUT_VISUALIZATION_ENABLE` is **disabled** to avoid performance overhead during compilation.
When enabled, `TL_LAYOUT_VISUALIZATION_FORMATS` accepts string values to control output formats:
- "txt": Text output only (same as default)
- "all": Generates all formats (TXT, PDF, PNG, SVG)
- "png": Generate PNG format only
- "pdf": Generate PDF format only
- "svg": Generate SVG format only
- "txt,svg": Generate multiple formats (comma-separated) in addition to text output
The output messages of "txt" will include something like:
```
C_local inferenced layout:
Shape: [32, 32] -> [8]
Thread: _j // 16 * 64 + _i // 16 * 32 + _i % 8 * 4 + _j % 8 // 2
Index: [_j % 16 // 8 * 4 + _i % 16 // 8 * 2 + _j % 2]
```
## Conclusion
By carefully examining intermediate representations (IR) before final code generation—and by leveraging runtime printing through `T.print`—one can quickly diagnose where index calculations, copy logic, or other kernel operations deviate from the intended behavior. This two-pronged approach (inspecting IR transformations and using runtime prints) is often sufficient for resolving generation and correctness issues in TileLang programs.
......
Logging in Tilelang/TVM
===================================================
<div style="text-align: left;">
<em>Author:</em> <a href="https://github.com/SiriusNEO">SiriusNEO</a>
</div>
## TVM Logging Overview
Tilelang currently utilizes the logging system from TVM. The implementation can be found in:
- [include/tvm/runtime/logging.h](https://github.com/apache/tvm/blob/main/include/tvm/runtime/logging.h): Macro definitions
- [src/runtime/logging.cc](https://github.com/apache/tvm/blob/main/src/runtime/logging.cc): Logging logic implementation
The design style is inspired by [Google's glog](https://google.github.io/glog/stable/).
## Logging Categories
There are three primary macro types:
```c++
LOG(INFO) << "aaa";
DLOG(INFO) << "aaa";
VLOG(1) << "aaa";
```
- **LOG**: Standard logging preserved in code for displaying necessary information at different levels during runtime. Most Tilelang C++ error reporting is implemented via `LOG(FATAL) << "error msg"`.
- **DLOG**: Debug logging for developer debugging output. DLOG is controlled at build time by the TVM_LOG_DEBUG environment variable and is **eliminated in Release builds through dead code elimination**.
- The key difference between LOG(DEBUG) and DLOG is this build-time elimination. We recommend using DLOG over LOG(DEBUG), as the latter has overlapping functionality and gets compiled into the release runtime.
- **VLOG**: [Verbose logging](https://google.github.io/glog/stable/logging/#verbose-logging), primarily for debugging. Its main feature is customizable verbosity levels. For example, VLOG(n) where n can be 1, 2, 3, 4, 5, or 6, enabling complex tracing requirements. In contrast, LOG and DLOG typically use predefined verbose levels like INFO and DEBUG.
- In practical Tilelang development, VLOG is used less frequently.
- TVM's VLOG is implemented using DLOG, thus inheriting DLOG's characteristics.
Additional useful macros include various **CHECK** variants:
```c++
CHECK(cond) << "error msg";
DCHECK(cond) << "error msg";
ICHECK(cond) << "error msg";
```
The implementation routes errors to LogFatal:
```c++
#define CHECK(x) \
if (!(x)) \
::tvm::runtime::detail::LogFatal(__FILE__, __LINE__).stream() \
<< "Check failed: (" #x << ") is false: "
```
- **DCHECK**: Debug mode CHECK, only compiled in debug builds
- **ICHECK**: Internal Check that should exist in Release builds. When ICHECK fails, the entire system should report an error.
## Logging Verbose Levels
TVM defines 5 levels for LOG and DLOG (adding DEBUG compared to glog):
```c++
#define TVM_LOG_LEVEL_DEBUG 0
#define TVM_LOG_LEVEL_INFO 1
#define TVM_LOG_LEVEL_WARNING 2
#define TVM_LOG_LEVEL_ERROR 3
#define TVM_LOG_LEVEL_FATAL 4
```
## Using Logging in TileLang Development
### Guidelines
For temporary debugging output in your code, there are no restrictions (you can even use std::cout). Just remember to remove it before submitting a PR.
For meaningful logging that should remain in the Tilelang codebase:
- Critical correctness checks: Use ICHECK with sufficient error messages to facilitate debugging when issues arise.
- Complex Pass debugging: For passes requiring intermediate output that may need future review (e.g., LayoutInference), use DLOG.
- General INFO/WARNING messages: Use standard LOG.
### Enabling Log Output in Tilelang
To specify current log level at runtime, we need to set the environment variable `TVM_LOG_LEVEL`. An example usage is:
```c++
TVM_LOG_DEBUG=1 python3 code.py
```
which enables all DEBUG/INFO (level <= 1) logs for all files.
#### Detailed Rules for TVM_LOG_DEBUG Specification
The parsing logic is in `logging.cc`. Reference: [HyperAI Zhihu Article](https://zhuanlan.zhihu.com/p/1933106843468665163).
Launch Python with `TVM_LOG_DEBUG=<spec>`, where `<spec>` is a comma-separated list of level assignments in the form `<file_name>=<level>`. Important notes:
- The special filename DEFAULT sets the LOG level for all files.
- `<level>` can be set to -1 to disable LOG for that file.
- `<file_name>` is the C++ source filename (e.g., .cc, not .h) relative to the `src/` directory in the TVM repository. The `src/` prefix is optional when specifying file paths.
### Enabling Debug Mode
To enable DLOG/DCHECK, developers need to first build Tilelang in Debug mode:
```bash
cmake .. -DCMAKE_BUILD_TYPE=Debug -DUSE_CUDA=ON
```
Tilelang's CMake logic automatically adds the `TVM_LOG_DEBUG` macro, compiling all DLOG statements:
```cmake
target_compile_definitions(tilelang_objs PRIVATE "TVM_LOG_DEBUG")
```
Then you also need to specify the runtime environment variables. For example, to use `DLOG(INFO) << "xxx"` for debugging, run your code with INFO level (1): `TVM_LOG_DEBUG=1`.
:::{note}
**Important**: There are two TVM_LOG_DEBUG variables. (1) Compile-time macro: Determines whether debug content (like DLOG) is compiled into the .so file. Referenced in C++ source via #ifdef TVM_LOG_DEBUG. This is automatically enabled when using Debug build mode in CMake. (2) Runtime environment variable: Controls logging level at runtime. TVM provides a specification for this variable, allowing control over per-file logging levels.
These two should ideally have different names, but TVM uses the same name for both, which can cause confusion.
:::
......@@ -2,7 +2,7 @@ import torch
import torch.nn.functional as F
import tilelang
import tilelang.language as T
from tilelang.primitives.gemm.base import GemmWarpPolicy
from tilelang.tileop.base import GemmWarpPolicy
import itertools
import argparse
from functools import partial
......@@ -11,22 +11,20 @@ import time
def ref_program(Q, K, V, is_causal, groups=1):
assert Q.size(
2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}"
assert Q.size(
2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}"
assert Q.size(2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}"
assert Q.size(2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}"
dim = Q.size(-1)
K_ref = K.repeat_interleave(groups, dim=2)
V_ref = V.repeat_interleave(groups, dim=2)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K_ref)
scores = torch.einsum("bqhd,bkhd->bhqk", Q, K_ref)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V_ref)
output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V_ref)
lse = torch.logsumexp(scores, dim=-1).float()
return output, lse
......@@ -45,23 +43,23 @@ def get_fwd_configs():
valid_configs = []
for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product(block_M, block_N, num_split_q,
threads, num_stages,
enable_rasterization, k_pack,
panel_size, qk_coalesced_width,
v_coalesced_width):
valid_configs.append({
"block_M": m,
"block_N": n,
"num_split_q": s,
"threads": t,
"num_stages": stages,
"enable_rasterization": r,
"k_pack": k,
"panel_size": p,
"qk_coalesced_width": qkw,
"v_coalesced_width": vw,
})
for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product(
block_M, block_N, num_split_q, threads, num_stages, enable_rasterization, k_pack, panel_size, qk_coalesced_width, v_coalesced_width
):
valid_configs.append(
{
"block_M": m,
"block_N": n,
"num_split_q": s,
"threads": t,
"num_stages": stages,
"enable_rasterization": r,
"k_pack": k,
"panel_size": p,
"qk_coalesced_width": qkw,
"v_coalesced_width": vw,
}
)
return valid_configs
......@@ -85,23 +83,23 @@ def fast_flashattn(
qk_coalesced_width: int,
v_coalesced_width: int,
):
scale = (1.0 / dim)**0.5
scale = (1.0 / dim) ** 0.5
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
vec_size = qk_coalesced_width
v_vec_size = v_coalesced_width
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
LSE: T.Tensor([batch, heads, seq_len], accum_dtype),
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
LSE: T.Tensor([batch, heads, seq_len], accum_dtype),
):
with T.Kernel(num_split_q, batch * heads, threads=threads) as (b_split, byz_combined):
T.use_swizzle(panel_size, enable=enable_rasterization)
......@@ -111,7 +109,7 @@ def fast_flashattn(
num_q_blocks = T.ceildiv(seq_len, block_M)
bx_loop_var = T.alloc_var("int32")
bx_loop_var = T.alloc_var(T.int32)
bx_loop_var = b_split
with T.While(bx_loop_var < num_q_blocks):
......@@ -135,33 +133,21 @@ def fast_flashattn(
m_prev = T.alloc_fragment([block_M], accum_dtype)
scale_factor = T.alloc_fragment([block_M], accum_dtype)
T.copy(
Q[bz, q_block_offset:q_block_offset + block_M, by, :],
Q_shared,
coalesced_width=vec_size)
T.copy(Q[bz, q_block_offset : q_block_offset + block_M, by, :], Q_shared, coalesced_width=vec_size)
loop_end_k = (
T.ceildiv(q_block_offset +
block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N))
loop_end_k = T.ceildiv(q_block_offset + block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)
row_sum = T.alloc_fragment([block_M], accum_dtype)
for k in T.Pipelined(loop_end_k, num_stages=num_stages):
kv_idx = k * block_N
T.copy(
K[bz, kv_idx:kv_idx + block_N, by // groups, :],
K_shared,
coalesced_width=vec_size)
T.copy(
V[bz, kv_idx:kv_idx + block_N, by // groups, :],
V_shared,
coalesced_width=v_vec_size)
T.copy(K[bz, kv_idx : kv_idx + block_N, by // groups, :], K_shared, coalesced_width=vec_size)
T.copy(V[bz, kv_idx : kv_idx + block_N, by // groups, :], V_shared, coalesced_width=v_vec_size)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0,
-T.infinity(acc_s.dtype))
acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(
......@@ -178,6 +164,8 @@ def fast_flashattn(
T.copy(m_i, m_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for i in T.Parallel(block_M):
m_i[i] = T.max(m_i[i], m_prev[i])
for i in T.Parallel(block_M):
if m_prev[i] == -T.infinity(accum_dtype):
......@@ -214,8 +202,7 @@ def fast_flashattn(
for i in T.Parallel(block_M):
if q_block_offset + i < seq_len:
lse_val = T.if_then_else(l_i[i] > 0,
T.log(l_i[i]) + m_i[i], -T.infinity(accum_dtype))
lse_val = T.if_then_else(l_i[i] > 0, T.log(l_i[i]) + m_i[i], -T.infinity(accum_dtype))
LSE[bz, by, q_block_offset + i] = lse_val
bx_loop_var = current_bx + num_split_q
......@@ -232,30 +219,30 @@ def get_bwd_configs():
panel_size = [7, 8, 9, 10]
configs = []
for m, n, stages, t, r, p in itertools.product(block_M, block_N, num_stages, threads,
enable_rasterization, panel_size):
configs.append({
"block_M": m,
"block_N": n,
"num_stages": stages,
"threads": t,
"enable_rasterization": r,
"panel_size": p,
})
for m, n, stages, t, r, p in itertools.product(block_M, block_N, num_stages, threads, enable_rasterization, panel_size):
configs.append(
{
"block_M": m,
"block_N": n,
"num_stages": stages,
"threads": t,
"enable_rasterization": r,
"panel_size": p,
}
)
return configs
@tilelang.jit(out_idx=[2])
def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
shape = [batch, seq_len, heads, dim]
blk = 32
@T.prim_func
def flash_bwd_prep(O: T.Tensor(shape, dtype), dO: T.Tensor(shape, dtype),
Delta: T.Tensor([batch, heads, seq_len], accum_dtype)):
def flash_bwd_prep(O: T.Tensor(shape, dtype), dO: T.Tensor(shape, dtype), Delta: T.Tensor([batch, heads, seq_len], accum_dtype)):
with T.Kernel(batch, heads, T.ceildiv(seq_len, blk)) as (bz, bx, by):
o = T.alloc_fragment([blk, blk], dtype)
do = T.alloc_fragment([blk, blk], dtype)
......@@ -263,36 +250,51 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
delta = T.alloc_fragment([blk], accum_dtype)
T.clear(acc)
for k in range(T.ceildiv(dim, blk)):
T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o)
T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do)
T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o)
T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do)
for i, j in T.Parallel(blk, blk):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk])
T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk])
return flash_bwd_prep
@tilelang.autotune(configs=get_bwd_configs(), cache_input_tensors=True)
@tilelang.jit
def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, block_N: int,
num_stages: int, threads: int, enable_rasterization: bool, panel_size: int):
sm_scale = (1.0 / dim)**0.5
def flashattn_bwd(
batch,
heads,
seq_len,
dim,
is_causal,
groups,
block_M: int,
block_N: int,
num_stages: int,
threads: int,
enable_rasterization: bool,
panel_size: int,
):
sm_scale = (1.0 / dim) ** 0.5
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_bwd_kernel(Q: T.Tensor(q_shape,
dtype), K: T.Tensor(kv_shape,
dtype), V: T.Tensor(kv_shape, dtype),
dO: T.Tensor(q_shape, dtype), lse: T.Tensor([batch, heads, seq_len],
accum_dtype),
Delta: T.Tensor([batch, heads, seq_len],
accum_dtype), dQ: T.Tensor(q_shape, accum_dtype),
dK: T.Tensor(kv_shape, accum_dtype), dV: T.Tensor(kv_shape, accum_dtype)):
def flash_bwd_kernel(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
dO: T.Tensor(q_shape, dtype),
lse: T.Tensor([batch, heads, seq_len], accum_dtype),
Delta: T.Tensor([batch, heads, seq_len], accum_dtype),
dQ: T.Tensor(q_shape, accum_dtype),
dK: T.Tensor(kv_shape, accum_dtype),
dV: T.Tensor(kv_shape, accum_dtype),
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
T.use_swizzle(panel_size, enable=enable_rasterization)
......@@ -313,8 +315,8 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b
dk = T.alloc_fragment([block_M, dim], accum_dtype)
dq = T.alloc_fragment([block_N, dim], accum_dtype)
T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared)
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared)
T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared)
T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared)
T.clear(dv)
T.clear(dk)
......@@ -322,22 +324,21 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b
loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q_shared)
T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q_shared)
T.clear(qkT)
T.gemm(K_shared, q_shared, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared)
T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
P_acc[i, j] = T.exp(qkT[i, j] * sm_scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
P_acc[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j,
P_acc[i, j], 0.0)
P_acc[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, P_acc[i, j], 0.0)
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do_shared)
T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do_shared)
T.clear(dP)
T.gemm(V_shared, do_shared, dP, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
......@@ -345,7 +346,7 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b
T.copy(P_acc, p_cast)
T.gemm(p_cast, do_shared, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta_shared)
T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta_shared)
for i, j in T.Parallel(block_M, block_N):
p_cast[i, j] = P_acc[i, j] * (dP[i, j] - delta_shared[j]) * sm_scale
......@@ -367,8 +368,8 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b
@tilelang.jit(out_idx=[1])
def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
shape = [batch, seq_len, heads, dim]
blk = 64
......@@ -376,8 +377,8 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
def flash_bwd_post(dQ_in: T.Tensor(shape, accum_dtype), dQ_out: T.Tensor(shape, dtype)):
with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
T.copy(
dQ_in[bz, bx * blk:(bx + 1) * blk, by, :],
dQ_out[bz, bx * blk:(bx + 1) * blk, by, :],
dQ_in[bz, bx * blk : (bx + 1) * blk, by, :],
dQ_out[bz, bx * blk : (bx + 1) * blk, by, :],
)
return flash_bwd_post
......@@ -444,22 +445,14 @@ def benchmark_function(func, *args, warmup=10, repeat=100):
return np.median(times)
def main(batch: int = 1,
heads: int = 8,
seq_len: int = 4096,
dim: int = 128,
is_causal: bool = False,
groups: int = 1):
def main(batch: int = 1, heads: int = 8, seq_len: int = 4096, dim: int = 128, is_causal: bool = False, groups: int = 1):
device = "cuda"
dtype = torch.float16
torch.manual_seed(42)
torch.cuda.manual_seed(42)
print(
f"Test configuration: batch={batch}, heads={heads}, seq_len={seq_len}, dim={dim}, is_causal={is_causal}, groups={groups}"
)
print(f"Test configuration: batch={batch}, heads={heads}, seq_len={seq_len}, dim={dim}, is_causal={is_causal}, groups={groups}")
flops_per_gemm = 2.0 * batch * heads * seq_len * seq_len * dim
total_flops = 5 * flops_per_gemm
......@@ -515,22 +508,19 @@ def main(batch: int = 1,
o_ref.backward(dO)
print("Verifying backward pass correctness...")
dq_close, dq_max_diff, dq_mean_diff = debug_tensor_comparison(
dQ_tl, q_ref.grad, "dQ", rtol=0.05, atol=0.05)
dq_close, dq_max_diff, dq_mean_diff = debug_tensor_comparison(dQ_tl, q_ref.grad, "dQ", rtol=0.05, atol=0.05)
if dq_close:
print("dQ is correct.")
else:
print("dQ mismatch detected.")
dk_close, dk_max_diff, dk_mean_diff = debug_tensor_comparison(
dK_tl.to(torch.float16), k_ref.grad, "dK", rtol=0.05, atol=0.05)
dk_close, dk_max_diff, dk_mean_diff = debug_tensor_comparison(dK_tl.to(torch.float16), k_ref.grad, "dK", rtol=0.05, atol=0.05)
if dk_close:
print("dK is correct.")
else:
print("dK mismatch detected.")
dv_close, dv_max_diff, dv_mean_diff = debug_tensor_comparison(
dV_tl.to(torch.float16), v_ref.grad, "dV", rtol=0.05, atol=0.05)
dv_close, dv_max_diff, dv_mean_diff = debug_tensor_comparison(dV_tl.to(torch.float16), v_ref.grad, "dV", rtol=0.05, atol=0.05)
if dv_close:
print("dV is correct.")
else:
......@@ -551,9 +541,7 @@ def main(batch: int = 1,
torch.cuda.synchronize()
ref_latency = benchmark_function(run_reference_fwd_bwd, warmup=10, repeat=100)
print(
f"Reference PyTorch Forward+Backward: {ref_latency:.2f} ms | {total_flops / ref_latency * 1e-9:.2f} TFlops"
)
print(f"Reference PyTorch Forward+Backward: {ref_latency:.2f} ms | {total_flops / ref_latency * 1e-9:.2f} TFlops")
def run_complete_fwd_bwd():
o_tl_bench, lse_tl_bench = fwd_kernel(q, k, v)
......@@ -591,12 +579,12 @@ def main(batch: int = 1,
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size')
parser.add_argument('--heads', type=int, default=8, help='heads')
parser.add_argument('--seq_len', type=int, default=1024, help='sequence length')
parser.add_argument('--dim', type=int, default=64, help='dim')
parser.add_argument('--is_causal', action='store_true', help='causal')
parser.add_argument('--groups', type=int, default=1, help='groups')
parser.add_argument("--batch", type=int, default=1, help="batch size")
parser.add_argument("--heads", type=int, default=8, help="heads")
parser.add_argument("--seq_len", type=int, default=1024, help="sequence length")
parser.add_argument("--dim", type=int, default=64, help="dim")
parser.add_argument("--is_causal", action="store_true", help="causal")
parser.add_argument("--groups", type=int, default=1, help="groups")
args = parser.parse_args()
main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups)
......@@ -2,29 +2,42 @@ import torch
import torch.nn.functional as F
import tilelang
import tilelang.language as T
from tilelang.primitives.gemm.base import GemmWarpPolicy
from tilelang.tileop.base import GemmWarpPolicy
import itertools
import argparse
from functools import partial
# Custom supply function to ensure tensors are created on GPU
def supply_tensors_gpu(params):
"""Supply function that creates tensors on GPU for ROCm/HIP."""
tensors = []
for param in params:
if hasattr(param, "shape") and hasattr(param, "dtype"):
# Force creation on GPU device
shape = [int(s) for s in param.shape]
tensor = torch.randn(shape, dtype=param.dtype, device="cuda")
tensors.append(tensor)
else:
tensors.append(param)
return tensors
def ref_program(Q, K, V, is_causal, groups=1):
assert Q.size(
2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}"
assert Q.size(
2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}"
assert Q.size(2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}"
assert Q.size(2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}"
dim = Q.size(-1)
K = K.repeat_interleave(groups, dim=2)
V = V.repeat_interleave(groups, dim=2)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
scores = torch.einsum("bqhd,bkhd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V)
return output
......@@ -43,27 +56,27 @@ def get_configs():
valid_configs = []
for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product(block_M, block_N, num_split_q,
threads, num_stages,
enable_rasterization, k_pack,
panel_size, qk_coalesced_width,
v_coalesced_width):
valid_configs.append({
"block_M": m,
"block_N": n,
"num_split_q": s,
"threads": t,
"num_stages": stages,
"enable_rasterization": r,
"k_pack": k,
"panel_size": p,
"qk_coalesced_width": qkw,
"v_coalesced_width": vw,
})
for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product(
block_M, block_N, num_split_q, threads, num_stages, enable_rasterization, k_pack, panel_size, qk_coalesced_width, v_coalesced_width
):
valid_configs.append(
{
"block_M": m,
"block_N": n,
"num_split_q": s,
"threads": t,
"num_stages": stages,
"enable_rasterization": r,
"k_pack": k,
"panel_size": p,
"qk_coalesced_width": qkw,
"v_coalesced_width": vw,
}
)
return valid_configs
@tilelang.autotune(configs=get_configs(), cache_input_tensors=True)
@tilelang.autotune(configs=get_configs(), cache_input_tensors=True, supply_prog=supply_tensors_gpu)
@tilelang.jit(out_idx=[3])
def fast_flashattn(
batch,
......@@ -83,22 +96,22 @@ def fast_flashattn(
qk_coalesced_width: int,
v_coalesced_width: int,
):
scale = (1.0 / dim)**0.5
scale = (1.0 / dim) ** 0.5
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
vec_size = qk_coalesced_width
v_vec_size = v_coalesced_width
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(num_split_q, batch * heads, threads=threads) as (b_split, byz_combined):
T.use_swizzle(panel_size, enable=enable_rasterization)
......@@ -108,7 +121,7 @@ def fast_flashattn(
num_q_blocks = T.ceildiv(seq_len, block_M)
bx = T.alloc_var("int32")
bx = T.alloc_var(T.int32)
bx = b_split
with T.While(bx < num_q_blocks):
......@@ -132,32 +145,21 @@ def fast_flashattn(
m_prev = T.alloc_fragment([block_M], accum_dtype)
scale_factor = T.alloc_fragment([block_M], accum_dtype)
T.copy(
Q[bz, q_block_offset:q_block_offset + block_M, by, :],
Q_shared,
coalesced_width=vec_size)
T.copy(Q[bz, q_block_offset : q_block_offset + block_M, by, :], Q_shared, coalesced_width=vec_size)
loop_end_k = T.ceildiv(q_block_offset + block_M,
block_N) if is_causal else T.ceildiv(seq_len, block_N)
loop_end_k = T.ceildiv(q_block_offset + block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)
row_sum = T.alloc_fragment([block_M], accum_dtype)
for k in T.Pipelined(loop_end_k, num_stages=num_stages):
kv_idx = k * block_N
T.copy(
K[bz, kv_idx:kv_idx + block_N, by // groups, :],
K_shared,
coalesced_width=vec_size)
T.copy(
V[bz, kv_idx:kv_idx + block_N, by // groups, :],
V_shared,
coalesced_width=v_vec_size)
T.copy(K[bz, kv_idx : kv_idx + block_N, by // groups, :], K_shared, coalesced_width=vec_size)
T.copy(V[bz, kv_idx : kv_idx + block_N, by // groups, :], V_shared, coalesced_width=v_vec_size)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0,
-T.infinity(acc_s.dtype))
acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(
......@@ -171,6 +173,8 @@ def fast_flashattn(
T.copy(m_i, m_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for i in T.Parallel(block_M):
m_i[i] = T.max(m_i[i], m_prev[i])
for i in T.Parallel(block_M):
sf = T.exp(m_prev[i] * scale - m_i[i] * scale)
......@@ -205,13 +209,7 @@ def fast_flashattn(
return main
def main(batch: int = 1,
heads: int = 8,
seq_len: int = 4096,
dim: int = 128,
is_causal: bool = False,
groups: int = 1):
def main(batch: int = 1, heads: int = 8, seq_len: int = 4096, dim: int = 128, is_causal: bool = False, groups: int = 1):
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim
total_flops = 2 * flops_per_matmul
if is_causal:
......@@ -233,18 +231,16 @@ def main(batch: int = 1,
print(f"Reference (PyTorch): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops")
latency = profiler.do_bench(warmup=100)
print(
f"Fast Flash Attention V2 (Tile-lang): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops"
)
print(f"Fast Flash Attention V2 (Tile-lang): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size')
parser.add_argument('--heads', type=int, default=8, help='heads')
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--is_causal', action='store_true', help='causal')
parser.add_argument('--groups', type=int, default=1, help='groups')
parser.add_argument("--batch", type=int, default=1, help="batch size")
parser.add_argument("--heads", type=int, default=8, help="heads")
parser.add_argument("--seq_len", type=int, default=4096, help="sequence length")
parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--is_causal", action="store_true", help="causal")
parser.add_argument("--groups", type=int, default=1, help="groups")
args = parser.parse_args()
main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups)
......@@ -21,9 +21,9 @@ M = N = K = 1024
def kernel(block_M=128, block_N=128, block_K=32, num_stages=3, thread_num=128):
@T.prim_func
def main(A: T.Tensor((M, K), "float16"),
B: T.Tensor((N, K), "float16"),
C: T.Tensor((M, N), "float")):
def main(A: T.Tensor((M, K), T.float16),
B: T.Tensor((N, K), T.float16),
C: T.Tensor((M, N), T.float)):
# ... (kernel definition)
return main
......@@ -40,9 +40,9 @@ from tilelang.carver.arch import CUDA
def kernel(N=64, C=256, H=512, W=512, F=512, K=3, block_M=64, block_N=128):
@T.prim_func
def main(data: T.Tensor((N, H, W, C), "float16"),
kernel: T.Tensor((K, K, C, F), "float16"),
out: T.Tensor((N, (H-K+1), (W-K+1), F), "float")):
def main(data: T.Tensor((N, H, W, C), T.float16),
kernel: T.Tensor((K, K, C, F), T.float16),
out: T.Tensor((N, (H-K+1), (W-K+1), F), T.float)):
# ... (convolution kernel definition)
return main
......
......@@ -25,38 +25,21 @@ def check_hopper():
return False
def kernel(N,
C,
H,
W,
F,
K,
S,
D,
P,
block_M,
block_N,
block_K,
num_stages,
threads,
dtype="float16",
accum_dtype="float"):
def kernel(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype=T.float16, accum_dtype=T.float32):
KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
is_hopper = check_hopper()
@T.prim_func
def conv(
data: T.Tensor((N, H, W, C), dtype),
kernel: T.Tensor((KH, KW, C, F), dtype),
out: T.Tensor((N, OH, OW, F), dtype),
data: T.Tensor((N, H, W, C), dtype),
kernel: T.Tensor((KH, KW, C, F), dtype),
out: T.Tensor((N, OH, OW, F), dtype),
):
with T.Kernel(
T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M),
threads=threads) as (bx, by):
with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=threads) as (bx, by):
data_shared = T.alloc_shared((block_M, block_K), dtype)
kernel_shared = T.alloc_shared((block_K, block_N), dtype)
out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
......@@ -65,11 +48,13 @@ def kernel(N,
kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data)
out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)
T.annotate_layout({
out_shared: make_swizzled_layout(out_shared),
data_shared: make_swizzled_layout(data_shared),
kernel_shared: make_swizzled_layout(kernel_shared),
})
T.annotate_layout(
{
out_shared: make_swizzled_layout(out_shared),
data_shared: make_swizzled_layout(data_shared),
kernel_shared: make_swizzled_layout(kernel_shared),
}
)
T.clear(out_local)
for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
......@@ -81,10 +66,8 @@ def kernel(N,
m = by * block_M + i
access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P
access_w = m % OW * S + k // C % KW * D - P
in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and
(access_w < W))
data_shared[i, j] = T.if_then_else(
in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0)
in_bound = (access_h >= 0) and (access_w >= 0) and (access_h < H) and (access_w < W)
data_shared[i, j] = T.if_then_else(in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0)
T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared)
T.gemm(data_shared, kernel_shared, out_local)
......
......@@ -15,14 +15,14 @@ def kernel(
thread_num=None,
enable_rasteration=None,
):
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def matmul(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
......
import torch
import argparse
from tilelang.profiler import do_bench
from tilelang import language as T
import triton
import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor
......@@ -51,8 +52,7 @@ def triton_kernel(
q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM])
if BANDWIDTH:
lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M -
BANDWIDTH), start_q + (start_m + 1) * BLOCK_M
lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M
else:
lo, hi = 0, start_q + (start_m + 1) * BLOCK_M
......@@ -120,7 +120,8 @@ def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.T
BANDWIDTH=window_size,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
start_q=seq_kv - seq_q)
start_q=seq_kv - seq_q,
)
return o
......@@ -135,14 +136,14 @@ def main(
dtype: str = "float16",
tune: bool = False,
):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
dtype = T.dtype(dtype)
torch_dtype = dtype.as_torch()
if window_size is not None:
print('Using sliding window attention.')
print("Using sliding window attention.")
assert window_size <= seq_q
flops_per_matmul = 2.0 * batch * heads * min(
window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
else:
print('Using full attention.')
print("Using full attention.")
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5
total_flops = 2 * flops_per_matmul
......@@ -170,15 +171,14 @@ def main(
block_N=block_N,
num_stages=num_stages,
threads=threads,
dtype=dtype)
dtype=dtype,
)
Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups, dtype=torch_dtype)
if torch.allclose(
triton_program(Q, K, V, sinks, window_size),
ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype),
rtol=1e-2,
atol=1e-2):
triton_program(Q, K, V, sinks, window_size), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2
):
print("Checks for triton passed.✅")
else:
print("Checks for triton failed.❌")
......@@ -198,20 +198,14 @@ def main(
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size')
parser.add_argument('--heads', type=int, default=64, help='heads')
parser.add_argument('--seq_q', type=int, default=2048, help='sequence length of query')
parser.add_argument('--seq_kv', type=int, default=2048, help='sequence length of key/value')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--groups', type=int, default=8, help='groups')
parser.add_argument(
'--window_size',
type=int,
default=None,
help='window size (default: None, which means full attention)')
parser.add_argument(
'--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument('--tune', action='store_true', help='tune configs')
parser.add_argument("--batch", type=int, default=1, help="batch size")
parser.add_argument("--heads", type=int, default=64, help="heads")
parser.add_argument("--seq_q", type=int, default=2048, help="sequence length of query")
parser.add_argument("--seq_kv", type=int, default=2048, help="sequence length of key/value")
parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--groups", type=int, default=8, help="groups")
parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)")
parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument("--tune", action="store_true", help="tune configs")
args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size,
args.dtype, args.tune)
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size, args.dtype, args.tune)
import torch
import argparse
from tilelang.profiler import do_bench
from tilelang import language as T
import triton
import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor
......@@ -50,8 +51,7 @@ def triton_kernel(
q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM])
if BANDWIDTH:
lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M -
BANDWIDTH), start_q + (start_m + 1) * BLOCK_M
lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M
else:
lo, hi = 0, start_q + (start_m + 1) * BLOCK_M
......@@ -117,26 +117,29 @@ def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.T
BANDWIDTH=window_size,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
start_q=seq_kv - seq_q)
start_q=seq_kv - seq_q,
)
return o
def main(batch: int = 1,
heads: int = 32,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 128,
window_size: Optional[int] = None,
dtype: str = "float16",
tune: bool = False):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
def main(
batch: int = 1,
heads: int = 32,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 128,
window_size: Optional[int] = None,
dtype: str = "float16",
tune: bool = False,
):
dtype = T.dtype(dtype)
torch_dtype = dtype.as_torch()
if window_size is not None:
print('Using sliding window attention.')
print("Using sliding window attention.")
assert window_size <= seq_q
flops_per_matmul = 2.0 * batch * heads * min(
window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
else:
print('Using full attention.')
print("Using full attention.")
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5
total_flops = 2 * flops_per_matmul
......@@ -163,15 +166,14 @@ def main(batch: int = 1,
block_N=block_N,
num_stages=num_stages,
threads=threads,
dtype=dtype)
dtype=dtype,
)
Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype)
torch.testing.assert_close(
kernel(Q, K, V, sinks),
ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype),
rtol=1e-2,
atol=1e-2)
kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2
)
print("All checks passed.✅")
latency = do_bench(lambda: triton_program(Q, K, V, sinks, window_size), warmup=500)
......@@ -184,19 +186,13 @@ def main(batch: int = 1,
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--seq_q', type=int, default=4096, help='sequence length of query')
parser.add_argument('--seq_kv', type=int, default=4096, help='sequence length of key/value')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument(
'--window_size',
type=int,
default=None,
help='window size (default: None, which means full attention)')
parser.add_argument(
'--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument('--tune', action='store_true', help='tune')
parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument("--seq_q", type=int, default=4096, help="sequence length of query")
parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value")
parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)")
parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument("--tune", action="store_true", help="tune")
args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype,
args.tune)
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune)
......@@ -12,8 +12,7 @@ bitblas.set_log_level("INFO")
def generate_text_batch(model, tokenizer, prompts, max_length=100):
# Encode the input prompts as a batch
input_ids = tokenizer(
prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(model.device)
input_ids = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(model.device)
# Generate cos and sin values (commented out as not used in generation)
seq_length = input_ids.size(1)
......@@ -37,9 +36,7 @@ def generate_text_batch(model, tokenizer, prompts, max_length=100):
end_time = time.time()
# Decode the output ids to text
generated_texts = [
tokenizer.decode(output_id, skip_special_tokens=True) for output_id in output_ids
]
generated_texts = [tokenizer.decode(output_id, skip_special_tokens=True) for output_id in output_ids]
generation_time = end_time - start_time
num_tokens = sum(len(output_id) for output_id in output_ids)
......@@ -52,8 +49,8 @@ def generate_text_batch(model, tokenizer, prompts, max_length=100):
def profile(model, input_data):
import numpy as np
model = model.cuda()
model.eval()
......@@ -74,25 +71,29 @@ def profile(model, input_data):
return np.mean(times)
model_path = '1bitLLM/bitnet_b1_58-3B'
model_path = "1bitLLM/bitnet_b1_58-3B"
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--bs', default=16, type=int)
parser.add_argument('--in_seq_len', default=32, type=int)
parser.add_argument('--out_seq_len', default=128, type=int)
parser.add_argument('--bitblas', action='store_true')
parser.add_argument("--bs", default=16, type=int)
parser.add_argument("--in_seq_len", default=32, type=int)
parser.add_argument("--out_seq_len", default=128, type=int)
parser.add_argument("--bitblas", action="store_true")
args = parser.parse_args()
bs = args.bs
in_seq_len = args.in_seq_len
out_seq_len = args.out_seq_len
is_bitblas = args.bitblas
model = BitnetForCausalLM.from_pretrained(
model_path,
use_flash_attention_2=True,
torch_dtype=torch.float16,
).cuda().half()
model = (
BitnetForCausalLM.from_pretrained(
model_path,
use_flash_attention_2=True,
torch_dtype=torch.float16,
)
.cuda()
.half()
)
if is_bitblas:
with torch.no_grad():
model.quantize()
......@@ -109,5 +110,5 @@ def main():
print(generate_text_batch(model, tokenizer, prompts, max_length=max_length))
if __name__ == '__main__':
if __name__ == "__main__":
main()
......@@ -6,13 +6,14 @@ from modeling_bitnet import BitnetForCausalLM
torch.set_grad_enabled(False)
parser = argparse.ArgumentParser()
parser.add_argument('--hf_path', default='1bitLLM/bitnet_b1_58-3B', type=str)
parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str)
def profile(model, input_data):
import time
import numpy as np
model = model.cuda()
model.eval()
......@@ -35,8 +36,8 @@ def profile(model, input_data):
def main():
model = BitnetForCausalLM.from_pretrained(
'1bitLLM/bitnet_b1_58-3B',
device_map='auto',
"1bitLLM/bitnet_b1_58-3B",
device_map="auto",
low_cpu_mem_usage=True,
use_flash_attention_2=True,
torch_dtype=torch.float16,
......@@ -52,5 +53,5 @@ def main():
print(f"Batch size: {batch_size}, Seq len: {seq_len}, Latency: {latency}")
if __name__ == '__main__':
if __name__ == "__main__":
main()
......@@ -17,7 +17,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" LLaMA model configuration"""
"""LLaMA model configuration"""
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
......@@ -180,16 +180,10 @@ class BitnetConfig(PretrainedConfig):
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
f"got {self.rope_scaling}")
raise ValueError(f"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, got {self.rope_scaling}")
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor,
float) or rope_scaling_factor <= 1.0:
raise ValueError(
f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
raise ValueError(f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}")
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment