"examples/diffusion/python_stable_diffusion_21/gradio_app.py" did not exist on "e44cecbc67d53dd62ef575eebd81d61dc866b8b3"
Commit f2e99180 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Phaseout LLVM Dependency by Making it Optional (#247)

* remove llvm build

* [Refactor] Update kernel compilation and profiling in examples

- Replaced `tilelang.lower` with `tilelang.compile` in multiple example scripts to streamline kernel compilation.
- Updated profiling calls to utilize the new `get_profiler` method, enhancing performance measurement consistency.
- Adjusted assertions and benchmarking methods to align with the new profiling structure across various examples, ensuring correctness and clarity in performance evaluations.

* lint fix

* License Update

* [Refactor] Improve code formatting and documentation in CUDA header and HIP runtime files

- Adjusted formatting in `cuda.h` for better readability, including alignment of comments and struct fields.
- Cleaned up whitespace and improved comment clarity in `rt_mod_hip.cc` to enhance code maintainability.

* [Refactor] Enhance formatting and clarity in CUDA header and HIP runtime files

- Improved comment alignment and readability in `cuda.h`.
- Cleaned up whitespace and formatting in `rt_mod_hip.cc` to enhance maintainability.

* lint fix

* lint fix

* lint fix

* lint fix

* fix

* License update

* [Enhancement] Update JITKernel to use artifact for kernel source

- Assigned the generated artifact to `self.artifact` for better management.
- Updated kernel source references to use `artifact.kernel_source` for consistency in execution backend handling.

* lint fix

* Add @tilelang.testing.requires_llvm decorator to vectorization tests

* Enhance setup.py and env.py for library management

- Added functionality to remove original files after copying in CMakeBuild.
- Updated TVM_LIBRARY_PATH in env.py to include the PyPI build library path for better integration.

* Refactor TVM_LIBRARY_PATH assignment for improved readability in env.py

* Refactor CMakeBuild file handling in setup.py

- Added a check to ensure the target library directory exists before copying .so files.
- Improved the logic for creating the target directory and copying files to enhance robustness.

* bugfix

* Rename BuildTLDebug to BuildTileLangCUDAWithoutCompile and update registration. Add @tilelang.testing.requires_llvm decorator to multiple tests for LLVM requirement.

* lint fix

* Enhance TileLang code generation by adding support for device code generation without compilation. Updated `host_codegen` and `device_codegen` functions to include new transformations and registration for `tilelang_hip_without_compile`. Refactored JIT kernel adapters to accommodate host and device modules, improving overall integration and flexibility.

* lint fix

* Add support for C target in device code generation

- Updated `device_codegen_without_compile` to include handling for the C target by registering the `tilelang_cpp` function.

* [Enhancement] Implement auto-clear cache feature based on environment variable

* Added TILELANG_CLEAR_CACHE environment variable to control cache clearing.
* Updated CI workflow to set TILELANG_CLEAR_CACHE during testing.
* Modified cache initialization to clear cache if TILELANG_CLEAR_CACHE is set to true.

* [Refactor] Update kernel invocation and import paths in tests and cache

* Changed kernel invocation in `test_tilelang_kernel_dequantize_gemm.py` to return the result.
* Updated import statements in `test_tilelang_kernel_int4_gemm_mma.py` to use `bitblas` instead of `tilelang`.
* Refactored paths for artifact and parameters in `kernel_cache.py` for better maintainability.

* [Refactor] Clean up whitespace and improve code formatting in kernel_cache.py

* Removed unnecessary blank lines and adjusted spacing for better readability in the KernelCache class.
* Enhanced overall code formatting to align with project standards.

* [Enhancement] Add bfloat16 test case and improve kernel caching logic

* Introduced a new test case for bfloat16 matrix multiplication in `test_tilelang_kernel_gemm_mma_intrinsic.py`.
* Updated `KernelCache` to handle multiple kernel source files and improve error handling during saving and loading.
* Refactored `JITKernel` to support instantiation from a database, enhancing flexibility in kernel management.
* Adjusted `CtypesKernelAdapter` and `CythonKernelAdapter` to utilize the new kernel loading mechanism from the database.
* Improved code formatting and readability across several files.

* lint fix

* Update bfloat16 matrix multiplication test case to use larger dimensions for improved coverage
parent 43bd9d3e
...@@ -216,10 +216,15 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): ...@@ -216,10 +216,15 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
def test_assert_tl_matmul(): def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16")
assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32") assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32")
assert_tl_matmul_correctness(128, 128, 128, "bfloat16", "float32", "float32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", "int32") assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", "int32")
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(8, 0)
def test_assert_tl_matmul_bfloat16():
assert_tl_matmul_correctness(256, 256, 256, "bfloat16", "float32", "float32")
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(8, 9) @tilelang.testing.requires_cuda_compute_version(8, 9)
def test_assert_tl_matmul_fp8(): def test_assert_tl_matmul_fp8():
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch import torch
import torch.backends import torch.backends
import tilelang.testing import tilelang.testing
......
...@@ -353,6 +353,7 @@ def tl_matmul_weight_only_transform( ...@@ -353,6 +353,7 @@ def tl_matmul_weight_only_transform(
def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
import bitblas
matmul = tl_matmul_weight_only_transform(M, N, K, in_dtype, out_dtype, accum_dtype) matmul = tl_matmul_weight_only_transform(M, N, K, in_dtype, out_dtype, accum_dtype)
kernel = tilelang.compile(matmul, out_idx=[2]) kernel = tilelang.compile(matmul, out_idx=[2])
profiler = kernel.get_profiler() profiler = kernel.get_profiler()
...@@ -367,7 +368,7 @@ def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dt ...@@ -367,7 +368,7 @@ def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dt
compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4)
compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4)
ladder_permutate_config = tilelang.ops.LadderPermutateConfig( ladder_permutate_config = bitblas.ops.LadderPermutateConfig(
M=N, M=N,
N=(K // 2), N=(K // 2),
datatype="int8", datatype="int8",
...@@ -376,7 +377,7 @@ def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dt ...@@ -376,7 +377,7 @@ def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dt
transpose_matrix=True, transpose_matrix=True,
) )
ladder_permutate = tilelang.ops.LadderPermutate(ladder_permutate_config) ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config)
LB = ladder_permutate(compressed_B.cpu()).cuda() LB = ladder_permutate(compressed_B.cpu()).cuda()
C = kernel(compressed_A, LB) C = kernel(compressed_A, LB)
...@@ -398,4 +399,5 @@ def test_assert_tl_matmul_weight_only_transform(): ...@@ -398,4 +399,5 @@ def test_assert_tl_matmul_weight_only_transform():
if __name__ == "__main__": if __name__ == "__main__":
tilelang.testing.main() # tilelang.testing.main()
test_assert_tl_matmul_weight_only_transform()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import tilelang import tilelang
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang as tl import tilelang as tl
import tilelang.language as T import tilelang.language as T
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tilelang import tilelang
import tilelang.testing import tilelang.testing
from tilelang import language as T from tilelang import language as T
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang as tl import tilelang as tl
from tilelang.utils.target import determine_target from tilelang.utils.target import determine_target
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang as tl import tilelang as tl
import tilelang.language as T import tilelang.language as T
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang as tl import tilelang as tl
from tilelang.utils.target import determine_target from tilelang.utils.target import determine_target
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tilelang.utils.target import determine_target from tilelang.utils.target import determine_target
import tilelang as tl import tilelang as tl
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang as tl import tilelang as tl
import tilelang.language as T import tilelang.language as T
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang as tl import tilelang as tl
import tilelang.language as T import tilelang.language as T
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang as tl import tilelang as tl
from tilelang.utils.target import determine_target from tilelang.utils.target import determine_target
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tilelang.utils.target import determine_target from tilelang.utils.target import determine_target
import tilelang as tl import tilelang as tl
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import pytest import pytest
import tilelang import tilelang
...@@ -268,6 +265,7 @@ def test_subroutine_call_to_externally_visible_subroutine(): ...@@ -268,6 +265,7 @@ def test_subroutine_call_to_externally_visible_subroutine():
f"but instead has an operation of type {subroutine_call_op}") f"but instead has an operation of type {subroutine_call_op}")
@tilelang.testing.requires_llvm
def test_function_call_with_wrong_argument_count(): def test_function_call_with_wrong_argument_count():
"""Argument counts must be checked before accessing the type codes""" """Argument counts must be checked before accessing the type codes"""
...@@ -286,6 +284,7 @@ def test_function_call_with_wrong_argument_count(): ...@@ -286,6 +284,7 @@ def test_function_call_with_wrong_argument_count():
built() built()
@tilelang.testing.requires_llvm
def test_function_call_with_wrong_type_code(): def test_function_call_with_wrong_type_code():
"""Type codes must be checked before accessing the arguments""" """Type codes must be checked before accessing the arguments"""
...@@ -299,6 +298,7 @@ def test_function_call_with_wrong_type_code(): ...@@ -299,6 +298,7 @@ def test_function_call_with_wrong_type_code():
built(0) built(0)
@tilelang.testing.requires_llvm
def test_function_call_with_null_data_pointer(): def test_function_call_with_null_data_pointer():
"""The data pointer must be checked before accessing the array""" """The data pointer must be checked before accessing the array"""
...@@ -318,6 +318,7 @@ def test_function_call_with_null_data_pointer(): ...@@ -318,6 +318,7 @@ def test_function_call_with_null_data_pointer():
built(A, B) built(A, B)
@tilelang.testing.requires_llvm
def test_function_call_with_wrong_dimensionality(): def test_function_call_with_wrong_dimensionality():
"""The dimensionality must be checked before validating the shape""" """The dimensionality must be checked before validating the shape"""
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang as tl import tilelang as tl
from tilelang.utils.target import determine_target from tilelang.utils.target import determine_target
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang as tl import tilelang as tl
from tilelang.utils.target import determine_target from tilelang.utils.target import determine_target
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang as tl import tilelang as tl
import tilelang.language as T import tilelang.language as T
...@@ -88,7 +85,7 @@ def test_matmul(): ...@@ -88,7 +85,7 @@ def test_matmul():
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
# Get CUDA Source # Get CUDA Source
# print(rt_mod.imported_modules[0].get_source()) print(kernel.get_kernel_source())
if __name__ == "__main__": if __name__ == "__main__":
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ruff: noqa # ruff: noqa
import tilelang import tilelang
from tilelang import tvm as tvm from tilelang import tvm as tvm
...@@ -13,6 +11,7 @@ simple_target = tvm.target.Target("llvm -mtriple=x86_64-linux-gnu") ...@@ -13,6 +11,7 @@ simple_target = tvm.target.Target("llvm -mtriple=x86_64-linux-gnu")
sve_target = tvm.target.Target("llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+v8.2a,+sve") sve_target = tvm.target.Target("llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+v8.2a,+sve")
@tilelang.testing.requires_llvm
@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) @pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)])
def test_vectorize_loop(extent, target): def test_vectorize_loop(extent, target):
...@@ -36,6 +35,7 @@ def test_vectorize_loop(extent, target): ...@@ -36,6 +35,7 @@ def test_vectorize_loop(extent, target):
tvm.ir.assert_structural_equal(mod, After) tvm.ir.assert_structural_equal(mod, After)
@tilelang.testing.requires_llvm
def test_vectorize_vector(): def test_vectorize_vector():
n = te.var("n") n = te.var("n")
ib = tvm.tir.ir_builder.create() ib = tvm.tir.ir_builder.create()
...@@ -56,6 +56,7 @@ def test_vectorize_vector(): ...@@ -56,6 +56,7 @@ def test_vectorize_vector():
assert isinstance(stmt.body.value, tvm.tir.Broadcast) assert isinstance(stmt.body.value, tvm.tir.Broadcast)
@tilelang.testing.requires_llvm
def test_vectorize_vector_scalable_error(): def test_vectorize_vector_scalable_error():
@I.ir_module @I.ir_module
...@@ -72,6 +73,7 @@ def test_vectorize_vector_scalable_error(): ...@@ -72,6 +73,7 @@ def test_vectorize_vector_scalable_error():
tilelang.transform.VectorizeLoop()(Module) tilelang.transform.VectorizeLoop()(Module)
@tilelang.testing.requires_llvm
def test_vectorize_vector_scalable_error2(): def test_vectorize_vector_scalable_error2():
@I.ir_module @I.ir_module
...@@ -87,6 +89,7 @@ def test_vectorize_vector_scalable_error2(): ...@@ -87,6 +89,7 @@ def test_vectorize_vector_scalable_error2():
tilelang.transform.VectorizeLoop()(Module) tilelang.transform.VectorizeLoop()(Module)
@tilelang.testing.requires_llvm
def test_vectorize_vector_scalable_error3(): def test_vectorize_vector_scalable_error3():
@I.ir_module @I.ir_module
...@@ -105,6 +108,7 @@ def test_vectorize_vector_scalable_error3(): ...@@ -105,6 +108,7 @@ def test_vectorize_vector_scalable_error3():
tilelang.transform.VectorizeLoop()(Module) tilelang.transform.VectorizeLoop()(Module)
@tilelang.testing.requires_llvm
def test_vectorize_vector_scalable_error4(): def test_vectorize_vector_scalable_error4():
@I.ir_module @I.ir_module
...@@ -123,6 +127,7 @@ def test_vectorize_vector_scalable_error4(): ...@@ -123,6 +127,7 @@ def test_vectorize_vector_scalable_error4():
tilelang.transform.VectorizeLoop()(Module) tilelang.transform.VectorizeLoop()(Module)
@tilelang.testing.requires_llvm
@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) @pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)])
def test_vectorize_with_if(extent, target): def test_vectorize_with_if(extent, target):
...@@ -156,6 +161,7 @@ def test_vectorize_with_if(extent, target): ...@@ -156,6 +161,7 @@ def test_vectorize_with_if(extent, target):
tvm.ir.assert_structural_equal(mod, After) tvm.ir.assert_structural_equal(mod, After)
@tilelang.testing.requires_llvm
def test_vectorize_with_if_cond_int64(): def test_vectorize_with_if_cond_int64():
m = te.size_var("m", dtype="int64") m = te.size_var("m", dtype="int64")
A = te.placeholder((m,), name="A", dtype="float32") A = te.placeholder((m,), name="A", dtype="float32")
...@@ -166,6 +172,7 @@ def test_vectorize_with_if_cond_int64(): ...@@ -166,6 +172,7 @@ def test_vectorize_with_if_cond_int64():
f = tvm.build(s, [A, B], "llvm") f = tvm.build(s, [A, B], "llvm")
@tilelang.testing.requires_llvm
@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) @pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)])
def test_vectorize_let(extent, target): def test_vectorize_let(extent, target):
...@@ -191,6 +198,7 @@ def test_vectorize_let(extent, target): ...@@ -191,6 +198,7 @@ def test_vectorize_let(extent, target):
tvm.ir.assert_structural_equal(mod, After) tvm.ir.assert_structural_equal(mod, After)
@tilelang.testing.requires_llvm
@pytest.mark.parametrize("extent, target", [(4, simple_target), (tvm.tir.vscale() * 4, sve_target)]) @pytest.mark.parametrize("extent, target", [(4, simple_target), (tvm.tir.vscale() * 4, sve_target)])
def test_vectorize_with_le_cond(extent, target): def test_vectorize_with_le_cond(extent, target):
n = te.var("n") n = te.var("n")
...@@ -210,6 +218,7 @@ def test_vectorize_with_le_cond(extent, target): ...@@ -210,6 +218,7 @@ def test_vectorize_with_le_cond(extent, target):
assert isinstance(stmt, tvm.tir.For) assert isinstance(stmt, tvm.tir.For)
@tilelang.testing.requires_llvm
@pytest.mark.parametrize("extent, target", [(4, simple_target), (tvm.tir.vscale() * 4, sve_target)]) @pytest.mark.parametrize("extent, target", [(4, simple_target), (tvm.tir.vscale() * 4, sve_target)])
def test_vectorize_with_ge_cond(extent, target): def test_vectorize_with_ge_cond(extent, target):
n = te.var("n") n = te.var("n")
...@@ -229,6 +238,7 @@ def test_vectorize_with_ge_cond(extent, target): ...@@ -229,6 +238,7 @@ def test_vectorize_with_ge_cond(extent, target):
assert isinstance(stmt, tvm.tir.For) assert isinstance(stmt, tvm.tir.For)
@tilelang.testing.requires_llvm
@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) @pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)])
def test_vectorize_if_then_else_scalarize(extent, target): def test_vectorize_if_then_else_scalarize(extent, target):
...@@ -253,6 +263,7 @@ def test_vectorize_if_then_else_scalarize(extent, target): ...@@ -253,6 +263,7 @@ def test_vectorize_if_then_else_scalarize(extent, target):
tvm.ir.assert_structural_equal(mod, After) tvm.ir.assert_structural_equal(mod, After)
@tilelang.testing.requires_llvm
@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) @pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)])
def test_vectorize_if_then_else_vector(extent, target): def test_vectorize_if_then_else_vector(extent, target):
...@@ -280,6 +291,7 @@ def test_vectorize_if_then_else_vector(extent, target): ...@@ -280,6 +291,7 @@ def test_vectorize_if_then_else_vector(extent, target):
tvm.ir.assert_structural_equal(mod, After) tvm.ir.assert_structural_equal(mod, After)
@tilelang.testing.requires_llvm
def test_vectorize_while_fail(): def test_vectorize_while_fail():
"""A while loop inside a vectorized loop should fail.""" """A while loop inside a vectorized loop should fail."""
...@@ -327,6 +339,7 @@ def test_vectorize_while_fail(): ...@@ -327,6 +339,7 @@ def test_vectorize_while_fail():
assert expected in error_msg assert expected in error_msg
@tilelang.testing.requires_llvm
def test_vectorize_dtype_mismatch(): def test_vectorize_dtype_mismatch():
n = tvm.tir.IntImm("int64", 4) n = tvm.tir.IntImm("int64", 4)
A = te.compute((n,), lambda i: tvm.tir.IntImm("int64", 2**31 - 1) + i, name="A") A = te.compute((n,), lambda i: tvm.tir.IntImm("int64", 2**31 - 1) + i, name="A")
...@@ -335,6 +348,7 @@ def test_vectorize_dtype_mismatch(): ...@@ -335,6 +348,7 @@ def test_vectorize_dtype_mismatch():
tvm.lower(s, [A], "llvm", simple_mode=True) tvm.lower(s, [A], "llvm", simple_mode=True)
@tilelang.testing.requires_llvm
@pytest.mark.parametrize( @pytest.mark.parametrize(
"extent, vec_str, target", "extent, vec_str, target",
[(16, "float32x16", simple_target), (T.vscale() * 8, "float32xvscalex8", sve_target)], [(16, "float32x16", simple_target), (T.vscale() * 8, "float32xvscalex8", sve_target)],
...@@ -361,6 +375,7 @@ def test_vectorize_with_reinterpret(extent, vec_str, target): ...@@ -361,6 +375,7 @@ def test_vectorize_with_reinterpret(extent, vec_str, target):
tvm.ir.assert_structural_equal(mod, After) tvm.ir.assert_structural_equal(mod, After)
@tilelang.testing.requires_llvm
@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) @pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"op", "op",
...@@ -404,6 +419,7 @@ def test_vectorize_binary(op, extent, target): ...@@ -404,6 +419,7 @@ def test_vectorize_binary(op, extent, target):
tvm.ir.assert_structural_equal(mod, After) tvm.ir.assert_structural_equal(mod, After)
@tilelang.testing.requires_llvm
@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) @pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)])
@pytest.mark.parametrize("op", (T.And, T.Or)) @pytest.mark.parametrize("op", (T.And, T.Or))
def test_vectorize_logical(op, extent, target): def test_vectorize_logical(op, extent, target):
...@@ -428,6 +444,7 @@ def test_vectorize_logical(op, extent, target): ...@@ -428,6 +444,7 @@ def test_vectorize_logical(op, extent, target):
tvm.ir.assert_structural_equal(mod, After) tvm.ir.assert_structural_equal(mod, After)
@tilelang.testing.requires_llvm
@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) @pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)])
def test_vectorize_select(extent, target): def test_vectorize_select(extent, target):
...@@ -455,6 +472,7 @@ def test_vectorize_select(extent, target): ...@@ -455,6 +472,7 @@ def test_vectorize_select(extent, target):
tvm.ir.assert_structural_equal(mod, After) tvm.ir.assert_structural_equal(mod, After)
@tilelang.testing.requires_llvm
@pytest.mark.parametrize( @pytest.mark.parametrize(
"extent, vec_str, target", "extent, vec_str, target",
[(4, "int32x4", simple_target), (T.vscale() * 4, "int32xvscalex4", sve_target)], [(4, "int32x4", simple_target), (T.vscale() * 4, "int32xvscalex4", sve_target)],
...@@ -481,6 +499,7 @@ def test_vectorize_cast(extent, vec_str, target): ...@@ -481,6 +499,7 @@ def test_vectorize_cast(extent, vec_str, target):
tvm.ir.assert_structural_equal(mod, After) tvm.ir.assert_structural_equal(mod, After)
@tilelang.testing.requires_llvm
def test_illegal_extent(): def test_illegal_extent():
@I.ir_module(check_well_formed=False) @I.ir_module(check_well_formed=False)
...@@ -497,6 +516,7 @@ def test_illegal_extent(): ...@@ -497,6 +516,7 @@ def test_illegal_extent():
tilelang.transform.VectorizeLoop()(Mod) tilelang.transform.VectorizeLoop()(Mod)
@tilelang.testing.requires_llvm
def test_illegal_vscale_in_non_sve_compilation(): def test_illegal_vscale_in_non_sve_compilation():
@I.ir_module @I.ir_module
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang as tl import tilelang as tl
from tilelang.utils.target import determine_target from tilelang.utils.target import determine_target
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment