Unverified Commit fc4bd452 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Layout] Strict annotate completed replicated layout for fragment with constant index (#929)

* [Layout] Add IsCompletedReplicated method and enhance layout inference in ParallelOpNode

- Introduced IsCompletedReplicated method in FragmentNode to check if a buffer is fully replicated.
- Enhanced InferLayout in ParallelOpNode to handle layout inference for replicated buffers, ensuring only fragment[0] access is allowed.
- Updated error handling for non-zero index access in fragment buffers to improve robustness.

* [Layout] Improve code formatting and readability in layout.cc and parallel.cc

- Enhanced formatting in FragmentNode's IsCompletedReplicated method for better clarity.
- Updated InferLayout method in ParallelOpNode to improve code readability by adjusting line breaks and indentation.
- Ensured consistent formatting across conditional statements and comments for improved maintainability.

* updt

* optimize const index related op

* bug fix

* reduce gdn test

* test fix

* lintfix

* lint fix

* test fix
parent f09e91e3
import tilelang.testing
import torch
tilelang.disable_cache()
B = 1
S = 32768
S = 1024 # small but for test only.
H = 32
DK = 128
DV = 128
......@@ -26,7 +24,7 @@ num_stages = 1
def test_example_wy_fast_compilation():
from example_wy_fast import tilelang_recompute_w_u_fwd, prepare_input, prepare_output
from example_wy_fast import tilelang_recompute_w_u_fwd, prepare_input
K, V, Beta, G, A = prepare_input(
B,
S,
......@@ -37,7 +35,6 @@ def test_example_wy_fast_compilation():
getattr(torch, input_dtype),
getattr(torch, output_dtype),
gate_dtype=getattr(torch, gate_dtype))
W_tilelang, U_tilelang = prepare_output(B, S, H, DK, DV, getattr(torch, output_dtype))
# tilelang
block_S = chunk_size
kernel = tilelang_recompute_w_u_fwd(
......@@ -97,13 +94,12 @@ def test_example_wy_fast_bwd_split_compilation():
def test_example_chunk_o_compilation():
from example_chunk_o import tilelang_chunk_fwd_o, prepare_input, prepare_output
from example_chunk_o import tilelang_chunk_fwd_o, prepare_input
Q, K, V, HIDDEN, G = prepare_input(B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype),
getattr(torch, output_dtype), getattr(torch, accum_dtype),
getattr(torch, gate_dtype))
scale = 1.0 / DK**0.5
block_S = chunk_size
O_tilelang = prepare_output(B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype))
kernel = tilelang_chunk_fwd_o(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype,
gate_dtype, chunk_size, scale, use_g, block_S, block_DK, block_DV,
threads, num_stages)
......@@ -111,16 +107,13 @@ def test_example_chunk_o_compilation():
def test_example_chunk_o_bwd_compilation():
from example_chunk_o_bwd import tilelang_chunk_o_bwd_dqkwg, prepare_input, prepare_output
from example_chunk_o_bwd import tilelang_chunk_o_bwd_dqkwg, prepare_input
Q, K, V, h, G, dO, dh, dv, W = prepare_input(B, S, H, DK, DV, chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype))
dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype),
getattr(torch, state_dtype), block_DK)
kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype,
gate_dtype, state_dtype, chunk_size, 1.0, use_g, True,
block_DK, block_DV, threads, num_stages)
......@@ -131,10 +124,9 @@ def test_example_chunk_o_bwd_compilation():
def test_example_chunk_scaled_dot_kkt_compilation():
from example_chunk_scaled_dot_kkt import tilelang_chunk_scaled_dot_kkt_fwd, prepare_input, prepare_output
from example_chunk_scaled_dot_kkt import tilelang_chunk_scaled_dot_kkt_fwd, prepare_input
K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype),
getattr(torch, output_dtype), getattr(torch, accum_dtype))
A_tilelang = prepare_output(B, S, H, chunk_size, getattr(torch, output_dtype))
block_S = chunk_size
kernel = tilelang_chunk_scaled_dot_kkt_fwd(B, S, H, DK, chunk_size, input_dtype, output_dtype,
accum_dtype, use_g, block_S, block_DK, threads,
......@@ -164,15 +156,12 @@ def test_example_cumsum_compilation():
def test_example_chunk_delta_h_compilation():
from example_chunk_delta_h import tilelang_chunk_gated_delta_rule_fwd_h, prepare_input, prepare_output
from example_chunk_delta_h import tilelang_chunk_gated_delta_rule_fwd_h, prepare_input
K, W, U, G, initial_state = prepare_input(B, S, H, DK, DV, chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype))
h_tilelang, final_state_tilelang, V_new_tilelang = prepare_output(B, S, H, DK, DV, chunk_size,
getattr(torch, output_dtype),
getattr(torch, state_dtype))
kernel = tilelang_chunk_gated_delta_rule_fwd_h(B, S, H, DK, DV, input_dtype, output_dtype,
accum_dtype, gate_dtype, state_dtype, chunk_size,
use_g, use_initial_state, store_final_state,
......@@ -183,17 +172,13 @@ def test_example_chunk_delta_h_compilation():
def test_example_chunk_delta_bwd_compilation():
from example_chunk_delta_bwd import tilelang_chunk_gated_delta_rule_bwd_dhu, prepare_input, prepare_output
from example_chunk_delta_bwd import tilelang_chunk_gated_delta_rule_bwd_dhu, prepare_input
Q, K, W, G, h0, dht, dO, dv = prepare_input(B, S, H, DK, DV, chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype))
dh_tilelang, dh0_tilelang, dv2_tilelang = prepare_output(B, S, H, DK, DV, chunk_size,
getattr(torch, output_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype))
kernel = tilelang_chunk_gated_delta_rule_bwd_dhu(B, S, H, DK, DV, input_dtype, output_dtype,
accum_dtype, gate_dtype, state_dtype,
chunk_size, 1.0, use_g, use_initial_state,
......
......@@ -236,12 +236,11 @@ def matmul(M,
return gemm_autotune
def main(m: int = 4096,
n: int = 4096,
k: int = 4096,
def main(M: int = 4096,
N: int = 4096,
K: int = 4096,
use_autotune: bool = False,
with_roller: bool = False):
M, N, K = m, n, k
use_autotune = True
if use_autotune:
result = get_best_config(M, N, K, with_roller)
......
......@@ -162,8 +162,7 @@ def ref_program(A, B):
return A @ B.T
def main():
M, N, K = 16384, 16384, 16384
def main(M=4096, N=4096, K=4096):
in_dtype, out_dtype, accum_dtype = "float16", "float16", "float32"
kernel = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
src_code = kernel.get_kernel_source()
......@@ -183,4 +182,4 @@ def main():
if __name__ == "__main__":
main()
main(M=4096, N=4096, K=4096)
......@@ -118,13 +118,7 @@ def ref_program(A, B):
return A @ B
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--M', type=int, default=8192, help='M dimension')
parser.add_argument('--N', type=int, default=8192, help='N dimension')
parser.add_argument('--K', type=int, default=8192, help='K dimension')
args = parser.parse_args()
M, N, K = args.M, args.N, args.K
def main(M=4096, N=4096, K=4096):
total_flops = 2 * M * N * K
BLOCK_M = 128
......@@ -156,4 +150,10 @@ def main():
if __name__ == "__main__":
main()
parser = argparse.ArgumentParser()
parser.add_argument('--M', type=int, default=8192, help='M dimension')
parser.add_argument('--N', type=int, default=8192, help='N dimension')
parser.add_argument('--K', type=int, default=8192, help='K dimension')
args = parser.parse_args()
M, N, K = args.M, args.N, args.K
main(M, N, K)
......@@ -7,11 +7,11 @@ import example_gemm
def test_example_gemm_autotune():
# enable roller for fast tuning
example_gemm_autotune.main(with_roller=True)
example_gemm_autotune.main(M=1024, N=1024, K=1024, with_roller=True)
def test_example_gemm_intrinsics():
example_gemm_intrinsics.main()
example_gemm_intrinsics.main(M=1024, N=1024, K=1024)
def test_example_gemm_schedule():
......
......@@ -51,10 +51,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
return main
def main():
M = 16384
N = 16384
K = 16384
def main(M=16384, N=16384, K=16384):
block_M = 128
block_N = 128
block_K = 64
......
......@@ -48,10 +48,7 @@ def matmul_warp_specialize_copy_0_gemm_1(M,
return main
def main():
M = 16384
N = 16384
K = 16384
def main(M=1024, N=1024, K=1024):
block_M = 128
block_N = 128
block_K = 64
......
......@@ -49,10 +49,7 @@ def matmul_warp_specialize_copy_1_gemm_0(M,
return main
def main():
M = 16384
N = 16384
K = 16384
def main(M=16384, N=16384, K=16384):
block_M = 128
block_N = 128
block_K = 64
......
......@@ -7,10 +7,8 @@ tilelang.disable_cache()
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
@tilelang.jit(
out_idx=[2],
pass_configs={
out_idx=[2], pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
# tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
def matmul_warp_specialize_copy_1_gemm_0(M,
N,
......
......@@ -43,10 +43,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
return main
def main():
M = 16384
N = 16384
K = 16384
def main(M=16384, N=16384, K=16384):
block_M = 128
block_N = 128
block_K = 64
......
......@@ -16,25 +16,25 @@ def test_example_warp_specialize_flashmla():
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_eq(9, 0)
def test_example_warp_specialize_gemm_barrierpipe_stage2():
example_warp_specialize_gemm_barrierpipe_stage2.main()
example_warp_specialize_gemm_barrierpipe_stage2.main(M=1024, N=1024, K=1024)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_eq(9, 0)
def test_example_warp_specialize_gemm_copy_0_gemm_1():
example_warp_specialize_gemm_copy_0_gemm_1.main()
example_warp_specialize_gemm_copy_0_gemm_1.main(M=1024, N=1024, K=1024)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_eq(9, 0)
def test_example_warp_specialize_gemm_copy_1_gemm_0():
example_warp_specialize_gemm_copy_1_gemm_0.main()
example_warp_specialize_gemm_copy_1_gemm_0.main(M=1024, N=1024, K=1024)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_eq(9, 0)
def test_example_warp_specialize_gemm_softpipe_stage2():
example_warp_specialize_gemm_softpipe_stage2.main()
example_warp_specialize_gemm_softpipe_stage2.main(M=1024, N=1024, K=1024)
if __name__ == "__main__":
......
......@@ -326,6 +326,13 @@ Fragment::Fragment(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
data_ = std::move(n);
}
// which means the forward_thread is rep_var -> lambda i, rep: rep
bool FragmentNode::IsCompletedReplicated() const {
arith::Analyzer analyzer;
return ExprDeepEqual()(analyzer.Simplify(forward_thread_),
ReplicationPlaceholder());
}
PrimExpr FragmentNode::ThreadExtent() const {
Array<PrimExpr> ret(OutputDim(), 1);
arith::Analyzer analyzer;
......
......@@ -101,6 +101,8 @@ public:
bool IsEqual(const FragmentNode *other, bool skip_index = false) const;
bool IsCompletedReplicated() const;
static void RegisterReflection();
bool SEqualReduce(const FragmentNode *other, SEqualReducer equal) const;
......
......@@ -213,11 +213,107 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
if (loop_layout_.defined())
return {};
if (level == InferLevel::kStrict)
return {};
if (level == InferLevel::kStrict) {
LayoutMap results;
// Deduce buffers that shoule be complicated replicated.
// For example:
// for i in T.Parllel(m):
// fragment[0] = x[i]
// then fragment[0] must be replicated on all threads.
for (const auto &[buffer, indices] : indice_map_) {
if (T.layout_map.count(buffer)) {
continue;
}
if (buffer.scope() != "local.fragment")
continue;
// Check if all indices are zero
bool all_indices_zero = true;
for (const auto &index : indices) {
if (const auto *imm = index.as<IntImmNode>()) {
if (imm->value != 0) {
all_indices_zero = false;
LOG(FATAL)
<< "Fragment buffer access with non-zero index [" << imm->value
<< "] is not supported. "
<< "Only fragment[0] access is allowed within T.Parallel loop.";
}
} else {
// Non-constant index, not all zero
all_indices_zero = false;
}
}
// Only set layout if all indices are zero
if (all_indices_zero) {
Array<IterVar> forward_vars;
for (const auto &s : buffer->shape) {
forward_vars.push_back(
IterVar(Range(0, s), Var(), IterVarType::kDataPar));
}
Array<PrimExpr> forward_index;
for (const auto &iv : forward_vars) {
forward_index.push_back(iv->var);
}
Var rep;
auto rep_iter =
IterVar({0, T.thread_bounds->extent}, rep, IterVarType::kDataPar);
const PrimExpr &forward_thread = rep;
results.Set(buffer, Fragment(forward_vars, forward_index,
forward_thread, rep_iter));
}
}
return results;
}
auto buffer_is_completed_replicated = [&](const Buffer &buffer) {
if (buffer.scope() != "local.fragment")
return false;
auto frag = T.layout_map[buffer].as<Fragment>().value();
// buffer indices should be IntImm
for (const auto &index : indice_map_[buffer]) {
if (!index.as<IntImmNode>()) {
return false;
} else if (index.as<IntImmNode>()->value != 0) {
LOG(FATAL) << "buffer " << buffer << " is not completed replicated";
}
}
return frag->IsCompletedReplicated();
};
// Collect fragment buffers with const index and all fragment_buffers
std::vector<Buffer> const_index_fragment_buffer, fragment_buffers;
for (const auto &[buffer, indices] : indice_map_) {
if (buffer.scope() != "local.fragment")
continue;
fragment_buffers.push_back(buffer);
bool is_const_index = true;
for (const auto &index : indices) {
if (!index.as<IntImmNode>()) {
is_const_index = false;
break;
}
}
if (is_const_index) {
const_index_fragment_buffer.push_back(buffer);
}
}
// Determine if common layout propagation should be applied.
// If there are fragment buffers with non-constant indices, we need to
// propagate the common layout pattern to ensure consistency across all
// fragments. Example cases:
// - Need propagation: frag_a[0] = T.min(frag_a[0], frag_b[i])
// (const index frag_a interacts with non-const index frag_b)
// - No propagation needed: shared_a[i] = frag_a[0]
// (const index frag_a with non-fragment buffer)
bool allow_layout_propgate =
fragment_buffers.size() > const_index_fragment_buffer.size();
// Step 1: try to infer loop's partition from a source fragment
Buffer source_buffer, read_source_buffer;
Buffer replicated_write_buffer; // Backup: fully replicated write buffer
for (const auto &[buffer, indices] : indice_map_) {
if (T.layout_map.count(buffer)) {
// skip reducers with rep=ALL
......@@ -226,15 +322,19 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
continue;
auto frag = T.layout_map[buffer].as<Fragment>().value();
bool is_fully_replicated = buffer_is_completed_replicated(buffer);
if (buffer_is_write_.count(buffer)) {
source_buffer = buffer;
} else {
// Keep the buffer with largest number of indices
// (which means the inference based on that buffer is more accurate)
// as read_source_buffer to get more accurate layout
if (!read_source_buffer.defined() ||
// if the buffer is completed replicated, we don't need to infer the
// layout from this buffer.
if ((!read_source_buffer.defined() ||
indice_map_[buffer].size() >
indice_map_[read_source_buffer].size()) {
indice_map_[read_source_buffer].size())) {
read_source_buffer = buffer;
}
// If the buffer is not replicated and shape is equal to the
......@@ -250,6 +350,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
Fragment src_layout = T.layout_map[buffer].as<Fragment>().value();
DLOG(INFO) << "[compute_loop_layout_from_buffer] infer from buffer `"
<< buffer << "` of layout " << src_layout->DebugOutput() << '\n';
Fragment result;
if (IsCommonAccessIndice(buffer)) {
result = src_layout;
......@@ -260,15 +361,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
PrimExpr loop_var_to_thread =
src_layout->ForwardThread(indice_map_[buffer], rep);
loop_var_to_thread = analyzer_.Simplify(loop_var_to_thread);
PostOrderVisit(loop_var_to_thread, [&](const ObjectRef &objref) {
if (auto opt_var = objref.as<Var>();
opt_var && inner_vars_.count(*opt_var)) {
std::ostringstream oss;
oss << "loop_var_to_thread = " << loop_var_to_thread
<< "contains inner var" << *opt_var;
throw LayoutConflictException(oss.str());
}
});
result = Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter)
->BindThreadRange(T.thread_bounds);
}
......@@ -276,10 +369,17 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
<< result->DebugOutput() << '\n';
return result;
};
if (source_buffer.defined()) {
// Try to infer loop layout from buffers in order of preference:
// 1. Non-replicated write buffer (most reliable)
// 2. Non-replicated read buffer
// 3. Fully replicated write buffer (backup, may cause issues)
// 4. Free inference mode (no source buffer)
if (source_buffer.defined() && allow_layout_propgate) {
loop_layout_ = compute_loop_layout_from_buffer(source_buffer);
} else if (level == InferLevel::kFree) {
if (read_source_buffer.defined()) {
if (read_source_buffer.defined() && allow_layout_propgate) {
loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer);
// // Loop don't need to be replicated.
// if (!is_one(loop_layout_->ReplicateExtent()))
......@@ -330,7 +430,10 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
auto rep = inv->Forward(fwd).back();
AddPredicate(EQ(rep, 0));
}
} else {
}
if (!loop_layout_.defined()) {
// No source buffer available, use free mode inference
// Vectorize Size must be aware of the buffer_remap
// As the pass will do post processing to the layout
auto maybe_remapped_root_ =
......
......@@ -134,7 +134,8 @@ def AddWrapperForSingleBufStore():
# Validate fragment buffer indices - only index 0 is supported
buffer_indices = collect_buffer_indices(statement)
for buffer, indices in buffer_indices.items():
if buffer.scope() == "local.fragment":
if buffer.scope() != "local.fragment":
continue
for index in indices:
if isinstance(index, IntImm) and index != 0:
raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment