layernorm_utils.cuh 20.1 KB
Newer Older
1
2
3
4
5
6
7
#pragma once

/**
 * __device__ layernorm utilities.
 */

#include "quantization/vectorization.cuh"
8
#include "quantization/utils.cuh"
9
10
#include "quant_conversions.cuh"

Aidyn-A's avatar
Aidyn-A committed
11
#include "../../cub_helpers.h"
12
#include "../../cuda_compat.h"
13
14
15
16
17
18
19
20
21
22
23
24

namespace vllm {

// has_residual must be true, if residual is not a nullptr
template <typename scalar_t, bool has_residual = false>
__device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
                            int32_t const hidden_size, float const epsilon,
                            scalar_t const* __restrict__ residual = nullptr) {
  int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
  // sum of squares
  float ss = 0.0f;

25
  for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
26
27
28
29
30
31
32
33
34
35
    float x = static_cast<float>(input[token_offset + i]);
    if constexpr (has_residual) {
      x += static_cast<float>(residual[token_offset + i]);
    }

    ss += x * x;
  }

  using BlockReduce = cub::BlockReduce<float, 1024>;
  __shared__ typename BlockReduce::TempStorage reduceStore;
Aidyn-A's avatar
Aidyn-A committed
36
  ss = BlockReduce(reduceStore).Reduce(ss, CubAddOp{}, blockDim.x);
37
38
39
40
41
42
43
44
45
46

  __shared__ float s_rms;
  if (threadIdx.x == 0) {
    s_rms = rsqrtf(ss / hidden_size + epsilon);
  }
  __syncthreads();

  *rms = s_rms;
}

47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
__device__ float warpReduceMaxSpecialized(volatile float* val, int64_t tid,
                                          int64_t thread_in_warp,
                                          int64_t reduced_elems) {
  static_assert(WARP_SIZE == 32 || WARP_SIZE == 64);
  if constexpr (WARP_SIZE == 64) {
    if (thread_in_warp + 64 < reduced_elems)
      val[tid] = fmaxf(val[tid], val[tid + 64]);
  }
  if (thread_in_warp + 32 < reduced_elems)
    val[tid] = fmaxf(val[tid], val[tid + 32]);
  if (thread_in_warp + 16 < reduced_elems)
    val[tid] = fmaxf(val[tid], val[tid + 16]);
  if (thread_in_warp + 8 < reduced_elems)
    val[tid] = fmaxf(val[tid], val[tid + 8]);
  if (thread_in_warp + 4 < reduced_elems)
    val[tid] = fmaxf(val[tid], val[tid + 4]);
  if (thread_in_warp + 2 < reduced_elems)
    val[tid] = fmaxf(val[tid], val[tid + 2]);
  if (thread_in_warp + 1 < reduced_elems)
    val[tid] = fmaxf(val[tid], val[tid + 1]);
  return val[tid];
}

template <typename scalar_t, typename scalar_out_t, bool has_residual = false,
          bool is_scale_transposed = false>
72
73
74
75
__device__ void compute_dynamic_per_token_scales(
    float* __restrict__ token_scale, float* __restrict__ all_token_scales,
    scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
    float const rms, float const* __restrict__ scale_ub,
76
    int32_t const hidden_size, scalar_t const* __restrict__ residual = nullptr,
77
    int32_t const group_size = 0, int64_t outer_scale_stride = 1) {
78
  float block_absmax_val_maybe = 0.0f;
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
  constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>};
  __syncthreads();
  if (group_size > 0) {
    __shared__ float s_max_vals[1024];
    int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
    int64_t num_groups = hidden_size / group_size;
    int64_t const threads_per_group = blockDim.x / num_groups;
    int64_t const thread_in_group = threadIdx.x % threads_per_group;
    int64_t const group_offset = threadIdx.x / threads_per_group * group_size;
    int64_t const thread_offset = group_offset + thread_in_group;
    int64_t const thread_end =
        min(group_offset + group_size, static_cast<int64_t>(hidden_size));
    for (auto i = thread_offset; i < thread_end; i += threads_per_group) {
      float x = static_cast<float>(input[token_offset + i]);
      if constexpr (has_residual) {
        x += static_cast<float>(residual[token_offset + i]);
      }
      x = static_cast<float>(static_cast<scalar_t>(x * rms) * weight[i]);
      block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x));
98
    }
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    s_max_vals[threadIdx.x] = block_absmax_val_maybe;
    __syncthreads();

    int64_t const warp_size = WARP_SIZE;
    int64_t const num_warps = blockDim.x / warp_size;
    int64_t const warp_id = threadIdx.x / warp_size;
    int64_t const thread_in_warp = threadIdx.x % warp_size;
    int64_t const groups_per_warp = (num_groups + num_warps - 1) / num_warps;
    for (auto i = 0; i < groups_per_warp; ++i) {
      int64_t const group_id = i * num_warps + warp_id;
      if (group_id < num_groups) {
        int64_t warp_start = group_id * threads_per_group;
        int64_t const start = warp_start + thread_in_warp;
        int64_t const warp_end = min(warp_start + threads_per_group,
                                     static_cast<int64_t>(hidden_size));
        for (auto j = start; j + warp_size < warp_end; j += warp_size) {
          s_max_vals[start] =
              fmaxf(s_max_vals[start], s_max_vals[j + warp_size]);
        }
        warpReduceMaxSpecialized(s_max_vals, start, thread_in_warp,
                                 min(warp_end - warp_start, warp_size));
      }
    }
    __syncthreads();

    if (thread_in_group == 0 && thread_offset < thread_end) {
      block_absmax_val_maybe = s_max_vals[threadIdx.x];
      float scale = 0.0f;
      if (scale_ub) {
        scale = min(block_absmax_val_maybe, *scale_ub);
      } else {
        scale = block_absmax_val_maybe;
      }
      // token scale computation
      scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
      // Global output store
      if constexpr (is_scale_transposed) {
136
137
138
        int64_t const scale_rows = (gridDim.x + outer_scale_stride - 1) /
                                   outer_scale_stride * outer_scale_stride;
        all_token_scales[(threadIdx.x / threads_per_group) * scale_rows +
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
                         blockIdx.x] = scale;
      } else {
        all_token_scales[blockIdx.x * num_groups +
                         threadIdx.x / threads_per_group] = scale;
      }
    }
    __syncthreads();
  } else {
    int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);

    for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
      float x = static_cast<float>(input[token_offset + i]);
      if constexpr (has_residual) {
        x += static_cast<float>(residual[token_offset + i]);
      }
154

155
156
      x = static_cast<float>(static_cast<scalar_t>(x * rms) * weight[i]);
      block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x));
157
    }
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
    using BlockReduce = cub::BlockReduce<float, 1024>;
    __shared__ typename BlockReduce::TempStorage reduceStore;
    block_absmax_val_maybe =
        BlockReduce(reduceStore)
            .Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x);

    __shared__ float s_token_scale;
    if (threadIdx.x == 0) {
      float scale = 0.0f;
      if (scale_ub) {
        scale = min(block_absmax_val_maybe, *scale_ub);
      } else {
        scale = block_absmax_val_maybe;
      }
      // token scale computation
      scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
      s_token_scale = scale;                 // Shared memory store
      all_token_scales[blockIdx.x] = scale;  // Global output store
    }
    __syncthreads();
178

179
180
    *token_scale = s_token_scale;
  }
181
182
183
}

template <typename scalar_t, typename scalar_out_t, bool is_scale_inverted,
184
          bool has_residual = false, bool is_scale_transposed = false>
185
186
187
188
189
__device__ void norm_and_quant(
    scalar_out_t* __restrict__ output, scalar_t const* __restrict__ input,
    scalar_t const* __restrict__ weight, float const rms, float* const scale,
    int32_t const hidden_size, scalar_t* __restrict__ residual = nullptr,
    int32_t const group_size = 0, int64_t outer_scale_stride = 1) {
190
191
  int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);

192
  for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
193
194
195
196
197
198
199
200
    float x = static_cast<float>(input[token_offset + i]);
    if constexpr (has_residual) {
      x += static_cast<float>(residual[token_offset + i]);
      residual[token_offset + i] = static_cast<scalar_t>(x);
    }
    // Norm
    x = static_cast<float>(static_cast<scalar_t>(x * rms) * weight[i]);
    // Quant
201
202
203
204
    // If groupwise is_scale_inverted is true, so we invert the scale here.
    int64_t scale_idx = 0;
    if (group_size > 0) {
      if constexpr (is_scale_transposed) {
205
206
207
        int64_t const scale_rows = (gridDim.x + outer_scale_stride - 1) /
                                   outer_scale_stride * outer_scale_stride;
        scale_idx = (i / group_size) * scale_rows + blockIdx.x;
208
209
210
211
212
213
214
215
      } else {
        scale_idx = blockIdx.x * (hidden_size / group_size) + i / group_size;
      }
    }
    auto scale_val =
        (group_size > 0
             ? (is_scale_inverted ? 1.0f / scale[scale_idx] : scale[scale_idx])
             : *scale);
216
    output[token_offset + i] =
217
        ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(x, scale_val);
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
  }
}

namespace vectorized {

// Compute 1.0/rms(input)
// hidden_size must be a multiple of 4
template <typename scalar_t, bool has_residual = false>
__device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
                            int32_t const hidden_size, float const epsilon,
                            scalar_t const* __restrict__ residual = nullptr) {
  int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);

  // Vectorized input/output to better utilize memory bandwidth.
  vec4_t<scalar_t> const* vec_input =
      reinterpret_cast<vec4_t<scalar_t> const*>(&input[token_offset]);
  vec4_t<scalar_t> const* vec_residual = nullptr;
  if constexpr (has_residual) {
    vec_residual =
        reinterpret_cast<vec4_t<scalar_t> const*>(&residual[token_offset]);
  }

  // sum of squares
  float ss = 0.0f;

243
  const int VEC_SIZE = 4;
244
245
246
  int32_t const num_vec_elems = hidden_size >> 2;

#pragma unroll 4
247
  for (auto i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
248
249
250
    vec4_t<scalar_t> in = vec_input[i];

    vec4_t<float> x;
251
252
253
254
255
#pragma unroll
    for (int j = 0; j < VEC_SIZE; ++j) {
      x.val[j] = static_cast<float>(in.val[j]);
    }

256
257
    if constexpr (has_residual) {
      vec4_t<scalar_t> r = vec_residual[i];
258
259
260
261
#pragma unroll
      for (int j = 0; j < VEC_SIZE; ++j) {
        x.val[j] += static_cast<float>(r.val[j]);
      }
262
263
    }

264
265
266
267
#pragma unroll
    for (int j = 0; j < VEC_SIZE; ++j) {
      ss += x.val[j] * x.val[j];
    }
268
269
270
271
  }

  using BlockReduce = cub::BlockReduce<float, 1024>;
  __shared__ typename BlockReduce::TempStorage reduceStore;
Aidyn-A's avatar
Aidyn-A committed
272
  ss = BlockReduce(reduceStore).Reduce(ss, CubAddOp{}, blockDim.x);
273
274
275
276
277
278
279
280
281
282
283
284

  __shared__ float s_rms;
  if (threadIdx.x == 0) {
    s_rms = rsqrtf(ss / hidden_size + epsilon);
  }
  __syncthreads();

  *rms = s_rms;
}

// Vectorized version of vllm::compute_dynamic_per_token_scales
// hidden_size must be a multiple of 4
285
286
template <typename scalar_t, typename scalar_out_t, bool has_residual = false,
          bool is_scale_transposed = false, int32_t group_size = 0>
287
288
289
290
__device__ void compute_dynamic_per_token_scales(
    float* __restrict__ token_scale, float* __restrict__ all_token_scales,
    scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
    float const rms, float const* __restrict__ scale_ub,
291
292
    int32_t const hidden_size, scalar_t const* __restrict__ residual = nullptr,
    int64_t outer_scale_stride = 1) {
293
  constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>};
294

295
  const int VEC_SIZE = 4;
296
297
  float block_absmax_val_maybe = 0.0f;

298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
  // Vectorized input/weight/residual to better utilize memory bandwidth.
  vec4_t<scalar_t> const* vec_input = nullptr;
  vec4_t<scalar_t> const* vec_weight = nullptr;
  vec4_t<scalar_t> const* vec_residual = nullptr;

  if constexpr (group_size > 0) {
    __shared__ float s_max_vals[1024];

    int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
    int64_t const num_groups = hidden_size / group_size;
    int64_t const threads_per_group = blockDim.x / num_groups;
    int64_t const thread_in_group = threadIdx.x % threads_per_group;
    int64_t const group_offset =
        threadIdx.x / threads_per_group * (group_size >> 2);
    int64_t const thread_offset = group_offset + thread_in_group;
    int64_t const thread_end = min(group_offset + (group_size >> 2),
                                   static_cast<int64_t>(hidden_size >> 2));
    vec_input = reinterpret_cast<vec4_t<scalar_t> const*>(&input[token_offset]);
    vec_weight = reinterpret_cast<vec4_t<scalar_t> const*>(weight);
    if constexpr (has_residual) {
      vec_residual =
          reinterpret_cast<vec4_t<scalar_t> const*>(&residual[token_offset]);
    }
    int32_t const num_vec_elems = thread_end;

323
#pragma unroll 4
324
325
326
    for (auto i = thread_offset; i < num_vec_elems; i += threads_per_group) {
      vec4_t<scalar_t> in = vec_input[i];
      vec4_t<scalar_t> const w = vec_weight[i];
327

328
      vec4_t<float> x;
329
#pragma unroll
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
      for (int j = 0; j < VEC_SIZE; ++j) {
        x.val[j] = static_cast<float>(in.val[j]);
      }

      if constexpr (has_residual) {
        vec4_t<scalar_t> r = vec_residual[i];
#pragma unroll
        for (int j = 0; j < VEC_SIZE; ++j) {
          x.val[j] += static_cast<float>(r.val[j]);
        }
      }

#pragma unroll
      for (int j = 0; j < VEC_SIZE; ++j) {
        block_absmax_val_maybe =
            fmaxf(block_absmax_val_maybe,
                  fabs(static_cast<scalar_t>(x.val[j] * rms) * w.val[j]));
      }
348
349
    }

350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
    s_max_vals[threadIdx.x] = block_absmax_val_maybe;
    __syncthreads();

    int64_t const warp_size = WARP_SIZE;
    int64_t const num_warps = blockDim.x / warp_size;
    int64_t const warp_id = threadIdx.x / warp_size;
    int64_t const thread_in_warp = threadIdx.x % warp_size;
    int64_t const groups_per_warp = (num_groups + num_warps - 1) / num_warps;
    for (auto i = 0; i < groups_per_warp; ++i) {
      int64_t const group_id = i * num_warps + warp_id;
      if (group_id < num_groups) {
        int64_t warp_start = group_id * threads_per_group;
        int64_t const start = warp_start + thread_in_warp;
        int64_t const warp_end = min(warp_start + threads_per_group,
                                     static_cast<int64_t>(hidden_size));
        for (auto j = start; j + warp_size < warp_end; j += warp_size) {
          s_max_vals[start] =
              fmaxf(s_max_vals[start], s_max_vals[j + warp_size]);
        }
        warpReduceMaxSpecialized(s_max_vals, start, thread_in_warp,
                                 min(warp_end - warp_start, warp_size));
      }
    }
    __syncthreads();

    if (thread_in_group == 0 && thread_offset < thread_end) {
      block_absmax_val_maybe = s_max_vals[threadIdx.x];
      float scale = 0.0f;
      if (scale_ub) {
        scale = min(block_absmax_val_maybe, *scale_ub);
      } else {
        scale = block_absmax_val_maybe;
      }
      // token scale computation
      scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
      // Global output store
      if constexpr (is_scale_transposed) {
387
388
389
        int64_t const scale_rows = (gridDim.x + outer_scale_stride - 1) /
                                   outer_scale_stride * outer_scale_stride;
        all_token_scales[(threadIdx.x / threads_per_group) * scale_rows +
390
391
392
393
394
395
396
397
398
399
400
401
                         blockIdx.x] = scale;
      } else {
        all_token_scales[blockIdx.x * num_groups +
                         threadIdx.x / threads_per_group] = scale;
      }
    }
    __syncthreads();

  } else {
    int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
    vec_input = reinterpret_cast<vec4_t<scalar_t> const*>(&input[token_offset]);
    vec_weight = reinterpret_cast<vec4_t<scalar_t> const*>(weight);
402
    if constexpr (has_residual) {
403
404
405
406
407
408
409
410
411
412
413
414
      vec_residual =
          reinterpret_cast<vec4_t<scalar_t> const*>(&residual[token_offset]);
    }

    int32_t const num_vec_elems = (hidden_size >> 2);

#pragma unroll 4
    for (auto i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
      vec4_t<scalar_t> in = vec_input[i];
      vec4_t<scalar_t> const w = vec_weight[i];

      vec4_t<float> x;
415
416
#pragma unroll
      for (int j = 0; j < VEC_SIZE; ++j) {
417
        x.val[j] = static_cast<float>(in.val[j]);
418
      }
419

420
421
      if constexpr (has_residual) {
        vec4_t<scalar_t> r = vec_residual[i];
422
#pragma unroll
423
424
425
426
        for (int j = 0; j < VEC_SIZE; ++j) {
          x.val[j] += static_cast<float>(r.val[j]);
        }
      }
427

428
429
430
431
432
433
434
#pragma unroll
      for (int j = 0; j < VEC_SIZE; ++j) {
        block_absmax_val_maybe =
            fmaxf(block_absmax_val_maybe,
                  fabs(static_cast<scalar_t>(x.val[j] * rms) * w.val[j]));
      }
    }
435

436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
    using BlockReduce = cub::BlockReduce<float, 1024>;
    __shared__ typename BlockReduce::TempStorage reduceStore;
    block_absmax_val_maybe =
        BlockReduce(reduceStore)
            .Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x);

    __shared__ float s_token_scale;
    if (threadIdx.x == 0) {
      float scale = 0.0f;
      if (scale_ub) {
        scale = min(block_absmax_val_maybe, *scale_ub);
      } else {
        scale = block_absmax_val_maybe;
      }
      // token scale computation
      scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
      s_token_scale = scale;                 // shared memory store
      all_token_scales[blockIdx.x] = scale;  // global output store
454
    }
455
    __syncthreads();
456

457
458
    *token_scale = s_token_scale;
  }
459
460
461
462
}

// hidden_size must be a multiple of 4
template <typename scalar_t, typename scalar_out_t, bool is_scale_inverted,
463
464
          bool has_residual = false, bool is_scale_transposed = false,
          int32_t group_size = 0>
465
466
467
__device__ void norm_and_quant(scalar_out_t* __restrict__ output,
                               scalar_t const* __restrict__ input,
                               scalar_t const* __restrict__ weight,
468
                               float const rms, float* const scale,
469
                               int32_t const hidden_size,
470
471
                               scalar_t* __restrict__ residual = nullptr,
                               int64_t outer_scale_stride = 1) {
472
473
474
475
476
477
478
479
480
481
482
483
484
485
  int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);

  // Vectorized input/output/weight/residual to better utilize memory bandwidth.
  vec4_t<scalar_t> const* vec_input =
      reinterpret_cast<vec4_t<scalar_t> const*>(&input[token_offset]);
  vec4_t<scalar_t> const* vec_weight =
      reinterpret_cast<vec4_t<scalar_t> const*>(weight);
  q8x4_t<scalar_out_t>* vec_output =
      reinterpret_cast<q8x4_t<scalar_out_t>*>(&output[token_offset]);
  vec4_t<scalar_t>* vec_residual = nullptr;
  if constexpr (has_residual) {
    vec_residual = reinterpret_cast<vec4_t<scalar_t>*>(&residual[token_offset]);
  }

486
  const int VEC_SIZE = 4;
487
488
489
490
491
  int32_t const num_vec_elems = hidden_size >> 2;

// TODO(luka/varun) extract into type-agnostic vectorized quant function to
//  replace scaled_fp8_conversion_vec
#pragma unroll 4
492
  for (auto i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
493
494
495
496
    vec4_t<scalar_t> const in = vec_input[i];
    vec4_t<scalar_t> const w = vec_weight[i];

    vec4_t<float> x;
497
498
499
500
501
#pragma unroll
    for (int j = 0; j < VEC_SIZE; ++j) {
      x.val[j] = static_cast<float>(in.val[j]);
    }

502
503
    if constexpr (has_residual) {
      vec4_t<scalar_t> r = vec_residual[i];
504
505
506
507
508
509
510
511
512
#pragma unroll
      for (int j = 0; j < VEC_SIZE; ++j) {
        x.val[j] += static_cast<float>(r.val[j]);
      }
// Update residual
#pragma unroll
      for (int j = 0; j < VEC_SIZE; ++j) {
        r.val[j] = static_cast<scalar_t>(x.val[j]);
      }
513
514
515
516
      vec_residual[i] = r;
    }

    q8x4_t<scalar_out_t> out;
517
518
519
520
521
522
523

    float scale_val;

    if constexpr (group_size > 0) {
      int64_t const num_groups = hidden_size / group_size;
      int64_t scale_idx = 0;
      if constexpr (is_scale_transposed) {
524
525
526
        int64_t const scale_rows = (gridDim.x + outer_scale_stride - 1) /
                                   outer_scale_stride * outer_scale_stride;
        scale_idx = (i * VEC_SIZE / group_size) * scale_rows + blockIdx.x;
527
528
529
530
531
532
533
534
      } else {
        scale_idx = blockIdx.x * num_groups + i * VEC_SIZE / group_size;
      }
      scale_val =
          is_scale_inverted ? 1.0f / scale[scale_idx] : scale[scale_idx];
    } else {
      scale_val = *scale;
    }
535
536
537
#pragma unroll
    for (int j = 0; j < VEC_SIZE; ++j) {
      out.val[j] = ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(
538
          static_cast<scalar_t>(x.val[j] * rms) * w.val[j], scale_val);
539
    }
540
541
542
543
544
545
546
    vec_output[i] = out;
  }
}

}  // namespace vectorized

}  // namespace vllm