reduction_utils.cuh 2.96 KB
Newer Older
1
/*
2
3
 * Adapted from
 * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh
Woosuk Kwon's avatar
Woosuk Kwon committed
4
 * Copyright (c) 2023, The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
 * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
Woosuk Kwon's avatar
Woosuk Kwon committed
19
20
#pragma once

21
22
#include "cuda_compat.h"

Woosuk Kwon's avatar
Woosuk Kwon committed
23
namespace vllm {
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47

namespace detail {

template <typename T>
__inline__ __device__ T _max(T a, T b) {
  return max(a, b);
}

template <typename T>
__inline__ __device__ T _sum(T a, T b) {
  return a + b;
}

}  // namespace detail

template <typename T>
using ReduceFnType = T (*)(T, T);

// Helper function to return the next largest power of 2
static constexpr int _nextPow2(unsigned int num) {
  if (num <= 1) return num;
  return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}

48
template <typename T, int numLanes = WARP_SIZE>
49
__inline__ __device__ T warpReduce(T val, ReduceFnType<T> fn) {
50
51
52
  static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0,
                "numLanes is not a positive power of 2!");
  static_assert(numLanes <= WARP_SIZE);
53
#pragma unroll
54
  for (int mask = numLanes >> 1; mask > 0; mask >>= 1)
55
    val = fn(val, VLLM_SHFL_XOR_SYNC(val, mask));
Woosuk Kwon's avatar
Woosuk Kwon committed
56

57
  return val;
58
59
}

60
template <typename T, int maxBlockSize = 1024>
61
__inline__ __device__ T blockReduce(T val, ReduceFnType<T> fn) {
62
63
  static_assert(maxBlockSize <= 1024);
  if constexpr (maxBlockSize > WARP_SIZE) {
64
    val = warpReduce<T>(val, fn);
65
66
    // Calculates max number of lanes that need to participate in the last
    // warpReduce
67
68
69
70
    constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE;
    static __shared__ T shared[maxActiveLanes];
    int lane = threadIdx.x % WARP_SIZE;
    int wid = threadIdx.x / WARP_SIZE;
71
    if (lane == 0) shared[wid] = val;
Woosuk Kwon's avatar
Woosuk Kwon committed
72

73
    __syncthreads();
Woosuk Kwon's avatar
Woosuk Kwon committed
74

75
76
    val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane]
                                                        : (T)(0.0f);
77
    val = warpReduce<T, _nextPow2(maxActiveLanes)>(val, fn);
78
79
  } else {
    // A single warpReduce is equal to blockReduce
80
    val = warpReduce<T, _nextPow2(maxBlockSize)>(val, fn);
81
  }
Woosuk Kwon's avatar
Woosuk Kwon committed
82
83
84
  return val;
}

85
86
87
88
89
90
91
92
93
94
template <typename T, int maxBlockSize = 1024>
__inline__ __device__ T blockReduceMax(T val) {
  return blockReduce<T, maxBlockSize>(val, detail::_max<T>);
}

template <typename T, int maxBlockSize = 1024>
__inline__ __device__ T blockReduceSum(T val) {
  return blockReduce<T, maxBlockSize>(val, detail::_sum<T>);
}

95
}  // namespace vllm