norm.cpp 9.9 KB
Newer Older
1
2
3
4
5
6
#include "common.h"
#include "vec.h"

namespace {

// NB: avoid using `at::vec::map<>` on bfloat16 or half
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
// Llama4TextL2Norm
template <typename scalar_t>
void l2norm_kernel_impl(
    scalar_t* __restrict__ output,
    const scalar_t* __restrict__ input,
    int64_t batch_size,
    int64_t hidden_size,
    float eps = 1e-5) {
  using bVec = at::vec::Vectorized<scalar_t>;
  using fVec = at::vec::Vectorized<float>;

  constexpr int kVecSize = bVec::size();
  at::parallel_for(0, batch_size, 0, [&](int64_t begin, int64_t end) {
    for (int64_t i = begin; i < end; ++i) {
      // local ptrs
      scalar_t* __restrict__ out_ptr = output + i * hidden_size;
      const scalar_t* __restrict__ input_ptr = input + i * hidden_size;

      fVec sum_fvec = fVec(float(0));
      float sum_val = float(0);

      int64_t d;
#pragma GCC unroll 4
      for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
        bVec x_bvec = bVec::loadu(input_ptr + d);
        fVec x_fvec0, x_fvec1;
        std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);

        sum_fvec += x_fvec0 * x_fvec0;
        sum_fvec += x_fvec1 * x_fvec1;
      }
#pragma GCC unroll 4
      for (; d < hidden_size; ++d) {
        float x_val = static_cast<float>(input_ptr[d]);
        sum_val += x_val * x_val;
      }

      sum_val += vec_reduce_sum(sum_fvec);
      float rsqrt_var = float(1) / std::sqrt(sum_val / hidden_size + eps);
      const fVec scale_fvec = fVec(rsqrt_var);

#pragma GCC unroll 4
      for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
        bVec x_bvec = bVec::loadu(input_ptr + d);
        fVec x_fvec0, x_fvec1;
        std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);

        x_fvec0 = x_fvec0 * scale_fvec;
        x_fvec1 = x_fvec1 * scale_fvec;

        bVec out_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
        out_bvec.store(out_ptr + d);
      }
#pragma GCC unroll 4
      for (; d < hidden_size; ++d) {
        float x_val = static_cast<float>(input_ptr[d]);
        out_ptr[d] = static_cast<scalar_t>(x_val * rsqrt_var);
      }
    }
  });
}
68
69
70
71
72
73
74
template <typename scalar_t>
void rmsnorm_kernel_impl(
    scalar_t* __restrict__ output,
    const scalar_t* __restrict__ input,
    const scalar_t* __restrict__ weight,
    int64_t batch_size,
    int64_t hidden_size,
75
    int64_t input_strideN,
76
77
78
79
80
81
82
83
84
    float eps = 1e-5) {
  using bVec = at::vec::Vectorized<scalar_t>;
  using fVec = at::vec::Vectorized<float>;

  constexpr int kVecSize = bVec::size();
  at::parallel_for(0, batch_size, 0, [&](int64_t begin, int64_t end) {
    for (int64_t i = begin; i < end; ++i) {
      // local ptrs
      scalar_t* __restrict__ out_ptr = output + i * hidden_size;
85
      const scalar_t* __restrict__ input_ptr = input + i * input_strideN;
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143

      fVec sum_fvec = fVec(float(0));
      float sum_val = float(0);

      int64_t d;
#pragma GCC unroll 4
      for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
        bVec x_bvec = bVec::loadu(input_ptr + d);
        fVec x_fvec0, x_fvec1;
        std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);

        sum_fvec += x_fvec0 * x_fvec0;
        sum_fvec += x_fvec1 * x_fvec1;
      }
#pragma GCC unroll 4
      for (; d < hidden_size; ++d) {
        float x_val = static_cast<float>(input_ptr[d]);
        sum_val += x_val * x_val;
      }

      sum_val += vec_reduce_sum(sum_fvec);
      float rsqrt_var = float(1) / std::sqrt(sum_val / hidden_size + eps);
      const fVec scale_fvec = fVec(rsqrt_var);

#pragma GCC unroll 4
      for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
        bVec x_bvec = bVec::loadu(input_ptr + d);
        fVec x_fvec0, x_fvec1;
        std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);

        bVec w_bvec = bVec::loadu(weight + d);
        fVec w_fvec0, w_fvec1;
        std::tie(w_fvec0, w_fvec1) = at::vec::convert_to_float(w_bvec);

        x_fvec0 = x_fvec0 * scale_fvec * w_fvec0;
        x_fvec1 = x_fvec1 * scale_fvec * w_fvec1;

        bVec out_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
        out_bvec.store(out_ptr + d);
      }
#pragma GCC unroll 4
      for (; d < hidden_size; ++d) {
        float x_val = static_cast<float>(input_ptr[d]);
        float w_val = static_cast<float>(weight[d]);
        out_ptr[d] = static_cast<scalar_t>(x_val * rsqrt_var * w_val);
      }
    }
  });
}

template <typename scalar_t>
void fused_add_rmsnorm_kernel_impl(
    scalar_t* __restrict__ input,
    scalar_t* __restrict__ residual,
    const scalar_t* __restrict__ weight,
    float* __restrict__ buffer,
    int64_t batch_size,
    int64_t hidden_size,
144
    int64_t input_strideN,
145
146
147
148
149
150
151
152
153
154
155
    float eps = 1e-5) {
  using bVec = at::vec::Vectorized<scalar_t>;
  using fVec = at::vec::Vectorized<float>;

  constexpr int kVecSize = bVec::size();
  at::parallel_for(0, batch_size, 0, [&](int64_t begin, int64_t end) {
    int tid = at::get_thread_num();
    float* __restrict__ buffer_ptr = buffer + tid * hidden_size;

    for (int64_t i = begin; i < end; ++i) {
      // local ptrs
156
      scalar_t* __restrict__ input_ptr = input + i * input_strideN;
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
      scalar_t* __restrict__ residual_ptr = residual + i * hidden_size;

      fVec sum_fvec = fVec(float(0));
      float sum_val = float(0);

      int64_t d;
#pragma GCC unroll 4
      for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
        bVec x_bvec = bVec::loadu(input_ptr + d);
        fVec x_fvec0, x_fvec1;
        std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);

        bVec r_bvec = bVec::loadu(residual_ptr + d);
        fVec r_fvec0, r_fvec1;
        std::tie(r_fvec0, r_fvec1) = at::vec::convert_to_float(r_bvec);

        x_fvec0 += r_fvec0;
        x_fvec1 += r_fvec1;

        bVec out_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
        out_bvec.store(residual_ptr + d);

        sum_fvec += x_fvec0 * x_fvec0;
        sum_fvec += x_fvec1 * x_fvec1;

        x_fvec0.store(buffer_ptr + d);
        x_fvec1.store(buffer_ptr + d + fVec::size());
      }
#pragma GCC unroll 4
      for (; d < hidden_size; ++d) {
        float x_val = static_cast<float>(input_ptr[d]);
        float r_val = static_cast<float>(residual_ptr[d]);

        x_val += r_val;
        residual_ptr[d] = static_cast<scalar_t>(x_val);

        sum_val += x_val * x_val;
        buffer_ptr[d] = x_val;
      }

      sum_val += vec_reduce_sum(sum_fvec);
      float rsqrt_var = float(1) / std::sqrt(sum_val / hidden_size + eps);
      const fVec scale_fvec = fVec(rsqrt_var);

#pragma GCC unroll 4
      for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
        fVec x_fvec0 = fVec::loadu(buffer_ptr + d);
        fVec x_fvec1 = fVec::loadu(buffer_ptr + d + fVec::size());

        bVec w_bvec = bVec::loadu(weight + d);
        fVec w_fvec0, w_fvec1;
        std::tie(w_fvec0, w_fvec1) = at::vec::convert_to_float(w_bvec);

        x_fvec0 = x_fvec0 * scale_fvec * w_fvec0;
        x_fvec1 = x_fvec1 * scale_fvec * w_fvec1;
        bVec x_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
        x_bvec.store(input_ptr + d);
      }
#pragma GCC unroll 4
      for (; d < hidden_size; ++d) {
        float x_val = buffer_ptr[d] * rsqrt_var * static_cast<float>(weight[d]);
        input_ptr[d] = x_val;
      }
    }
  });
}

}  // anonymous namespace

// input : {batch_size, hidden_size}
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
at::Tensor l2norm_cpu(at::Tensor& input, double eps) {
  RECORD_FUNCTION("sgl-kernel::l2norm_cpu", std::vector<c10::IValue>({input}));

  CHECK_INPUT(input);
  CHECK_DIM(2, input);
  int64_t batch_size = input.size(0);
  int64_t hidden_size = input.size(1);
  at::Tensor output = at::empty_like(input);

  AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "l2norm_kernel", [&] {
    l2norm_kernel_impl<scalar_t>(output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), batch_size, hidden_size, eps);
  });
  return output;
}

// input : {batch_size, hidden_size}
243
244
245
246
// weight: {hidden_size}
at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps) {
  RECORD_FUNCTION("sgl-kernel::rmsnorm_cpu", std::vector<c10::IValue>({input, weight}));

247
  CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
248
249
250
251
252
253
254
  CHECK_INPUT(weight);
  CHECK_DIM(2, input);
  CHECK_DIM(1, weight);
  CHECK_EQ(input.size(1), weight.size(0));
  int64_t batch_size = input.size(0);
  int64_t hidden_size = input.size(1);
  at::Tensor output = at::empty_like(input);
255
  int64_t input_strideN = input.stride(0);
256
257
258
259
260
261
262
263

  AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "rmsnorm_kernel", [&] {
    rmsnorm_kernel_impl<scalar_t>(
        output.data_ptr<scalar_t>(),
        input.data_ptr<scalar_t>(),
        weight.data_ptr<scalar_t>(),
        batch_size,
        hidden_size,
264
        input_strideN,
265
266
267
268
269
270
271
272
273
274
        eps);
  });
  return output;
}

// input   : {batch_size, hidden_size}
// residual: {batch_size, hidden_size}
// weight  : {hidden_size}
void fused_add_rmsnorm_cpu(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps) {
  RECORD_FUNCTION("sgl-kernel::fused_add_rmsnorm_cpu", std::vector<c10::IValue>({input, residual, weight}));
275
  CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
276
277
278
279
280
281
282
283
284
285
  CHECK_INPUT(residual);
  CHECK_INPUT(weight);
  CHECK_DIM(2, input);
  CHECK_DIM(2, residual);
  CHECK_DIM(1, weight);
  CHECK_EQ(input.size(0), residual.size(0));
  CHECK_EQ(input.size(1), residual.size(1));
  CHECK_EQ(input.size(1), weight.size(0));
  int64_t batch_size = input.size(0);
  int64_t hidden_size = input.size(1);
286
  int64_t input_strideN = input.stride(0);
287
288
289
290
291
292
293
294
295
296
297
298
299
300

  // allocate temp buffer to store x in float32 per thread
  // TODO: implement a singleton for context
  int64_t num_threads = at::get_num_threads();
  at::Tensor buffer = at::empty({num_threads, hidden_size}, input.options().dtype(at::kFloat));

  AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "fused_add_rmsnorm_kernel", [&] {
    fused_add_rmsnorm_kernel_impl<scalar_t>(
        input.data_ptr<scalar_t>(),
        residual.data_ptr<scalar_t>(),
        weight.data_ptr<scalar_t>(),
        buffer.data_ptr<float>(),
        batch_size,
        hidden_size,
301
        input_strideN,
302
303
304
        eps);
  });
}