"testing/python/jit/test_tilelang_jit_callback.py" did not exist on "0d8421f1b8b22d8d5fe47525bd0c7cfda9cbde4a"
Unverified Commit fe6cdc9d authored by Zhengju Tang's avatar Zhengju Tang Committed by GitHub
Browse files

[BugFix] Do not modify strict layout in common or relax level of layout...


[BugFix] Do not modify strict layout in common or relax level of layout inference. More conditions on layout checking (#653)

* [BugFix] Do not modify strict layout in common or relax level of layout inference. More conditions on layout checking

* Lint

* test fix

* Update CI workflow to install dependencies without user site packages

- Modified the installation commands in the CI workflow to include the `--no-user` flag for both `requirements-dev.txt` and `requirements-test.txt`, ensuring that packages are installed in the virtual environment rather than the user site directory.

* Update CI workflow to install pip without user site packages

- Added the `--no-user` flag to the pip installation command in the CI workflow for both development and testing dependencies, ensuring that packages are installed within the virtual environment.

* Update requirements-test.txt

* reduce ci problem size,

* Refactor example_mla_decode.py for consistent formatting and remove unused imports in test_example_mla_decode.py

---------
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
parent 8361eb5c
......@@ -276,16 +276,14 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
return out
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=132, help='batch size')
parser.add_argument('--heads', type=int, default=128, help='q heads number')
parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number')
parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length')
parser.add_argument('--dim', type=int, default=512, help='head dim')
parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim')
args = parser.parse_args()
batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
def main(
batch=1,
heads=128,
kv_heads=1,
kv_ctx=8192,
dim=512,
pe_dim=64,
):
qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim)
pv_flops = 2 * batch * heads * kv_ctx * dim
total_flops = qk_flops + pv_flops
......@@ -302,4 +300,13 @@ def main():
if __name__ == "__main__":
main()
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=132, help='batch size')
parser.add_argument('--heads', type=int, default=128, help='q heads number')
parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number')
parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length')
parser.add_argument('--dim', type=int, default=512, help='head dim')
parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim')
args = parser.parse_args()
batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
main(batch, heads, kv_heads, kv_ctx, dim, pe_dim)
import tilelang.testing
import example_mla_decode
from unittest import mock
import sys
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_mla_decode():
with mock.patch.object(sys, 'argv', ["example_mla_decode.py"]):
example_mla_decode.main()
example_mla_decode.main()
if __name__ == "__main__":
......
......@@ -302,9 +302,9 @@ def ref_program(Q, K, V, is_causal, groups=1):
return output
def main(BATCH: int = 8,
def main(BATCH: int = 1,
H: int = 32,
N_CTX: int = 1024,
N_CTX: int = 256,
D_HEAD_QK: int = 192,
D_HEAD_V: int = 128,
groups: int = 16,
......
......@@ -170,10 +170,10 @@ def ref_program(Q, K, V, is_causal):
def main(
batch: int = 8,
batch: int = 1,
heads: int = 32,
seq_q: int = 4096,
seq_kv: int = 4096,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 128,
is_causal: bool = False,
tune: bool = False,
......
......@@ -29,4 +29,4 @@ attrs
decorator
flash-attn<=2.2.0
scipy
tornado
\ No newline at end of file
tornado
......@@ -294,8 +294,12 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
T.thread_bounds));
}
// Layout infer conflict for local.fragment can noy be handled here
// Layout infer conflict for local.fragment can not be handled here
// because the source_buffer is not always available
// (zhengju) do not modify strict layout even if it is conflict with the
// dst layout. This will not influence the result because the strict
// layout is usually with rep = 1 Since the real layout map is
// controlled by layout_inference.cc, we should add this check there
if (buffer.scope() == "local.fragment" && source_buffer.defined() &&
source_buffer.scope() == "local.fragment") {
if (T.layout_map.count(buffer)) {
......
......@@ -153,10 +153,17 @@ public:
}
}
// If already in map, ensure they are structurally equal
ICHECK(StructuralEqual()(layout, layout_map[buffer]))
<< "Get different layout for " << buffer
<< "\n current layout: " << layout->DebugOutput()
<< "\n previous layout: " << layout_map[buffer]->DebugOutput();
// (zhengju) We can not modify the strict layout map when current
// level is not strict. This check should be done in certain
// conditions, since the strict layout map is not updated in the
// above code when current level is not strict
if (level == InferLevel::kStrict ||
!strict_layout_map.count(buffer)) {
ICHECK(StructuralEqual()(layout, layout_map[buffer]))
<< "Get different layout for " << buffer
<< "\n current layout: " << layout->DebugOutput()
<< "\n previous layout: " << layout_map[buffer]->DebugOutput();
}
} else {
// Otherwise, update map
layout_map.Set(buffer, layout);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment