Unverified Commit 3aecab8f authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Example] Disable TMA and enable FastMath for NSA Examples (#941)

* tma disable

* int64 cast fix.
parent 557589ff
......@@ -38,9 +38,6 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
v += (bos * H + i_h) * V
block_indices += (bos + i_t) * H * S + i_h * S
# if USE_BLOCK_COUNTS:
# NS = tl.load(block_counts + (bos + i_t) * H + i_h)
# else:
NS = S
p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK),
......@@ -452,7 +449,12 @@ def get_configs():
@tilelang.autotune(configs=get_configs(),)
@tilelang.jit
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
def tilelang_sparse_attention(batch,
heads,
seq_len,
......
......@@ -17,9 +17,12 @@ from einops import rearrange
import tilelang
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
def tilelang_kernel_fwd(
batch,
heads,
......
......@@ -9,8 +9,11 @@ tilelang.testing.set_random_seed(0)
@tilelang.jit(
out_idx=[-1], pass_configs={
out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
def native_sparse_attention(batch,
heads,
......
......@@ -16,9 +16,12 @@ from reference import naive_nsa
from einops import rearrange
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
def native_sparse_attention_varlen(batch,
heads,
c_seq_len,
......
......@@ -62,6 +62,43 @@ private:
using IRMutatorWithAnalyzer::VisitStmt;
using IRMutatorWithAnalyzer::VisitStmt_;
class Int64Promoter : public tir::IndexDataTypeRewriter {
public:
using Parent = IndexDataTypeRewriter;
PrimExpr VisitExpr_(const VarNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) {
return cast(DataType::Int(64), GetRef<Var>(op));
}
return GetRef<PrimExpr>(op);
}
PrimExpr VisitExpr_(const IntImmNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) {
return IntImm(DataType::Int(64), op->value);
}
return GetRef<PrimExpr>(op);
}
PrimExpr VisitExpr_(const CastNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) {
return cast(DataType::Int(64), op->value);
}
return GetRef<PrimExpr>(op);
}
Stmt VisitStmt_(const BufferStoreNode *op) final {
// Force indices to be int64
auto node = Downcast<BufferStore>(Parent::VisitStmt_(op));
return std::move(node);
}
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto node = Downcast<BufferLoad>(Parent::VisitExpr_(op));
return std::move(node);
}
};
explicit BufferFlattener(arith::Analyzer *ana) : IRMutatorWithAnalyzer(ana) {}
Stmt VisitStmt_(const BlockNode *op) final {
......@@ -244,7 +281,29 @@ private:
Array<PrimExpr> GetSimplifiedElemOffset(const Buffer &buffer,
const Array<PrimExpr> &indices) {
auto flattened_indices = buffer->ElemOffset(indices);
return this->IterMapSimplifyWithContext(flattened_indices, false);
Array<PrimExpr> safe_indices;
for (auto index : flattened_indices) {
auto int_bound = analyzer_->const_int_bound(index);
DataType dtype = index->dtype;
if (dtype.is_int() && dtype.bits() < 64) {
int64_t max_value = int_bound->max_value;
int64_t min_value = int_bound->min_value;
const int64_t type_max = (1LL << (dtype.bits() - 1));
const int64_t type_min = -(1LL << (dtype.bits() - 1));
if (max_value >= (type_max - 1) || min_value < type_min) {
Int64Promoter promoter;
for (auto &index : flattened_indices) {
safe_indices.push_back(promoter(index));
}
} else {
safe_indices.push_back(index);
}
} else {
safe_indices.push_back(index);
}
}
return this->IterMapSimplifyWithContext(safe_indices, false);
}
template <typename Node> Node VisitBufferAccess(Node node) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment