layernorm_utils.cuh 20.6 KB
Newer Older
1
2
3
4
5
6
#pragma once

/**
 * __device__ layernorm utilities.
 */

7
#include "libtorch_stable/quantization/vectorization.cuh"
8
#include "quantization/utils.cuh"
9
10
#include "quant_conversions.cuh"

Aidyn-A's avatar
Aidyn-A committed
11
#include "../../cub_helpers.h"
12
#include "../../cuda_compat.h"
13
14
15
16
17
18

namespace vllm {

// has_residual must be true, if residual is not a nullptr
template <typename scalar_t, bool has_residual = false>
__device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
19
20
                            int32_t const hidden_size,
                            int32_t const input_stride, float const epsilon,
21
                            scalar_t const* __restrict__ residual = nullptr) {
22
23
  int64_t const input_token_offset =
      blockIdx.x * static_cast<int64_t>(input_stride);
24
25
26
27
  int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
  // sum of squares
  float ss = 0.0f;

28
  for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
29
    float x = static_cast<float>(input[input_token_offset + i]);
30
31
32
33
34
35
36
37
38
    if constexpr (has_residual) {
      x += static_cast<float>(residual[token_offset + i]);
    }

    ss += x * x;
  }

  using BlockReduce = cub::BlockReduce<float, 1024>;
  __shared__ typename BlockReduce::TempStorage reduceStore;
Aidyn-A's avatar
Aidyn-A committed
39
  ss = BlockReduce(reduceStore).Reduce(ss, CubAddOp{}, blockDim.x);
40
41
42
43
44
45
46
47
48
49

  __shared__ float s_rms;
  if (threadIdx.x == 0) {
    s_rms = rsqrtf(ss / hidden_size + epsilon);
  }
  __syncthreads();

  *rms = s_rms;
}

50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
__device__ float warpReduceMaxSpecialized(volatile float* val, int64_t tid,
                                          int64_t thread_in_warp,
                                          int64_t reduced_elems) {
  static_assert(WARP_SIZE == 32 || WARP_SIZE == 64);
  if constexpr (WARP_SIZE == 64) {
    if (thread_in_warp + 64 < reduced_elems)
      val[tid] = fmaxf(val[tid], val[tid + 64]);
  }
  if (thread_in_warp + 32 < reduced_elems)
    val[tid] = fmaxf(val[tid], val[tid + 32]);
  if (thread_in_warp + 16 < reduced_elems)
    val[tid] = fmaxf(val[tid], val[tid + 16]);
  if (thread_in_warp + 8 < reduced_elems)
    val[tid] = fmaxf(val[tid], val[tid + 8]);
  if (thread_in_warp + 4 < reduced_elems)
    val[tid] = fmaxf(val[tid], val[tid + 4]);
  if (thread_in_warp + 2 < reduced_elems)
    val[tid] = fmaxf(val[tid], val[tid + 2]);
  if (thread_in_warp + 1 < reduced_elems)
    val[tid] = fmaxf(val[tid], val[tid + 1]);
  return val[tid];
}

template <typename scalar_t, typename scalar_out_t, bool has_residual = false,
          bool is_scale_transposed = false>
75
76
77
78
__device__ void compute_dynamic_per_token_scales(
    float* __restrict__ token_scale, float* __restrict__ all_token_scales,
    scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
    float const rms, float const* __restrict__ scale_ub,
79
80
    int32_t const hidden_size, int32_t const input_stride,
    scalar_t const* __restrict__ residual = nullptr,
81
    int32_t const group_size = 0, int64_t outer_scale_stride = 1) {
82
  float block_absmax_val_maybe = 0.0f;
83
84
  constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>};
  __syncthreads();
85
86
87
88
89

  int64_t const input_token_offset =
      blockIdx.x * static_cast<int64_t>(input_stride);
  int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);

90
91
  if (group_size > 0) {
    int64_t num_groups = hidden_size / group_size;
92
    __shared__ float s_max_vals[1024];
93
94
95
96
97
98
99
    int64_t const threads_per_group = blockDim.x / num_groups;
    int64_t const thread_in_group = threadIdx.x % threads_per_group;
    int64_t const group_offset = threadIdx.x / threads_per_group * group_size;
    int64_t const thread_offset = group_offset + thread_in_group;
    int64_t const thread_end =
        min(group_offset + group_size, static_cast<int64_t>(hidden_size));
    for (auto i = thread_offset; i < thread_end; i += threads_per_group) {
100
      float x = static_cast<float>(input[input_token_offset + i]);
101
102
103
104
105
      if constexpr (has_residual) {
        x += static_cast<float>(residual[token_offset + i]);
      }
      x = static_cast<float>(static_cast<scalar_t>(x * rms) * weight[i]);
      block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x));
106
    }
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
    s_max_vals[threadIdx.x] = block_absmax_val_maybe;
    __syncthreads();

    int64_t const warp_size = WARP_SIZE;
    int64_t const num_warps = blockDim.x / warp_size;
    int64_t const warp_id = threadIdx.x / warp_size;
    int64_t const thread_in_warp = threadIdx.x % warp_size;
    int64_t const groups_per_warp = (num_groups + num_warps - 1) / num_warps;
    for (auto i = 0; i < groups_per_warp; ++i) {
      int64_t const group_id = i * num_warps + warp_id;
      if (group_id < num_groups) {
        int64_t warp_start = group_id * threads_per_group;
        int64_t const start = warp_start + thread_in_warp;
        int64_t const warp_end = min(warp_start + threads_per_group,
                                     static_cast<int64_t>(hidden_size));
        for (auto j = start; j + warp_size < warp_end; j += warp_size) {
          s_max_vals[start] =
              fmaxf(s_max_vals[start], s_max_vals[j + warp_size]);
        }
        warpReduceMaxSpecialized(s_max_vals, start, thread_in_warp,
                                 min(warp_end - warp_start, warp_size));
      }
    }
    __syncthreads();

    if (thread_in_group == 0 && thread_offset < thread_end) {
      block_absmax_val_maybe = s_max_vals[threadIdx.x];
      float scale = 0.0f;
      if (scale_ub) {
        scale = min(block_absmax_val_maybe, *scale_ub);
      } else {
        scale = block_absmax_val_maybe;
      }
      // token scale computation
      scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
      // Global output store
      if constexpr (is_scale_transposed) {
144
145
146
        int64_t const scale_rows = (gridDim.x + outer_scale_stride - 1) /
                                   outer_scale_stride * outer_scale_stride;
        all_token_scales[(threadIdx.x / threads_per_group) * scale_rows +
147
148
149
150
151
152
153
154
155
                         blockIdx.x] = scale;
      } else {
        all_token_scales[blockIdx.x * num_groups +
                         threadIdx.x / threads_per_group] = scale;
      }
    }
    __syncthreads();
  } else {
    for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
156
      float x = static_cast<float>(input[input_token_offset + i]);
157
158
159
      if constexpr (has_residual) {
        x += static_cast<float>(residual[token_offset + i]);
      }
160

161
162
      x = static_cast<float>(static_cast<scalar_t>(x * rms) * weight[i]);
      block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x));
163
    }
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    using BlockReduce = cub::BlockReduce<float, 1024>;
    __shared__ typename BlockReduce::TempStorage reduceStore;
    block_absmax_val_maybe =
        BlockReduce(reduceStore)
            .Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x);

    __shared__ float s_token_scale;
    if (threadIdx.x == 0) {
      float scale = 0.0f;
      if (scale_ub) {
        scale = min(block_absmax_val_maybe, *scale_ub);
      } else {
        scale = block_absmax_val_maybe;
      }
      // token scale computation
      scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
      s_token_scale = scale;                 // Shared memory store
      all_token_scales[blockIdx.x] = scale;  // Global output store
    }
    __syncthreads();
184

185
186
    *token_scale = s_token_scale;
  }
187
188
189
}

template <typename scalar_t, typename scalar_out_t, bool is_scale_inverted,
190
          bool has_residual = false, bool is_scale_transposed = false>
191
192
193
__device__ void norm_and_quant(
    scalar_out_t* __restrict__ output, scalar_t const* __restrict__ input,
    scalar_t const* __restrict__ weight, float const rms, float* const scale,
194
195
196
197
198
    int32_t const hidden_size, int32_t const input_stride,
    scalar_t* __restrict__ residual = nullptr, int32_t const group_size = 0,
    int64_t outer_scale_stride = 1) {
  int64_t const input_token_offset =
      blockIdx.x * static_cast<int64_t>(input_stride);
199
200
  int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);

201
  for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
202
    float x = static_cast<float>(input[input_token_offset + i]);
203
204
205
206
207
208
209
    if constexpr (has_residual) {
      x += static_cast<float>(residual[token_offset + i]);
      residual[token_offset + i] = static_cast<scalar_t>(x);
    }
    // Norm
    x = static_cast<float>(static_cast<scalar_t>(x * rms) * weight[i]);
    // Quant
210
211
212
213
    // If groupwise is_scale_inverted is true, so we invert the scale here.
    int64_t scale_idx = 0;
    if (group_size > 0) {
      if constexpr (is_scale_transposed) {
214
215
216
        int64_t const scale_rows = (gridDim.x + outer_scale_stride - 1) /
                                   outer_scale_stride * outer_scale_stride;
        scale_idx = (i / group_size) * scale_rows + blockIdx.x;
217
218
219
220
221
222
223
224
      } else {
        scale_idx = blockIdx.x * (hidden_size / group_size) + i / group_size;
      }
    }
    auto scale_val =
        (group_size > 0
             ? (is_scale_inverted ? 1.0f / scale[scale_idx] : scale[scale_idx])
             : *scale);
225
    output[token_offset + i] =
226
        ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(x, scale_val);
227
228
229
230
231
232
233
234
235
  }
}

namespace vectorized {

// Compute 1.0/rms(input)
// hidden_size must be a multiple of 4
template <typename scalar_t, bool has_residual = false>
__device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
236
237
                            int32_t const hidden_size,
                            int32_t const input_stride, float const epsilon,
238
                            scalar_t const* __restrict__ residual = nullptr) {
239
240
  int64_t const input_token_offset =
      blockIdx.x * static_cast<int64_t>(input_stride);
241
242
243
244
  int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);

  // Vectorized input/output to better utilize memory bandwidth.
  vec4_t<scalar_t> const* vec_input =
245
      reinterpret_cast<vec4_t<scalar_t> const*>(&input[input_token_offset]);
246
247
248
249
250
251
252
253
254
  vec4_t<scalar_t> const* vec_residual = nullptr;
  if constexpr (has_residual) {
    vec_residual =
        reinterpret_cast<vec4_t<scalar_t> const*>(&residual[token_offset]);
  }

  // sum of squares
  float ss = 0.0f;

255
  const int VEC_SIZE = 4;
256
257
258
  int32_t const num_vec_elems = hidden_size >> 2;

#pragma unroll 4
259
  for (auto i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
260
261
262
    vec4_t<scalar_t> in = vec_input[i];

    vec4_t<float> x;
263
264
265
266
267
#pragma unroll
    for (int j = 0; j < VEC_SIZE; ++j) {
      x.val[j] = static_cast<float>(in.val[j]);
    }

268
269
    if constexpr (has_residual) {
      vec4_t<scalar_t> r = vec_residual[i];
270
271
272
273
#pragma unroll
      for (int j = 0; j < VEC_SIZE; ++j) {
        x.val[j] += static_cast<float>(r.val[j]);
      }
274
275
    }

276
277
278
279
#pragma unroll
    for (int j = 0; j < VEC_SIZE; ++j) {
      ss += x.val[j] * x.val[j];
    }
280
281
282
283
  }

  using BlockReduce = cub::BlockReduce<float, 1024>;
  __shared__ typename BlockReduce::TempStorage reduceStore;
Aidyn-A's avatar
Aidyn-A committed
284
  ss = BlockReduce(reduceStore).Reduce(ss, CubAddOp{}, blockDim.x);
285
286
287
288
289
290
291
292
293
294
295
296

  __shared__ float s_rms;
  if (threadIdx.x == 0) {
    s_rms = rsqrtf(ss / hidden_size + epsilon);
  }
  __syncthreads();

  *rms = s_rms;
}

// Vectorized version of vllm::compute_dynamic_per_token_scales
// hidden_size must be a multiple of 4
297
298
template <typename scalar_t, typename scalar_out_t, bool has_residual = false,
          bool is_scale_transposed = false, int32_t group_size = 0>
299
300
301
302
__device__ void compute_dynamic_per_token_scales(
    float* __restrict__ token_scale, float* __restrict__ all_token_scales,
    scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
    float const rms, float const* __restrict__ scale_ub,
303
304
    int32_t const hidden_size, int32_t const input_stride,
    scalar_t const* __restrict__ residual = nullptr,
305
    int64_t outer_scale_stride = 1) {
306
  constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>};
307

308
  const int VEC_SIZE = 4;
309
310
  float block_absmax_val_maybe = 0.0f;

311
312
313
314
315
  // Vectorized input/weight/residual to better utilize memory bandwidth.
  vec4_t<scalar_t> const* vec_input = nullptr;
  vec4_t<scalar_t> const* vec_weight = nullptr;
  vec4_t<scalar_t> const* vec_residual = nullptr;

316
317
318
319
  int64_t const input_token_offset =
      blockIdx.x * static_cast<int64_t>(input_stride);
  int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);

320
321
322
323
324
325
326
327
328
329
330
  if constexpr (group_size > 0) {
    __shared__ float s_max_vals[1024];

    int64_t const num_groups = hidden_size / group_size;
    int64_t const threads_per_group = blockDim.x / num_groups;
    int64_t const thread_in_group = threadIdx.x % threads_per_group;
    int64_t const group_offset =
        threadIdx.x / threads_per_group * (group_size >> 2);
    int64_t const thread_offset = group_offset + thread_in_group;
    int64_t const thread_end = min(group_offset + (group_size >> 2),
                                   static_cast<int64_t>(hidden_size >> 2));
331
332
    vec_input =
        reinterpret_cast<vec4_t<scalar_t> const*>(&input[input_token_offset]);
333
334
335
336
337
338
339
    vec_weight = reinterpret_cast<vec4_t<scalar_t> const*>(weight);
    if constexpr (has_residual) {
      vec_residual =
          reinterpret_cast<vec4_t<scalar_t> const*>(&residual[token_offset]);
    }
    int32_t const num_vec_elems = thread_end;

340
#pragma unroll 4
341
342
343
    for (auto i = thread_offset; i < num_vec_elems; i += threads_per_group) {
      vec4_t<scalar_t> in = vec_input[i];
      vec4_t<scalar_t> const w = vec_weight[i];
344

345
      vec4_t<float> x;
346
#pragma unroll
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
      for (int j = 0; j < VEC_SIZE; ++j) {
        x.val[j] = static_cast<float>(in.val[j]);
      }

      if constexpr (has_residual) {
        vec4_t<scalar_t> r = vec_residual[i];
#pragma unroll
        for (int j = 0; j < VEC_SIZE; ++j) {
          x.val[j] += static_cast<float>(r.val[j]);
        }
      }

#pragma unroll
      for (int j = 0; j < VEC_SIZE; ++j) {
        block_absmax_val_maybe =
            fmaxf(block_absmax_val_maybe,
                  fabs(static_cast<scalar_t>(x.val[j] * rms) * w.val[j]));
      }
365
366
    }

367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
    s_max_vals[threadIdx.x] = block_absmax_val_maybe;
    __syncthreads();

    int64_t const warp_size = WARP_SIZE;
    int64_t const num_warps = blockDim.x / warp_size;
    int64_t const warp_id = threadIdx.x / warp_size;
    int64_t const thread_in_warp = threadIdx.x % warp_size;
    int64_t const groups_per_warp = (num_groups + num_warps - 1) / num_warps;
    for (auto i = 0; i < groups_per_warp; ++i) {
      int64_t const group_id = i * num_warps + warp_id;
      if (group_id < num_groups) {
        int64_t warp_start = group_id * threads_per_group;
        int64_t const start = warp_start + thread_in_warp;
        int64_t const warp_end = min(warp_start + threads_per_group,
                                     static_cast<int64_t>(hidden_size));
        for (auto j = start; j + warp_size < warp_end; j += warp_size) {
          s_max_vals[start] =
              fmaxf(s_max_vals[start], s_max_vals[j + warp_size]);
        }
        warpReduceMaxSpecialized(s_max_vals, start, thread_in_warp,
                                 min(warp_end - warp_start, warp_size));
      }
    }
    __syncthreads();

    if (thread_in_group == 0 && thread_offset < thread_end) {
      block_absmax_val_maybe = s_max_vals[threadIdx.x];
      float scale = 0.0f;
      if (scale_ub) {
        scale = min(block_absmax_val_maybe, *scale_ub);
      } else {
        scale = block_absmax_val_maybe;
      }
      // token scale computation
      scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
      // Global output store
      if constexpr (is_scale_transposed) {
404
405
406
        int64_t const scale_rows = (gridDim.x + outer_scale_stride - 1) /
                                   outer_scale_stride * outer_scale_stride;
        all_token_scales[(threadIdx.x / threads_per_group) * scale_rows +
407
408
409
410
411
412
413
414
415
                         blockIdx.x] = scale;
      } else {
        all_token_scales[blockIdx.x * num_groups +
                         threadIdx.x / threads_per_group] = scale;
      }
    }
    __syncthreads();

  } else {
416
417
    vec_input =
        reinterpret_cast<vec4_t<scalar_t> const*>(&input[input_token_offset]);
418
    vec_weight = reinterpret_cast<vec4_t<scalar_t> const*>(weight);
419
    if constexpr (has_residual) {
420
421
422
423
424
425
426
427
428
429
430
431
      vec_residual =
          reinterpret_cast<vec4_t<scalar_t> const*>(&residual[token_offset]);
    }

    int32_t const num_vec_elems = (hidden_size >> 2);

#pragma unroll 4
    for (auto i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
      vec4_t<scalar_t> in = vec_input[i];
      vec4_t<scalar_t> const w = vec_weight[i];

      vec4_t<float> x;
432
433
#pragma unroll
      for (int j = 0; j < VEC_SIZE; ++j) {
434
        x.val[j] = static_cast<float>(in.val[j]);
435
      }
436

437
438
      if constexpr (has_residual) {
        vec4_t<scalar_t> r = vec_residual[i];
439
#pragma unroll
440
441
442
443
        for (int j = 0; j < VEC_SIZE; ++j) {
          x.val[j] += static_cast<float>(r.val[j]);
        }
      }
444

445
446
447
448
449
450
451
#pragma unroll
      for (int j = 0; j < VEC_SIZE; ++j) {
        block_absmax_val_maybe =
            fmaxf(block_absmax_val_maybe,
                  fabs(static_cast<scalar_t>(x.val[j] * rms) * w.val[j]));
      }
    }
452

453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
    using BlockReduce = cub::BlockReduce<float, 1024>;
    __shared__ typename BlockReduce::TempStorage reduceStore;
    block_absmax_val_maybe =
        BlockReduce(reduceStore)
            .Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x);

    __shared__ float s_token_scale;
    if (threadIdx.x == 0) {
      float scale = 0.0f;
      if (scale_ub) {
        scale = min(block_absmax_val_maybe, *scale_ub);
      } else {
        scale = block_absmax_val_maybe;
      }
      // token scale computation
      scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
      s_token_scale = scale;                 // shared memory store
      all_token_scales[blockIdx.x] = scale;  // global output store
471
    }
472
    __syncthreads();
473

474
475
    *token_scale = s_token_scale;
  }
476
477
478
479
}

// hidden_size must be a multiple of 4
template <typename scalar_t, typename scalar_out_t, bool is_scale_inverted,
480
481
          bool has_residual = false, bool is_scale_transposed = false,
          int32_t group_size = 0>
482
483
484
485
486
487
488
__device__ void norm_and_quant(
    scalar_out_t* __restrict__ output, scalar_t const* __restrict__ input,
    scalar_t const* __restrict__ weight, float const rms, float* const scale,
    int32_t const hidden_size, int32_t const input_stride,
    scalar_t* __restrict__ residual = nullptr, int64_t outer_scale_stride = 1) {
  int64_t const input_token_offset =
      blockIdx.x * static_cast<int64_t>(input_stride);
489
490
491
492
  int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);

  // Vectorized input/output/weight/residual to better utilize memory bandwidth.
  vec4_t<scalar_t> const* vec_input =
493
      reinterpret_cast<vec4_t<scalar_t> const*>(&input[input_token_offset]);
494
495
496
497
498
499
500
501
502
  vec4_t<scalar_t> const* vec_weight =
      reinterpret_cast<vec4_t<scalar_t> const*>(weight);
  q8x4_t<scalar_out_t>* vec_output =
      reinterpret_cast<q8x4_t<scalar_out_t>*>(&output[token_offset]);
  vec4_t<scalar_t>* vec_residual = nullptr;
  if constexpr (has_residual) {
    vec_residual = reinterpret_cast<vec4_t<scalar_t>*>(&residual[token_offset]);
  }

503
  const int VEC_SIZE = 4;
504
505
506
507
508
  int32_t const num_vec_elems = hidden_size >> 2;

// TODO(luka/varun) extract into type-agnostic vectorized quant function to
//  replace scaled_fp8_conversion_vec
#pragma unroll 4
509
  for (auto i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
510
511
512
513
    vec4_t<scalar_t> const in = vec_input[i];
    vec4_t<scalar_t> const w = vec_weight[i];

    vec4_t<float> x;
514
515
516
517
518
#pragma unroll
    for (int j = 0; j < VEC_SIZE; ++j) {
      x.val[j] = static_cast<float>(in.val[j]);
    }

519
520
    if constexpr (has_residual) {
      vec4_t<scalar_t> r = vec_residual[i];
521
522
523
524
525
526
527
528
529
#pragma unroll
      for (int j = 0; j < VEC_SIZE; ++j) {
        x.val[j] += static_cast<float>(r.val[j]);
      }
// Update residual
#pragma unroll
      for (int j = 0; j < VEC_SIZE; ++j) {
        r.val[j] = static_cast<scalar_t>(x.val[j]);
      }
530
531
532
533
      vec_residual[i] = r;
    }

    q8x4_t<scalar_out_t> out;
534
535
536
537
538
539
540

    float scale_val;

    if constexpr (group_size > 0) {
      int64_t const num_groups = hidden_size / group_size;
      int64_t scale_idx = 0;
      if constexpr (is_scale_transposed) {
541
542
543
        int64_t const scale_rows = (gridDim.x + outer_scale_stride - 1) /
                                   outer_scale_stride * outer_scale_stride;
        scale_idx = (i * VEC_SIZE / group_size) * scale_rows + blockIdx.x;
544
545
546
547
548
549
550
551
      } else {
        scale_idx = blockIdx.x * num_groups + i * VEC_SIZE / group_size;
      }
      scale_val =
          is_scale_inverted ? 1.0f / scale[scale_idx] : scale[scale_idx];
    } else {
      scale_val = *scale;
    }
552
553
554
#pragma unroll
    for (int j = 0; j < VEC_SIZE; ++j) {
      out.val[j] = ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(
555
          static_cast<scalar_t>(x.val[j] * rms) * w.val[j], scale_val);
556
    }
557
558
559
560
561
562
563
    vec_output[i] = out;
  }
}

}  // namespace vectorized

}  // namespace vllm