Unverified Commit fc4bd452 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Layout] Strict annotate completed replicated layout for fragment with constant index (#929)

* [Layout] Add IsCompletedReplicated method and enhance layout inference in ParallelOpNode

- Introduced IsCompletedReplicated method in FragmentNode to check if a buffer is fully replicated.
- Enhanced InferLayout in ParallelOpNode to handle layout inference for replicated buffers, ensuring only fragment[0] access is allowed.
- Updated error handling for non-zero index access in fragment buffers to improve robustness.

* [Layout] Improve code formatting and readability in layout.cc and parallel.cc

- Enhanced formatting in FragmentNode's IsCompletedReplicated method for better clarity.
- Updated InferLayout method in ParallelOpNode to improve code readability by adjusting line breaks and indentation.
- Ensured consistent formatting across conditional statements and comments for improved maintainability.

* updt

* optimize const index related op

* bug fix

* reduce gdn test

* test fix

* lintfix

* lint fix

* test fix
parent f09e91e3
import tilelang.testing import tilelang.testing
import torch import torch
tilelang.disable_cache()
B = 1 B = 1
S = 32768 S = 1024 # small but for test only.
H = 32 H = 32
DK = 128 DK = 128
DV = 128 DV = 128
...@@ -26,7 +24,7 @@ num_stages = 1 ...@@ -26,7 +24,7 @@ num_stages = 1
def test_example_wy_fast_compilation(): def test_example_wy_fast_compilation():
from example_wy_fast import tilelang_recompute_w_u_fwd, prepare_input, prepare_output from example_wy_fast import tilelang_recompute_w_u_fwd, prepare_input
K, V, Beta, G, A = prepare_input( K, V, Beta, G, A = prepare_input(
B, B,
S, S,
...@@ -37,7 +35,6 @@ def test_example_wy_fast_compilation(): ...@@ -37,7 +35,6 @@ def test_example_wy_fast_compilation():
getattr(torch, input_dtype), getattr(torch, input_dtype),
getattr(torch, output_dtype), getattr(torch, output_dtype),
gate_dtype=getattr(torch, gate_dtype)) gate_dtype=getattr(torch, gate_dtype))
W_tilelang, U_tilelang = prepare_output(B, S, H, DK, DV, getattr(torch, output_dtype))
# tilelang # tilelang
block_S = chunk_size block_S = chunk_size
kernel = tilelang_recompute_w_u_fwd( kernel = tilelang_recompute_w_u_fwd(
...@@ -97,13 +94,12 @@ def test_example_wy_fast_bwd_split_compilation(): ...@@ -97,13 +94,12 @@ def test_example_wy_fast_bwd_split_compilation():
def test_example_chunk_o_compilation(): def test_example_chunk_o_compilation():
from example_chunk_o import tilelang_chunk_fwd_o, prepare_input, prepare_output from example_chunk_o import tilelang_chunk_fwd_o, prepare_input
Q, K, V, HIDDEN, G = prepare_input(B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype), Q, K, V, HIDDEN, G = prepare_input(B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype),
getattr(torch, output_dtype), getattr(torch, accum_dtype), getattr(torch, output_dtype), getattr(torch, accum_dtype),
getattr(torch, gate_dtype)) getattr(torch, gate_dtype))
scale = 1.0 / DK**0.5 scale = 1.0 / DK**0.5
block_S = chunk_size block_S = chunk_size
O_tilelang = prepare_output(B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype))
kernel = tilelang_chunk_fwd_o(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, kernel = tilelang_chunk_fwd_o(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype,
gate_dtype, chunk_size, scale, use_g, block_S, block_DK, block_DV, gate_dtype, chunk_size, scale, use_g, block_S, block_DK, block_DV,
threads, num_stages) threads, num_stages)
...@@ -111,16 +107,13 @@ def test_example_chunk_o_compilation(): ...@@ -111,16 +107,13 @@ def test_example_chunk_o_compilation():
def test_example_chunk_o_bwd_compilation(): def test_example_chunk_o_bwd_compilation():
from example_chunk_o_bwd import tilelang_chunk_o_bwd_dqkwg, prepare_input, prepare_output from example_chunk_o_bwd import tilelang_chunk_o_bwd_dqkwg, prepare_input
Q, K, V, h, G, dO, dh, dv, W = prepare_input(B, S, H, DK, DV, chunk_size, Q, K, V, h, G, dO, dh, dv, W = prepare_input(B, S, H, DK, DV, chunk_size,
getattr(torch, input_dtype), getattr(torch, input_dtype),
getattr(torch, output_dtype), getattr(torch, output_dtype),
getattr(torch, accum_dtype), getattr(torch, accum_dtype),
getattr(torch, gate_dtype), getattr(torch, gate_dtype),
getattr(torch, state_dtype)) getattr(torch, state_dtype))
dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype),
getattr(torch, state_dtype), block_DK)
kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype,
gate_dtype, state_dtype, chunk_size, 1.0, use_g, True, gate_dtype, state_dtype, chunk_size, 1.0, use_g, True,
block_DK, block_DV, threads, num_stages) block_DK, block_DV, threads, num_stages)
...@@ -131,10 +124,9 @@ def test_example_chunk_o_bwd_compilation(): ...@@ -131,10 +124,9 @@ def test_example_chunk_o_bwd_compilation():
def test_example_chunk_scaled_dot_kkt_compilation(): def test_example_chunk_scaled_dot_kkt_compilation():
from example_chunk_scaled_dot_kkt import tilelang_chunk_scaled_dot_kkt_fwd, prepare_input, prepare_output from example_chunk_scaled_dot_kkt import tilelang_chunk_scaled_dot_kkt_fwd, prepare_input
K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype), K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype),
getattr(torch, output_dtype), getattr(torch, accum_dtype)) getattr(torch, output_dtype), getattr(torch, accum_dtype))
A_tilelang = prepare_output(B, S, H, chunk_size, getattr(torch, output_dtype))
block_S = chunk_size block_S = chunk_size
kernel = tilelang_chunk_scaled_dot_kkt_fwd(B, S, H, DK, chunk_size, input_dtype, output_dtype, kernel = tilelang_chunk_scaled_dot_kkt_fwd(B, S, H, DK, chunk_size, input_dtype, output_dtype,
accum_dtype, use_g, block_S, block_DK, threads, accum_dtype, use_g, block_S, block_DK, threads,
...@@ -164,15 +156,12 @@ def test_example_cumsum_compilation(): ...@@ -164,15 +156,12 @@ def test_example_cumsum_compilation():
def test_example_chunk_delta_h_compilation(): def test_example_chunk_delta_h_compilation():
from example_chunk_delta_h import tilelang_chunk_gated_delta_rule_fwd_h, prepare_input, prepare_output from example_chunk_delta_h import tilelang_chunk_gated_delta_rule_fwd_h, prepare_input
K, W, U, G, initial_state = prepare_input(B, S, H, DK, DV, chunk_size, K, W, U, G, initial_state = prepare_input(B, S, H, DK, DV, chunk_size,
getattr(torch, input_dtype), getattr(torch, input_dtype),
getattr(torch, output_dtype), getattr(torch, output_dtype),
getattr(torch, accum_dtype), getattr(torch, accum_dtype),
getattr(torch, gate_dtype)) getattr(torch, gate_dtype))
h_tilelang, final_state_tilelang, V_new_tilelang = prepare_output(B, S, H, DK, DV, chunk_size,
getattr(torch, output_dtype),
getattr(torch, state_dtype))
kernel = tilelang_chunk_gated_delta_rule_fwd_h(B, S, H, DK, DV, input_dtype, output_dtype, kernel = tilelang_chunk_gated_delta_rule_fwd_h(B, S, H, DK, DV, input_dtype, output_dtype,
accum_dtype, gate_dtype, state_dtype, chunk_size, accum_dtype, gate_dtype, state_dtype, chunk_size,
use_g, use_initial_state, store_final_state, use_g, use_initial_state, store_final_state,
...@@ -183,17 +172,13 @@ def test_example_chunk_delta_h_compilation(): ...@@ -183,17 +172,13 @@ def test_example_chunk_delta_h_compilation():
def test_example_chunk_delta_bwd_compilation(): def test_example_chunk_delta_bwd_compilation():
from example_chunk_delta_bwd import tilelang_chunk_gated_delta_rule_bwd_dhu, prepare_input, prepare_output from example_chunk_delta_bwd import tilelang_chunk_gated_delta_rule_bwd_dhu, prepare_input
Q, K, W, G, h0, dht, dO, dv = prepare_input(B, S, H, DK, DV, chunk_size, Q, K, W, G, h0, dht, dO, dv = prepare_input(B, S, H, DK, DV, chunk_size,
getattr(torch, input_dtype), getattr(torch, input_dtype),
getattr(torch, output_dtype), getattr(torch, output_dtype),
getattr(torch, accum_dtype), getattr(torch, accum_dtype),
getattr(torch, gate_dtype), getattr(torch, gate_dtype),
getattr(torch, state_dtype)) getattr(torch, state_dtype))
dh_tilelang, dh0_tilelang, dv2_tilelang = prepare_output(B, S, H, DK, DV, chunk_size,
getattr(torch, output_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype))
kernel = tilelang_chunk_gated_delta_rule_bwd_dhu(B, S, H, DK, DV, input_dtype, output_dtype, kernel = tilelang_chunk_gated_delta_rule_bwd_dhu(B, S, H, DK, DV, input_dtype, output_dtype,
accum_dtype, gate_dtype, state_dtype, accum_dtype, gate_dtype, state_dtype,
chunk_size, 1.0, use_g, use_initial_state, chunk_size, 1.0, use_g, use_initial_state,
......
...@@ -236,12 +236,11 @@ def matmul(M, ...@@ -236,12 +236,11 @@ def matmul(M,
return gemm_autotune return gemm_autotune
def main(m: int = 4096, def main(M: int = 4096,
n: int = 4096, N: int = 4096,
k: int = 4096, K: int = 4096,
use_autotune: bool = False, use_autotune: bool = False,
with_roller: bool = False): with_roller: bool = False):
M, N, K = m, n, k
use_autotune = True use_autotune = True
if use_autotune: if use_autotune:
result = get_best_config(M, N, K, with_roller) result = get_best_config(M, N, K, with_roller)
......
...@@ -162,8 +162,7 @@ def ref_program(A, B): ...@@ -162,8 +162,7 @@ def ref_program(A, B):
return A @ B.T return A @ B.T
def main(): def main(M=4096, N=4096, K=4096):
M, N, K = 16384, 16384, 16384
in_dtype, out_dtype, accum_dtype = "float16", "float16", "float32" in_dtype, out_dtype, accum_dtype = "float16", "float16", "float32"
kernel = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) kernel = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
src_code = kernel.get_kernel_source() src_code = kernel.get_kernel_source()
...@@ -183,4 +182,4 @@ def main(): ...@@ -183,4 +182,4 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main(M=4096, N=4096, K=4096)
...@@ -118,13 +118,7 @@ def ref_program(A, B): ...@@ -118,13 +118,7 @@ def ref_program(A, B):
return A @ B return A @ B
def main(): def main(M=4096, N=4096, K=4096):
parser = argparse.ArgumentParser()
parser.add_argument('--M', type=int, default=8192, help='M dimension')
parser.add_argument('--N', type=int, default=8192, help='N dimension')
parser.add_argument('--K', type=int, default=8192, help='K dimension')
args = parser.parse_args()
M, N, K = args.M, args.N, args.K
total_flops = 2 * M * N * K total_flops = 2 * M * N * K
BLOCK_M = 128 BLOCK_M = 128
...@@ -156,4 +150,10 @@ def main(): ...@@ -156,4 +150,10 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() parser = argparse.ArgumentParser()
parser.add_argument('--M', type=int, default=8192, help='M dimension')
parser.add_argument('--N', type=int, default=8192, help='N dimension')
parser.add_argument('--K', type=int, default=8192, help='K dimension')
args = parser.parse_args()
M, N, K = args.M, args.N, args.K
main(M, N, K)
...@@ -7,11 +7,11 @@ import example_gemm ...@@ -7,11 +7,11 @@ import example_gemm
def test_example_gemm_autotune(): def test_example_gemm_autotune():
# enable roller for fast tuning # enable roller for fast tuning
example_gemm_autotune.main(with_roller=True) example_gemm_autotune.main(M=1024, N=1024, K=1024, with_roller=True)
def test_example_gemm_intrinsics(): def test_example_gemm_intrinsics():
example_gemm_intrinsics.main() example_gemm_intrinsics.main(M=1024, N=1024, K=1024)
def test_example_gemm_schedule(): def test_example_gemm_schedule():
......
...@@ -51,10 +51,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ...@@ -51,10 +51,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
return main return main
def main(): def main(M=16384, N=16384, K=16384):
M = 16384
N = 16384
K = 16384
block_M = 128 block_M = 128
block_N = 128 block_N = 128
block_K = 64 block_K = 64
......
...@@ -48,10 +48,7 @@ def matmul_warp_specialize_copy_0_gemm_1(M, ...@@ -48,10 +48,7 @@ def matmul_warp_specialize_copy_0_gemm_1(M,
return main return main
def main(): def main(M=1024, N=1024, K=1024):
M = 16384
N = 16384
K = 16384
block_M = 128 block_M = 128
block_N = 128 block_N = 128
block_K = 64 block_K = 64
......
...@@ -49,10 +49,7 @@ def matmul_warp_specialize_copy_1_gemm_0(M, ...@@ -49,10 +49,7 @@ def matmul_warp_specialize_copy_1_gemm_0(M,
return main return main
def main(): def main(M=16384, N=16384, K=16384):
M = 16384
N = 16384
K = 16384
block_M = 128 block_M = 128
block_N = 128 block_N = 128
block_K = 64 block_K = 64
......
...@@ -7,10 +7,8 @@ tilelang.disable_cache() ...@@ -7,10 +7,8 @@ tilelang.disable_cache()
# add decorator @tilelang.jit if you want to return a torch function # add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit # @tilelang.jit
@tilelang.jit( @tilelang.jit(
out_idx=[2], out_idx=[2], pass_configs={
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
# tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) })
def matmul_warp_specialize_copy_1_gemm_0(M, def matmul_warp_specialize_copy_1_gemm_0(M,
N, N,
......
...@@ -43,10 +43,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ...@@ -43,10 +43,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
return main return main
def main(): def main(M=16384, N=16384, K=16384):
M = 16384
N = 16384
K = 16384
block_M = 128 block_M = 128
block_N = 128 block_N = 128
block_K = 64 block_K = 64
......
...@@ -16,25 +16,25 @@ def test_example_warp_specialize_flashmla(): ...@@ -16,25 +16,25 @@ def test_example_warp_specialize_flashmla():
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_eq(9, 0) @tilelang.testing.requires_cuda_compute_version_eq(9, 0)
def test_example_warp_specialize_gemm_barrierpipe_stage2(): def test_example_warp_specialize_gemm_barrierpipe_stage2():
example_warp_specialize_gemm_barrierpipe_stage2.main() example_warp_specialize_gemm_barrierpipe_stage2.main(M=1024, N=1024, K=1024)
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_eq(9, 0) @tilelang.testing.requires_cuda_compute_version_eq(9, 0)
def test_example_warp_specialize_gemm_copy_0_gemm_1(): def test_example_warp_specialize_gemm_copy_0_gemm_1():
example_warp_specialize_gemm_copy_0_gemm_1.main() example_warp_specialize_gemm_copy_0_gemm_1.main(M=1024, N=1024, K=1024)
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_eq(9, 0) @tilelang.testing.requires_cuda_compute_version_eq(9, 0)
def test_example_warp_specialize_gemm_copy_1_gemm_0(): def test_example_warp_specialize_gemm_copy_1_gemm_0():
example_warp_specialize_gemm_copy_1_gemm_0.main() example_warp_specialize_gemm_copy_1_gemm_0.main(M=1024, N=1024, K=1024)
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_eq(9, 0) @tilelang.testing.requires_cuda_compute_version_eq(9, 0)
def test_example_warp_specialize_gemm_softpipe_stage2(): def test_example_warp_specialize_gemm_softpipe_stage2():
example_warp_specialize_gemm_softpipe_stage2.main() example_warp_specialize_gemm_softpipe_stage2.main(M=1024, N=1024, K=1024)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -326,6 +326,13 @@ Fragment::Fragment(Array<PrimExpr> input_size, Array<PrimExpr> forward_index, ...@@ -326,6 +326,13 @@ Fragment::Fragment(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
data_ = std::move(n); data_ = std::move(n);
} }
// which means the forward_thread is rep_var -> lambda i, rep: rep
bool FragmentNode::IsCompletedReplicated() const {
arith::Analyzer analyzer;
return ExprDeepEqual()(analyzer.Simplify(forward_thread_),
ReplicationPlaceholder());
}
PrimExpr FragmentNode::ThreadExtent() const { PrimExpr FragmentNode::ThreadExtent() const {
Array<PrimExpr> ret(OutputDim(), 1); Array<PrimExpr> ret(OutputDim(), 1);
arith::Analyzer analyzer; arith::Analyzer analyzer;
......
...@@ -101,6 +101,8 @@ public: ...@@ -101,6 +101,8 @@ public:
bool IsEqual(const FragmentNode *other, bool skip_index = false) const; bool IsEqual(const FragmentNode *other, bool skip_index = false) const;
bool IsCompletedReplicated() const;
static void RegisterReflection(); static void RegisterReflection();
bool SEqualReduce(const FragmentNode *other, SEqualReducer equal) const; bool SEqualReduce(const FragmentNode *other, SEqualReducer equal) const;
......
...@@ -213,11 +213,107 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -213,11 +213,107 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const { InferLevel level) const {
if (loop_layout_.defined()) if (loop_layout_.defined())
return {}; return {};
if (level == InferLevel::kStrict) if (level == InferLevel::kStrict) {
return {}; LayoutMap results;
// Deduce buffers that shoule be complicated replicated.
// For example:
// for i in T.Parllel(m):
// fragment[0] = x[i]
// then fragment[0] must be replicated on all threads.
for (const auto &[buffer, indices] : indice_map_) {
if (T.layout_map.count(buffer)) {
continue;
}
if (buffer.scope() != "local.fragment")
continue;
// Check if all indices are zero
bool all_indices_zero = true;
for (const auto &index : indices) {
if (const auto *imm = index.as<IntImmNode>()) {
if (imm->value != 0) {
all_indices_zero = false;
LOG(FATAL)
<< "Fragment buffer access with non-zero index [" << imm->value
<< "] is not supported. "
<< "Only fragment[0] access is allowed within T.Parallel loop.";
}
} else {
// Non-constant index, not all zero
all_indices_zero = false;
}
}
// Only set layout if all indices are zero
if (all_indices_zero) {
Array<IterVar> forward_vars;
for (const auto &s : buffer->shape) {
forward_vars.push_back(
IterVar(Range(0, s), Var(), IterVarType::kDataPar));
}
Array<PrimExpr> forward_index;
for (const auto &iv : forward_vars) {
forward_index.push_back(iv->var);
}
Var rep;
auto rep_iter =
IterVar({0, T.thread_bounds->extent}, rep, IterVarType::kDataPar);
const PrimExpr &forward_thread = rep;
results.Set(buffer, Fragment(forward_vars, forward_index,
forward_thread, rep_iter));
}
}
return results;
}
auto buffer_is_completed_replicated = [&](const Buffer &buffer) {
if (buffer.scope() != "local.fragment")
return false;
auto frag = T.layout_map[buffer].as<Fragment>().value();
// buffer indices should be IntImm
for (const auto &index : indice_map_[buffer]) {
if (!index.as<IntImmNode>()) {
return false;
} else if (index.as<IntImmNode>()->value != 0) {
LOG(FATAL) << "buffer " << buffer << " is not completed replicated";
}
}
return frag->IsCompletedReplicated();
};
// Collect fragment buffers with const index and all fragment_buffers
std::vector<Buffer> const_index_fragment_buffer, fragment_buffers;
for (const auto &[buffer, indices] : indice_map_) {
if (buffer.scope() != "local.fragment")
continue;
fragment_buffers.push_back(buffer);
bool is_const_index = true;
for (const auto &index : indices) {
if (!index.as<IntImmNode>()) {
is_const_index = false;
break;
}
}
if (is_const_index) {
const_index_fragment_buffer.push_back(buffer);
}
}
// Determine if common layout propagation should be applied.
// If there are fragment buffers with non-constant indices, we need to
// propagate the common layout pattern to ensure consistency across all
// fragments. Example cases:
// - Need propagation: frag_a[0] = T.min(frag_a[0], frag_b[i])
// (const index frag_a interacts with non-const index frag_b)
// - No propagation needed: shared_a[i] = frag_a[0]
// (const index frag_a with non-fragment buffer)
bool allow_layout_propgate =
fragment_buffers.size() > const_index_fragment_buffer.size();
// Step 1: try to infer loop's partition from a source fragment // Step 1: try to infer loop's partition from a source fragment
Buffer source_buffer, read_source_buffer; Buffer source_buffer, read_source_buffer;
Buffer replicated_write_buffer; // Backup: fully replicated write buffer
for (const auto &[buffer, indices] : indice_map_) { for (const auto &[buffer, indices] : indice_map_) {
if (T.layout_map.count(buffer)) { if (T.layout_map.count(buffer)) {
// skip reducers with rep=ALL // skip reducers with rep=ALL
...@@ -226,15 +322,19 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -226,15 +322,19 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
continue; continue;
auto frag = T.layout_map[buffer].as<Fragment>().value(); auto frag = T.layout_map[buffer].as<Fragment>().value();
bool is_fully_replicated = buffer_is_completed_replicated(buffer);
if (buffer_is_write_.count(buffer)) { if (buffer_is_write_.count(buffer)) {
source_buffer = buffer; source_buffer = buffer;
} else { } else {
// Keep the buffer with largest number of indices // Keep the buffer with largest number of indices
// (which means the inference based on that buffer is more accurate) // (which means the inference based on that buffer is more accurate)
// as read_source_buffer to get more accurate layout // as read_source_buffer to get more accurate layout
if (!read_source_buffer.defined() || // if the buffer is completed replicated, we don't need to infer the
// layout from this buffer.
if ((!read_source_buffer.defined() ||
indice_map_[buffer].size() > indice_map_[buffer].size() >
indice_map_[read_source_buffer].size()) { indice_map_[read_source_buffer].size())) {
read_source_buffer = buffer; read_source_buffer = buffer;
} }
// If the buffer is not replicated and shape is equal to the // If the buffer is not replicated and shape is equal to the
...@@ -250,6 +350,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -250,6 +350,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
Fragment src_layout = T.layout_map[buffer].as<Fragment>().value(); Fragment src_layout = T.layout_map[buffer].as<Fragment>().value();
DLOG(INFO) << "[compute_loop_layout_from_buffer] infer from buffer `" DLOG(INFO) << "[compute_loop_layout_from_buffer] infer from buffer `"
<< buffer << "` of layout " << src_layout->DebugOutput() << '\n'; << buffer << "` of layout " << src_layout->DebugOutput() << '\n';
Fragment result; Fragment result;
if (IsCommonAccessIndice(buffer)) { if (IsCommonAccessIndice(buffer)) {
result = src_layout; result = src_layout;
...@@ -260,15 +361,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -260,15 +361,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
PrimExpr loop_var_to_thread = PrimExpr loop_var_to_thread =
src_layout->ForwardThread(indice_map_[buffer], rep); src_layout->ForwardThread(indice_map_[buffer], rep);
loop_var_to_thread = analyzer_.Simplify(loop_var_to_thread); loop_var_to_thread = analyzer_.Simplify(loop_var_to_thread);
PostOrderVisit(loop_var_to_thread, [&](const ObjectRef &objref) {
if (auto opt_var = objref.as<Var>();
opt_var && inner_vars_.count(*opt_var)) {
std::ostringstream oss;
oss << "loop_var_to_thread = " << loop_var_to_thread
<< "contains inner var" << *opt_var;
throw LayoutConflictException(oss.str());
}
});
result = Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter) result = Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter)
->BindThreadRange(T.thread_bounds); ->BindThreadRange(T.thread_bounds);
} }
...@@ -276,10 +369,17 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -276,10 +369,17 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
<< result->DebugOutput() << '\n'; << result->DebugOutput() << '\n';
return result; return result;
}; };
if (source_buffer.defined()) {
// Try to infer loop layout from buffers in order of preference:
// 1. Non-replicated write buffer (most reliable)
// 2. Non-replicated read buffer
// 3. Fully replicated write buffer (backup, may cause issues)
// 4. Free inference mode (no source buffer)
if (source_buffer.defined() && allow_layout_propgate) {
loop_layout_ = compute_loop_layout_from_buffer(source_buffer); loop_layout_ = compute_loop_layout_from_buffer(source_buffer);
} else if (level == InferLevel::kFree) { } else if (level == InferLevel::kFree) {
if (read_source_buffer.defined()) { if (read_source_buffer.defined() && allow_layout_propgate) {
loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer); loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer);
// // Loop don't need to be replicated. // // Loop don't need to be replicated.
// if (!is_one(loop_layout_->ReplicateExtent())) // if (!is_one(loop_layout_->ReplicateExtent()))
...@@ -330,7 +430,10 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -330,7 +430,10 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
auto rep = inv->Forward(fwd).back(); auto rep = inv->Forward(fwd).back();
AddPredicate(EQ(rep, 0)); AddPredicate(EQ(rep, 0));
} }
} else { }
if (!loop_layout_.defined()) {
// No source buffer available, use free mode inference
// Vectorize Size must be aware of the buffer_remap // Vectorize Size must be aware of the buffer_remap
// As the pass will do post processing to the layout // As the pass will do post processing to the layout
auto maybe_remapped_root_ = auto maybe_remapped_root_ =
......
...@@ -134,7 +134,8 @@ def AddWrapperForSingleBufStore(): ...@@ -134,7 +134,8 @@ def AddWrapperForSingleBufStore():
# Validate fragment buffer indices - only index 0 is supported # Validate fragment buffer indices - only index 0 is supported
buffer_indices = collect_buffer_indices(statement) buffer_indices = collect_buffer_indices(statement)
for buffer, indices in buffer_indices.items(): for buffer, indices in buffer_indices.items():
if buffer.scope() == "local.fragment": if buffer.scope() != "local.fragment":
continue
for index in indices: for index in indices:
if isinstance(index, IntImm) and index != 0: if isinstance(index, IntImm) and index != 0:
raise ValueError( raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment