Unverified Commit 35cf8885 authored by LJC00118's avatar LJC00118 Committed by GitHub
Browse files

[Enhancement] Remove constraint requiring last dimension stride to be 1 (#1040)



* remove last dimension stride must be 1 constraint

* add vectorize test

* minor fix

* [Lint]: [pre-commit.ci] auto fixes [...]

---------
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent fd1493be
import torch
import tilelang.testing
import tilelang.language as T
@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_VECTORIZE_256: True})
def vectorize_test(N, M, stride_A, stride_B):
assert N % 128 == 0 and M % 128 == 0
@T.prim_func
def main(
A: T.StridedTensor[(N, M), (1, stride_A), "float32"], # noqa: F821
B: T.StridedTensor[(N, M), (1, stride_B), "float32"], # noqa: F821
):
with T.Kernel(M // 128, threads=128) as (bx):
tx = T.get_thread_binding(0)
col = bx * 128 + tx
for row in T.vectorized(N):
B[row, col] = A[row, col]
return main
def run_vectorize(N, M, stride_A, stride_B):
assert stride_A >= N and stride_B >= N
jit_kernel = vectorize_test(N, M, stride_A, stride_B)
base_a = torch.randn(stride_A, M, device="cuda", dtype=torch.float32)
base_b = torch.zeros(stride_B, M, device="cuda", dtype=torch.float32)
a = torch.as_strided(base_a, size=(N, M), stride=(1, stride_A))
b = torch.as_strided(base_b, size=(N, M), stride=(1, stride_B))
jit_kernel(a, b)
torch.testing.assert_close(a, b, atol=1e-8, rtol=1e-8)
code = jit_kernel.get_kernel_source()
vectorize_size = 1
while vectorize_size <= 2 and \
stride_A % (vectorize_size * 2) == 0 and \
stride_B % (vectorize_size * 2) == 0:
vectorize_size *= 2
if vectorize_size == 4:
assert "float4" in code
elif vectorize_size == 2:
assert "float2" in code
def test_vectorize():
N, M = 512, 256
run_vectorize(N, M, N, N)
run_vectorize(N, M, N + 2, N + 4)
run_vectorize(N, M, N + 4, N + 8)
run_vectorize(N, M, N + 8, N + 16)
if __name__ == "__main__":
tilelang.testing.main()
...@@ -178,9 +178,6 @@ class StridedTensorProxy(BaseTensorProxy): ...@@ -178,9 +178,6 @@ class StridedTensorProxy(BaseTensorProxy):
scope=None) -> tir.Buffer: scope=None) -> tir.Buffer:
if len(shape) != len(strides): if len(shape) != len(strides):
raise ValueError("Invalid shape/strides' dimensions") raise ValueError("Invalid shape/strides' dimensions")
if not bool(strides[-1] == 1):
# TODO(chenggang): shall we support non-contiguous even for the last dimension?
raise ValueError("The stride of the last dimension must be 1 (contiguous)")
return super().__call__(shape, dtype=dtype, strides=strides, scope=scope) return super().__call__(shape, dtype=dtype, strides=strides, scope=scope)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment