Commit a179dcc3 authored by PanZezhong1725's avatar PanZezhong1725
Browse files

rename register weight

parent 3998658b
......@@ -42,7 +42,7 @@ Loader::Loader(infiniDevice_t dev, const std::vector<int> &dev_ids) : _device(de
RUN_INFINI(infinirtStreamCreate(&_streams[rank]));
}
}
void Loader::resigter(const std::string &name, std::shared_ptr<Tensor> tensor, int rank, DistributionType dist_type) {
void Loader::register_weight(const std::string &name, std::shared_ptr<Tensor> tensor, int rank, DistributionType dist_type) {
_weights_maps[rank][name] = std::make_shared<Weight>(tensor, rank, _dev_ids.size(), dist_type);
}
void Loader::load(const std::string &name, const void *host_data) {
......
......@@ -48,7 +48,7 @@ public:
/// @param tensor
/// @param rank the rank of the weight tensor (default 0)
/// @param dist_type either FULL, or distributed by ROW or COLUMN (default FULL)
void resigter(const std::string &name, std::shared_ptr<Tensor> tensor, int rank = 0, DistributionType dist_type = DistributionType::FULL);
void register_weight(const std::string &name, std::shared_ptr<Tensor> tensor, int rank = 0, DistributionType dist_type = DistributionType::FULL);
void load(const std::string &name, const void *host_data);
void finalize();
std::shared_ptr<Tensor> get(const std::string &name, int rank = 0);
......
......@@ -66,15 +66,15 @@ JiugeAWQWeights::JiugeAWQWeights(
_device_weights[i] = weight;
auto w_in_embd = Tensor::weight(nullptr, dt_logits, {dvoc, d});
this->resigter("model.embed_tokens.weight", w_in_embd, i);
this->register_weight("model.embed_tokens.weight", w_in_embd, i);
weight->w_in_embd = w_in_embd;
auto w_out_norm = Tensor::weight(nullptr, dt_norm_w, {d});
this->resigter("model.norm.weight", w_out_norm, i);
this->register_weight("model.norm.weight", w_out_norm, i);
weight->w_out_norm = w_out_norm;
auto w_out_embd = Tensor::weight(nullptr, dt_logits, {dvoc, d})->permute({1, 0});
this->resigter("lm_head.weight", w_out_embd, i);
this->register_weight("lm_head.weight", w_out_embd, i);
weight->w_out_embd = w_out_embd;
weight->sin_table = getSinTable(dctx, dh, meta->theta);
......@@ -84,7 +84,7 @@ JiugeAWQWeights::JiugeAWQWeights(
#define RIGISTER_LAYER_WEIGHT(W_NAME, W_VAR, W_SHAPE, W_DTYPE, W_DIST_TYPE) \
auto W_VAR = Tensor::weight(nullptr, W_DTYPE, W_SHAPE); \
this->resigter(W_NAME, W_VAR, i, infinicore::weights::DistributionType::W_DIST_TYPE); \
this->register_weight(W_NAME, W_VAR, i, infinicore::weights::DistributionType::W_DIST_TYPE); \
weight->W_VAR.push_back(W_VAR);
RIGISTER_LAYER_WEIGHT("model.layers." + std::to_string(layer) + ".input_layernorm.weight", w_attn_norm, {d}, dt_norm_w, FULL);
......@@ -92,11 +92,11 @@ JiugeAWQWeights::JiugeAWQWeights(
#define REGISTER_LAYER_QUANT_WEIGHT(W_NAME, W_VAR, W_IN, W_OUT, W_DIST_TYPE) \
auto W_VAR = std::make_shared<QuantInt4Weight>(); \
W_VAR->w = Tensor::weight(nullptr, INFINI_DTYPE_I32, {W_IN, (W_OUT)*nbit / 32}); \
this->resigter(W_NAME + ".qweight", W_VAR->w, i, infinicore::weights::DistributionType::W_DIST_TYPE); \
this->register_weight(W_NAME + ".qweight", W_VAR->w, i, infinicore::weights::DistributionType::W_DIST_TYPE); \
W_VAR->s = Tensor::weight(nullptr, INFINI_DTYPE_F16, {(W_IN) / quant_group_size, (W_OUT)}); \
this->resigter(W_NAME + ".scales", W_VAR->s, i, infinicore::weights::DistributionType::W_DIST_TYPE); \
this->register_weight(W_NAME + ".scales", W_VAR->s, i, infinicore::weights::DistributionType::W_DIST_TYPE); \
W_VAR->z = Tensor::weight(nullptr, INFINI_DTYPE_I32, {(W_IN) / quant_group_size, (W_OUT)*nbit / 32}); \
this->resigter(W_NAME + ".qzeros", W_VAR->z, i, infinicore::weights::DistributionType::W_DIST_TYPE); \
this->register_weight(W_NAME + ".qzeros", W_VAR->z, i, infinicore::weights::DistributionType::W_DIST_TYPE); \
weight->W_VAR.push_back(W_VAR);
REGISTER_LAYER_QUANT_WEIGHT("model.layers." + std::to_string(layer) + ".self_attn.q_proj", w_attn_q, d, nh * dh, COLUMN);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment