Unverified Commit eb89439d authored by qinyiqun's avatar qinyiqun Committed by GitHub
Browse files

Support Quantization (#996)



demo131 - multiple issues regarding quantization, qy, and so forth

* issue/843: success per_channel_quant_int8

* issue/843: success qy quant

* issue/843: modified quant

* Add w8a8int8 performance tests

* add infinicore op linear_w8a8i8

* w8a8 linear module functional nn

* issue/843: QY-GPU Support Int8 scale_mm (#68)

* issue/843: success qy scaled_mm

* issue/843: modified kernel.cuh as per_channel_dequant_int8.cuh

* fix parallel slic in w8

* w8: support multiple batch size

* temp: 修改quantconfig处理

* fix format and delete redundancy code

* fix format

* fix format

* fix format

* Refactor: add new API alongside legacy interfaces with deprecation warnings

* 添加w4 inifnicore相关内容,以及将Quantization config划入InfiniCore

* 量化算子支持图

* solve cub version problem and fix code structure

* fix format

* demo131 - remove commented lines

---------
Co-authored-by: default avatarxgqdut2016 <kenan_gewei@163.com>
Co-authored-by: default avatarxgqdut2016 <140036308+xgqdut2016@users.noreply.github.com>
Co-authored-by: default avatarwooway777 <wooway777@gmail.com>
parent abab5652
#include "../../utils.hpp"
#include "../infiniop_impl.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/dequantize_awq.hpp"
#include <infiniop.h>
namespace infinicore::op::dequantize_awq_impl::infiniop {
INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, DequantizeAWQ, 100);
struct PlannedMeta {
std::shared_ptr<Descriptor> descriptor;
graph::GraphTensor workspace, x, x_packed, x_scale, x_zeros;
};
void *plan(Tensor x, const Tensor &x_packed, const Tensor &x_scale, const Tensor &x_zeros) {
size_t seed = hash_combine(x, x_packed, x_scale, x_zeros);
INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
Descriptor, descriptor, DequantizeAWQ,
seed,
x->desc(), x_packed->desc(), x_scale->desc(), x_zeros->desc());
INFINIOP_WORKSPACE_TENSOR(workspace, DequantizeAWQ, descriptor);
return new PlannedMeta{
descriptor,
graph::GraphTensor(workspace),
graph::GraphTensor(x),
graph::GraphTensor(x_packed),
graph::GraphTensor(x_scale),
graph::GraphTensor(x_zeros)};
}
void run(void *planned_meta) {
auto planned = reinterpret_cast<PlannedMeta *>(planned_meta);
INFINICORE_CHECK_ERROR(infiniopDequantizeAWQ(
planned->descriptor->desc,
planned->workspace->data(),
planned->workspace->numel(),
planned->x->data(),
planned->x_packed->data(),
planned->x_scale->data(),
planned->x_zeros->data(),
context::getStream()));
}
void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
}
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(DequantizeAWQ, &plan, &run, &cleanup);
} // namespace infinicore::op::dequantize_awq_impl::infiniop
#include "infinicore/ops/linear_w4a16_awq.hpp"
#include "infinicore/ops/dequantize_awq.hpp"
#include "infinicore/ops/gemm.hpp"
namespace infinicore::op {
Tensor linear_w4a16_awq(Tensor input,
Tensor weight_packed,
Tensor weight_scale,
Tensor weight_zeros,
std::optional<Tensor> bias) {
// Input is of shape [M, K], Weight_packed is of shape [N, K],stirdes is [N, 1]
Size ndim = input->ndim();
Size out_features = weight_packed->shape()[0];
// Assign memory to out variables
auto output_shape = input->shape();
output_shape[ndim - 1] = out_features;
auto out = Tensor::empty(output_shape, input->dtype(), input->device());
// Inplace Calculate
linear_w4a16_awq_(out, input, weight_packed, weight_scale, weight_zeros, bias);
return out;
}
void linear_w4a16_awq_(Tensor out,
Tensor input,
Tensor weight_packed,
Tensor weight_scale,
Tensor weight_zeros,
std::optional<Tensor> bias) {
auto weight_packed_shape = weight_packed->shape();
Size out_features = weight_packed_shape[0];
Size in_features = weight_packed_shape[1];
Size ndim = input->ndim();
assert(out->ndim() == ndim);
Size N = 1;
auto input_shape = input->shape();
for (size_t i = 0; i < ndim - 1; ++i) {
N *= input_shape[i];
}
auto weight = Tensor::empty(
{out_features, in_features},
out->dtype(),
weight_packed->device());
float alpha = 1.0f;
float beta = 0.0f;
op::dequantize_awq_(weight, weight_packed, weight_scale, weight_zeros);
bias = std::make_optional(bias.value()->as_strided({N, out_features}, {0, 1}));
gemm_(out->view({N, out_features}),
input->view({N, in_features}),
weight->permute({1, 0}), alpha, beta);
}
} // namespace infinicore::op
#include "infinicore/ops/linear_w8a8i8.hpp"
#include "infinicore/ops/per_channel_quant_i8.hpp"
#include "infinicore/ops/scaled_mm_i8.hpp"
namespace infinicore::op {
Tensor linear_w8a8i8(Tensor input,
Tensor weight_packed,
Tensor weight_scale,
std::optional<Tensor> bias) {
// Input is of shape [M, K], Weight_packed is of shape [N, K],stirdes is [N, 1]
Size ndim = input->ndim();
Size out_features = weight_packed->shape()[0];
// Assign memory to out variables
auto output_shape = input->shape();
output_shape[ndim - 1] = out_features;
auto out = Tensor::empty(output_shape, input->dtype(), input->device());
// Inplace Calculate
linear_w8a8i8_(out, input, weight_packed, weight_scale, bias);
return out;
}
void linear_w8a8i8_(Tensor out,
Tensor input,
Tensor weight_packed,
Tensor weight_scale,
std::optional<Tensor> bias) {
auto weight_packed_shape = weight_packed->shape();
Size out_features = weight_packed_shape[0];
Size in_features = weight_packed_shape[1];
Size ndim = input->ndim();
assert(out->ndim() == ndim);
Size N = 1;
auto input_shape = input->shape();
for (size_t i = 0; i < ndim - 1; ++i) {
N *= input_shape[i];
}
auto input_packed = Tensor::empty(
{N, input_shape[ndim - 1]},
DataType::I8,
input->device());
auto input_scale = Tensor::empty(
{N, 1},
DataType::F32,
input->device());
op::per_channel_quant_i8_(input->view({N, in_features}), input_packed, input_scale);
if (bias.has_value()) {
bias = std::make_optional(bias.value()->as_strided({N, out_features}, {0, 1}));
}
op::scaled_mm_i8_(
out->view({N, out_features}),
input_packed,
input_scale,
weight_packed->permute({1, 0}),
weight_scale,
bias);
}
} // namespace infinicore::op
#include "infinicore/ops/per_channel_quant_i8.hpp"
#include "../../utils.hpp"
namespace infinicore::op {
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(PerChannelQuantI8);
PerChannelQuantI8::PerChannelQuantI8(const Tensor &x, Tensor x_packed, Tensor x_scale) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(x, x_packed, x_scale);
INFINICORE_GRAPH_OP_DISPATCH(x->device().getType(), x, x_packed, x_scale);
}
void PerChannelQuantI8::execute(const Tensor &x, Tensor x_packed, Tensor x_scale) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(PerChannelQuantI8, x, x_packed, x_scale);
}
void per_channel_quant_i8_(const Tensor &x, Tensor x_packed, Tensor x_scale) {
PerChannelQuantI8::execute(x, x_packed, x_scale);
}
} // namespace infinicore::op
#include "../../utils.hpp"
#include "../infiniop_impl.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/per_channel_quant_i8.hpp"
#include <infiniop.h>
namespace infinicore::op::per_channel_quant_i8_impl::infiniop {
INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, PerChannelQuantI8, 100);
struct PlannedMeta {
std::shared_ptr<Descriptor> descriptor;
graph::GraphTensor workspace, x, x_packed, x_scale;
};
void *plan(const Tensor &x, Tensor x_packed, Tensor x_scale) {
size_t seed = hash_combine(x, x_packed, x_scale);
INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
Descriptor, descriptor, PerChannelQuantI8,
seed,
x_packed->desc(), x_scale->desc(), nullptr, x->desc());
INFINIOP_WORKSPACE_TENSOR(workspace, PerChannelQuantI8, descriptor);
return new PlannedMeta{
descriptor,
graph::GraphTensor(workspace),
graph::GraphTensor(x),
graph::GraphTensor(x_packed),
graph::GraphTensor(x_scale)};
}
void run(void *planned_meta) {
auto planned = reinterpret_cast<PlannedMeta *>(planned_meta);
INFINICORE_CHECK_ERROR(infiniopPerChannelQuantI8(
planned->descriptor->desc,
planned->workspace->data(),
planned->workspace->numel(),
planned->x_packed->data(),
planned->x_scale->data(),
nullptr,
planned->x->data(),
context::getStream()));
}
void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
}
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(PerChannelQuantI8, &plan, &run, &cleanup);
} // namespace infinicore::op::per_channel_quant_i8_impl::infiniop
#include "infinicore/ops/scaled_mm_i8.hpp"
#include "../../utils.hpp"
namespace infinicore::op {
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(I8Gemm);
I8Gemm::I8Gemm(Tensor c, const Tensor &a_p, const Tensor &a_s, const Tensor &b_p, const Tensor &b_s, std::optional<Tensor> bias) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a_p, a_s, b_p, b_s);
INFINICORE_GRAPH_OP_DISPATCH(c->device().getType(), c, a_p, a_s, b_p, b_s, bias);
}
void I8Gemm::execute(Tensor c, const Tensor &a_p, const Tensor &a_s, const Tensor &b_p, const Tensor &b_s, std::optional<Tensor> bias) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(I8Gemm, c, a_p, a_s, b_p, b_s, bias);
}
void scaled_mm_i8_(Tensor c, const Tensor &a_p, const Tensor &a_s, const Tensor &b_p, const Tensor &b_s, std::optional<Tensor> bias) {
I8Gemm::execute(c, a_p, a_s, b_p, b_s, bias);
}
} // namespace infinicore::op
#include "../../utils.hpp"
#include "../infiniop_impl.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/scaled_mm_i8.hpp"
#include <infiniop.h>
namespace infinicore::op::scaled_mm_i8_impl::infiniop {
INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, I8Gemm, 100);
struct PlannedMeta {
std::shared_ptr<Descriptor> descriptor;
graph::GraphTensor workspace, c, a_p, a_s, b_p, b_s;
std::optional<graph::GraphTensor> bias;
};
void *plan(Tensor c, const Tensor &a_p, const Tensor &a_s, const Tensor &b_p, const Tensor &b_s, std::optional<Tensor> bias) {
size_t seed = hash_combine(c, a_p, a_s, b_p, b_s);
INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
Descriptor, descriptor, I8Gemm,
seed,
c->desc(), bias.has_value() ? bias.value()->desc() : nullptr,
a_p->desc(), a_s->desc(), b_p->desc(), b_s->desc());
INFINIOP_WORKSPACE_TENSOR(workspace, I8Gemm, descriptor);
return new PlannedMeta{
descriptor,
graph::GraphTensor(workspace),
graph::GraphTensor(c),
graph::GraphTensor(a_p),
graph::GraphTensor(a_s),
graph::GraphTensor(b_p),
graph::GraphTensor(b_s),
// bias.has_value() ? bias.value()->desc() : nullptr};
bias ? std::optional<graph::GraphTensor>(graph::GraphTensor(*bias)) : std::nullopt};
}
void run(void *planned_meta) {
auto planned = reinterpret_cast<PlannedMeta *>(planned_meta);
INFINICORE_CHECK_ERROR(infiniopI8Gemm(
planned->descriptor->desc,
planned->workspace->data(),
planned->workspace->numel(),
planned->c->data(),
// planned->bias->data(),
planned->bias.has_value() ? planned->bias.value()->data() : nullptr,
planned->a_p->data(),
planned->a_s->data(),
planned->b_p->data(),
planned->b_s->data(),
context::getStream()));
}
void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
}
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(I8Gemm, &plan, &run, &cleanup);
} // namespace infinicore::op::scaled_mm_i8_impl::infiniop
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include "ops/flash_attention.hpp" #include "ops/flash_attention.hpp"
#include "ops/kv_caching.hpp" #include "ops/kv_caching.hpp"
#include "ops/linear.hpp" #include "ops/linear.hpp"
#include "ops/linear_w8a8i8.hpp"
#include "ops/matmul.hpp" #include "ops/matmul.hpp"
#include "ops/mul.hpp" #include "ops/mul.hpp"
#include "ops/paged_attention.hpp" #include "ops/paged_attention.hpp"
...@@ -46,6 +47,7 @@ inline void bind(py::module &m) { ...@@ -46,6 +47,7 @@ inline void bind(py::module &m) {
bind_swiglu(m); bind_swiglu(m);
bind_rope(m); bind_rope(m);
bind_embedding(m); bind_embedding(m);
bind_linear_w8a8i8(m);
} }
} // namespace infinicore::ops } // namespace infinicore::ops
#pragma once
#include <pybind11/pybind11.h>
#include "infinicore/ops/linear_w8a8i8.hpp"
namespace py = pybind11;
namespace infinicore::ops {
Tensor py_linear_w8a8i8(Tensor input,
Tensor weight_packed,
Tensor weight_scale,
pybind11::object bias) {
std::optional<Tensor> bias_tensor = std::nullopt;
if (!bias.is_none()) {
bias_tensor = bias.cast<Tensor>();
}
return op::linear_w8a8i8(input, weight_packed, weight_scale, bias_tensor);
}
void py_linear_w8a8i8_(Tensor out,
Tensor input,
Tensor weight_packed,
Tensor weight_scale,
pybind11::object bias) {
std::optional<Tensor> bias_tensor = std::nullopt;
if (!bias.is_none()) {
bias_tensor = bias.cast<Tensor>();
}
op::linear_w8a8i8_(out, input, weight_packed, weight_scale, bias_tensor);
}
inline void bind_linear_w8a8i8(py::module &m) {
m.def("linear_w8a8i8",
&ops::py_linear_w8a8i8,
py::arg("input"),
py::arg("weight_packed"),
py::arg("weight_scale"),
py::arg("bias") = py::none(),
R"doc(linear_w8a8i8.)doc");
m.def("linear_w8a8i8_",
&ops::py_linear_w8a8i8_,
py::arg("out"),
py::arg("input"),
py::arg("weight_packed"),
py::arg("weight_scale"),
py::arg("bias") = py::none(),
R"doc(linear_w8a8i8_.)doc");
}
} // namespace infinicore::ops
#pragma once
#include <pybind11/pybind11.h>
#include "infinicore/ops/per_channel_quant_i8.hpp"
namespace py = pybind11;
namespace infinicore::ops {
inline void bind_per_channel_quant_i8(py::module &m) {
m.def("per_channel_quant_i8_",
&op::per_channel_quant_i8_,
py::arg("x"),
py::arg("x_packed"),
py::arg("x_scale"),
R"doc(Per-channel quantization of a tensor.)doc");
}
} // namespace infinicore::ops
#pragma once
#include <pybind11/pybind11.h>
#include "infinicore/ops/scaled_mm_i8.hpp"
namespace py = pybind11;
namespace infinicore::ops {
inline void bind_scaled_mm_i8(py::module &m) {
m.def("scaled_mm_i8",
&op::scaled_mm_i8,
py::arg("a_p"),
py::arg("a_s"),
py::arg("b_p"),
py::arg("b_s"),
py::arg("bias"),
R"doc(Scaled matrix multiplication of two tensors.)doc");
m.def("scaled_mm_i8_",
&op::scaled_mm_i8_,
py::arg("a"),
py::arg("b"),
py::arg("a_scale"),
py::arg("b_scale"),
R"doc(In-place Scaled matrix multiplication of two tensors.)doc");
}
} // namespace infinicore::ops
...@@ -95,6 +95,20 @@ void print_data_bf16(const uint16_t *data, const Shape &shape, const Strides &st ...@@ -95,6 +95,20 @@ void print_data_bf16(const uint16_t *data, const Shape &shape, const Strides &st
} }
} }
// Function for printing I8 data
void print_data_i8(const int8_t *data, const Shape &shape, const Strides &strides, size_t dim) {
if (dim == shape.size() - 1) {
for (size_t i = 0; i < shape[dim]; i++) {
std::cout << static_cast<int>(data[i * strides[dim]]) << " ";
}
std::cout << std::endl;
} else if (dim < shape.size() - 1) {
for (size_t i = 0; i < shape[dim]; i++) {
print_data_i8(data + i * strides[dim], shape, strides, dim + 1);
}
}
}
// Template function for writing data recursively to binary file (handles non-contiguous tensors) // Template function for writing data recursively to binary file (handles non-contiguous tensors)
template <typename T> template <typename T>
void write_binary_data(std::ofstream &out, const T *data, const Shape &shape, const Strides &strides, size_t dim) { void write_binary_data(std::ofstream &out, const T *data, const Shape &shape, const Strides &strides, size_t dim) {
...@@ -181,8 +195,8 @@ void TensorImpl::debug(const std::string &filename) const { ...@@ -181,8 +195,8 @@ void TensorImpl::debug(const std::string &filename) const {
cpu_tensor->shape(), cpu_tensor->strides(), 0); cpu_tensor->shape(), cpu_tensor->strides(), 0);
break; break;
case DataType::I8: case DataType::I8:
print_data(reinterpret_cast<const int8_t *>(cpu_data), print_data_i8(reinterpret_cast<const int8_t *>(cpu_data),
cpu_tensor->shape(), cpu_tensor->strides(), 0); cpu_tensor->shape(), cpu_tensor->strides(), 0);
break; break;
case DataType::BF16: case DataType::BF16:
print_data_bf16(reinterpret_cast<const uint16_t *>(cpu_data), print_data_bf16(reinterpret_cast<const uint16_t *>(cpu_data),
......
#ifndef __PERCHANNEL_QUANTINT8_KERNEL_CUH__
#define __PERCHANNEL_QUANTINT8_KERNEL_CUH__
#include <cub/block/block_reduce.cuh>
__device__ inline int round_half_away_from_zero(float x) {
float ax = fabsf(x);
float r = floorf(ax + 0.5f);
return (x >= 0.0f) ? (int)r : -(int)r;
}
template <typename Tdata, unsigned int BLOCK_SIZE>
__device__ void blockPerChannelQuantI8Kernel(
int8_t *x_packed, float *x_scale, float *x_zero, const Tdata *x,
int M, int K) {
int row = blockIdx.x;
int tid = row * K;
// ---- 1. reduce max ----
float local_max = op::common_cuda::reduce_op::max<BLOCK_SIZE, Tdata>(
x + tid, K);
__shared__ float global_max_f;
if (threadIdx.x == 0) {
global_max_f = local_max;
}
__syncthreads();
typedef cub::BlockReduce<float, BLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
// ---- 2. reduce min ----
float thread_min = __FLT_MAX__;
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE) {
thread_min = fminf(thread_min, (float)x[tid + ind]);
}
#if CUDART_VERSION >= 12090
float local_min = BlockReduce(temp_storage).Reduce(thread_min, ::cuda::minimum());
#else
float local_min = BlockReduce(temp_storage).Reduce(thread_min, cub::Min());
#endif
__shared__ float global_min_f;
if (threadIdx.x == 0) {
global_min_f = local_min;
}
__syncthreads();
float global_max = global_max_f;
float global_min = global_min_f;
float scale = (global_max - global_min) / 255.0f;
if (scale < 1e-8f) {
scale = 1e-8f;
}
float inv_scale = 1.0f / scale;
float zero = -global_min * inv_scale - 128.0f;
x_scale[row] = (Tdata)scale;
x_zero[row] = (Tdata)zero;
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE) {
float v = (float)x[tid + ind];
float qf = v * inv_scale + zero;
int q = round_half_away_from_zero(qf);
if (q > 127) {
q = 127;
}
if (q < -128) {
q = -128;
}
x_packed[tid + ind] = (int8_t)q;
}
}
template <typename Tdata, unsigned int BLOCK_SIZE>
__device__ void blockPerChannelQuantI8SymKernel(
int8_t *x_packed, float *x_scale, const Tdata *x,
int M, int K) {
int row = blockIdx.x;
int tid = row * K;
typedef cub::BlockReduce<float, BLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
// ---- 2. reduce min ----
float thread_max = -__FLT_MAX__;
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE) {
thread_max = fmaxf(thread_max, fabs((float)x[tid + ind]));
}
#if CUDART_VERSION >= 12090
float local_max = BlockReduce(temp_storage).Reduce(thread_max, ::cuda::maximum());
#else
float local_max = BlockReduce(temp_storage).Reduce(thread_max, cub::Max());
#endif
__shared__ float global_max_f;
if (threadIdx.x == 0) {
global_max_f = local_max;
}
__syncthreads();
float global_max = global_max_f;
float scale = global_max / 127.0f;
if (scale < 1e-8f) {
scale = 1e-8f;
}
float inv_scale = 1.0f / scale;
x_scale[row] = (Tdata)scale;
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE) {
float v = (float)x[tid + ind];
float qf = v * inv_scale;
int q = round_half_away_from_zero(qf);
if (q > 127) {
q = 127;
}
if (q < -127) {
q = -127;
}
x_packed[tid + ind] = (int8_t)q;
}
}
template <typename T>
struct MaxOp {
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
return max(a, b);
}
};
template <typename T>
struct MinOp {
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
return min(a, b);
}
};
template <template <typename> class ReductionOp, typename T,
int thread_group_width>
__inline__ __device__ T WarpAllReduce(T val) {
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
val = ReductionOp<T>()(val, __shfl_xor_sync(0xffffffff, val, mask));
}
return val;
}
template <typename Tdata, unsigned int BLOCK_SIZE_x, unsigned int BLOCK_SIZE_y>
__device__ void warpPerChannelQuantI8Kernel(
int8_t *x_packed, float *x_scale, float *x_zero, const Tdata *x,
int M, int K) {
int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;
int tid = otherIdx * K;
if (otherIdx < M) {
__shared__ float max_total[BLOCK_SIZE_y];
__shared__ float min_total[BLOCK_SIZE_y];
float max_data = -__FLT_MAX__;
float min_data = __FLT_MAX__;
// ---- reduce max/min ----
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE_x) {
float v = (float)x[tid + ind];
max_data = fmaxf(max_data, v);
min_data = fminf(min_data, v);
}
max_data = WarpAllReduce<MaxOp, float, BLOCK_SIZE_x>(max_data);
min_data = WarpAllReduce<MinOp, float, BLOCK_SIZE_x>(min_data);
if (threadIdx.x == 0) {
max_total[threadIdx.y] = max_data;
min_total[threadIdx.y] = min_data;
}
__syncthreads();
float max_f = max_total[threadIdx.y];
float min_f = min_total[threadIdx.y];
float scale = (max_f - min_f) / 255.0f;
if (scale < 1e-8f) {
scale = 1e-8f;
}
float inv_scale = 1.0f / scale;
float zero = -min_f * inv_scale - 128.0f;
x_scale[otherIdx] = scale;
x_zero[otherIdx] = zero;
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE_x) {
float v = (float)x[tid + ind];
float qf = v * inv_scale + zero;
int q = round_half_away_from_zero(qf);
if (q > 127) {
q = 127;
}
if (q < -128) {
q = -128;
}
x_packed[tid + ind] = (int8_t)q;
}
}
}
template <typename Tdata, unsigned int BLOCK_SIZE_x, unsigned int BLOCK_SIZE_y>
__device__ void warpPerChannelQuantI8SymKernel(
int8_t *x_packed, float *x_scale, const Tdata *x,
int M, int K) {
int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;
int tid = otherIdx * K;
if (otherIdx < M) {
__shared__ float max_total[BLOCK_SIZE_y];
float max_data = -__FLT_MAX__;
// ---- reduce max/min ----
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE_x) {
float v = fabs((float)x[tid + ind]);
max_data = fmaxf(max_data, v);
}
max_data = WarpAllReduce<MaxOp, float, BLOCK_SIZE_x>(max_data);
if (threadIdx.x == 0) {
max_total[threadIdx.y] = max_data;
}
__syncthreads();
float max_f = max_total[threadIdx.y];
float scale = max_f / 127.0f;
if (scale < 1e-8f) {
scale = 1e-8f;
}
float inv_scale = 1.0f / scale;
x_scale[otherIdx] = scale;
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE_x) {
float v = (float)x[tid + ind];
float qf = v * inv_scale;
int q = round_half_away_from_zero(qf);
if (q > 127) {
q = 127;
}
if (q < -127) {
q = -127;
}
x_packed[tid + ind] = (int8_t)q;
}
}
}
#endif // __PERCHANNEL_QUANTINT8_KERNEL_CUH__
#ifndef __PER_CHANNEL_QUANT_INT8_INFO_H__
#define __PER_CHANNEL_QUANT_INT8_INFO_H__
#include "../../../../utils.h"
#include "../../../operator.h"
#include "../../../tensor.h"
namespace op::per_channel_quant_int8 {
class PerChannelQuantI8Info {
private:
PerChannelQuantI8Info() = default;
public:
infiniDtype_t dtype, packed_type;
size_t M, K;
static utils::Result<PerChannelQuantI8Info> createPerChannelQuantI8Info(
infiniopTensorDescriptor_t x_packed_desc,
infiniopTensorDescriptor_t x_scale_desc,
infiniopTensorDescriptor_t x_zero_desc,
infiniopTensorDescriptor_t x_desc) {
CHECK_OR_RETURN(
x_packed_desc != nullptr && x_scale_desc != nullptr && x_desc != nullptr,
INFINI_STATUS_NULL_POINTER);
const infiniDtype_t dtype = x_desc->dtype();
const infiniDtype_t packed_type = x_packed_desc->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32);
CHECK_DTYPE(packed_type, INFINI_DTYPE_I8);
CHECK_OR_RETURN(x_desc->ndim() == 2
&& x_packed_desc->ndim() == 2
&& x_scale_desc->ndim() == 2,
INFINI_STATUS_BAD_TENSOR_SHAPE);
size_t M = x_desc->dim(0);
size_t K = x_desc->dim(1);
CHECK_OR_RETURN(M == x_packed_desc->dim(0)
|| K == x_packed_desc->dim(1)
|| M == x_scale_desc->dim(0)
|| 1 == x_scale_desc->dim(1),
INFINI_STATUS_BAD_TENSOR_SHAPE);
return utils::Result<PerChannelQuantI8Info>(PerChannelQuantI8Info{
dtype,
packed_type,
M,
K,
});
}
};
} // namespace op::per_channel_quant_int8
#endif // __PER_CHANNEL_QUANT_INT8_INFO_H__
#include "../../../../devices/nvidia/nvidia_common.cuh"
#include "per_channel_quant_int8_nvidia.cuh"
#include "../../../../devices/nvidia/nvidia_kernel_common.cuh"
#include "../../../../reduce/cuda/reduce.cuh"
#include <cub/block/block_reduce.cuh>
#include "../cuda/kernel.cuh"
template <typename Tdata, unsigned int BLOCK_SIZE>
INFINIOP_CUDA_KERNEL blockPerChannelQuantI8(
int8_t *x_packed, float *x_scale, float *x_zero, const Tdata *x, int M, int K) {
blockPerChannelQuantI8Kernel<Tdata, BLOCK_SIZE>(x_packed, x_scale, x_zero, x, M, K);
}
template <typename Tdata, unsigned int BLOCK_SIZE>
INFINIOP_CUDA_KERNEL blockPerChannelQuantI8Sym(
int8_t *x_packed, float *x_scale, const Tdata *x, int M, int K) {
blockPerChannelQuantI8SymKernel<Tdata, BLOCK_SIZE>(x_packed, x_scale, x, M, K);
}
template <typename Tdata, unsigned int BLOCK_SIZE_x, unsigned int BLOCK_SIZE_y>
INFINIOP_CUDA_KERNEL warpPerChannelQuantI8(
int8_t *x_packed, float *x_scale, float *x_zero, const Tdata *x, int M, int K) {
warpPerChannelQuantI8Kernel<Tdata, BLOCK_SIZE_x, BLOCK_SIZE_y>(x_packed, x_scale, x_zero, x, M, K);
}
template <typename Tdata, unsigned int BLOCK_SIZE_x, unsigned int BLOCK_SIZE_y>
INFINIOP_CUDA_KERNEL warpPerChannelQuantI8Sym(
int8_t *x_packed, float *x_scale, const Tdata *x, int M, int K) {
warpPerChannelQuantI8SymKernel<Tdata, BLOCK_SIZE_x, BLOCK_SIZE_y>(x_packed, x_scale, x, M, K);
}
namespace op::per_channel_quant_int8::nvidia {
struct Descriptor::Opaque {
std::shared_ptr<device::nvidia::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle, Descriptor **desc_ptr,
infiniopTensorDescriptor_t x_packed_desc,
infiniopTensorDescriptor_t x_scale_desc,
infiniopTensorDescriptor_t x_zero_desc,
infiniopTensorDescriptor_t x_desc) {
auto info = PerChannelQuantI8Info::createPerChannelQuantI8Info(x_packed_desc, x_scale_desc, x_zero_desc, x_desc);
CHECK_RESULT(info);
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::nvidia::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <unsigned int BLOCK_SIZE, typename Tdata>
infiniStatus_t per_channel_quant_int8Kernel(const PerChannelQuantI8Info &info, int8_t *x_packed, float *x_scale, float *x_zero, const Tdata *x, cudaStream_t stream) {
int M = (int)info.M;
int K = (int)info.K;
if (K >= 1024) {
if (x_zero == nullptr) {
blockPerChannelQuantI8Sym<Tdata, BLOCK_SIZE>
<<<M, BLOCK_SIZE, 0, stream>>>(x_packed, x_scale, x, M, K);
} else {
blockPerChannelQuantI8<Tdata, BLOCK_SIZE>
<<<M, BLOCK_SIZE, 0, stream>>>(x_packed, x_scale, x_zero, x, M, K);
}
} else {
constexpr unsigned int BLOCK_SIZE_x = 32;
constexpr unsigned int BLOCK_SIZE_y = 32;
int num_block_x = (M + BLOCK_SIZE_y - 1) / BLOCK_SIZE_y;
dim3 block_dim(BLOCK_SIZE_x, BLOCK_SIZE_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
if (x_zero == nullptr) {
warpPerChannelQuantI8Sym<Tdata, BLOCK_SIZE_x, BLOCK_SIZE_y>
<<<grid_dim, block_dim, 0, stream>>>(x_packed, x_scale, x, M, K);
} else {
warpPerChannelQuantI8<Tdata, BLOCK_SIZE_x, BLOCK_SIZE_y>
<<<grid_dim, block_dim, 0, stream>>>(x_packed, x_scale, x_zero, x, M, K);
}
}
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
void *x_packed, void *x_scale, void *x_zero, const void *x,
void *stream_) const {
cudaStream_t stream = (cudaStream_t)stream_;
#define QUANT(BLOCK_SIZE, TDATA) \
per_channel_quant_int8Kernel<BLOCK_SIZE, TDATA>(_info, (int8_t *)x_packed, (float *)x_scale, (float *)x_zero, (const TDATA *)x, stream)
#define QUANT_WITH_BLOCK_SIZE(BLOCK_SIZE) \
{ \
if (_info.dtype == INFINI_DTYPE_F16) \
return QUANT(BLOCK_SIZE, half); \
else if (_info.dtype == INFINI_DTYPE_F32) \
return QUANT(BLOCK_SIZE, float); \
else if (_info.dtype == INFINI_DTYPE_BF16) \
return QUANT(BLOCK_SIZE, __nv_bfloat16); \
else \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) {
QUANT_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_1024)
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) {
QUANT_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_512)
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) {
QUANT_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_4096)
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::per_channel_quant_int8::nvidia
#ifndef __PER_CHANNEL_QUANT_INT8_NVIDIA_API_H__
#define __PER_CHANNEL_QUANT_INT8_NVIDIA_API_H__
#include "../per_channel_quant_int8.h"
DESCRIPTOR(nvidia)
#endif // __PER_CHANNEL_QUANT_INT8_NVIDIA_API_H__
#include "../../../operator.h"
#include "../../../handle.h"
#include "infiniop/ops/quant/per_channel_quant_int8.h"
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
#include "nvidia/per_channel_quant_int8_nvidia.cuh"
#endif
__C infiniStatus_t infiniopCreatePerChannelQuantI8Descriptor(infiniopHandle_t handle,
infiniopPerChannelQuantI8Descriptor_t *desc_ptr,
infiniopTensorDescriptor_t x_packed_desc,
infiniopTensorDescriptor_t x_scale_desc,
infiniopTensorDescriptor_t x_zero_desc,
infiniopTensorDescriptor_t x_desc) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::per_channel_quant_int8::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::per_channel_quant_int8::NAMESPACE::Descriptor **>(desc_ptr), \
x_packed_desc, \
x_scale_desc, \
x_zero_desc, \
x_desc);
switch (handle->device) {
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CREATE
}
__C infiniStatus_t infiniopGetPerChannelQuantI8WorkspaceSize(infiniopPerChannelQuantI8Descriptor_t desc, size_t *size) {
switch (desc->device_type) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::per_channel_quant_int8::NAMESPACE::Descriptor *>(desc)->minWorkspaceSize(); \
return INFINI_STATUS_SUCCESS;
#ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef GET
}
__C infiniStatus_t infiniopPerChannelQuantI8(infiniopPerChannelQuantI8Descriptor_t desc,
void *workspace,
size_t workspace_size,
void *x_packed,
void *x_scale,
void *x_zero,
const void *x,
void *stream) {
#define QUANT(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<op::per_channel_quant_int8::NAMESPACE::Descriptor *>(desc)->calculate( \
workspace, workspace_size, x_packed, x_scale, x_zero, x, stream);
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
QUANT(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_QY_API
QUANT(INFINI_DEVICE_QY, nvidia)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef QUANT
}
__C infiniStatus_t infiniopDestroyPerChannelQuantI8Descriptor(infiniopPerChannelQuantI8Descriptor_t desc) {
#define DESTROY(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::per_channel_quant_int8::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
DESTROY(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_QY_API
DESTROY(INFINI_DEVICE_QY, nvidia)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef DESTROY
}
#ifndef __QUANT_H__
#define __QUANT_H__
#include "../../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::per_channel_quant_int8::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
PerChannelQuantI8Info _info; \
size_t _workspace_size; \
\
Descriptor(Opaque *opaque, PerChannelQuantI8Info info, \
size_t workspace_size, \
infiniDevice_t device_type, int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), _info(info), _workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t minWorkspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, Descriptor **desc_ptr, \
infiniopTensorDescriptor_t x_packed_desc, \
infiniopTensorDescriptor_t x_scale_desc, \
infiniopTensorDescriptor_t x_zero_desc, \
infiniopTensorDescriptor_t x_desc); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *x_packed, void *x_scale, void *x_zero, const void *x, void *stream) const; \
}; \
}
#endif // __QUANT_H__
#ifndef __PER_CHANNEL_DEQUANT_INT8_KERNEL_CUH__
#define __PER_CHANNEL_DEQUANT_INT8_KERNEL_CUH__
/**
* @brief Symmetric dequantization kernel for post-processing quantized matrix multiplication
*
* This kernel performs symmetric dequantization on the packed integer output from
* a quantized matrix multiplication. It converts integer results back to floating-point
* values by applying per-tensor scaling factors from both input and weight tensors,
* then adds bias terms.
*
* The dequantization formula is:
* y = x_scale * w_scale * y_packed + bias
*
* @tparam Tdata Output data type (typically bfloat16 or half)
*
* @param[out] y Output tensor after dequantization
* Shape: [M, N], Data type: Tdata
*
* @param[in] y_packed Packed integer output from quantized matmul
* Shape: [M, N], Data type: int32_t
* Contains integer results of: x_packed[i,:] * w_packed[:,j]
*
* @param[in] bias Bias tensor to add after dequantization
* Shape: [N], Data type: Tdata
* Broadcasted across all rows
*
* @param[in] x_packed Packed quantized input tensor (not directly used here)
* Shape: [M, K], Data type: int8_t
* Included for context of the computation pipeline
*
* @param[in] x_scale Per-tensor scaling factors for input
* Shape: [M], Data type: float
* One scale value per input row
*
* @param[in] w_packed Packed quantized weight tensor (not directly used here)
* Shape: [K, N], Data type: int8_t
* Included for context of the computation pipeline
*
* @param[in] w_scale Per-tensor scaling factors for weights
* Shape: [N], Data type: float
* One scale value per output column
*
* @param[in] M Batch size / number of input rows
*
* @param[in] K Inner dimension of matrix multiplication
*
* @param[in] N Output dimension / number of output columns
*
* @note This kernel assumes symmetric quantization (zero-point = 0)
* @note Each thread processes one element of the output matrix
* @note Grid and block dimensions should be configured to cover [M, N] output space
*/
template <typename Tdata>
__device__ void postSymKernel(Tdata *y, int32_t *y_packed, const Tdata *bias, const int8_t *x_packed, const float *x_scale, const int8_t *w_packed, const float *w_scale, int M, int K, int N) {
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
if (row >= M || col >= N) {
return;
}
int idx = row * N + col;
float output1 = x_scale[row] * w_scale[col] * ((float)y_packed[idx]);
float output = output1 + (float)bias[col];
y[idx] = static_cast<Tdata>(output);
}
// y = x_scale * w_scale * y_packed
template <typename Tdata>
__device__ void postSymKernel(Tdata *y, int32_t *y_packed, const int8_t *x_packed, const float *x_scale, const int8_t *w_packed, const float *w_scale, int M, int K, int N) {
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
if (row >= M || col >= N) {
return;
}
int idx = row * N + col;
float output = x_scale[row] * w_scale[col] * ((float)y_packed[idx]);
y[idx] = static_cast<Tdata>(output);
}
#endif // __PER_CHANNEL_DEQUANT_INT8_KERNEL_CUH__
...@@ -4,43 +4,48 @@ ...@@ -4,43 +4,48 @@
#include "../../operator.h" #include "../../operator.h"
#include "info.h" #include "info.h"
#define DESCRIPTOR(NAMESPACE) \ #define DESCRIPTOR(NAMESPACE) \
\ \
namespace op::i8gemm::NAMESPACE { \ namespace op::i8gemm::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \ class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \ struct Opaque; \
Opaque *_opaque; \ Opaque *_opaque; \
size_t _workspace_size; \ size_t _workspace_size; \
I8GemmInfo _info; \ I8GemmInfo _info; \
infiniDtype_t _out_dtype; \ infiniDtype_t _out_dtype; \
\ \
Descriptor(Opaque *opaque, I8GemmInfo info, \ Descriptor(Opaque *opaque, I8GemmInfo info, \
size_t workspace_size, \ size_t workspace_size, \
infiniDtype_t out_dtype, \ infiniDtype_t out_dtype, \
infiniDevice_t device_type, int device_id) \ infiniDevice_t device_type, int device_id) \
: InfiniopDescriptor{device_type, device_id}, _out_dtype(out_dtype), \ : InfiniopDescriptor{device_type, device_id}, _out_dtype(out_dtype), \
_opaque(opaque), _info(info), _workspace_size(workspace_size) {} \ _opaque(opaque), _info(info), _workspace_size(workspace_size) {} \
\ \
public: \ public: \
~Descriptor(); \ ~Descriptor(); \
\ \
size_t minWorkspaceSize() const { return _workspace_size; } \ size_t minWorkspaceSize() const { return _workspace_size; } \
\ \
static infiniStatus_t create( \ static infiniStatus_t create( \
infiniopHandle_t handle, Descriptor **desc_ptr, \ infiniopHandle_t handle, Descriptor **desc_ptr, \
infiniopTensorDescriptor_t out_desc, \ infiniopTensorDescriptor_t out_desc, \
infiniopTensorDescriptor_t bias_desc, \ infiniopTensorDescriptor_t bias_desc, \
infiniopTensorDescriptor_t a_desc, \ infiniopTensorDescriptor_t a_desc, \
infiniopTensorDescriptor_t a_scale_desc, \ infiniopTensorDescriptor_t a_scale_desc, \
infiniopTensorDescriptor_t b_desc, \ infiniopTensorDescriptor_t b_desc, \
infiniopTensorDescriptor_t b_scale_desc); \ infiniopTensorDescriptor_t b_scale_desc); \
\ template <unsigned int BLOCK_SIZE, typename Tdata> \
infiniStatus_t calculate( \ infiniStatus_t launchKernel(const I8GemmInfo &info, Tdata *y, \
void *workspace, size_t workspace_size, \ const Tdata *bias, const int8_t *x_packed, \
void *out, const void *bias, const void *a, \ const float *x_scale, const int8_t *w_packed, \
const void *a_scale, const void *b, \ const float *w_scale, void *stream, void *workspace) const; \
const void *b_scale, void *stream) const; \ \
}; \ infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *out, const void *bias, const void *a, \
const void *a_scale, const void *b, \
const void *b_scale, void *stream) const; \
}; \
} }
#endif // __I8GEMM_H__ #endif // __I8GEMM_H__
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment