Unverified Commit e60985dc authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #1040 from InfiniTensor/Issue/1030

Issue/1030: Nvidia 支持w4a16推理
parents 58771213 63233f9b
...@@ -8,12 +8,23 @@ class AWQ : public BaseQuantization { ...@@ -8,12 +8,23 @@ class AWQ : public BaseQuantization {
// information and support multiple quantization schemes. // information and support multiple quantization schemes.
public: public:
explicit AWQ(const nlohmann::json &quant_config) explicit AWQ(const nlohmann::json &quant_config)
: BaseQuantization(quant_config) {}; : BaseQuantization(quant_config){};
infinicore::quantization::QuantScheme infinicore::quantization::QuantScheme
get_quant_scheme() const override { get_quant_scheme() const override {
return infinicore::quantization::QuantScheme::AWQ_W4A16; return infinicore::quantization::QuantScheme::AWQ_W4A16;
}; };
int get_packing_num() const {
// For AWQ, we pack 8 int4 weights into a single int32 value.
return 32 / this->get_or<int>("bits", 4); // Default to 8 if not specified in config
}
int get_group_size() const {
// For simplicity, we return a fixed group size here. In a more complete implementation,
// this could be extracted from quant_config_ to support different group sizes.
return this->get_or<int>("group_size", 128); // Standard AWQ group size
}
}; };
} // namespace infinicore::quantization } // namespace infinicore::quantization
...@@ -6,10 +6,34 @@ namespace infinicore::quantization { ...@@ -6,10 +6,34 @@ namespace infinicore::quantization {
class BaseQuantization { class BaseQuantization {
// Base class for quantization schemes. Intended to be extended to support various quantization methods. // Base class for quantization schemes. Intended to be extended to support various quantization methods.
public: public:
explicit BaseQuantization(const nlohmann::json &quant_config) : quant_config_(quant_config) {}; explicit BaseQuantization(const nlohmann::json &quant_config) : quant_config_(quant_config){};
virtual ~BaseQuantization() = default; virtual ~BaseQuantization() = default;
virtual infinicore::quantization::QuantScheme get_quant_scheme() const = 0; virtual infinicore::quantization::QuantScheme get_quant_scheme() const = 0;
template <typename T>
T get(const std::string &key) const {
if (!quant_config_.contains(key)) {
throw std::out_of_range("Key '" + key + "' not found in config.");
}
try {
return quant_config_.at(key).get<T>();
} catch (const nlohmann::json::type_error &e) {
throw std::runtime_error("Type conversion failed for key '" + key + "': " + std::string(e.what()));
}
}
template <typename T>
T get_or(const std::string &key, const T &default_value) const {
if (!quant_config_.contains(key) || quant_config_.at(key).is_null()) {
return default_value;
}
try {
return quant_config_.at(key).get<T>();
} catch (const nlohmann::json::type_error &) {
// If type conversion fails, return default value
return default_value;
}
}
protected: protected:
nlohmann::json quant_config_; nlohmann::json quant_config_;
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "infinicore/ops.hpp" #include "infinicore/ops.hpp"
#include "infinicore/ops/distributed/allreduce.hpp" #include "infinicore/ops/distributed/allreduce.hpp"
#include "infinicore/ops/linear.hpp" #include "infinicore/ops/linear.hpp"
#include "infinicore/ops/linear_w4a16_awq.hpp"
#include "infinicore/ops/linear_w8a8i8.hpp" #include "infinicore/ops/linear_w8a8i8.hpp"
#include <optional> #include <optional>
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
...@@ -43,6 +44,15 @@ Tensor BaseLinear::compute_linear(Tensor &input) const { ...@@ -43,6 +44,15 @@ Tensor BaseLinear::compute_linear(Tensor &input) const {
auto output = infinicore::op::linear_w8a8i8(input_contiguous->contiguous(), weight_packed_tensor, weight_scale_tensor, bias_opt); auto output = infinicore::op::linear_w8a8i8(input_contiguous->contiguous(), weight_packed_tensor, weight_scale_tensor, bias_opt);
return output; return output;
} }
case infinicore::quantization::QuantScheme::AWQ_W4A16: {
Tensor input_contiguous = input->is_contiguous() ? input : input->contiguous();
Tensor qweight = static_cast<const Tensor &>(weight_);
Tensor qzeros = static_cast<const Tensor &>(weight_zeros_);
Tensor scales = static_cast<const Tensor &>(weight_scale_);
std::optional<Tensor> bias_opt = has_bias_ ? std::make_optional<Tensor>(static_cast<const Tensor &>(bias_)) : std::nullopt;
auto output = infinicore::op::linear_w4a16_awq(input_contiguous->contiguous(), qweight, scales, qzeros, bias_opt);
return output;
}
default: { default: {
// Ensure input is contiguous before creating views (required for matmul) // Ensure input is contiguous before creating views (required for matmul)
// This prevents hanging when input tensor has non-contiguous memory layout // This prevents hanging when input tensor has non-contiguous memory layout
...@@ -116,6 +126,20 @@ Linear::Linear(size_t in_features, size_t out_features, ...@@ -116,6 +126,20 @@ Linear::Linear(size_t in_features, size_t out_features,
} }
break; break;
} }
case infinicore::quantization::QuantScheme::AWQ_W4A16: {
weight_ = infinicore::nn::Parameter({out_features, in_features}, infinicore::DataType::I32, device);
this->register_parameter("qweight", weight_);
weight_zeros_ = infinicore::nn::Parameter({out_features, in_features}, infinicore::DataType::I32, device);
this->register_parameter("qzeros", weight_zeros_);
weight_scale_ = infinicore::nn::Parameter({out_features, in_features}, dtype_, device);
this->register_parameter("scales", weight_scale_);
if (bias) {
INFINICORE_NN_PARAMETER_INIT(bias, ({out_features}, dtype_, device));
} else {
bias_ = Parameter();
}
break;
}
default: { default: {
// Initialize parameters using macro // Initialize parameters using macro
INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, dtype_, device)); INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, dtype_, device));
...@@ -190,6 +214,39 @@ ColumnParallelLinear::ColumnParallelLinear(size_t in_features, size_t out_featur ...@@ -190,6 +214,39 @@ ColumnParallelLinear::ColumnParallelLinear(size_t in_features, size_t out_featur
} }
break; break;
} }
case infinicore::quantization::QuantScheme::AWQ_W4A16: {
auto awq_ptr = std::static_pointer_cast<infinicore::quantization::AWQ>(this->quantization_);
int group_size = awq_ptr->get_group_size();
int packing_num = awq_ptr->get_packing_num();
weight_ = infinicore::nn::Parameter({in_features, out_features / packing_num},
infinicore::DataType::I32,
device, 1, tp_rank_, tp_size_);
this->register_parameter("qweight", weight_);
// Weight scale: [out_features, in_features / group_size]
// One FP32 scale per group of weights (group_size=128)
weight_scale_ = infinicore::nn::Parameter({in_features / group_size, out_features},
dtype_,
device, 1, tp_rank_, tp_size_);
this->register_parameter("scales", weight_scale_);
// Weight zeros (zero points): [out_features, in_features / group_size]
// AWQ implementations (e.g., AutoAWQ) typically store zero points as I32
// for symmetric/asymmetric quantization support
weight_zeros_ = infinicore::nn::Parameter({in_features / group_size, out_features / packing_num},
infinicore::DataType::I32,
device, 1, tp_rank_, tp_size_);
this->register_parameter("qzeros", weight_zeros_);
if (bias) {
INFINICORE_NN_PARAMETER_INIT(bias, ({out_features}, dtype_, device, 0, 0, 1));
} else {
bias_ = Parameter();
}
break;
}
default: { default: {
// Initialize parameters using macro // Initialize parameters using macro
INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, dtype_, device, INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, dtype_, device,
...@@ -261,6 +318,44 @@ RowParallelLinear::RowParallelLinear(size_t in_features, size_t out_features, st ...@@ -261,6 +318,44 @@ RowParallelLinear::RowParallelLinear(size_t in_features, size_t out_features, st
} }
break; break;
} }
case infinicore::quantization::QuantScheme::AWQ_W4A16: {
// AWQ W4A16 for RowParallelLinear:切分维度为 in_features(权重矩阵的第1维)
// - Weight: packed int4 in I32 containers (8 int4 per I32)
// - Group-wise quantization with group_size=128
// - Scale and zero points stored per group along in_features dimension
auto awq_ptr = std::static_pointer_cast<infinicore::quantization::AWQ>(this->quantization_);
int group_size = awq_ptr->get_group_size();
int packing_num = awq_ptr->get_packing_num();
// Packed weight: [out_features, in_features / 8]
weight_ = infinicore::nn::Parameter({in_features, out_features / packing_num},
infinicore::DataType::I32,
device, 0, tp_rank_, tp_size_);
this->register_parameter("qweight", weight_);
// Weight scale: [out_features, in_features / group_size]
weight_scale_ = infinicore::nn::Parameter({in_features / group_size, out_features},
dtype_,
device, 0, tp_rank_, tp_size_);
this->register_parameter("scales", weight_scale_);
// Weight zeros (zero points): [out_features, in_features / group_size]
weight_zeros_ = infinicore::nn::Parameter({in_features / group_size, out_features / packing_num},
infinicore::DataType::I32,
device, 0, tp_rank_, tp_size_);
this->register_parameter("qzeros", weight_zeros_);
// Bias handling in RowParallelLinear:
// - Only rank 0 holds the full bias (after all-reduce on output)
// - Other ranks have empty bias parameter
if (bias && (0 == tp_rank_)) {
INFINICORE_NN_PARAMETER_INIT(bias, ({out_features}, dtype_, device, 0, 0, 1));
} else {
bias_ = Parameter();
}
break;
}
default: { default: {
// Initialize parameters using macro // Initialize parameters using macro
INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, dtype_, device, INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, dtype_, device,
......
#include "infinicore/ops/linear_w4a16_awq.hpp" #include "infinicore/ops/linear_w4a16_awq.hpp"
#include "infinicore/ops/dequantize_awq.hpp" #include "infinicore/ops/dequantize_awq.hpp"
#include "infinicore/ops/gemm.hpp" #include "infinicore/ops/gemm.hpp"
#include "infinicore/ops/rearrange.hpp"
namespace infinicore::op { namespace infinicore::op {
Tensor linear_w4a16_awq(Tensor input, Tensor linear_w4a16_awq(Tensor input,
...@@ -12,7 +12,8 @@ Tensor linear_w4a16_awq(Tensor input, ...@@ -12,7 +12,8 @@ Tensor linear_w4a16_awq(Tensor input,
// Input is of shape [M, K], Weight_packed is of shape [N, K],stirdes is [N, 1] // Input is of shape [M, K], Weight_packed is of shape [N, K],stirdes is [N, 1]
Size ndim = input->ndim(); Size ndim = input->ndim();
Size out_features = weight_packed->shape()[0]; Size element_size = weight_packed->element_size();
Size out_features = weight_packed->shape()[1] * element_size * 2;
// Assign memory to out variables // Assign memory to out variables
auto output_shape = input->shape(); auto output_shape = input->shape();
...@@ -33,7 +34,7 @@ void linear_w4a16_awq_(Tensor out, ...@@ -33,7 +34,7 @@ void linear_w4a16_awq_(Tensor out,
auto weight_packed_shape = weight_packed->shape(); auto weight_packed_shape = weight_packed->shape();
Size out_features = weight_packed_shape[0]; Size out_features = weight_packed_shape[0];
Size in_features = weight_packed_shape[1]; Size in_features = weight_packed_shape[1] * 8;
Size ndim = input->ndim(); Size ndim = input->ndim();
assert(out->ndim() == ndim); assert(out->ndim() == ndim);
...@@ -43,7 +44,6 @@ void linear_w4a16_awq_(Tensor out, ...@@ -43,7 +44,6 @@ void linear_w4a16_awq_(Tensor out,
for (size_t i = 0; i < ndim - 1; ++i) { for (size_t i = 0; i < ndim - 1; ++i) {
N *= input_shape[i]; N *= input_shape[i];
} }
auto weight = Tensor::empty( auto weight = Tensor::empty(
{out_features, in_features}, {out_features, in_features},
out->dtype(), out->dtype(),
...@@ -51,10 +51,14 @@ void linear_w4a16_awq_(Tensor out, ...@@ -51,10 +51,14 @@ void linear_w4a16_awq_(Tensor out,
float alpha = 1.0f; float alpha = 1.0f;
float beta = 0.0f; float beta = 0.0f;
op::dequantize_awq_(weight, weight_packed, weight_scale, weight_zeros); op::dequantize_awq_(weight, weight_packed, weight_scale, weight_zeros);
bias = std::make_optional(bias.value()->as_strided({N, out_features}, {0, 1})); if (bias.has_value()) {
gemm_(out->view({N, out_features}), rearrange_(out,
input->view({N, in_features}), bias.value()->as_strided({N, in_features}, {0, 1}));
weight->permute({1, 0}), alpha, beta); beta = 1.0f;
}
gemm_(out,
input->view({N, out_features}),
weight, alpha, beta);
} }
} // namespace infinicore::op } // namespace infinicore::op
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment