layernorm_kernels.cu 8.83 KB
Newer Older
1
2
3
4
#include "type_convert.cuh"
#include "dispatch_utils.h"

#include <torch/cuda.h>
5
#include <c10/cuda/CUDAGuard.h>
zhuwenwen's avatar
zhuwenwen committed
6

7
#ifndef USE_ROCM
8
  #include <cub/cub.cuh>
9
#else
10
  #include <hipcub/hipcub.hpp>
11
#endif
12

Woosuk Kwon's avatar
Woosuk Kwon committed
13
namespace vllm {
14
15

// TODO(woosuk): Further optimize this kernel.
16
template <typename scalar_t>
17
__global__ void rms_norm_kernel(
18
19
20
    scalar_t* __restrict__ out,          // [..., hidden_size]
    const scalar_t* __restrict__ input,  // [..., hidden_size]
    const int64_t input_stride,
21
22
    const scalar_t* __restrict__ weight,  // [hidden_size]
    const float epsilon, const int num_tokens, const int hidden_size) {
23
24
25
26
  __shared__ float s_variance;
  float variance = 0.0f;

  for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
27
    const float x = (float)input[blockIdx.x * input_stride + idx];
28
29
    variance += x * x;
  }
30
31
32
33
34

  using BlockReduce = cub::BlockReduce<float, 1024>;
  __shared__ typename BlockReduce::TempStorage reduceStore;
  variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);

35
36
37
38
39
40
  if (threadIdx.x == 0) {
    s_variance = rsqrtf(variance / hidden_size + epsilon);
  }
  __syncthreads();

  for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
41
    float x = (float)input[blockIdx.x * input_stride + idx];
42
43
    out[blockIdx.x * hidden_size + idx] =
        ((scalar_t)(x * s_variance)) * weight[idx];
44
45
46
  }
}

47
48
49
50
/* Function specialization in the case of FP16/BF16 tensors.
   Additional optimizations we can make in this case are
   packed and vectorized operations, which help with the
   memory latency bottleneck. */
51
52
53
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
fused_add_rms_norm_kernel(
54
55
    scalar_t* __restrict__ input,  // [..., hidden_size]
    const int64_t input_stride,
56
57
58
    scalar_t* __restrict__ residual,      // [..., hidden_size]
    const scalar_t* __restrict__ weight,  // [hidden_size]
    const float epsilon, const int num_tokens, const int hidden_size) {
59
60
61
62
63
  // Sanity checks on our vector struct and type-punned pointer arithmetic
  static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
  static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);

  const int vec_hidden_size = hidden_size / width;
64
  const int64_t vec_input_stride = input_stride / width;
65
66
67
68
69
  __shared__ float s_variance;
  float variance = 0.0f;
  /* These and the argument pointers are all declared `restrict` as they are
     not aliased in practice. Argument pointers should not be dereferenced
     in this kernel as that would be undefined behavior */
70
71
72
73
74
75
  auto* __restrict__ input_v =
      reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
  auto* __restrict__ residual_v =
      reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
  auto* __restrict__ weight_v =
      reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
76
77
78

  for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
    int id = blockIdx.x * vec_hidden_size + idx;
79
80
    int64_t strided_id = blockIdx.x * vec_input_stride + idx;
    _f16Vec<scalar_t, width> temp = input_v[strided_id];
81
82
83
84
    temp += residual_v[id];
    variance += temp.sum_squares();
    residual_v[id] = temp;
  }
85
86
87
88
89

  using BlockReduce = cub::BlockReduce<float, 1024>;
  __shared__ typename BlockReduce::TempStorage reduceStore;
  variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);

90
91
92
93
94
95
96
  if (threadIdx.x == 0) {
    s_variance = rsqrtf(variance / hidden_size + epsilon);
  }
  __syncthreads();

  for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
    int id = blockIdx.x * vec_hidden_size + idx;
97
    int64_t strided_id = blockIdx.x * vec_input_stride + idx;
98
99
100
    _f16Vec<scalar_t, width> temp = residual_v[id];
    temp *= s_variance;
    temp *= weight_v[idx];
101
    input_v[strided_id] = temp;
102
103
104
105
106
107
  }
}

/* Generic fused_add_rms_norm_kernel
   The width field is not used here but necessary for other specializations.
 */
108
109
110
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
fused_add_rms_norm_kernel(
111
112
    scalar_t* __restrict__ input,  // [..., hidden_size]
    const int64_t input_stride,
113
114
115
    scalar_t* __restrict__ residual,      // [..., hidden_size]
    const scalar_t* __restrict__ weight,  // [hidden_size]
    const float epsilon, const int num_tokens, const int hidden_size) {
116
117
118
119
  __shared__ float s_variance;
  float variance = 0.0f;

  for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
120
    scalar_t z = input[blockIdx.x * input_stride + idx];
121
    z += residual[blockIdx.x * hidden_size + idx];
122
    float x = (float)z;
123
    variance += x * x;
124
    residual[blockIdx.x * hidden_size + idx] = z;
125
  }
126
127
128
129
130

  using BlockReduce = cub::BlockReduce<float, 1024>;
  __shared__ typename BlockReduce::TempStorage reduceStore;
  variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);

131
132
133
134
135
136
  if (threadIdx.x == 0) {
    s_variance = rsqrtf(variance / hidden_size + epsilon);
  }
  __syncthreads();

  for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
137
    float x = (float)residual[blockIdx.x * hidden_size + idx];
138
    input[blockIdx.x * input_stride + idx] =
139
        ((scalar_t)(x * s_variance)) * weight[idx];
140
141
142
  }
}

143
}  // namespace vllm
144

145
146
147
void rms_norm(torch::Tensor& out,     // [..., hidden_size]
              torch::Tensor& input,   // [..., hidden_size]
              torch::Tensor& weight,  // [hidden_size]
148
              double epsilon) {
149
  TORCH_CHECK(out.is_contiguous());
150
  TORCH_CHECK(input.stride(-1) == 1);
151
152
  TORCH_CHECK(weight.is_contiguous());

153
154
  int hidden_size = input.size(-1);
  int num_tokens = input.numel() / hidden_size;
155
  int64_t input_stride = input.stride(-2);
zhuwenwen's avatar
zhuwenwen committed
156
157
158

  dim3 grid(num_tokens);
  dim3 block(std::min(hidden_size, 1024));
159
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
160
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
zhuwenwen's avatar
zhuwenwen committed
161
162
  VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
    vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
163
        out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), input_stride,
zhuwenwen's avatar
zhuwenwen committed
164
165
        weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
  });
166
}
167

168
169
170
171
172
173
174
175
#define LAUNCH_FUSED_ADD_RMS_NORM(width)                                    \
  VLLM_DISPATCH_FLOATING_TYPES(                                             \
      input.scalar_type(), "fused_add_rms_norm_kernel", [&] {               \
        vllm::fused_add_rms_norm_kernel<scalar_t, width>                    \
            <<<grid, block, 0, stream>>>(                                   \
                input.data_ptr<scalar_t>(), input_stride,                   \
                residual.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), \
                epsilon, num_tokens, hidden_size);                          \
176
177
178
179
180
      });

void fused_add_rms_norm(torch::Tensor& input,     // [..., hidden_size]
                        torch::Tensor& residual,  // [..., hidden_size]
                        torch::Tensor& weight,    // [hidden_size]
181
                        double epsilon) {
182
183
  TORCH_CHECK(residual.is_contiguous());
  TORCH_CHECK(weight.is_contiguous());
184
  int hidden_size = input.size(-1);
185
  int64_t input_stride = input.stride(-2);
186
  int num_tokens = input.numel() / hidden_size;
zhuwenwen's avatar
zhuwenwen committed
187
188
189
190
191
192
193
194

  dim3 grid(num_tokens);
  /* This kernel is memory-latency bound in many scenarios.
     When num_tokens is large, a smaller block size allows
     for increased block occupancy on CUs and better latency
     hiding on global mem ops. */
  const int max_block_size = (num_tokens < 256) ? 1024 : 256;
  dim3 block(std::min(hidden_size, max_block_size));
195
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
196
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
zhuwenwen's avatar
zhuwenwen committed
197
198
199
200
201
202
203
204
205
206
  /*If the tensor types are FP16/BF16, try to use the optimized kernel
    with packed + vectorized ops.
    Max optimization is achieved with a width-8 vector of FP16/BF16s
    since we can load at most 128 bits at once in a global memory op.
    However, this requires each tensor's data to be aligned to 16
    bytes.
   */
  auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
  auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
  auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
207
208
209
210
211
212
213
214
215
216
  constexpr int vector_width = 8;
  constexpr int req_alignment_bytes =
      vector_width * 2;  // vector_width * sizeof(bfloat16 or float16) (float32
                         // falls back to non-vectorized version anyway)
  bool ptrs_are_aligned = inp_ptr % req_alignment_bytes == 0 &&
                          res_ptr % req_alignment_bytes == 0 &&
                          wt_ptr % req_alignment_bytes == 0;
  bool offsets_are_multiple_of_vector_width =
      hidden_size % vector_width == 0 && input_stride % vector_width == 0;
  if (ptrs_are_aligned && offsets_are_multiple_of_vector_width) {
zhuwenwen's avatar
zhuwenwen committed
217
218
219
    LAUNCH_FUSED_ADD_RMS_NORM(8);
  } else {
    LAUNCH_FUSED_ADD_RMS_NORM(0);
zhangshao's avatar
zhangshao committed
220
  }
zhuwenwen's avatar
zhuwenwen committed
221
}