quant.cpp 23.4 KB
Newer Older
1
2
3
4
5
6
7
#include "cpu_types.hpp"
#include "dnnl_helper.hpp"

namespace {
template <typename scalar_t>
struct KernelVecType {
  using load_vec_type = void;
8
  using azp_adj_load_vec_type = void;
9
10
11
12
13
14
  using cvt_vec_type = void;
};

template <>
struct KernelVecType<float> {
  using load_vec_type = vec_op::FP32Vec16;
15
  using azp_adj_load_vec_type = vec_op::INT32Vec16;
16
17
18
19
20
21
  using cvt_vec_type = vec_op::FP32Vec16;
};

template <>
struct KernelVecType<c10::BFloat16> {
  using load_vec_type = vec_op::BF16Vec16;
22
  using azp_adj_load_vec_type = vec_op::INT32Vec16;
23
24
25
26
  using cvt_vec_type = vec_op::FP32Vec16;
};

#ifdef __AVX512F__
27
template <bool AZP, typename scalar_t>
28
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
29
30
                                   const float* scale, const int32_t* azp,
                                   const int num_tokens,
31
32
33
34
35
36
37
38
39
40
41
42
43
                                   const int hidden_size) {
  using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
  using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
  constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;

  constexpr float i8_min =
      static_cast<float>(std::numeric_limits<int8_t>::min());
  constexpr float i8_max =
      static_cast<float>(std::numeric_limits<int8_t>::max());
  const cvt_vec_t inv_scale(1.0 / *scale);
  const cvt_vec_t i8_min_vec(i8_min);
  const cvt_vec_t i8_max_vec(i8_max);

44
45
46
47
48
  cvt_vec_t zp_vec;
  if constexpr (AZP) {
    zp_vec = cvt_vec_t(static_cast<float>(*azp));
  }

49
50
51
52
53
54
  #pragma omp parallel for
  for (int i = 0; i < num_tokens; ++i) {
    int j = 0;
    for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
      load_vec_t elems(input + i * hidden_size + j);
      cvt_vec_t elems_fp32(elems);
55
56
57
58
59
60
61
      elems_fp32 = elems_fp32 * inv_scale;

      if constexpr (AZP) {
        elems_fp32 = elems_fp32 + zp_vec;
      }

      elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
62
63
64
65
66
67
      vec_op::INT8Vec16 elems_int8(elems_fp32);
      elems_int8.save(output + i * hidden_size + j);
    }

    load_vec_t elems(input + i * hidden_size + j);
    cvt_vec_t elems_fp32(elems);
68
    elems_fp32 = elems_fp32 * inv_scale;
69

70
71
    if constexpr (AZP) {
      elems_fp32 = elems_fp32 + zp_vec;
72
    }
73
74
75
76

    elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
    vec_op::INT8Vec16 elems_int8(elems_fp32);
    elems_int8.save(output + i * hidden_size + j, hidden_size - j);
77
78
79
  }
}

80
template <bool AZP, typename scalar_t>
81
void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
82
83
                                    float* scale, int32_t* azp,
                                    const int num_tokens,
84
85
86
87
88
                                    const int hidden_size) {
  using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
  using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
  constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;

89
90
91
92
93
94
95
  constexpr float i8_min =
      static_cast<float>(std::numeric_limits<int8_t>::min());
  constexpr float i8_max =
      static_cast<float>(std::numeric_limits<int8_t>::max());
  const cvt_vec_t i8_min_vec(i8_min);
  const cvt_vec_t i8_max_vec(i8_max);

96
97
  #pragma omp parallel for
  for (int i = 0; i < num_tokens; ++i) {
98
99
    cvt_vec_t max_value(std::numeric_limits<float>::lowest());
    cvt_vec_t min_value(std::numeric_limits<float>::max());
100
101
102
103
104
    {
      int j = 0;
      for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
        load_vec_t elems(input + i * hidden_size + j);
        cvt_vec_t elems_fp32(elems);
105
106
107
108
109
110
        if constexpr (AZP) {
          max_value = max_value.max(elems_fp32);
          min_value = min_value.min(elems_fp32);
        } else {
          max_value = max_value.max(elems_fp32.abs());
        }
111
112
113
114
115
116
      }

      load_vec_t elems(input + i * hidden_size + j);
      cvt_vec_t elems_fp32(elems);

      if (j + vec_elem_num == hidden_size) {
117
118
119
120
121
122
        if constexpr (AZP) {
          max_value = max_value.max(elems_fp32);
          min_value = min_value.min(elems_fp32);
        } else {
          max_value = max_value.max(elems_fp32.abs());
        }
123
      } else {
124
125
126
127
128
129
        if constexpr (AZP) {
          max_value = max_value.max(elems_fp32, hidden_size - j);
          min_value = min_value.min(elems_fp32, hidden_size - j);
        } else {
          max_value = max_value.max(elems_fp32.abs(), hidden_size - j);
        }
130
131
132
      }
    }

133
134
135
136
137
138
139
140
141
142
143
144
145
    float scale_val, azp_val;
    if constexpr (AZP) {
      float max_scalar = max_value.reduce_max();
      float min_scalar = min_value.reduce_min();
      scale_val = (max_scalar - min_scalar) / 255.0f;
      azp_val = std::nearbyint(-128.0f - min_scalar / scale_val);
      azp[i] = static_cast<int32_t>(azp_val);
      scale[i] = scale_val;
    } else {
      scale_val = max_value.reduce_max() / 127.0f;
      scale[i] = scale_val;
    }

146
    const cvt_vec_t inv_scale(1.0 / scale_val);
147
    const cvt_vec_t azp_vec(azp_val);
148
149
150
151
152
153
154

    {
      int j = 0;
      for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
        load_vec_t elems(input + i * hidden_size + j);
        cvt_vec_t elems_fp32(elems);
        elems_fp32 = (elems_fp32 * inv_scale);
155
156
157
158
159

        if constexpr (AZP) {
          elems_fp32 = elems_fp32 + azp_vec;
        }
        elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
160
161
162
163
164
165
166
167
        vec_op::INT8Vec16 elems_int8(elems_fp32);
        elems_int8.save(output + i * hidden_size + j);
      }

      load_vec_t elems(input + i * hidden_size + j);
      cvt_vec_t elems_fp32(elems);
      elems_fp32 = (elems_fp32 * inv_scale);

168
169
      if constexpr (AZP) {
        elems_fp32 = elems_fp32 + azp_vec;
170
      }
171
172
173
      elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
      vec_op::INT8Vec16 elems_int8(elems_fp32);
      elems_int8.save(output + i * hidden_size + j, hidden_size - j);
174
175
176
177
    }
  }
}

178
179
180
181
182
template <bool PerChannel, typename scalar_t>
void static_quant_epilogue(const float* input, scalar_t* output,
                           const float a_scale, const float* b_scale,
                           const int32_t* azp_with_adj, const int num_tokens,
                           const int hidden_size) {
183
184
  CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl)
  using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
  using azp_adj_load_vec_t =
      typename KernelVecType<scalar_t>::azp_adj_load_vec_type;
  using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
  constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;

  #pragma omp parallel for
  for (int i = 0; i < num_tokens; ++i) {
    cvt_vec_t a_scale_vec(a_scale);
    cvt_vec_t b_scale_vec(*b_scale);
    cvt_vec_t scale_vec = a_scale_vec * b_scale_vec;

    int j = 0;
    for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
      cvt_vec_t elems_fp32(input + i * hidden_size + j);
      azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
      cvt_vec_t azp_adj_fp32(azp_adj_vec);

      if constexpr (PerChannel) {
        b_scale_vec = cvt_vec_t(b_scale + j);
        scale_vec = b_scale_vec * a_scale_vec;
      }

      elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;

      load_vec_t elems_out(elems_fp32);
      elems_out.save(output + i * hidden_size + j);
    }

    cvt_vec_t elems_fp32(input + i * hidden_size + j);
    azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
    cvt_vec_t azp_adj_fp32(azp_adj_vec);

    if constexpr (PerChannel) {
      b_scale_vec = cvt_vec_t(b_scale + j);
      scale_vec = b_scale_vec * a_scale_vec;
    }

    elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;

    load_vec_t elems_out(elems_fp32);
    elems_out.save(output + i * hidden_size + j, hidden_size - j);
  }
}

template <bool AZP, bool PerChannel, bool Bias, typename scalar_t>
void dynamic_quant_epilogue(const float* input, scalar_t* output,
                            const float* a_scale, const float* b_scale,
                            const int32_t* azp, const int32_t* azp_adj,
                            const scalar_t* bias, const int num_tokens,
                            const int hidden_size) {
  CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue)
  using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
  using azp_adj_load_vec_t =
      typename KernelVecType<scalar_t>::azp_adj_load_vec_type;
239
240
241
242
243
244
  using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
  constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;

  #pragma omp parallel for
  for (int i = 0; i < num_tokens; ++i) {
    int j = 0;
245
246
247
248
249
250
251
252
253
254
    cvt_vec_t token_scale_vec(a_scale[i]);
    cvt_vec_t token_zp_scale_vec;
    if constexpr (AZP) {
      float zp_scale_val = a_scale[i] * static_cast<float>(azp[i]);
      if constexpr (!PerChannel) {
        zp_scale_val *= *b_scale;
      }
      token_zp_scale_vec = cvt_vec_t(zp_scale_val);
    }

255
256
257
258
    for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
      cvt_vec_t elems_fp32(input + i * hidden_size + j);
      elems_fp32 = elems_fp32 * token_scale_vec;

259
260
261
262
263
264
265
266
267
268
269
270
271
      if constexpr (AZP) {
        azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
        cvt_vec_t azp_adj_fp32(azp_adj_vec);
        azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;

        if constexpr (PerChannel) {
          cvt_vec_t b_scale_vec(b_scale + j);
          azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
        }

        elems_fp32 = elems_fp32 - azp_adj_fp32;
      }

272
273
274
275
276
277
278
279
280
281
282
283
284
      if constexpr (Bias) {
        load_vec_t bias_vec(bias + j);
        cvt_vec_t bias_vec_fp32(bias_vec);
        elems_fp32 = elems_fp32 + bias_vec_fp32;
      }

      load_vec_t elems_out(elems_fp32);
      elems_out.save(output + i * hidden_size + j);
    }

    cvt_vec_t elems_fp32(input + i * hidden_size + j);
    elems_fp32 = elems_fp32 * token_scale_vec;

285
286
287
288
289
290
291
292
293
294
295
296
297
    if constexpr (AZP) {
      azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
      cvt_vec_t azp_adj_fp32(azp_adj_vec);
      azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;

      if constexpr (PerChannel) {
        cvt_vec_t b_scale_vec(b_scale + j);
        azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
      }

      elems_fp32 = elems_fp32 - azp_adj_fp32;
    }

298
299
300
301
302
303
304
    if constexpr (Bias) {
      load_vec_t bias_vec(bias + j);
      cvt_vec_t bias_vec_fp32(bias_vec);
      elems_fp32 = elems_fp32 + bias_vec_fp32;
    }

    load_vec_t elems_out(elems_fp32);
305
    elems_out.save(output + i * hidden_size + j, hidden_size - j);
306
307
308
309
310
  }
}
#else
template <typename scalar_t>
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
311
312
                                   const float* scale, const int32_t* azp,
                                   const int num_tokens,
313
314
315
316
317
318
                                   const int hidden_size) {
  TORCH_CHECK(false, "static_scaled_int8_quant_impl requires AVX512 support.")
}

template <typename scalar_t>
void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
319
320
                                    float* scale, int32_t* azp,
                                    const int num_tokens,
321
322
323
324
                                    const int hidden_size) {
  TORCH_CHECK(false, "dynamic_scaled_int8_quant_impl requires AVX512 support.")
}

325
326
327
328
329
330
331
332
template <bool PerChannel, typename scalar_t>
void static_quant_epilogue(const float* input, scalar_t* output,
                           const float a_scale, const float* b_scale,
                           const int32_t* azp_with_adj, const int num_tokens,
                           const int hidden_size) {
  TORCH_CHECK(false, "static_quant_epilogue requires AVX512 support.")
}

333
template <typename scalar_t>
334
335
336
337
338
339
void dynamic_quant_epilogue(const float* input, scalar_t* output,
                            const float* a_scale, const float* b_scale,
                            const int32_t* azp, const int32_t* azp_with_adj,
                            const scalar_t* bias, const int num_tokens,
                            const int hidden_size) {
  TORCH_CHECK(false, "dynamic_quant_epilogue requires AVX512 support.")
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
}
#endif
}  // namespace

void int8_scaled_mm(torch::Tensor& c,               // [M, OC], row-major
                    const torch::Tensor& a,         // [M, IC], row-major
                    const torch::Tensor& b,         // [IC, OC], column-major
                    const torch::Tensor& a_scales,  // [1] or [M]
                    const torch::Tensor& b_scales,  // [1] or [OC]
                    const c10::optional<torch::Tensor>& bias  // [OC]
) {
  CPU_KERNEL_GUARD_IN(cutlass_scaled_mm)
  // Checks for conformality
  TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8,
              "int8_scaled_mm only supports INT8 inputs.")
  TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
  TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
              b.size(1) == c.size(1));
  TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
  TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));

  // Check for strides and alignment
  TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1);  // Row-major
  TORCH_CHECK(b.stride(0) == 1);                      // Column-major
  TORCH_CHECK(c.stride(0) % 16 == 0 &&
              b.stride(1) % 16 == 0);  // 16 Byte Alignment
  TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());

  if (bias) {
    TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
                bias->dim() == 1);
  }

373
  VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm", [&] {
374
375
376
    if (a_scales.numel() != 1) {
      // per-token
      // Note: oneDNN doesn't support per-token activation quantization
377
378
379
380
381
382
383
384
      // Ideally we want to fuse the GEMM and the scale procedure with oneDNN
      // JIT, the intermediate data is cached in registers or L1. But for now
      // the oneDNN GEMM code generation only supports two quantization
      // patterns: per-tensor or per-output-channel of weight.
      // So we have to apply the per-token scale with a 'epilogue'. In C=s_a *
      // s_b * (A@B) + bias, the C_inter = s_b * (A@B) is computed by oneDNN
      // GEMM, then the per-token scale (and bias) is applied with the epilogue
      // C=s_a * C_inter + bias.
385
386
      torch::Tensor tmp_fp32_out =
          torch::empty_like(c, ::at::ScalarType::Float);
387
388
      // Compute C_inter=s_b * (A@B)
      DNNLPrimitiveHelper<true>::gemm_s8s8_jit<float, void>(
389
          a.data_ptr<int8_t>(), b.data_ptr<int8_t>(),
390
391
          tmp_fp32_out.data_ptr<float>(), nullptr, a.size(0), b.size(1),
          a.size(1), nullptr, b_scales.data_ptr<float>(), 0, b_scales.numel());
392
      if (bias.has_value()) {
393
394
        // Compute C=s_a * C_inter + bias
        dynamic_quant_epilogue<false, true, true>(
395
            tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
396
397
            a_scales.data_ptr<float>(), nullptr, nullptr, nullptr,
            bias->data_ptr<scalar_t>(), c.size(0), c.size(1));
398
      } else {
399
400
        // Compute C=s_a * C_inter
        dynamic_quant_epilogue<false, true, false, scalar_t>(
401
            tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
402
403
            a_scales.data_ptr<float>(), nullptr, nullptr, nullptr, nullptr,
            c.size(0), c.size(1));
404
405
406
407
      }
    } else {
      // per-tensor
      if (bias.has_value()) {
408
        // Compute C=s_a * s_b * (A@B) + bias
409
410
411
412
413
414
        DNNLPrimitiveHelper<false>::gemm_s8s8_jit(
            a.data_ptr<int8_t>(), b.data_ptr<int8_t>(), c.data_ptr<scalar_t>(),
            bias->data_ptr<scalar_t>(), a.size(0), b.size(1), a.size(1),
            a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
            a_scales.numel(), b_scales.numel());
      } else {
415
416
        // Compute C=s_a * s_b * (A@B)
        DNNLPrimitiveHelper<false>::gemm_s8s8_jit<scalar_t, void>(
417
            a.data_ptr<int8_t>(), b.data_ptr<int8_t>(), c.data_ptr<scalar_t>(),
418
            nullptr, a.size(0), b.size(1), a.size(1),
419
420
421
422
423
424
425
            a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
            a_scales.numel(), b_scales.numel());
      }
    }
  });
}

426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
void int8_scaled_mm_azp(torch::Tensor& c,        // [M, OC], row-major
                        const torch::Tensor& a,  // [M, IC], row-major
                        const torch::Tensor& b,  // [IC, OC], column-major
                        const torch::Tensor& a_scales,            // [1] or [M]
                        const torch::Tensor& b_scales,            // [1] or [OC]
                        const torch::Tensor& azp_adj,             // [OC]
                        const c10::optional<torch::Tensor>& azp,  // [1] or [M]
                        const c10::optional<torch::Tensor>& bias  // [OC]
) {
  CPU_KERNEL_GUARD_IN(cutlass_scaled_mm_azp)
  // Checks for conformality
  TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8,
              "int8_scaled_mm_azp only supports INT8 inputs.")
  TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
  TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
              b.size(1) == c.size(1));
  TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
  TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));

  // Check for strides and alignment
  TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1);  // Row-major
  TORCH_CHECK(b.stride(0) == 1);                      // Column-major
  TORCH_CHECK(c.stride(0) % 16 == 0 &&
              b.stride(1) % 16 == 0);  // 16 Byte Alignment
  TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());

  if (bias) {
    TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous());
  }
  if (azp) {
    TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous());
  }
  TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous());

  // azp & bias types
  TORCH_CHECK(azp_adj.dtype() == torch::kInt32);
  TORCH_CHECK(!azp || azp->dtype() == torch::kInt32);
  TORCH_CHECK(!bias || bias->dtype() == c.dtype(),
              "currently bias dtype must match output dtype ", c.dtype());

  VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm_azp", [&] {
    torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float);
    if (a_scales.numel() != 1) {
      // per-token
      // Note: oneDNN doesn't support per-token activation quantization
      // Compute C_inter=s_b * (A@B)
      DNNLPrimitiveHelper<true>::gemm_s8s8_jit<float, void>(
          a.data_ptr<int8_t>(), b.data_ptr<int8_t>(),
          tmp_fp32_out.data_ptr<float>(), nullptr, a.size(0), b.size(1),
          a.size(1), nullptr, b_scales.data_ptr<float>(), 0, b_scales.numel());
      if (bias.has_value()) {
        // Compute C=s_a * C_inter - s_a * s_b * azp * azp_adj + bias
        if (b_scales.numel() != 1) {
          // Per-Channel
          dynamic_quant_epilogue<true, true, true>(
              tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
              a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
              azp->data_ptr<int32_t>(), azp_adj.data_ptr<int32_t>(),
              bias->data_ptr<scalar_t>(), c.size(0), c.size(1));
        } else {
          // Per-Tensor
          dynamic_quant_epilogue<true, false, true>(
              tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
              a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
              azp->data_ptr<int32_t>(), azp_adj.data_ptr<int32_t>(),
              bias->data_ptr<scalar_t>(), c.size(0), c.size(1));
        }
      } else {
        // Compute C=s_a * C_inter - s_a * s_b * azp * azp_adj
        if (b_scales.numel() != 1) {
          // Per-Channel
          dynamic_quant_epilogue<true, true, false, scalar_t>(
              tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
              a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
              azp->data_ptr<int32_t>(), azp_adj.data_ptr<int32_t>(), nullptr,
              c.size(0), c.size(1));
        } else {
          // Per-Tensor
          dynamic_quant_epilogue<true, false, false, scalar_t>(
              tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
              a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
              azp->data_ptr<int32_t>(), azp_adj.data_ptr<int32_t>(), nullptr,
              c.size(0), c.size(1));
        }
      }
    } else {
      // per-tensor
      if (bias.has_value()) {
        // Compute C_inter=s_a * s_b * (A@B) + bias
        DNNLPrimitiveHelper<false>::gemm_s8s8_jit(
            a.data_ptr<int8_t>(), b.data_ptr<int8_t>(),
            tmp_fp32_out.data_ptr<float>(), bias->data_ptr<scalar_t>(),
            a.size(0), b.size(1), a.size(1), a_scales.data_ptr<float>(),
            b_scales.data_ptr<float>(), a_scales.numel(), b_scales.numel());
      } else {
        // Compute C_inter=s_a * s_b * (A@B)
        DNNLPrimitiveHelper<false>::gemm_s8s8_jit<float, void>(
            a.data_ptr<int8_t>(), b.data_ptr<int8_t>(),
            tmp_fp32_out.data_ptr<float>(), nullptr, a.size(0), b.size(1),
            a.size(1), a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
            a_scales.numel(), b_scales.numel());
      }

      // Compute C=C_inter - s_a * s_b * azp_adj
      if (b_scales.numel() != 1) {
        // Per-Channel
        static_quant_epilogue<true>(
            tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
            *a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
            azp_adj.data_ptr<int32_t>(), a.size(0), b.size(1));
      } else {
        // Per-Tensor
        static_quant_epilogue<false>(
            tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
            *a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
            azp_adj.data_ptr<int32_t>(), a.size(0), b.size(1));
      }
    }
  });
}

547
548
549
// static-per-tensor quantization.
void static_scaled_int8_quant(torch::Tensor& out,          // [..., hidden_size]
                              const torch::Tensor& input,  // [..., hidden_size]
550
551
                              const torch::Tensor& scale,
                              c10::optional<torch::Tensor> const& azp) {
552
553
554
555
  CPU_KERNEL_GUARD_IN(static_scaled_int8_quant)
  TORCH_CHECK(input.is_contiguous());
  TORCH_CHECK(out.is_contiguous());
  TORCH_CHECK(scale.numel() == 1);
556
  TORCH_CHECK(!azp.has_value() || azp->numel() == 1);
557
558
559
560
561

  const int hidden_size = input.size(-1);
  const int num_tokens = input.numel() / hidden_size;
  VLLM_DISPATCH_FLOATING_TYPES(
      input.scalar_type(), "static_scaled_int8_quant_impl", [&] {
562
563
564
565
566
567
568
569
570
571
        if (azp.has_value()) {
          static_scaled_int8_quant_impl<true>(
              input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
              scale.data_ptr<float>(), azp->data_ptr<int32_t>(), num_tokens,
              hidden_size);
        } else {
          static_scaled_int8_quant_impl<false>(
              input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
              scale.data_ptr<float>(), nullptr, num_tokens, hidden_size);
        }
572
573
574
575
576
577
578
      });
}

// dynamic-per-token quantization.
void dynamic_scaled_int8_quant(
    torch::Tensor& out,          // [..., hidden_size]
    const torch::Tensor& input,  // [..., hidden_size]
579
580
    torch::Tensor& scale,        // [..., 1]
    c10::optional<torch::Tensor> const& azp) {
581
582
583
584
585
586
587
588
  CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant)
  TORCH_CHECK(input.is_contiguous());
  TORCH_CHECK(out.is_contiguous());

  int const hidden_size = input.size(-1);
  int const num_tokens = input.numel() / hidden_size;
  VLLM_DISPATCH_FLOATING_TYPES(
      input.scalar_type(), "dynamic_scaled_int8_quant_impl", [&] {
589
590
591
592
593
594
595
596
597
598
        if (azp.has_value()) {
          dynamic_scaled_int8_quant_impl<true>(
              input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
              scale.data_ptr<float>(), azp->data_ptr<int32_t>(), num_tokens,
              hidden_size);
        } else {
          dynamic_scaled_int8_quant_impl<false>(
              input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
              scale.data_ptr<float>(), nullptr, num_tokens, hidden_size);
        }
599
600
      });
}