Commit 5ee58ec7 authored by Zhengju Tang's avatar Zhengju Tang Committed by LeiWang1999
Browse files

[Dynamic Symbolic] Adaptively vectorize with different condition expressions (#326)



* [Dynamic Symbolic] Adaptively vectorize with different condition expressions

* Format

* Format

* Format

* Format

* Add MIT License headers to Python files

* Simplify return statement in loop vectorization

---------
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
parent eab47249
This diff is collapsed.
import math
def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, num_m_blocks, size_one_kv_head,
is_causal_or_local, max_splits):
"""
Determines the optimal number of splits for maximizing GPU occupancy while balancing memory efficiency.
Parameters:
- total_mblocks (int): Total number of m_blocks.
- num_SMs (int): Number of Streaming Multiprocessors (SMs) in the GPU.
- num_n_blocks (int): Number of n_blocks.
- num_m_blocks (int): Number of m_blocks.
- size_one_kv_head (int): Size of one KV head in bytes.
- is_causal_or_local (bool): Indicates whether the operation is causal or local.
- max_splits (int): Maximum number of allowed splits.
Returns:
- int: The optimal number of splits.
"""
# If we have enough m_blocks to almost fill the SMs, prefer 1 split unless memory constraints apply.
if total_mblocks >= 0.8 * num_SMs:
size_l2 = 50 * 1024 * 1024 # L2 cache size assumption (50MB)
# Only split if each KV head is too large for L2 and there are enough m_blocks
if size_one_kv_head > size_l2 and num_m_blocks >= num_SMs * 2 and not is_causal_or_local:
return min((size_one_kv_head + size_l2 - 1) // size_l2, max_splits)
else:
return 1
# If num_n_blocks is too small, we don't split
if num_n_blocks <= 4:
return 1
# Limit max_splits to a reasonable range
max_splits = min(max_splits, num_SMs, num_n_blocks)
max_efficiency = 0.0
efficiency = []
# Compute efficiency for different splits
for num_splits in range(1, max_splits + 1):
n_waves = (total_mblocks * num_splits) / num_SMs
eff = n_waves / math.ceil(n_waves)
# Track max efficiency
if eff > max_efficiency:
max_efficiency = eff
efficiency.append(eff)
# Find the smallest number of splits that achieves at least 85% of max efficiency
for num_splits in range(1, max_splits + 1):
if efficiency[num_splits - 1] >= 0.85 * max_efficiency:
return num_splits
return 1
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <tvm/arith/iter_affine_map.h> #include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <numeric> #include <numeric>
...@@ -262,6 +263,65 @@ private: ...@@ -262,6 +263,65 @@ private:
int loop_num_; int loop_num_;
}; };
// Modify every subexpression in the condition
class VectorizedConditionMutator : public StmtExprMutator {
public:
VectorizedConditionMutator(Var inner_var, int extent)
: inner_var_(inner_var), vector_size_(extent) {}
private:
PrimExpr VisitExpr_(const GENode *node) final {
PrimExpr lhs = StmtExprMutator::VisitExpr(node->a);
PrimExpr rhs = StmtExprMutator::VisitExpr(node->b);
auto span = node->span;
Map<Var, PrimExpr> vmap_lhs, vmap_rhs;
vmap_lhs.Set(inner_var_, 0);
PrimExpr lhs_bound = Substitute(lhs, vmap_lhs);
vmap_rhs.Set(inner_var_, vector_size_ - 1);
PrimExpr rhs_bound = Substitute(rhs, vmap_rhs);
return GE(lhs_bound, rhs_bound, span);
}
PrimExpr VisitExpr_(const GTNode *node) final {
PrimExpr lhs = StmtExprMutator::VisitExpr(node->a);
PrimExpr rhs = StmtExprMutator::VisitExpr(node->b);
auto span = node->span;
Map<Var, PrimExpr> vmap_lhs, vmap_rhs;
vmap_lhs.Set(inner_var_, 0);
PrimExpr lhs_bound = Substitute(lhs, vmap_lhs);
vmap_rhs.Set(inner_var_, vector_size_ - 1);
PrimExpr rhs_bound = Substitute(rhs, vmap_rhs);
return GT(lhs_bound, rhs_bound, span);
}
PrimExpr VisitExpr_(const LENode *node) final {
PrimExpr lhs = StmtExprMutator::VisitExpr(node->a);
PrimExpr rhs = StmtExprMutator::VisitExpr(node->b);
auto span = node->span;
Map<Var, PrimExpr> vmap_lhs, vmap_rhs;
vmap_lhs.Set(inner_var_, vector_size_ - 1);
PrimExpr lhs_bound = Substitute(lhs, vmap_lhs);
vmap_rhs.Set(inner_var_, 0);
PrimExpr rhs_bound = Substitute(rhs, vmap_rhs);
return LE(lhs_bound, rhs_bound, span);
}
PrimExpr VisitExpr_(const LTNode *node) final {
PrimExpr lhs = StmtExprMutator::VisitExpr(node->a);
PrimExpr rhs = StmtExprMutator::VisitExpr(node->b);
auto span = node->span;
Map<Var, PrimExpr> vmap_lhs, vmap_rhs;
vmap_lhs.Set(inner_var_, vector_size_ - 1);
PrimExpr lhs_bound = Substitute(lhs, vmap_lhs);
vmap_rhs.Set(inner_var_, 0);
PrimExpr rhs_bound = Substitute(rhs, vmap_rhs);
return LT(lhs_bound, rhs_bound, span);
}
Var inner_var_;
int vector_size_;
};
class VectorizeRewriterDynamic : public StmtExprMutator { class VectorizeRewriterDynamic : public StmtExprMutator {
public: public:
VectorizeRewriterDynamic(VectorizePlanResult plan) VectorizeRewriterDynamic(VectorizePlanResult plan)
...@@ -297,22 +357,19 @@ private: ...@@ -297,22 +357,19 @@ private:
VectorizedConditionExtracter extracter; VectorizedConditionExtracter extracter;
std::vector<PrimExpr> conditions = extracter.GetConditions(body); std::vector<PrimExpr> conditions = extracter.GetConditions(body);
// Set vectorize variable to the max value of the extent (i.e. VectorizedConditionMutator condition_mutator(inner_var, vector_size_);
// vector_size_ - 1)
PrimExpr condition = conditions[0]; // Adaptively set vectorized variable to the min/max value of the extent
PrimExpr condition_bound = condition_mutator(conditions[0]);
for (int i = 1; i < conditions.size(); ++i) { for (int i = 1; i < conditions.size(); ++i) {
condition = condition && conditions[i]; condition_bound = condition_bound && condition_mutator(conditions[i]);
} }
// add condition ifthenelse here
Map<Var, PrimExpr> vmap_condition;
vmap_condition.Set(inner_var, vector_size_ - 1);
PrimExpr condition_bound = Substitute(condition, vmap_condition);
// modify body in the vectorized loop // modify body in the vectorized loop
VectorizedBodyMutator mutator(inner_var, vector_size_, conditions); VectorizedBodyMutator mutator(inner_var, vector_size_, conditions);
Stmt vectorize_body = mutator(body); Stmt vectorize_body = mutator(body);
// add condition ifthenelse here
For vectorize_for = For vectorize_for =
For(inner_var, 0, vector_size_, ForKind::kVectorized, vectorize_body); For(inner_var, 0, vector_size_, ForKind::kVectorized, vectorize_body);
For serial_for = For(inner_var, 0, vector_size_, ForKind::kSerial, body); For serial_for = For(inner_var, 0, vector_size_, ForKind::kSerial, body);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment