Commit 3d206235 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Bugfix] Fix the test data distribution of cumsum (#432)

* [Refactor] Adjust layout inference calculations in Gemm and ParallelOp

* Updated block size calculation in Gemm to account for the range of thread bounds, improving accuracy in layout inference.
* Simplified layout conflict error messages in ParallelOp for better clarity, enhancing debugging experience.
* Removed redundant buffer checks in ParallelOp layout inference logic, streamlining the code.

* [Refactor] Clean up layout inference logic in Gemm and ParallelOp

* Removed unnecessary warning log in Gemm related to WGMMA conditions, streamlining the layout inference process.
* Commented out redundant checks in ParallelOp's layout inference, improving code clarity while maintaining functionality.
* Enhanced error messages in ParallelOp to provide clearer context for layout conflicts, aiding in debugging efforts.

* lint fix

* [Enhancement] Improve cumulative sum functionality and annotations handling

* Updated the `cumsum` function to include detailed documentation and error handling for dimension bounds.
* Modified the `run_cumsum` test to utilize a random tensor supply type for profiling, enhancing test robustness.
* Added annotations to the fused loop in `loop_fusion_utils.h`, ensuring proper metadata is preserved during loop fusion.

* lint fix
parent bb1a5fd8
...@@ -217,7 +217,7 @@ private: ...@@ -217,7 +217,7 @@ private:
// Create the fused loop // Create the fused loop
For fused_for = For(fused_var, 0, fused_extent, ForKind::kParallel, body); For fused_for = For(fused_var, 0, fused_extent, ForKind::kParallel, body);
fused_for.CopyOnWrite()->annotations = op->annotations;
return fused_for; return fused_for;
} }
}; };
......
...@@ -50,7 +50,7 @@ def run_cumsum(M, N, block_M, block_N, dim=0, reverse=False, dtype="float16", sc ...@@ -50,7 +50,7 @@ def run_cumsum(M, N, block_M, block_N, dim=0, reverse=False, dtype="float16", sc
elif scope == "fragment": elif scope == "fragment":
program = cumsum_fragment_test(M, N, block_M, block_N, dim, reverse, dtype) program = cumsum_fragment_test(M, N, block_M, block_N, dim, reverse, dtype)
jit_kernel = tl.compile(program, out_idx=-1) jit_kernel = tl.compile(program, out_idx=-1)
profiler = jit_kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.One) profiler = jit_kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.Randn)
def ref_program(A): def ref_program(A):
ref_b = torch.empty_like(A) ref_b = torch.empty_like(A)
...@@ -61,8 +61,8 @@ def run_cumsum(M, N, block_M, block_N, dim=0, reverse=False, dtype="float16", sc ...@@ -61,8 +61,8 @@ def run_cumsum(M, N, block_M, block_N, dim=0, reverse=False, dtype="float16", sc
block_N:(j + 1) * block_N].cumsum(dim=dim) block_N:(j + 1) * block_N].cumsum(dim=dim)
if reverse: if reverse:
ref_b[i * block_M:(i + 1) * block_M, j * block_N:(j + 1) * ref_b[i * block_M:(i + 1) * block_M, j * block_N:(j + 1) *
block_N] = ref_b[i * block_M:(i + 1) * block_M, block_N] = A[i * block_M:(i + 1) * block_M, j * block_N:(j + 1) *
j * block_N:(j + 1) * block_N].flip(dims=[dim]) block_N].flip(dims=[dim]).cumsum(dim=dim).flip(dims=[dim])
return ref_b return ref_b
profiler.assert_allclose(ref_program) profiler.assert_allclose(ref_program)
......
...@@ -124,6 +124,24 @@ def cumsum_fragment(src: tir.Buffer, dst: tir.Buffer, dim: int, reverse: bool) - ...@@ -124,6 +124,24 @@ def cumsum_fragment(src: tir.Buffer, dst: tir.Buffer, dim: int, reverse: bool) -
def cumsum(src: tir.Buffer, dst: Optional[tir.Buffer] = None, dim: int = 0, reverse: bool = False): def cumsum(src: tir.Buffer, dst: Optional[tir.Buffer] = None, dim: int = 0, reverse: bool = False):
"""Perform cumulative sum on input buffer, store the result to output buffer.
Args:
src (tir.Buffer): The input buffer
dst (tir.Buffer, optional): The output buffer. Defaults to None.
dim (int, optional): The dimension to perform cumulative sum on. Defaults to 0.
reverse (bool, optional): Whether to perform reverse cumulative sum. Defaults to False.
Returns:
tir.Call: Handle to the cumulative sum operation
"""
shape = src.shape
if dim >= len(shape) or dim <= -len(shape):
raise ValueError(f"Dimension {dim} is out of bounds for buffer with shape {shape}")
if dim < 0:
dim = len(shape) + dim
if dst is None: if dst is None:
dst = src dst = src
if src.scope() == "local.fragment": if src.scope() == "local.fragment":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment