Unverified Commit fe6cdc9d authored by Zhengju Tang's avatar Zhengju Tang Committed by GitHub
Browse files

[BugFix] Do not modify strict layout in common or relax level of layout...


[BugFix] Do not modify strict layout in common or relax level of layout inference. More conditions on layout checking (#653)

* [BugFix] Do not modify strict layout in common or relax level of layout inference. More conditions on layout checking

* Lint

* test fix

* Update CI workflow to install dependencies without user site packages

- Modified the installation commands in the CI workflow to include the `--no-user` flag for both `requirements-dev.txt` and `requirements-test.txt`, ensuring that packages are installed in the virtual environment rather than the user site directory.

* Update CI workflow to install pip without user site packages

- Added the `--no-user` flag to the pip installation command in the CI workflow for both development and testing dependencies, ensuring that packages are installed within the virtual environment.

* Update requirements-test.txt

* reduce ci problem size,

* Refactor example_mla_decode.py for consistent formatting and remove unused imports in test_example_mla_decode.py

---------
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
parent 8361eb5c
...@@ -276,16 +276,14 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): ...@@ -276,16 +276,14 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
return out return out
def main(): def main(
parser = argparse.ArgumentParser() batch=1,
parser.add_argument('--batch', type=int, default=132, help='batch size') heads=128,
parser.add_argument('--heads', type=int, default=128, help='q heads number') kv_heads=1,
parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number') kv_ctx=8192,
parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length') dim=512,
parser.add_argument('--dim', type=int, default=512, help='head dim') pe_dim=64,
parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim') ):
args = parser.parse_args()
batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim) qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim)
pv_flops = 2 * batch * heads * kv_ctx * dim pv_flops = 2 * batch * heads * kv_ctx * dim
total_flops = qk_flops + pv_flops total_flops = qk_flops + pv_flops
...@@ -302,4 +300,13 @@ def main(): ...@@ -302,4 +300,13 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=132, help='batch size')
parser.add_argument('--heads', type=int, default=128, help='q heads number')
parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number')
parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length')
parser.add_argument('--dim', type=int, default=512, help='head dim')
parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim')
args = parser.parse_args()
batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
main(batch, heads, kv_heads, kv_ctx, dim, pe_dim)
import tilelang.testing import tilelang.testing
import example_mla_decode import example_mla_decode
from unittest import mock
import sys
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0) @tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_mla_decode(): def test_example_mla_decode():
with mock.patch.object(sys, 'argv', ["example_mla_decode.py"]): example_mla_decode.main()
example_mla_decode.main()
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -302,9 +302,9 @@ def ref_program(Q, K, V, is_causal, groups=1): ...@@ -302,9 +302,9 @@ def ref_program(Q, K, V, is_causal, groups=1):
return output return output
def main(BATCH: int = 8, def main(BATCH: int = 1,
H: int = 32, H: int = 32,
N_CTX: int = 1024, N_CTX: int = 256,
D_HEAD_QK: int = 192, D_HEAD_QK: int = 192,
D_HEAD_V: int = 128, D_HEAD_V: int = 128,
groups: int = 16, groups: int = 16,
......
...@@ -170,10 +170,10 @@ def ref_program(Q, K, V, is_causal): ...@@ -170,10 +170,10 @@ def ref_program(Q, K, V, is_causal):
def main( def main(
batch: int = 8, batch: int = 1,
heads: int = 32, heads: int = 32,
seq_q: int = 4096, seq_q: int = 256,
seq_kv: int = 4096, seq_kv: int = 256,
dim: int = 128, dim: int = 128,
is_causal: bool = False, is_causal: bool = False,
tune: bool = False, tune: bool = False,
......
...@@ -29,4 +29,4 @@ attrs ...@@ -29,4 +29,4 @@ attrs
decorator decorator
flash-attn<=2.2.0 flash-attn<=2.2.0
scipy scipy
tornado tornado
\ No newline at end of file
...@@ -294,8 +294,12 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -294,8 +294,12 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
T.thread_bounds)); T.thread_bounds));
} }
// Layout infer conflict for local.fragment can noy be handled here // Layout infer conflict for local.fragment can not be handled here
// because the source_buffer is not always available // because the source_buffer is not always available
// (zhengju) do not modify strict layout even if it is conflict with the
// dst layout. This will not influence the result because the strict
// layout is usually with rep = 1 Since the real layout map is
// controlled by layout_inference.cc, we should add this check there
if (buffer.scope() == "local.fragment" && source_buffer.defined() && if (buffer.scope() == "local.fragment" && source_buffer.defined() &&
source_buffer.scope() == "local.fragment") { source_buffer.scope() == "local.fragment") {
if (T.layout_map.count(buffer)) { if (T.layout_map.count(buffer)) {
......
...@@ -153,10 +153,17 @@ public: ...@@ -153,10 +153,17 @@ public:
} }
} }
// If already in map, ensure they are structurally equal // If already in map, ensure they are structurally equal
ICHECK(StructuralEqual()(layout, layout_map[buffer])) // (zhengju) We can not modify the strict layout map when current
<< "Get different layout for " << buffer // level is not strict. This check should be done in certain
<< "\n current layout: " << layout->DebugOutput() // conditions, since the strict layout map is not updated in the
<< "\n previous layout: " << layout_map[buffer]->DebugOutput(); // above code when current level is not strict
if (level == InferLevel::kStrict ||
!strict_layout_map.count(buffer)) {
ICHECK(StructuralEqual()(layout, layout_map[buffer]))
<< "Get different layout for " << buffer
<< "\n current layout: " << layout->DebugOutput()
<< "\n previous layout: " << layout_map[buffer]->DebugOutput();
}
} else { } else {
// Otherwise, update map // Otherwise, update map
layout_map.Set(buffer, layout); layout_map.Set(buffer, layout);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment