Unverified Commit cda5ea15 authored by Tang Xinsheng's avatar Tang Xinsheng Committed by GitHub
Browse files

[AMD] fix bugs in warp shuffle (#790)



* [AMD] fix bugs in warp shuffle

* format

---------
Co-authored-by: default avatartangxinsheng.txs <tangxinsheng.txs@alibaba-inc.com>
parent 013adca0
...@@ -3,10 +3,13 @@ ...@@ -3,10 +3,13 @@
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tilelang.language import ptx_arrive_barrier, evaluate from tilelang.language import ptx_arrive_barrier, evaluate
from tilelang.language.kernel import get_thread_bindings, get_block_extents from tilelang.language.kernel import get_thread_bindings, get_block_extents
from tilelang.utils.target import check_hip_availability
from tvm import tir from tvm import tir
from typing import Union, Any from typing import Union, Any
from tvm.tir import PrimExpr, Var, Call from tvm.tir import PrimExpr, Var, Call
_IS_HIP_AVAILABLE = check_hip_availability()
def create_list_of_mbarrier(*args: Any) -> Call: def create_list_of_mbarrier(*args: Any) -> Call:
""" """
...@@ -295,7 +298,10 @@ def shfl_xor(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, ...@@ -295,7 +298,10 @@ def shfl_xor(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr,
Returns: Returns:
tir.Call: A handle to the shuffle operation tir.Call: A handle to the shuffle operation
""" """
return tir.call_extern(value.dtype, "__shfl_xor_sync", 0xffffffff, value, offset) if _IS_HIP_AVAILABLE:
return tir.call_extern(value.dtype, "__shfl_xor", value, offset)
else:
return tir.call_extern(value.dtype, "__shfl_xor_sync", 0xffffffff, value, offset)
def shfl_down(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, tir.Call]): def shfl_down(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, tir.Call]):
...@@ -305,7 +311,10 @@ def shfl_down(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr ...@@ -305,7 +311,10 @@ def shfl_down(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr
value: Optional[int, PrimExpr] value: Optional[int, PrimExpr]
The value to shuffle The value to shuffle
""" """
return tir.call_extern(value.dtype, "__shfl_down_sync", 0xffffffff, value, offset) if _IS_HIP_AVAILABLE:
return tir.call_extern(value.dtype, "__shfl_down", value, offset)
else:
return tir.call_extern(value.dtype, "__shfl_down_sync", 0xffffffff, value, offset)
def shfl_up(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, tir.Call]): def shfl_up(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, tir.Call]):
...@@ -315,7 +324,10 @@ def shfl_up(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, ...@@ -315,7 +324,10 @@ def shfl_up(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr,
value: Optional[int, PrimExpr] value: Optional[int, PrimExpr]
The value to shuffle The value to shuffle
""" """
return tir.call_extern(value.dtype, "__shfl_up_sync", 0xffffffff, value, offset) if _IS_HIP_AVAILABLE:
return tir.call_extern(value.dtype, "__shfl_up", value, offset)
else:
return tir.call_extern(value.dtype, "__shfl_up_sync", 0xffffffff, value, offset)
def sync_threads(): def sync_threads():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment