Unverified Commit e60985dc authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #1040 from InfiniTensor/Issue/1030

Issue/1030: Nvidia 支持w4a16推理
parents 58771213 63233f9b
......@@ -8,12 +8,23 @@ class AWQ : public BaseQuantization {
// information and support multiple quantization schemes.
public:
explicit AWQ(const nlohmann::json &quant_config)
: BaseQuantization(quant_config) {};
: BaseQuantization(quant_config){};
infinicore::quantization::QuantScheme
get_quant_scheme() const override {
return infinicore::quantization::QuantScheme::AWQ_W4A16;
};
int get_packing_num() const {
// For AWQ, we pack 8 int4 weights into a single int32 value.
return 32 / this->get_or<int>("bits", 4); // Default to 8 if not specified in config
}
int get_group_size() const {
// For simplicity, we return a fixed group size here. In a more complete implementation,
// this could be extracted from quant_config_ to support different group sizes.
return this->get_or<int>("group_size", 128); // Standard AWQ group size
}
};
} // namespace infinicore::quantization
......@@ -6,10 +6,34 @@ namespace infinicore::quantization {
class BaseQuantization {
// Base class for quantization schemes. Intended to be extended to support various quantization methods.
public:
explicit BaseQuantization(const nlohmann::json &quant_config) : quant_config_(quant_config) {};
explicit BaseQuantization(const nlohmann::json &quant_config) : quant_config_(quant_config){};
virtual ~BaseQuantization() = default;
virtual infinicore::quantization::QuantScheme get_quant_scheme() const = 0;
template <typename T>
T get(const std::string &key) const {
if (!quant_config_.contains(key)) {
throw std::out_of_range("Key '" + key + "' not found in config.");
}
try {
return quant_config_.at(key).get<T>();
} catch (const nlohmann::json::type_error &e) {
throw std::runtime_error("Type conversion failed for key '" + key + "': " + std::string(e.what()));
}
}
template <typename T>
T get_or(const std::string &key, const T &default_value) const {
if (!quant_config_.contains(key) || quant_config_.at(key).is_null()) {
return default_value;
}
try {
return quant_config_.at(key).get<T>();
} catch (const nlohmann::json::type_error &) {
// If type conversion fails, return default value
return default_value;
}
}
protected:
nlohmann::json quant_config_;
......
......@@ -3,6 +3,7 @@
#include "infinicore/ops.hpp"
#include "infinicore/ops/distributed/allreduce.hpp"
#include "infinicore/ops/linear.hpp"
#include "infinicore/ops/linear_w4a16_awq.hpp"
#include "infinicore/ops/linear_w8a8i8.hpp"
#include <optional>
#include <spdlog/spdlog.h>
......@@ -43,6 +44,15 @@ Tensor BaseLinear::compute_linear(Tensor &input) const {
auto output = infinicore::op::linear_w8a8i8(input_contiguous->contiguous(), weight_packed_tensor, weight_scale_tensor, bias_opt);
return output;
}
case infinicore::quantization::QuantScheme::AWQ_W4A16: {
Tensor input_contiguous = input->is_contiguous() ? input : input->contiguous();
Tensor qweight = static_cast<const Tensor &>(weight_);
Tensor qzeros = static_cast<const Tensor &>(weight_zeros_);
Tensor scales = static_cast<const Tensor &>(weight_scale_);
std::optional<Tensor> bias_opt = has_bias_ ? std::make_optional<Tensor>(static_cast<const Tensor &>(bias_)) : std::nullopt;
auto output = infinicore::op::linear_w4a16_awq(input_contiguous->contiguous(), qweight, scales, qzeros, bias_opt);
return output;
}
default: {
// Ensure input is contiguous before creating views (required for matmul)
// This prevents hanging when input tensor has non-contiguous memory layout
......@@ -116,6 +126,20 @@ Linear::Linear(size_t in_features, size_t out_features,
}
break;
}
case infinicore::quantization::QuantScheme::AWQ_W4A16: {
weight_ = infinicore::nn::Parameter({out_features, in_features}, infinicore::DataType::I32, device);
this->register_parameter("qweight", weight_);
weight_zeros_ = infinicore::nn::Parameter({out_features, in_features}, infinicore::DataType::I32, device);
this->register_parameter("qzeros", weight_zeros_);
weight_scale_ = infinicore::nn::Parameter({out_features, in_features}, dtype_, device);
this->register_parameter("scales", weight_scale_);
if (bias) {
INFINICORE_NN_PARAMETER_INIT(bias, ({out_features}, dtype_, device));
} else {
bias_ = Parameter();
}
break;
}
default: {
// Initialize parameters using macro
INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, dtype_, device));
......@@ -190,6 +214,39 @@ ColumnParallelLinear::ColumnParallelLinear(size_t in_features, size_t out_featur
}
break;
}
case infinicore::quantization::QuantScheme::AWQ_W4A16: {
auto awq_ptr = std::static_pointer_cast<infinicore::quantization::AWQ>(this->quantization_);
int group_size = awq_ptr->get_group_size();
int packing_num = awq_ptr->get_packing_num();
weight_ = infinicore::nn::Parameter({in_features, out_features / packing_num},
infinicore::DataType::I32,
device, 1, tp_rank_, tp_size_);
this->register_parameter("qweight", weight_);
// Weight scale: [out_features, in_features / group_size]
// One FP32 scale per group of weights (group_size=128)
weight_scale_ = infinicore::nn::Parameter({in_features / group_size, out_features},
dtype_,
device, 1, tp_rank_, tp_size_);
this->register_parameter("scales", weight_scale_);
// Weight zeros (zero points): [out_features, in_features / group_size]
// AWQ implementations (e.g., AutoAWQ) typically store zero points as I32
// for symmetric/asymmetric quantization support
weight_zeros_ = infinicore::nn::Parameter({in_features / group_size, out_features / packing_num},
infinicore::DataType::I32,
device, 1, tp_rank_, tp_size_);
this->register_parameter("qzeros", weight_zeros_);
if (bias) {
INFINICORE_NN_PARAMETER_INIT(bias, ({out_features}, dtype_, device, 0, 0, 1));
} else {
bias_ = Parameter();
}
break;
}
default: {
// Initialize parameters using macro
INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, dtype_, device,
......@@ -261,6 +318,44 @@ RowParallelLinear::RowParallelLinear(size_t in_features, size_t out_features, st
}
break;
}
case infinicore::quantization::QuantScheme::AWQ_W4A16: {
// AWQ W4A16 for RowParallelLinear:切分维度为 in_features(权重矩阵的第1维)
// - Weight: packed int4 in I32 containers (8 int4 per I32)
// - Group-wise quantization with group_size=128
// - Scale and zero points stored per group along in_features dimension
auto awq_ptr = std::static_pointer_cast<infinicore::quantization::AWQ>(this->quantization_);
int group_size = awq_ptr->get_group_size();
int packing_num = awq_ptr->get_packing_num();
// Packed weight: [out_features, in_features / 8]
weight_ = infinicore::nn::Parameter({in_features, out_features / packing_num},
infinicore::DataType::I32,
device, 0, tp_rank_, tp_size_);
this->register_parameter("qweight", weight_);
// Weight scale: [out_features, in_features / group_size]
weight_scale_ = infinicore::nn::Parameter({in_features / group_size, out_features},
dtype_,
device, 0, tp_rank_, tp_size_);
this->register_parameter("scales", weight_scale_);
// Weight zeros (zero points): [out_features, in_features / group_size]
weight_zeros_ = infinicore::nn::Parameter({in_features / group_size, out_features / packing_num},
infinicore::DataType::I32,
device, 0, tp_rank_, tp_size_);
this->register_parameter("qzeros", weight_zeros_);
// Bias handling in RowParallelLinear:
// - Only rank 0 holds the full bias (after all-reduce on output)
// - Other ranks have empty bias parameter
if (bias && (0 == tp_rank_)) {
INFINICORE_NN_PARAMETER_INIT(bias, ({out_features}, dtype_, device, 0, 0, 1));
} else {
bias_ = Parameter();
}
break;
}
default: {
// Initialize parameters using macro
INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, dtype_, device,
......
#include "infinicore/ops/linear_w4a16_awq.hpp"
#include "infinicore/ops/dequantize_awq.hpp"
#include "infinicore/ops/gemm.hpp"
#include "infinicore/ops/rearrange.hpp"
namespace infinicore::op {
Tensor linear_w4a16_awq(Tensor input,
......@@ -12,7 +12,8 @@ Tensor linear_w4a16_awq(Tensor input,
// Input is of shape [M, K], Weight_packed is of shape [N, K],stirdes is [N, 1]
Size ndim = input->ndim();
Size out_features = weight_packed->shape()[0];
Size element_size = weight_packed->element_size();
Size out_features = weight_packed->shape()[1] * element_size * 2;
// Assign memory to out variables
auto output_shape = input->shape();
......@@ -33,7 +34,7 @@ void linear_w4a16_awq_(Tensor out,
auto weight_packed_shape = weight_packed->shape();
Size out_features = weight_packed_shape[0];
Size in_features = weight_packed_shape[1];
Size in_features = weight_packed_shape[1] * 8;
Size ndim = input->ndim();
assert(out->ndim() == ndim);
......@@ -43,7 +44,6 @@ void linear_w4a16_awq_(Tensor out,
for (size_t i = 0; i < ndim - 1; ++i) {
N *= input_shape[i];
}
auto weight = Tensor::empty(
{out_features, in_features},
out->dtype(),
......@@ -51,10 +51,14 @@ void linear_w4a16_awq_(Tensor out,
float alpha = 1.0f;
float beta = 0.0f;
op::dequantize_awq_(weight, weight_packed, weight_scale, weight_zeros);
bias = std::make_optional(bias.value()->as_strided({N, out_features}, {0, 1}));
gemm_(out->view({N, out_features}),
input->view({N, in_features}),
weight->permute({1, 0}), alpha, beta);
if (bias.has_value()) {
rearrange_(out,
bias.value()->as_strided({N, in_features}, {0, 1}));
beta = 1.0f;
}
gemm_(out,
input->view({N, out_features}),
weight, alpha, beta);
}
} // namespace infinicore::op
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment