amd_warp_shuffle.hpp 1.04 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
4
5
6
7
8
9
10
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck/ck.hpp"

namespace ck {

template <typename T>
Chao Liu's avatar
Chao Liu committed
11
__device__ T warp_shuffle_up(const T& v_local, uint32_t lane_delta)
Chao Liu's avatar
Chao Liu committed
12
13
{
#if 0
Chao Liu's avatar
Chao Liu committed
14
    return  __shfl_up(v_local, lane_delta);
Chao Liu's avatar
Chao Liu committed
15
#elif 1
Chao Liu's avatar
Chao Liu committed
16
    static_assert(sizeof(T) == sizeof(int32_t), "wrong!");
Chao Liu's avatar
Chao Liu committed
17

Chao Liu's avatar
Chao Liu committed
18
19
20
21
22
23
    const uint32_t wrap_around_lane_delta = warpSize - lane_delta;

    const int32_t v_remote_tmp = __builtin_amdgcn_ds_bpermute(
        (__lane_id() << 2) + (wrap_around_lane_delta << 2), bit_cast<int32_t>(v_local));

    return bit_cast<T>(v_remote_tmp);
Chao Liu's avatar
Chao Liu committed
24
25
26
27
#endif
}

template <typename T>
Chao Liu's avatar
Chao Liu committed
28
__device__ T warp_shuffle_down(const T& v_local, uint32_t lane_delta)
Chao Liu's avatar
Chao Liu committed
29
30
{
#if 0
Chao Liu's avatar
Chao Liu committed
31
    return  __shfl_down(v_local, lane_delta);
Chao Liu's avatar
Chao Liu committed
32
#elif 1
Chao Liu's avatar
Chao Liu committed
33
34
35
36
37
38
    static_assert(sizeof(T) == sizeof(int32_t), "wrong!");

    const int32_t v_remote_tmp = __builtin_amdgcn_ds_bpermute(
        (__lane_id() << 2) + (lane_delta << 2), bit_cast<int32_t>(v_local));

    return bit_cast<T>(v_remote_tmp);
Chao Liu's avatar
Chao Liu committed
39
40
41
42
#endif
}

} // namespace ck