/**
 * Copyright (c) Facebook, Inc. and its affiliates.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

#pragma once

#include <hip/hip_runtime.h>

namespace faiss {
namespace gpu {

// defines to simplify the SASS assembly structure file/line in the profiler
// HC
//#define GET_BITFIELD_U32(OUT, VAL, POS, LEN) \
//    asm("bfe.u32 %0, %1, %2, %3;" : "=r"(OUT) : "r"(VAL), "r"(POS), "r"(LEN));
//
//#define GET_BITFIELD_U64(OUT, VAL, POS, LEN) \
//    asm("bfe.u64 %0, %1, %2, %3;" : "=l"(OUT) : "l"(VAL), "r"(POS), "r"(LEN));

#define GET_BITFIELD_U32(OUT, VAL, POS, LEN) \
    {unsigned int pos1 = (POS) & 0xff; unsigned int len1 = (LEN) & 0xff; unsigned int m = (1u << len1) - 1u; OUT = ((VAL) >> pos1) & m;}

#define GET_BITFIELD_U64(OUT, VAL, POS, LEN) \
    POS &= 0xff; LEN &= 0xff; uint64_t m = (1u << LEN) - 1u; OUT = (VAL >> POS) & m; 

__device__ __forceinline__ unsigned int getBitfield(
        unsigned int val,
        int pos,
        int len) {
    unsigned int ret;
// HC
//    asm("v_bfe_u32 %0, %1, %2, %3" : "=v"(ret) : "v"(val), "v"(pos), "v"(len));
    pos &= 0xff;
    len &= 0xff;

    unsigned int m = (1u << len) - 1u;
    return (val >> pos) & m;
//    asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len));
//    return ret;
}

__device__ __forceinline__ uint64_t
getBitfield(uint64_t val, int pos, int len) {
    uint64_t ret;
    // HC
    //asm("v_bfe_u64 %0, %1, %2, %3" : "=v"(ret) : "v"(val), "v"(pos), "v"(len));
    pos &= 0xff;
    len &= 0xff;

    uint64_t m = (1u << len) - 1u;
    return (val >> pos) & m;    
    //asm("bfe.u64 %0, %1, %2, %3;" : "=l"(ret) : "l"(val), "r"(pos), "r"(len));
    //return ret;
}

__device__ __forceinline__ unsigned int setBitfield(
        unsigned int val,
        unsigned int toInsert,
        int pos,
        int len) {
    unsigned int ret;
    // HC
    pos &= 0xff;
    len &= 0xff;

    unsigned int m = (1u << len) - 1u;
    toInsert &= m;
    toInsert <<= pos;
    m <<= pos;

    return (val & ~m) | toInsert;
    //asm("bfi.b32 %0, %1, %2, %3, %4;"
    //    : "=r"(ret)
    //    : "r"(toInsert), "r"(val), "r"(pos), "r"(len));
    //return ret;
}

__device__ __forceinline__ int getLaneId() {
    int laneId;
    // HC
    laneId = __lane_id();
    // asm("mov.u32 %0, %%laneid;" : "=r"(laneId));
    return laneId;
}

// HC
__device__ __forceinline__ unsigned long long int getLaneMaskLt() {
  const std::uint64_t m = (1ull << getLaneId()) - 1ull;
  return m;
}
//__device__ __forceinline__ unsigned getLaneMaskLt() {
//    unsigned mask;
//    asm("mov.u32 %0, %%lanemask_lt;" : "=r"(mask));
//    return mask;
//}

// HC
#if defined (__HIP_PLATFORM_HCC__)
__device__ __forceinline__ unsigned long long int getLaneMaskLe() {
  std::uint64_t m = UINT64_MAX >> (sizeof(std::uint64_t) * CHAR_BIT - (getLaneId() + 1));
  return m;
}
#else
__device__ __forceinline__ unsigned getLaneMaskLe() {
  unsigned mask;
  asm("mov.u32 %0, %%lanemask_le;" : "=r"(mask));
  return mask;
}
#endif

//__device__ __forceinline__ unsigned getLaneMaskLe() {
//    unsigned mask;
//    asm("mov.u32 %0, %%lanemask_le;" : "=r"(mask));
//    return mask;
//}


// HC
#if defined(__HIP_PLATFORM_HCC__)
__device__ __forceinline__ unsigned long long int getLaneMaskGt() {
  const std::uint64_t m = getLaneMaskLe();
  return m ? ~m : m;
}
#else
__device__ __forceinline__ unsigned getLaneMaskGt() {
  unsigned mask;
  asm("mov.u32 %0, %%lanemask_gt;" : "=r"(mask));
  return mask;
}
#endif

//__device__ __forceinline__ unsigned getLaneMaskGt() {
//    unsigned mask;
//    asm("mov.u32 %0, %%lanemask_gt;" : "=r"(mask));
//    return mask;
//}

// HC
#if defined(__HIP_PLATFORM_HCC__)
__device__ __forceinline__ unsigned long long int getLaneMaskGe() {
  const std::uint64_t m = getLaneMaskLt();
  return ~m;
}
#else
__device__ __forceinline__ unsigned getLaneMaskGe() {
  unsigned mask;
  asm("mov.u32 %0, %%lanemask_ge;" : "=r"(mask));
  return mask;
}
#endif
//__device__ __forceinline__ unsigned getLaneMaskGe() {
//    unsigned mask;
//    asm("mov.u32 %0, %%lanemask_ge;" : "=r"(mask));
//    return mask;
//}

__device__ __forceinline__ void namedBarrierWait(int name, int numThreads) {
    asm volatile("bar.sync %0, %1;" : : "r"(name), "r"(numThreads) : "memory");
}

__device__ __forceinline__ void namedBarrierArrived(int name, int numThreads) {
    asm volatile("bar.arrive %0, %1;"
                 :
                 : "r"(name), "r"(numThreads)
                 : "memory");
}

} // namespace gpu
} // namespace faiss
