"src/diffusers/models/unets/unet_3d_condition.py" did not exist on "915a56361157b02f1429f5cbd60094a83b1b0ff4"
align.h 1.86 KB
Newer Older
lishen's avatar
lishen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#pragma once

#define DIVUP(x, y) (((x) + (y) - 1) / (y))

#define ROUNDUP(x, y) (DIVUP((x), (y)) * (y))

#define ALIGN_POWER(x, y) ((x) > (y) ? ROUNDUP(x, y) : ((y) / ((y) / (x))))

#define ALIGN_SIZE(size, align) size = ((size + (align) - 1) / (align)) * (align);

#if !__CUDA_ARCH__
#ifndef __host__
#define __host__
#endif
#ifndef __device__
#define __device__
#endif
#endif

template <typename X, typename Y, typename Z = decltype(X() + Y())>
/**
 * @brief 计算向上取整的除法结果
 * @tparam X 被除数的类型
 * @tparam Y 除数的类型
 * @tparam Z 返回值的类型
 * @param x 被除数
 * @param y 除数
 * @return 返回 (x + y - 1) / y 的结果
 * @note 该函数为constexpr,可在编译时计算
 * @note 支持host和device端调用
 */
__host__ __device__ constexpr Z divUp(X x, Y y) {
    return (x + y - 1) / y;
}

template <typename X, typename Y, typename Z = decltype(X() + Y())>
/**
 * @brief 将数值x向上对齐到y的倍数
 *
 * @tparam X 输入数值类型
 * @tparam Y 对齐基数类型
 * @tparam Z 返回数值类型
 * @param x 需要对齐的数值
 * @param y 对齐基数
 * @return constexpr Z 返回向上对齐后的数值
 *
 * @note 该函数支持主机端(__host__)和设备端(__device__)调用
 * @note 使用公式 (x + y - 1) - (x + y - 1) % y 实现向上对齐
 */
__host__ __device__ constexpr Z roundUp(X x, Y y) {
    return (x + y - 1) - (x + y - 1) % y;
}

// assumes second argument is a power of 2
template <typename X, typename Z = decltype(X() + int())>
/**
 * @brief 将给定值向上对齐到指定边界
 *
 * @tparam X 输入值类型
 * @tparam Z 返回值类型
 * @param x 需要对齐的值
 * @param a 对齐边界(必须是2的幂次)
 * @return constexpr Z 对齐后的值
 *
 * @note 该函数支持主机和设备端调用
 */
__host__ __device__ constexpr Z alignUp(X x, int a) {
    return (x + a - 1) & Z(-a);
}