reduction_utils.cuh 4.52 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
/*
Muyang Li's avatar
Muyang Li committed
2
3
 * Adapted from
 * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh
Zhekai Zhang's avatar
Zhekai Zhang committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
 * Copyright (c) 2023, The vLLM team.
 * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#pragma once
#define FINAL_MASK 0xffffffff

namespace vllm {

template<typename T>
__inline__ __device__ T warpReduceSum(T val) {
#pragma unroll
Muyang Li's avatar
Muyang Li committed
27
28
29
    for (int mask = 16; mask > 0; mask >>= 1)
        val += __shfl_xor_sync(0xffffffff, val, mask, 32);
    return val;
Zhekai Zhang's avatar
Zhekai Zhang committed
30
31
}

Muyang Li's avatar
Muyang Li committed
32
33
template<typename T, int NUM>
__inline__ __device__ T warpReduceSumV2(T *val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
34
#pragma unroll
Muyang Li's avatar
Muyang Li committed
35
    for (int i = 0; i < NUM; i++) {
Zhekai Zhang's avatar
Zhekai Zhang committed
36
37
38
39
#pragma unroll
        for (int mask = 16; mask > 0; mask >>= 1)
            val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32);
    }
Muyang Li's avatar
Muyang Li committed
40
    return (T)(0.0f);
Zhekai Zhang's avatar
Zhekai Zhang committed
41
42
43
44
45
}

/* Calculate the sum of all elements in a block */
template<typename T>
__inline__ __device__ T blockReduceSum(T val) {
Muyang Li's avatar
Muyang Li committed
46
47
48
    static __shared__ T shared[32];
    int lane = threadIdx.x & 0x1f;
    int wid  = threadIdx.x >> 5;
Zhekai Zhang's avatar
Zhekai Zhang committed
49

Muyang Li's avatar
Muyang Li committed
50
    val = warpReduceSum<T>(val);
Zhekai Zhang's avatar
Zhekai Zhang committed
51

Muyang Li's avatar
Muyang Li committed
52
53
    if (lane == 0)
        shared[wid] = val;
Zhekai Zhang's avatar
Zhekai Zhang committed
54

Muyang Li's avatar
Muyang Li committed
55
    __syncthreads();
Zhekai Zhang's avatar
Zhekai Zhang committed
56

Muyang Li's avatar
Muyang Li committed
57
58
59
60
61
    // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
    // blockDim.x is not divided by 32
    val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
    val = warpReduceSum<T>(val);
    return val;
Zhekai Zhang's avatar
Zhekai Zhang committed
62
63
64
65
66
}

/* Calculate the sum of all elements in a block */
template<typename T>
__inline__ __device__ T blockAllReduceSum(T val) {
Muyang Li's avatar
Muyang Li committed
67
68
69
    static __shared__ T shared[32];
    int lane = threadIdx.x & 0x1f;
    int wid  = threadIdx.x >> 5;
Zhekai Zhang's avatar
Zhekai Zhang committed
70

Muyang Li's avatar
Muyang Li committed
71
    val = warpReduceSum<T>(val);
Zhekai Zhang's avatar
Zhekai Zhang committed
72

Muyang Li's avatar
Muyang Li committed
73
74
    if (lane == 0)
        shared[wid] = val;
Zhekai Zhang's avatar
Zhekai Zhang committed
75

Muyang Li's avatar
Muyang Li committed
76
    __syncthreads();
Zhekai Zhang's avatar
Zhekai Zhang committed
77

Muyang Li's avatar
Muyang Li committed
78
79
80
81
82
    // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
    // blockDim.x is not divided by 32
    val = (lane < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
    val = warpReduceSum<T>(val);
    return val;
Zhekai Zhang's avatar
Zhekai Zhang committed
83
84
}

Muyang Li's avatar
Muyang Li committed
85
86
template<typename T, int NUM>
__inline__ __device__ T blockReduceSumV2(T *val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
87
88
    static __shared__ T shared[NUM][33];
    int lane = threadIdx.x & 0x1f;
Muyang Li's avatar
Muyang Li committed
89
    int wid  = threadIdx.x >> 5;
Zhekai Zhang's avatar
Zhekai Zhang committed
90
91
92

    warpReduceSumV2<T, NUM>(val);

Muyang Li's avatar
Muyang Li committed
93
    if (lane == 0) {
Zhekai Zhang's avatar
Zhekai Zhang committed
94
#pragma unroll
Muyang Li's avatar
Muyang Li committed
95
        for (int i = 0; i < NUM; i++) {
Zhekai Zhang's avatar
Zhekai Zhang committed
96
97
98
99
100
101
102
103
            shared[i][wid] = val[i];
        }
    }

    __syncthreads();

    bool is_mask = threadIdx.x < (blockDim.x / 32.f);
#pragma unroll
Muyang Li's avatar
Muyang Li committed
104
105
    for (int i = 0; i < NUM; i++) {
        val[i] = is_mask ? shared[i][lane] : (T)(0.0f);
Zhekai Zhang's avatar
Zhekai Zhang committed
106
107
    }
    warpReduceSumV2<T, NUM>(val);
Muyang Li's avatar
Muyang Li committed
108
    return (T)0.0f;
Zhekai Zhang's avatar
Zhekai Zhang committed
109
110
111
}

template<typename T>
Muyang Li's avatar
Muyang Li committed
112
__inline__ __device__ T warpReduceMax(T val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
113
114
115
116
117
118
119
#pragma unroll
    for (int mask = 16; mask > 0; mask >>= 1)
        val = max(val, __shfl_xor_sync(0xffffffff, val, mask, 32));
    return val;
}
/* Calculate the maximum of all elements in a block */
template<typename T>
Muyang Li's avatar
Muyang Li committed
120
__inline__ __device__ T blockReduceMax(T val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
121
    static __shared__ T shared[32];
Muyang Li's avatar
Muyang Li committed
122
123
124
125
    int lane = threadIdx.x & 0x1f; // in-warp idx
    int wid  = threadIdx.x >> 5;   // warp idx
    val      = warpReduceMax(val); // get maxx in each warp
    if (lane == 0)                 // record in-warp maxx by warp Idx
Zhekai Zhang's avatar
Zhekai Zhang committed
126
127
128
129
130
131
132
133
134
135
        shared[wid] = val;
    __syncthreads();
    // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
    // blockDim.x is not divided by 32
    val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : -1e20f;
    val = warpReduceMax(val);
    return val;
}

/* Calculate the maximum of all elements in a block */
Muyang Li's avatar
Muyang Li committed
136
137
template<typename T>
__inline__ __device__ T blockAllReduceMax(T val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
138
139
    static __shared__ T shared[32];
    int lane = threadIdx.x & 0x1f; // in-warp idx
Muyang Li's avatar
Muyang Li committed
140
    int wid  = threadIdx.x >> 5;   // warp idx
Zhekai Zhang's avatar
Zhekai Zhang committed
141

Muyang Li's avatar
Muyang Li committed
142
    val = warpReduceMax(val); // get maxx in each warp
Zhekai Zhang's avatar
Zhekai Zhang committed
143

Muyang Li's avatar
Muyang Li committed
144
    if (lane == 0) // record in-warp maxx by warp Idx
Zhekai Zhang's avatar
Zhekai Zhang committed
145
146
147
148
149
150
151
152
153
154
155
156
157
        shared[wid] = val;

    __syncthreads();

    // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
    // blockDim.x is not divided by 32
    val = (lane < (blockDim.x / 32.f)) ? shared[lane] : -1e20f;
    val = warpReduceMax(val);

    return val;
}

} // namespace vllm