Commit fa0fca58 authored by Thien Tran's avatar Thien Tran Committed by LeiWang1999
Browse files

[Bugfix] Check CUDA target before checking for TMA #482

parent 089cc0a7
...@@ -8,9 +8,12 @@ from typing import Optional ...@@ -8,9 +8,12 @@ from typing import Optional
def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None, def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None,
target: Optional[Target] = None) -> bool: target: Optional[Target] = None) -> bool:
# avoid circular import
from tilelang.jit.adapter.utils import is_cuda_target
if pass_ctx is None: if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context() pass_ctx = tilelang.transform.get_pass_context()
if not have_tma(target): if not is_cuda_target(target) or not have_tma(target):
return False return False
disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False) disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False)
disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False) disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False)
...@@ -19,7 +22,10 @@ def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None, ...@@ -19,7 +22,10 @@ def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None,
def allow_fence_proxy(target: Optional[Target] = None) -> bool: def allow_fence_proxy(target: Optional[Target] = None) -> bool:
return have_tma(target) # avoid circular import
from tilelang.jit.adapter.utils import is_cuda_target
return is_cuda_target(target) and have_tma(target)
def allow_vectorize(pass_ctx: Optional[PassContext] = None) -> bool: def allow_vectorize(pass_ctx: Optional[PassContext] = None) -> bool:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment