// SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/ck.hpp" namespace ck { template __device__ T warp_shuffle_up(const T& v_local, uint32_t lane_delta) { #if 0 return __shfl_up(v_local, lane_delta); #elif 1 static_assert(sizeof(T) == sizeof(int32_t), "wrong!"); const uint32_t wrap_around_lane_delta = warpSize - lane_delta; const int32_t v_remote_tmp = __builtin_amdgcn_ds_bpermute( (__lane_id() << 2) + (wrap_around_lane_delta << 2), bit_cast(v_local)); return bit_cast(v_remote_tmp); #endif } template __device__ T warp_shuffle_down(const T& v_local, uint32_t lane_delta) { #if 0 return __shfl_down(v_local, lane_delta); #elif 1 static_assert(sizeof(T) == sizeof(int32_t), "wrong!"); const int32_t v_remote_tmp = __builtin_amdgcn_ds_bpermute( (__lane_id() << 2) + (lane_delta << 2), bit_cast(v_local)); return bit_cast(v_remote_tmp); #endif } } // namespace ck