reduction_utils.cuh 1.93 KB
Newer Older
1
2
/*
 * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh
Woosuk Kwon's avatar
Woosuk Kwon committed
3
 * Copyright (c) 2023, The vLLM team.
4
5
6
7
8
9
10
11
12
13
14
15
16
17
 * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
Woosuk Kwon's avatar
Woosuk Kwon committed
18
19
#pragma once

20
21
#include "cuda_compat.h"

Woosuk Kwon's avatar
Woosuk Kwon committed
22
namespace vllm {
Woosuk Kwon's avatar
Woosuk Kwon committed
23
24
25
26

template<typename T>
__inline__ __device__ T warpReduceSum(T val) {
#pragma unroll
27
  for (int mask = WARP_SIZE/2; mask > 0; mask >>= 1)
28
    val += VLLM_SHFL_XOR_SYNC(val, mask);
Woosuk Kwon's avatar
Woosuk Kwon committed
29
30
31
  return val;
}

32
33
34
35
36
37
38
39
__inline__ __device__ constexpr int _calculateLaneMask(int warp_size) {
  return warp_size - 1;
}

__inline__ __device__ constexpr int _calculateWidShift(int warp_size) {
  return 5 + (warp_size >> 6);
}

Woosuk Kwon's avatar
Woosuk Kwon committed
40
41
42
/* Calculate the sum of all elements in a block */
template<typename T>
__inline__ __device__ T blockReduceSum(T val) {
43
  static __shared__ T shared[WARP_SIZE];
44
45
46
47
  constexpr auto LANE_MASK = _calculateLaneMask(WARP_SIZE);
  constexpr auto WID_SHIFT = _calculateWidShift(WARP_SIZE);
  int lane = threadIdx.x & LANE_MASK;
  int wid = threadIdx.x >> WID_SHIFT;
Woosuk Kwon's avatar
Woosuk Kwon committed
48
49
50
51
52
53
54
55
56
57

  val = warpReduceSum<T>(val);

  if (lane == 0)
    shared[wid] = val;

  __syncthreads();

  // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
  // blockDim.x is not divided by 32
58
  val = (threadIdx.x < (blockDim.x / (WARP_SIZE * 1.0f))) ? shared[lane] : (T)(0.0f);
Woosuk Kwon's avatar
Woosuk Kwon committed
59
60
61
62
  val = warpReduceSum<T>(val);
  return val;
}

Woosuk Kwon's avatar
Woosuk Kwon committed
63
} // namespace vllm