Unverified Commit 2abef3b7 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #118 from InfiniTensor/issue/114

issue/114 QKVParallelLinear, GateUpParallelLinear
parents 81081f3c 4ea1f3b6
#include "fused_linear.hpp"
#include <spdlog/spdlog.h>
namespace infinilm::layers {
// ---------------------------------------------------------
// QKV Parallel Linear
// ---------------------------------------------------------
QKVParallelLinear::QKVParallelLinear(size_t hidden_size,
size_t head_dim,
size_t num_q_head,
size_t num_kv_head,
bool bias,
const infinicore::DataType &dtype,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info)
: QKVParallelLinear(hidden_size,
head_dim, head_dim, head_dim,
num_q_head, num_kv_head, num_kv_head,
bias, bias, bias,
dtype, device, rank_info) {}
QKVParallelLinear::QKVParallelLinear(size_t hidden_size,
size_t q_dim, size_t k_dim, size_t v_dim,
size_t num_q_head, size_t num_k_head, size_t num_v_head,
bool q_bias, bool k_bias, bool v_bias,
const infinicore::DataType &dtype,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info)
: infinicore::nn::ColumnParallelLinear(
hidden_size,
num_q_head * q_dim + num_k_head * k_dim + num_v_head * v_dim,
(q_bias || k_bias || v_bias),
dtype,
device,
rank_info.tp_rank,
rank_info.tp_size),
q_dim_(q_dim),
k_dim_(k_dim),
v_dim_(v_dim),
num_q_head_(num_q_head),
num_k_head_(num_k_head),
num_v_head_(num_v_head),
q_bias_(q_bias),
k_bias_(k_bias),
v_bias_(v_bias) {
if (num_q_head % tp_size_ != 0 || num_k_head % tp_size_ != 0 || num_v_head % tp_size_ != 0) {
throw std::runtime_error("QKVParallelLinear: num_[q|k|v]_head must be divisible by tp_size");
}
if ((q_bias_ != k_bias_) || (k_bias_ != v_bias_)) {
throw std::runtime_error("q_bias, k_bias, v_bias must all match");
}
q_out_size_ = num_q_head_ * q_dim_ / tp_size_;
k_out_size_ = num_k_head_ * k_dim_ / tp_size_;
v_out_size_ = num_v_head_ * v_dim_ / tp_size_;
}
std::tuple<infinicore::Tensor, infinicore::Tensor, infinicore::Tensor>
QKVParallelLinear::forward_split(infinicore::Tensor &input) {
auto output = this->forward(input);
auto q_out = output->narrow({{2, 0, q_out_size_}});
auto k_out = output->narrow({{2, q_out_size_, k_out_size_}});
auto v_out = output->narrow({{2, q_out_size_ + k_out_size_, v_out_size_}});
return std::make_tuple(q_out, k_out, v_out);
}
infinicore::nn::Parameter QKVParallelLinear::get_q_weight() const {
return infinicore::nn::Parameter(
weight_->narrow({{0, 0, q_out_size_}}),
0, tp_rank_, tp_size_);
}
infinicore::nn::Parameter QKVParallelLinear::get_k_weight() const {
return infinicore::nn::Parameter(
weight_->narrow({{0, q_out_size_, k_out_size_}}),
0, tp_rank_, tp_size_);
}
infinicore::nn::Parameter QKVParallelLinear::get_v_weight() const {
return infinicore::nn::Parameter(
weight_->narrow({{0, q_out_size_ + k_out_size_, v_out_size_}}),
0, tp_rank_, tp_size_);
}
infinicore::nn::Parameter QKVParallelLinear::get_q_bias() const {
if (!q_bias_) {
return infinicore::nn::Parameter();
}
return infinicore::nn::Parameter(
bias_->narrow({{0, 0, q_out_size_}}),
0, tp_rank_, tp_size_);
}
infinicore::nn::Parameter QKVParallelLinear::get_k_bias() const {
if (!k_bias_) {
return infinicore::nn::Parameter();
}
return infinicore::nn::Parameter(
bias_->narrow({{0, q_out_size_, k_out_size_}}),
0, tp_rank_, tp_size_);
}
infinicore::nn::Parameter QKVParallelLinear::get_v_bias() const {
if (!v_bias_) {
return infinicore::nn::Parameter();
}
return infinicore::nn::Parameter(
bias_->narrow({{0, q_out_size_ + k_out_size_, v_out_size_}}),
0, tp_rank_, tp_size_);
}
bool QKVParallelLinear::has_q_bias() const { return q_bias_; }
bool QKVParallelLinear::has_k_bias() const { return k_bias_; }
bool QKVParallelLinear::has_v_bias() const { return v_bias_; }
// ---------------------------------------------------------
// Gate-Up Parallel Linear
// ---------------------------------------------------------
GateUpParallelLinear::GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, bool bias,
const infinicore::DataType &dtype, const infinicore::Device &device,
engine::distributed::RankInfo rank_info)
: GateUpParallelLinear(hidden_size, intermediate_size, bias, bias, dtype, device, rank_info) {
}
GateUpParallelLinear::GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, bool gate_bias, bool up_bias,
const infinicore::DataType &dtype, const infinicore::Device &device,
engine::distributed::RankInfo rank_info)
: infinicore::nn::ColumnParallelLinear(hidden_size, intermediate_size * 2, gate_bias || up_bias, dtype, device, rank_info.tp_rank, rank_info.tp_size), gate_bias_(gate_bias), up_bias_(up_bias) {
if (gate_bias_ != up_bias_) {
throw std::runtime_error("Not supported yet: gate_bias and up_bias should be given at the same time");
}
}
std::tuple<infinicore::Tensor, infinicore::Tensor> GateUpParallelLinear::forward_split(infinicore::Tensor &input) {
auto output = this->forward(input);
auto cols = output->shape()[2];
auto gate_output = output->narrow({{2, 0, cols / 2}});
auto up_output = output->narrow({{2, cols / 2, cols / 2}});
return std::make_tuple(gate_output, up_output);
}
infinicore::nn::Parameter GateUpParallelLinear::get_gate_weight() const {
return infinicore::nn::Parameter(weight_->narrow({{0, 0, weight_->size(0) / 2}}), 0, tp_rank_, tp_size_);
}
infinicore::nn::Parameter GateUpParallelLinear::get_gate_bias() const {
if (!gate_bias_) {
return infinicore::nn::Parameter();
} else {
return infinicore::nn::Parameter(bias_->narrow({{0, 0, bias_->size(0) / 2}}), 0, tp_rank_, tp_size_);
}
}
infinicore::nn::Parameter GateUpParallelLinear::get_up_weight() const {
return infinicore::nn::Parameter(weight_->narrow({{0, weight_->size(0) / 2, weight_->size(0) / 2}}), 0, tp_rank_, tp_size_);
}
infinicore::nn::Parameter GateUpParallelLinear::get_up_bias() const {
if (!up_bias_) {
return infinicore::nn::Parameter();
} else {
return infinicore::nn::Parameter(bias_->narrow({{0, bias_->size(0) / 2, bias_->size(0) / 2}}),
0, tp_rank_, tp_size_);
}
}
bool GateUpParallelLinear::has_gate_bias() const {
return gate_bias_;
}
bool GateUpParallelLinear::has_up_bias() const {
return up_bias_;
}
} // namespace infinilm::layers
#pragma once
#include "infinicore/nn/linear.hpp"
#include "../engine/distributed/communication_group.hpp"
namespace infinilm::layers {
class QKVParallelLinear : public infinicore::nn::ColumnParallelLinear {
public:
explicit QKVParallelLinear(size_t hidden_size,
size_t q_dim, size_t k_dim, size_t v_dim,
size_t num_q_head, size_t num_k_head, size_t num_v_head,
bool q_bias, bool k_bias, bool v_bias,
const infinicore::DataType &dtype = infinicore::DataType::F32,
const infinicore::Device &device = infinicore::Device(),
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
// A more common case where all heads have the same dimension
explicit QKVParallelLinear(size_t hidden_size,
size_t head_dim,
size_t num_q_head, size_t num_kv_head,
bool bias = false,
const infinicore::DataType &dtype = infinicore::DataType::F32,
const infinicore::Device &device = infinicore::Device(),
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
std::tuple<infinicore::Tensor, infinicore::Tensor, infinicore::Tensor>
forward_split(infinicore::Tensor &input);
infinicore::nn::Parameter get_q_weight() const;
infinicore::nn::Parameter get_k_weight() const;
infinicore::nn::Parameter get_v_weight() const;
infinicore::nn::Parameter get_q_bias() const;
infinicore::nn::Parameter get_k_bias() const;
infinicore::nn::Parameter get_v_bias() const;
bool has_q_bias() const;
bool has_k_bias() const;
bool has_v_bias() const;
private:
size_t q_dim_;
size_t k_dim_;
size_t v_dim_;
size_t num_q_head_;
size_t num_k_head_;
size_t num_v_head_;
bool q_bias_;
bool k_bias_;
bool v_bias_;
size_t q_out_size_; // num_q_head * q_dim / tp_size
size_t k_out_size_; // num_k_head * k_dim / tp_size
size_t v_out_size_; // num_v_head * v_dim / tp_size
};
class GateUpParallelLinear : public infinicore::nn::ColumnParallelLinear {
public:
GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, bool bias = false,
const infinicore::DataType &dtype = infinicore::DataType::F32, const infinicore::Device &device = infinicore::Device(),
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, bool gate_bias, bool up_bias,
const infinicore::DataType &dtype = infinicore::DataType::F32, const infinicore::Device &device = infinicore::Device(),
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
std::tuple<infinicore::Tensor, infinicore::Tensor> forward_split(infinicore::Tensor &input);
infinicore::nn::Parameter get_gate_weight() const;
infinicore::nn::Parameter get_gate_bias() const;
infinicore::nn::Parameter get_up_weight() const;
infinicore::nn::Parameter get_up_bias() const;
bool has_gate_bias() const;
bool has_up_bias() const;
private:
bool gate_bias_;
bool up_bias_;
};
#define INFINILM_QKV_LINEAR_INIT(name, q_name, k_name, v_name, ...) \
name##_ = std::make_shared<layers::QKVParallelLinear>(__VA_ARGS__); \
this->register_parameter(std::string(q_name) + ".weight", name##_->get_q_weight()); \
this->register_parameter(std::string(k_name) + ".weight", name##_->get_k_weight()); \
this->register_parameter(std::string(v_name) + ".weight", name##_->get_v_weight()); \
if (name##_->has_q_bias()) \
this->register_parameter(std::string(q_name) + ".bias", name##_->get_q_bias()); \
if (name##_->has_k_bias()) \
this->register_parameter(std::string(k_name) + ".bias", name##_->get_k_bias()); \
if (name##_->has_v_bias()) \
this->register_parameter(std::string(v_name) + ".bias", name##_->get_v_bias());
#define INFINILM_GATE_UP_LINEAR_INIT(name, gate_name, up_name, ...) \
name##_ = std::make_shared<layers::GateUpParallelLinear>(__VA_ARGS__); \
this->register_parameter(std::string(gate_name) + ".weight", name##_->get_gate_weight()); \
this->register_parameter(std::string(up_name) + ".weight", name##_->get_up_weight()); \
if (name##_->has_gate_bias()) \
this->register_parameter(std::string(gate_name) + ".bias", name##_->get_gate_bias()); \
if (name##_->has_up_bias()) \
this->register_parameter(std::string(up_name) + ".bias", name##_->get_up_bias());
} // namespace infinilm::layers
......@@ -42,12 +42,8 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config,
}
// Initialize projection layers
INFINICORE_NN_MODULE_INIT(q_proj, hidden_size_, hidden_size_, use_bias_,
dtype, device, tp_rank, tp_size);
INFINICORE_NN_MODULE_INIT(k_proj, hidden_size_, kv_dim_, use_bias_,
dtype, device, tp_rank, tp_size);
INFINICORE_NN_MODULE_INIT(v_proj, hidden_size_, kv_dim_, use_bias_,
dtype, device, tp_rank, tp_size);
INFINILM_QKV_LINEAR_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, config.num_attention_heads, config.num_key_value_heads, use_bias_,
dtype, device, rank_info);
// Output projection uses attention_output_bias (can be different from qkv)
INFINICORE_NN_MODULE_INIT(o_proj, hidden_size_, hidden_size_, use_output_bias_,
dtype, device, tp_rank, tp_size, rank_info.comm);
......@@ -66,11 +62,7 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
size_t seq_len = shape[1];
// 1. Project Q, K, V
auto q = q_proj_->forward(hidden_states_mutable); // [batch, seq_len, hidden_size]
auto k = k_proj_->forward(hidden_states_mutable); // [batch, seq_len, kv_dim]
auto v = v_proj_->forward(hidden_states_mutable); // [batch, seq_len, kv_dim]
auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable);
// 2. Reshape for multi-head attention
......
#pragma once
#include "cache/kv_cache.hpp"
#include "infinicore/device.hpp"
#include "../../cache/kv_cache.hpp"
#include "../../engine/distributed/distributed.hpp"
#include "../../layers/fused_linear.hpp"
#include "llama_config.hpp"
#include "infinicore/nn/linear.hpp"
#include "infinicore/nn/module.hpp"
#include "infinicore/nn/rope.hpp"
......@@ -11,8 +14,6 @@
#include <memory>
#include <utility>
#include "../../engine/distributed/distributed.hpp"
namespace infinilm::models::llama {
/**
......@@ -70,9 +71,7 @@ public:
protected:
// Projection layers
INFINICORE_NN_MODULE(infinicore::nn::ColumnParallelLinear, q_proj);
INFINICORE_NN_MODULE(infinicore::nn::ColumnParallelLinear, k_proj);
INFINICORE_NN_MODULE(infinicore::nn::ColumnParallelLinear, v_proj);
INFINICORE_NN_MODULE(infinilm::layers::QKVParallelLinear, qkv_proj);
INFINICORE_NN_MODULE(infinicore::nn::RowParallelLinear, o_proj);
engine::distributed::RankInfo rank_info_;
......
......@@ -16,10 +16,8 @@ LlamaMLP::LlamaMLP(const LlamaConfig &config,
int tp_size = rank_info.tp_size;
// Initialize projection layers
INFINICORE_NN_MODULE_INIT(gate_proj, hidden_size_, intermediate_size_, use_bias_,
dtype, device, tp_rank, tp_size);
INFINICORE_NN_MODULE_INIT(up_proj, hidden_size_, intermediate_size_, use_bias_,
dtype, device, tp_rank, tp_size);
INFINILM_GATE_UP_LINEAR_INIT(gate_up_proj, "gate_proj", "up_proj", hidden_size_, intermediate_size_, use_bias_,
dtype, device, rank_info_);
INFINICORE_NN_MODULE_INIT(down_proj, intermediate_size_, hidden_size_, use_bias_,
dtype, device, tp_rank, tp_size, rank_info.comm);
}
......@@ -27,9 +25,7 @@ LlamaMLP::LlamaMLP(const LlamaConfig &config,
infinicore::Tensor LlamaMLP::forward(const infinicore::Tensor &hidden_states) const {
// 1. Project to gate and up
auto hidden_states_mutable = hidden_states;
auto gate = gate_proj_->forward(hidden_states_mutable);
auto up = up_proj_->forward(hidden_states_mutable);
auto [gate, up] = gate_up_proj_->forward_split(hidden_states_mutable);
// 2. Apply SwiGLU: silu(gate) * up
// Note: swiglu kernel expects (up, gate) and computes gate * sigmoid(gate) * up
......
#pragma once
#include "../../layers/fused_linear.hpp"
#include "llama_config.hpp"
#include "infinicore/device.hpp"
#include "infinicore/nn/linear.hpp"
#include "infinicore/nn/module.hpp"
......@@ -48,14 +51,10 @@ public:
size_t intermediate_size() const { return intermediate_size_; }
protected:
// Projection layers
INFINICORE_NN_MODULE(infinicore::nn::ColumnParallelLinear, gate_proj);
INFINICORE_NN_MODULE(infinicore::nn::ColumnParallelLinear, up_proj);
INFINICORE_NN_MODULE(layers::GateUpParallelLinear, gate_up_proj);
INFINICORE_NN_MODULE(infinicore::nn::RowParallelLinear, down_proj);
engine::distributed::RankInfo rank_info_;
private:
size_t hidden_size_;
size_t intermediate_size_;
bool use_bias_;
......
#pragma once
#include "cache/kv_cache.hpp"
#include "infinicore/device.hpp"
#include "llama_config.hpp"
#include "llama_decoder_layer.hpp"
#include "../../cache/kv_cache.hpp"
#include "infinicore/nn/module.hpp"
#include "infinicore/nn/embedding.hpp"
#include "infinicore/nn/module.hpp"
#include "infinicore/nn/rmsnorm.hpp"
......
......@@ -201,6 +201,7 @@ class GenerationMixin:
# -------------------------------------------------------------------------- #
# prepare model inputs
# -------------------------------------------------------------------------- #
start_time = time.time()
model_inputs = self.prepare_inputs_for_generation(**model_kwargs)
model_kwargs["position_ids"] = model_inputs["position_ids"]
......@@ -208,8 +209,6 @@ class GenerationMixin:
# -------------------------------------------------------------------------- #
# 计算一次
# -------------------------------------------------------------------------- #
start_time = time.time()
logits = self(**model_inputs)
infinicore.sync_device()
......@@ -242,10 +241,6 @@ class GenerationMixin:
)
infinicore.sync_stream() # 计算结束前需要同步
end_time = time.time()
time_list.append((end_time - start_time) * 1000)
# ----------------------------------------------------------------- #
# 得到下一个token的id,并解码为字符
# ----------------------------------------------------------------- #
......@@ -256,6 +251,9 @@ class GenerationMixin:
output_tokens_list.append(token_id)
output_content += output_str
end_time = time.time()
time_list.append((end_time - start_time) * 1000)
print(output_str, end="", flush=True)
if stop_on_eos and token_id in eos_token_id_list:
break
......
......@@ -41,10 +41,10 @@ target("_infinilm")
local INFINI_ROOT = os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "HOMEPATH" or "HOME") .. "/.infini")
add_includedirs("csrc", { public = false })
add_includedirs("csrc/pybind11", { public = false })
add_includedirs("include", { public = false })
-- add_includedirs("csrc", { public = false })
-- add_includedirs("csrc/pybind11", { public = false })
add_includedirs(INFINI_ROOT.."/include", { public = true })
add_includedirs("include", { public = false })
-- spdlog is already included globally via add_includedirs at the top
add_linkdirs(INFINI_ROOT.."/lib")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment