Unverified Commit cda5ea15 authored by Tang Xinsheng's avatar Tang Xinsheng Committed by GitHub
Browse files

[AMD] fix bugs in warp shuffle (#790)



* [AMD] fix bugs in warp shuffle

* format

---------
Co-authored-by: default avatartangxinsheng.txs <tangxinsheng.txs@alibaba-inc.com>
parent 013adca0
......@@ -3,10 +3,13 @@
from tilelang import tvm as tvm
from tilelang.language import ptx_arrive_barrier, evaluate
from tilelang.language.kernel import get_thread_bindings, get_block_extents
from tilelang.utils.target import check_hip_availability
from tvm import tir
from typing import Union, Any
from tvm.tir import PrimExpr, Var, Call
_IS_HIP_AVAILABLE = check_hip_availability()
def create_list_of_mbarrier(*args: Any) -> Call:
"""
......@@ -295,7 +298,10 @@ def shfl_xor(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr,
Returns:
tir.Call: A handle to the shuffle operation
"""
return tir.call_extern(value.dtype, "__shfl_xor_sync", 0xffffffff, value, offset)
if _IS_HIP_AVAILABLE:
return tir.call_extern(value.dtype, "__shfl_xor", value, offset)
else:
return tir.call_extern(value.dtype, "__shfl_xor_sync", 0xffffffff, value, offset)
def shfl_down(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, tir.Call]):
......@@ -305,7 +311,10 @@ def shfl_down(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr
value: Optional[int, PrimExpr]
The value to shuffle
"""
return tir.call_extern(value.dtype, "__shfl_down_sync", 0xffffffff, value, offset)
if _IS_HIP_AVAILABLE:
return tir.call_extern(value.dtype, "__shfl_down", value, offset)
else:
return tir.call_extern(value.dtype, "__shfl_down_sync", 0xffffffff, value, offset)
def shfl_up(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, tir.Call]):
......@@ -315,7 +324,10 @@ def shfl_up(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr,
value: Optional[int, PrimExpr]
The value to shuffle
"""
return tir.call_extern(value.dtype, "__shfl_up_sync", 0xffffffff, value, offset)
if _IS_HIP_AVAILABLE:
return tir.call_extern(value.dtype, "__shfl_up", value, offset)
else:
return tir.call_extern(value.dtype, "__shfl_up_sync", 0xffffffff, value, offset)
def sync_threads():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment