Commit 3d206235 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Bugfix] Fix the test data distribution of cumsum (#432)

* [Refactor] Adjust layout inference calculations in Gemm and ParallelOp

* Updated block size calculation in Gemm to account for the range of thread bounds, improving accuracy in layout inference.
* Simplified layout conflict error messages in ParallelOp for better clarity, enhancing debugging experience.
* Removed redundant buffer checks in ParallelOp layout inference logic, streamlining the code.

* [Refactor] Clean up layout inference logic in Gemm and ParallelOp

* Removed unnecessary warning log in Gemm related to WGMMA conditions, streamlining the layout inference process.
* Commented out redundant checks in ParallelOp's layout inference, improving code clarity while maintaining functionality.
* Enhanced error messages in ParallelOp to provide clearer context for layout conflicts, aiding in debugging efforts.

* lint fix

* [Enhancement] Improve cumulative sum functionality and annotations handling

* Updated the `cumsum` function to include detailed documentation and error handling for dimension bounds.
* Modified the `run_cumsum` test to utilize a random tensor supply type for profiling, enhancing test robustness.
* Added annotations to the fused loop in `loop_fusion_utils.h`, ensuring proper metadata is preserved during loop fusion.

* lint fix
parent bb1a5fd8
......@@ -217,7 +217,7 @@ private:
// Create the fused loop
For fused_for = For(fused_var, 0, fused_extent, ForKind::kParallel, body);
fused_for.CopyOnWrite()->annotations = op->annotations;
return fused_for;
}
};
......
......@@ -50,7 +50,7 @@ def run_cumsum(M, N, block_M, block_N, dim=0, reverse=False, dtype="float16", sc
elif scope == "fragment":
program = cumsum_fragment_test(M, N, block_M, block_N, dim, reverse, dtype)
jit_kernel = tl.compile(program, out_idx=-1)
profiler = jit_kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.One)
profiler = jit_kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.Randn)
def ref_program(A):
ref_b = torch.empty_like(A)
......@@ -61,8 +61,8 @@ def run_cumsum(M, N, block_M, block_N, dim=0, reverse=False, dtype="float16", sc
block_N:(j + 1) * block_N].cumsum(dim=dim)
if reverse:
ref_b[i * block_M:(i + 1) * block_M, j * block_N:(j + 1) *
block_N] = ref_b[i * block_M:(i + 1) * block_M,
j * block_N:(j + 1) * block_N].flip(dims=[dim])
block_N] = A[i * block_M:(i + 1) * block_M, j * block_N:(j + 1) *
block_N].flip(dims=[dim]).cumsum(dim=dim).flip(dims=[dim])
return ref_b
profiler.assert_allclose(ref_program)
......
......@@ -124,6 +124,24 @@ def cumsum_fragment(src: tir.Buffer, dst: tir.Buffer, dim: int, reverse: bool) -
def cumsum(src: tir.Buffer, dst: Optional[tir.Buffer] = None, dim: int = 0, reverse: bool = False):
"""Perform cumulative sum on input buffer, store the result to output buffer.
Args:
src (tir.Buffer): The input buffer
dst (tir.Buffer, optional): The output buffer. Defaults to None.
dim (int, optional): The dimension to perform cumulative sum on. Defaults to 0.
reverse (bool, optional): Whether to perform reverse cumulative sum. Defaults to False.
Returns:
tir.Call: Handle to the cumulative sum operation
"""
shape = src.shape
if dim >= len(shape) or dim <= -len(shape):
raise ValueError(f"Dimension {dim} is out of bounds for buffer with shape {shape}")
if dim < 0:
dim = len(shape) + dim
if dst is None:
dst = src
if src.scope() == "local.fragment":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment