Commit 6998a8f1 authored by wooway777's avatar wooway777
Browse files

issue/21 - Improved linear and view implementations

parent 8853663e
......@@ -22,7 +22,7 @@ inline void hash_combine(size_t &seed, T value, typename std::enable_if<std::is_
}
// Helper function to compute hash for tensor descriptors
inline size_t computeTensorDescHash(std::shared_ptr<Tensor> tensor) {
inline size_t computeTensorDescHash(std::shared_ptr<Tensor> &tensor) {
size_t seed = 0;
hash_combine(seed, tensor->dtype());
for (auto dim : tensor->shape()) {
......
......@@ -190,7 +190,8 @@ void InferenceContext::linear(std::shared_ptr<Tensor> c,
std::shared_ptr<Tensor> a,
std::shared_ptr<Tensor> b,
float alpha, float beta,
std::shared_ptr<Tensor> residual) {
std::shared_ptr<Tensor> residual,
std::shared_ptr<Tensor> bias) {
if (residual) {
if (residual->data() == c->data()) {
if (beta == 0.0) {
......@@ -208,4 +209,13 @@ void InferenceContext::linear(std::shared_ptr<Tensor> c,
} else {
gemm(c, a, b, alpha, beta);
}
if (bias) {
int ndim_diff = c->ndim() - 1;
ASSERT_EQ(bias->ndim(), 1);
ASSERT_EQ(bias->shape()[0], c->shape()[ndim_diff]);
std::vector<ptrdiff_t> strides(ndim_diff, 0);
strides.push_back(bias->strides()[0]);
add(c, c, bias->view_as(c->shape(), strides));
}
}
......@@ -47,7 +47,8 @@ struct InferenceContext {
std::shared_ptr<Tensor> a,
std::shared_ptr<Tensor> b,
float alpha, float beta,
std::shared_ptr<Tensor> residual);
std::shared_ptr<Tensor> residual,
std::shared_ptr<Tensor> bias);
};
namespace {
......@@ -103,6 +104,6 @@ inline void randomSample(std::shared_ptr<Tensor> out, std::shared_ptr<Tensor> pr
inline void linear(std::shared_ptr<Tensor> c, std::shared_ptr<Tensor> a,
std::shared_ptr<Tensor> b, float alpha, float beta,
std::shared_ptr<Tensor> residual) {
getInferenceContext().linear(c, a, b, alpha, beta, residual);
std::shared_ptr<Tensor> residual, std::shared_ptr<Tensor> bias) {
getInferenceContext().linear(c, a, b, alpha, beta, residual, bias);
}
......@@ -141,7 +141,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto result_buf = Tensor::buffer(INFINI_DTYPE_I64, {nreq}, rsrc.memory_pool);
auto result_cpu = std::vector<int64_t>(nreq);
auto qkv_rope = qkv_buf->viewReshaped({ntok, nh + nkvh * 2, dh});
auto qkv_rope = qkv_buf->view({ntok, nh + nkvh * 2, dh});
// Prepare inputs
auto batch_pos_ids = std::vector<uint32_t>(ntok);
......@@ -183,9 +183,9 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto qk_buf = Tensor::buffer(dt_logits, {nh, max_qk_size}, rsrc.memory_pool);
auto rearrange_q_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool);
auto q_rearrange = rearrange_q_buf->viewReshaped({nkvh, ngroup, max_seq_len, dh});
auto q_rearrange = rearrange_q_buf->view({nkvh, ngroup, max_seq_len, dh});
auto attn_val_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool);
auto attn_val_gemm = attn_val_buf->viewReshaped({nkvh, ngroup, max_seq_len, dh});
auto attn_val_gemm = attn_val_buf->view({nkvh, ngroup, max_seq_len, dh});
// MLP buffers
auto gate_buf = gate_up_buf->slice(1, 0, di);
......@@ -197,10 +197,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
// rms norm
rmsnorm(logits_out, logits_in, rsrc.w_attn_norm[layer], meta.epsilon);
// qkv_proj
if (has_qkv_bias) {
rearrange(qkv_buf, rsrc.b_attn_qkv[layer]->view({ntok, (nh + nkvh * 2) * dh}, {0, 1}));
}
linear(qkv_buf, logits_out, rsrc.w_attn_qkv[layer], 1.0, 0.0, has_qkv_bias ? qkv_buf : nullptr);
linear(qkv_buf, logits_out, rsrc.w_attn_qkv[layer], 1.0, 0.0, nullptr, has_qkv_bias ? rsrc.b_attn_qkv[layer] : nullptr);
// rope
rope(qkv_rope->slice(1, 0, nh), qkv_rope->slice(1, 0, nh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table);
rope(qkv_rope->slice(1, nh, nkvh), qkv_rope->slice(1, nh, nkvh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table);
......@@ -210,7 +207,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto past_len = req_pos[req];
auto seq_len = req_lens[req];
auto total_len = past_len + seq_len;
auto o = o_buf->viewReshaped({ntok, nh, dh})->slice({{0, token_offset, seq_len}})->dimSplit(1, {nkvh, ngroup})->permute({1, 2, 0, 3});
auto o = o_buf->view({ntok, nh, dh})->slice({{0, token_offset, seq_len}})->dimSplit(1, {nkvh, ngroup})->permute({1, 2, 0, 3});
auto q = qkv_rope->slice({{0, token_offset, seq_len}, {1, 0, nh}})->dimSplit(1, {nkvh, ngroup})->permute({1, 2, 0, 3});
auto k = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh, nkvh}});
auto v = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}});
......@@ -221,14 +218,14 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
rearrange(kv_caches[req]->v[idev][layer]->slice(0, past_len, seq_len), v);
// qk
rearrange(q_rearrange, q);
auto qk_gemm = qk_buf->viewReshaped({nkvh, ngroup * seq_len, total_len});
auto qk_gemm = qk_buf->view({nkvh, ngroup * seq_len, total_len});
auto k_gemm = kv_caches[req]->k[idev][layer]->slice(0, 0, total_len)->permute({1, 2, 0});
linear(qk_gemm, rearrange_q_buf, k_gemm, 1. / sqrt(dh), 0.0, nullptr);
linear(qk_gemm, rearrange_q_buf, k_gemm, 1. / sqrt(dh), 0.0, nullptr, nullptr);
// softmax
auto qk_softmax = qk_buf->viewReshaped({nh, seq_len, total_len});
auto qk_softmax = qk_buf->view({nh, seq_len, total_len});
causalSoftmax(qk_softmax, qk_softmax);
auto v_gemm = kv_caches[req]->v[idev][layer]->slice(0, 0, total_len)->permute({1, 0, 2});
linear(attn_val_buf, qk_gemm, v_gemm, 1.0, 0.0, nullptr);
linear(attn_val_buf, qk_gemm, v_gemm, 1.0, 0.0, nullptr, nullptr);
// rearrange attn val
rearrange(o, attn_val_gemm);
......@@ -236,7 +233,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
}
// o_proj
linear(logits_in, o_buf, rsrc.w_attn_out[layer], 1.0, 0.0, idev == 0 ? logits_in : nullptr); // only rank 0 adds residual
linear(logits_in, o_buf, rsrc.w_attn_out[layer], 1.0, 0.0, idev == 0 ? logits_in : nullptr, nullptr); // only rank 0 adds residual
// All_reduce if distributed
if (rsrc.comm != nullptr) {
......@@ -247,9 +244,9 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
}
// 2. FFN
rmsnorm(logits_out, logits_in, rsrc.w_ffn_norm[layer], meta.epsilon);
linear(gate_up_buf, logits_out, rsrc.w_ffn_gate_up[layer], 1.0, 0.0, nullptr);
linear(gate_up_buf, logits_out, rsrc.w_ffn_gate_up[layer], 1.0, 0.0, nullptr, nullptr);
swiglu(gate_buf, up_buf, gate_buf);
linear(logits_in, gate_buf, rsrc.w_ffn_down[layer], 1.0, 0.0, idev == 0 ? logits_in : nullptr); // only rank 0 adds residual
linear(logits_in, gate_buf, rsrc.w_ffn_down[layer], 1.0, 0.0, idev == 0 ? logits_in : nullptr, nullptr); // only rank 0 adds residual
// All_reduce if distributed
if (rsrc.comm != nullptr) {
......@@ -270,15 +267,15 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
rsrc.w_out_norm,
meta.epsilon);
}
linear(prob_buf, logits_out->slice(0, 0, nreq), rsrc.w_out_embd, 1.0, 0.0, nullptr);
linear(prob_buf, logits_out->slice(0, 0, nreq), rsrc.w_out_embd, 1.0, 0.0, nullptr, nullptr);
std::random_device _rd;
std::mt19937 gen(_rd());
token_offset = 0;
for (uint32_t req = 0; req < nreq; req++) {
auto seq_len = req_lens[req];
float random_val = std::uniform_real_distribution<float>(0, 1)(gen);
randomSample(result_buf->view({}, {}),
prob_buf->view({dvoc}, {1}),
randomSample(result_buf->memShare({}, result_buf->dtype()),
prob_buf->view_as({dvoc}, {1}),
random_val, topp[req], topk[req], temperature[req]);
token_offset += seq_len;
}
......
......@@ -128,9 +128,8 @@ public:
void debug() const;
std::string info() const;
std::shared_ptr<Tensor> view() const;
std::shared_ptr<Tensor> view(const std::vector<size_t> new_shape, const std::vector<ptrdiff_t> new_strides) const;
std::shared_ptr<Tensor> viewReshaped(const std::vector<size_t> &new_shape) const;
std::shared_ptr<Tensor> view(const std::vector<size_t> &new_shape) const;
std::shared_ptr<Tensor> view_as(const std::vector<size_t> &new_shape, const std::vector<ptrdiff_t> &new_strides) const;
~Tensor();
};
......
......@@ -258,23 +258,7 @@ std::string Tensor::info() const {
return this->_desc->info();
}
std::shared_ptr<Tensor> Tensor::view() const {
std::shared_ptr<Tensor> tensor = std::make_shared<Tensor>();
tensor->_storage = this->_storage;
tensor->_desc = TensorDesc::create(this->dtype(), this->shape(), this->strides());
tensor->_offset = this->_offset;
return tensor;
}
std::shared_ptr<Tensor> Tensor::view(const std::vector<size_t> new_shape, const std::vector<ptrdiff_t> new_strides) const {
std::shared_ptr<Tensor> tensor = std::make_shared<Tensor>();
tensor->_storage = this->_storage;
tensor->_desc = TensorDesc::create(this->dtype(), new_shape, new_strides);
tensor->_offset = this->_offset;
return tensor;
}
std::shared_ptr<Tensor> Tensor::viewReshaped(const std::vector<size_t> &new_shape) const {
std::shared_ptr<Tensor> Tensor::view(const std::vector<size_t> &new_shape) const {
// Calculate total elements in current and new shape
size_t current_elements = std::accumulate(
_desc->shape().begin(), _desc->shape().end(),
......@@ -340,6 +324,14 @@ std::shared_ptr<Tensor> Tensor::viewReshaped(const std::vector<size_t> &new_shap
return result;
}
std::shared_ptr<Tensor> Tensor::view_as(const std::vector<size_t> &new_shape, const std::vector<ptrdiff_t> &new_strides) const {
std::shared_ptr<Tensor> tensor = std::make_shared<Tensor>();
tensor->_storage = this->_storage;
tensor->_desc = TensorDesc::create(this->dtype(), new_shape, new_strides);
tensor->_offset = this->_offset;
return tensor;
}
void Tensor::debug(const std::string &filename) const {
RUN_INFINI(infinirtDeviceSynchronize());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment