#pragma once

namespace sccl {
#define DIVUP(x, y) (((x) + (y) - 1) / (y))

#define ROUNDUP(x, y) (DIVUP((x), (y)) * (y))

#define ALIGN_POWER(x, y) ((x) > (y) ? ROUNDUP(x, y) : ((y) / ((y) / (x))))

#define ALIGN_SIZE(size, align) size = ((size + (align) - 1) / (align)) * (align);

#if !__CUDA_ARCH__
#ifndef __host__
#define __host__
#endif
#ifndef __device__
#define __device__
#endif
#endif

template <typename X, typename Y, typename Z = decltype(X() + Y())>
/**
 * @brief 计算向上取整的除法结果
 * @tparam X 被除数的类型
 * @tparam Y 除数的类型
 * @tparam Z 返回值的类型
 * @param x 被除数
 * @param y 除数
 * @return 返回 (x + y - 1) / y 的结果
 * @note 该函数为constexpr，可在编译时计算
 * @note 支持host和device端调用
 */
__host__ __device__ constexpr Z divUp(X x, Y y) {
    return (x + y - 1) / y;
}

template <typename X, typename Y, typename Z = decltype(X() + Y())>
/**
 * @brief 将数值x向上对齐到y的倍数
 *
 * @tparam X 输入数值类型
 * @tparam Y 对齐基数类型
 * @tparam Z 返回数值类型
 * @param x 需要对齐的数值
 * @param y 对齐基数
 * @return constexpr Z 返回向上对齐后的数值
 *
 * @note 该函数支持主机端(__host__)和设备端(__device__)调用
 * @note 使用公式 (x + y - 1) - (x + y - 1) % y 实现向上对齐
 */
__host__ __device__ constexpr Z roundUp(X x, Y y) {
    return (x + y - 1) - (x + y - 1) % y;
}

// assumes second argument is a power of 2
template <typename X, typename Z = decltype(X() + int())>
/**
 * @brief 将给定值向上对齐到指定边界
 *
 * @tparam X 输入值类型
 * @tparam Z 返回值类型
 * @param x 需要对齐的值
 * @param a 对齐边界(必须是2的幂次)
 * @return constexpr Z 对齐后的值
 *
 * @note 该函数支持主机和设备端调用
 */
__host__ __device__ constexpr Z alignUp(X x, int a) {
    return (x + a - 1) & Z(-a);
}

} // namespace sccl
