Commit 6998a8f1 authored by wooway777's avatar wooway777
Browse files

issue/21 - Improved linear and view implementations

parent 8853663e
...@@ -22,7 +22,7 @@ inline void hash_combine(size_t &seed, T value, typename std::enable_if<std::is_ ...@@ -22,7 +22,7 @@ inline void hash_combine(size_t &seed, T value, typename std::enable_if<std::is_
} }
// Helper function to compute hash for tensor descriptors // Helper function to compute hash for tensor descriptors
inline size_t computeTensorDescHash(std::shared_ptr<Tensor> tensor) { inline size_t computeTensorDescHash(std::shared_ptr<Tensor> &tensor) {
size_t seed = 0; size_t seed = 0;
hash_combine(seed, tensor->dtype()); hash_combine(seed, tensor->dtype());
for (auto dim : tensor->shape()) { for (auto dim : tensor->shape()) {
......
...@@ -190,7 +190,8 @@ void InferenceContext::linear(std::shared_ptr<Tensor> c, ...@@ -190,7 +190,8 @@ void InferenceContext::linear(std::shared_ptr<Tensor> c,
std::shared_ptr<Tensor> a, std::shared_ptr<Tensor> a,
std::shared_ptr<Tensor> b, std::shared_ptr<Tensor> b,
float alpha, float beta, float alpha, float beta,
std::shared_ptr<Tensor> residual) { std::shared_ptr<Tensor> residual,
std::shared_ptr<Tensor> bias) {
if (residual) { if (residual) {
if (residual->data() == c->data()) { if (residual->data() == c->data()) {
if (beta == 0.0) { if (beta == 0.0) {
...@@ -208,4 +209,13 @@ void InferenceContext::linear(std::shared_ptr<Tensor> c, ...@@ -208,4 +209,13 @@ void InferenceContext::linear(std::shared_ptr<Tensor> c,
} else { } else {
gemm(c, a, b, alpha, beta); gemm(c, a, b, alpha, beta);
} }
if (bias) {
int ndim_diff = c->ndim() - 1;
ASSERT_EQ(bias->ndim(), 1);
ASSERT_EQ(bias->shape()[0], c->shape()[ndim_diff]);
std::vector<ptrdiff_t> strides(ndim_diff, 0);
strides.push_back(bias->strides()[0]);
add(c, c, bias->view_as(c->shape(), strides));
}
} }
...@@ -47,7 +47,8 @@ struct InferenceContext { ...@@ -47,7 +47,8 @@ struct InferenceContext {
std::shared_ptr<Tensor> a, std::shared_ptr<Tensor> a,
std::shared_ptr<Tensor> b, std::shared_ptr<Tensor> b,
float alpha, float beta, float alpha, float beta,
std::shared_ptr<Tensor> residual); std::shared_ptr<Tensor> residual,
std::shared_ptr<Tensor> bias);
}; };
namespace { namespace {
...@@ -103,6 +104,6 @@ inline void randomSample(std::shared_ptr<Tensor> out, std::shared_ptr<Tensor> pr ...@@ -103,6 +104,6 @@ inline void randomSample(std::shared_ptr<Tensor> out, std::shared_ptr<Tensor> pr
inline void linear(std::shared_ptr<Tensor> c, std::shared_ptr<Tensor> a, inline void linear(std::shared_ptr<Tensor> c, std::shared_ptr<Tensor> a,
std::shared_ptr<Tensor> b, float alpha, float beta, std::shared_ptr<Tensor> b, float alpha, float beta,
std::shared_ptr<Tensor> residual) { std::shared_ptr<Tensor> residual, std::shared_ptr<Tensor> bias) {
getInferenceContext().linear(c, a, b, alpha, beta, residual); getInferenceContext().linear(c, a, b, alpha, beta, residual, bias);
} }
...@@ -141,7 +141,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -141,7 +141,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto result_buf = Tensor::buffer(INFINI_DTYPE_I64, {nreq}, rsrc.memory_pool); auto result_buf = Tensor::buffer(INFINI_DTYPE_I64, {nreq}, rsrc.memory_pool);
auto result_cpu = std::vector<int64_t>(nreq); auto result_cpu = std::vector<int64_t>(nreq);
auto qkv_rope = qkv_buf->viewReshaped({ntok, nh + nkvh * 2, dh}); auto qkv_rope = qkv_buf->view({ntok, nh + nkvh * 2, dh});
// Prepare inputs // Prepare inputs
auto batch_pos_ids = std::vector<uint32_t>(ntok); auto batch_pos_ids = std::vector<uint32_t>(ntok);
...@@ -183,9 +183,9 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -183,9 +183,9 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto qk_buf = Tensor::buffer(dt_logits, {nh, max_qk_size}, rsrc.memory_pool); auto qk_buf = Tensor::buffer(dt_logits, {nh, max_qk_size}, rsrc.memory_pool);
auto rearrange_q_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool); auto rearrange_q_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool);
auto q_rearrange = rearrange_q_buf->viewReshaped({nkvh, ngroup, max_seq_len, dh}); auto q_rearrange = rearrange_q_buf->view({nkvh, ngroup, max_seq_len, dh});
auto attn_val_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool); auto attn_val_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool);
auto attn_val_gemm = attn_val_buf->viewReshaped({nkvh, ngroup, max_seq_len, dh}); auto attn_val_gemm = attn_val_buf->view({nkvh, ngroup, max_seq_len, dh});
// MLP buffers // MLP buffers
auto gate_buf = gate_up_buf->slice(1, 0, di); auto gate_buf = gate_up_buf->slice(1, 0, di);
...@@ -197,10 +197,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -197,10 +197,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
// rms norm // rms norm
rmsnorm(logits_out, logits_in, rsrc.w_attn_norm[layer], meta.epsilon); rmsnorm(logits_out, logits_in, rsrc.w_attn_norm[layer], meta.epsilon);
// qkv_proj // qkv_proj
if (has_qkv_bias) { linear(qkv_buf, logits_out, rsrc.w_attn_qkv[layer], 1.0, 0.0, nullptr, has_qkv_bias ? rsrc.b_attn_qkv[layer] : nullptr);
rearrange(qkv_buf, rsrc.b_attn_qkv[layer]->view({ntok, (nh + nkvh * 2) * dh}, {0, 1}));
}
linear(qkv_buf, logits_out, rsrc.w_attn_qkv[layer], 1.0, 0.0, has_qkv_bias ? qkv_buf : nullptr);
// rope // rope
rope(qkv_rope->slice(1, 0, nh), qkv_rope->slice(1, 0, nh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table); rope(qkv_rope->slice(1, 0, nh), qkv_rope->slice(1, 0, nh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table);
rope(qkv_rope->slice(1, nh, nkvh), qkv_rope->slice(1, nh, nkvh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table); rope(qkv_rope->slice(1, nh, nkvh), qkv_rope->slice(1, nh, nkvh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table);
...@@ -210,7 +207,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -210,7 +207,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto past_len = req_pos[req]; auto past_len = req_pos[req];
auto seq_len = req_lens[req]; auto seq_len = req_lens[req];
auto total_len = past_len + seq_len; auto total_len = past_len + seq_len;
auto o = o_buf->viewReshaped({ntok, nh, dh})->slice({{0, token_offset, seq_len}})->dimSplit(1, {nkvh, ngroup})->permute({1, 2, 0, 3}); auto o = o_buf->view({ntok, nh, dh})->slice({{0, token_offset, seq_len}})->dimSplit(1, {nkvh, ngroup})->permute({1, 2, 0, 3});
auto q = qkv_rope->slice({{0, token_offset, seq_len}, {1, 0, nh}})->dimSplit(1, {nkvh, ngroup})->permute({1, 2, 0, 3}); auto q = qkv_rope->slice({{0, token_offset, seq_len}, {1, 0, nh}})->dimSplit(1, {nkvh, ngroup})->permute({1, 2, 0, 3});
auto k = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh, nkvh}}); auto k = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh, nkvh}});
auto v = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}}); auto v = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}});
...@@ -221,14 +218,14 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -221,14 +218,14 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
rearrange(kv_caches[req]->v[idev][layer]->slice(0, past_len, seq_len), v); rearrange(kv_caches[req]->v[idev][layer]->slice(0, past_len, seq_len), v);
// qk // qk
rearrange(q_rearrange, q); rearrange(q_rearrange, q);
auto qk_gemm = qk_buf->viewReshaped({nkvh, ngroup * seq_len, total_len}); auto qk_gemm = qk_buf->view({nkvh, ngroup * seq_len, total_len});
auto k_gemm = kv_caches[req]->k[idev][layer]->slice(0, 0, total_len)->permute({1, 2, 0}); auto k_gemm = kv_caches[req]->k[idev][layer]->slice(0, 0, total_len)->permute({1, 2, 0});
linear(qk_gemm, rearrange_q_buf, k_gemm, 1. / sqrt(dh), 0.0, nullptr); linear(qk_gemm, rearrange_q_buf, k_gemm, 1. / sqrt(dh), 0.0, nullptr, nullptr);
// softmax // softmax
auto qk_softmax = qk_buf->viewReshaped({nh, seq_len, total_len}); auto qk_softmax = qk_buf->view({nh, seq_len, total_len});
causalSoftmax(qk_softmax, qk_softmax); causalSoftmax(qk_softmax, qk_softmax);
auto v_gemm = kv_caches[req]->v[idev][layer]->slice(0, 0, total_len)->permute({1, 0, 2}); auto v_gemm = kv_caches[req]->v[idev][layer]->slice(0, 0, total_len)->permute({1, 0, 2});
linear(attn_val_buf, qk_gemm, v_gemm, 1.0, 0.0, nullptr); linear(attn_val_buf, qk_gemm, v_gemm, 1.0, 0.0, nullptr, nullptr);
// rearrange attn val // rearrange attn val
rearrange(o, attn_val_gemm); rearrange(o, attn_val_gemm);
...@@ -236,7 +233,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -236,7 +233,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
} }
// o_proj // o_proj
linear(logits_in, o_buf, rsrc.w_attn_out[layer], 1.0, 0.0, idev == 0 ? logits_in : nullptr); // only rank 0 adds residual linear(logits_in, o_buf, rsrc.w_attn_out[layer], 1.0, 0.0, idev == 0 ? logits_in : nullptr, nullptr); // only rank 0 adds residual
// All_reduce if distributed // All_reduce if distributed
if (rsrc.comm != nullptr) { if (rsrc.comm != nullptr) {
...@@ -247,9 +244,9 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -247,9 +244,9 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
} }
// 2. FFN // 2. FFN
rmsnorm(logits_out, logits_in, rsrc.w_ffn_norm[layer], meta.epsilon); rmsnorm(logits_out, logits_in, rsrc.w_ffn_norm[layer], meta.epsilon);
linear(gate_up_buf, logits_out, rsrc.w_ffn_gate_up[layer], 1.0, 0.0, nullptr); linear(gate_up_buf, logits_out, rsrc.w_ffn_gate_up[layer], 1.0, 0.0, nullptr, nullptr);
swiglu(gate_buf, up_buf, gate_buf); swiglu(gate_buf, up_buf, gate_buf);
linear(logits_in, gate_buf, rsrc.w_ffn_down[layer], 1.0, 0.0, idev == 0 ? logits_in : nullptr); // only rank 0 adds residual linear(logits_in, gate_buf, rsrc.w_ffn_down[layer], 1.0, 0.0, idev == 0 ? logits_in : nullptr, nullptr); // only rank 0 adds residual
// All_reduce if distributed // All_reduce if distributed
if (rsrc.comm != nullptr) { if (rsrc.comm != nullptr) {
...@@ -270,15 +267,15 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -270,15 +267,15 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
rsrc.w_out_norm, rsrc.w_out_norm,
meta.epsilon); meta.epsilon);
} }
linear(prob_buf, logits_out->slice(0, 0, nreq), rsrc.w_out_embd, 1.0, 0.0, nullptr); linear(prob_buf, logits_out->slice(0, 0, nreq), rsrc.w_out_embd, 1.0, 0.0, nullptr, nullptr);
std::random_device _rd; std::random_device _rd;
std::mt19937 gen(_rd()); std::mt19937 gen(_rd());
token_offset = 0; token_offset = 0;
for (uint32_t req = 0; req < nreq; req++) { for (uint32_t req = 0; req < nreq; req++) {
auto seq_len = req_lens[req]; auto seq_len = req_lens[req];
float random_val = std::uniform_real_distribution<float>(0, 1)(gen); float random_val = std::uniform_real_distribution<float>(0, 1)(gen);
randomSample(result_buf->view({}, {}), randomSample(result_buf->memShare({}, result_buf->dtype()),
prob_buf->view({dvoc}, {1}), prob_buf->view_as({dvoc}, {1}),
random_val, topp[req], topk[req], temperature[req]); random_val, topp[req], topk[req], temperature[req]);
token_offset += seq_len; token_offset += seq_len;
} }
......
...@@ -128,9 +128,8 @@ public: ...@@ -128,9 +128,8 @@ public:
void debug() const; void debug() const;
std::string info() const; std::string info() const;
std::shared_ptr<Tensor> view() const; std::shared_ptr<Tensor> view(const std::vector<size_t> &new_shape) const;
std::shared_ptr<Tensor> view(const std::vector<size_t> new_shape, const std::vector<ptrdiff_t> new_strides) const; std::shared_ptr<Tensor> view_as(const std::vector<size_t> &new_shape, const std::vector<ptrdiff_t> &new_strides) const;
std::shared_ptr<Tensor> viewReshaped(const std::vector<size_t> &new_shape) const;
~Tensor(); ~Tensor();
}; };
......
...@@ -258,23 +258,7 @@ std::string Tensor::info() const { ...@@ -258,23 +258,7 @@ std::string Tensor::info() const {
return this->_desc->info(); return this->_desc->info();
} }
std::shared_ptr<Tensor> Tensor::view() const { std::shared_ptr<Tensor> Tensor::view(const std::vector<size_t> &new_shape) const {
std::shared_ptr<Tensor> tensor = std::make_shared<Tensor>();
tensor->_storage = this->_storage;
tensor->_desc = TensorDesc::create(this->dtype(), this->shape(), this->strides());
tensor->_offset = this->_offset;
return tensor;
}
std::shared_ptr<Tensor> Tensor::view(const std::vector<size_t> new_shape, const std::vector<ptrdiff_t> new_strides) const {
std::shared_ptr<Tensor> tensor = std::make_shared<Tensor>();
tensor->_storage = this->_storage;
tensor->_desc = TensorDesc::create(this->dtype(), new_shape, new_strides);
tensor->_offset = this->_offset;
return tensor;
}
std::shared_ptr<Tensor> Tensor::viewReshaped(const std::vector<size_t> &new_shape) const {
// Calculate total elements in current and new shape // Calculate total elements in current and new shape
size_t current_elements = std::accumulate( size_t current_elements = std::accumulate(
_desc->shape().begin(), _desc->shape().end(), _desc->shape().begin(), _desc->shape().end(),
...@@ -340,6 +324,14 @@ std::shared_ptr<Tensor> Tensor::viewReshaped(const std::vector<size_t> &new_shap ...@@ -340,6 +324,14 @@ std::shared_ptr<Tensor> Tensor::viewReshaped(const std::vector<size_t> &new_shap
return result; return result;
} }
std::shared_ptr<Tensor> Tensor::view_as(const std::vector<size_t> &new_shape, const std::vector<ptrdiff_t> &new_strides) const {
std::shared_ptr<Tensor> tensor = std::make_shared<Tensor>();
tensor->_storage = this->_storage;
tensor->_desc = TensorDesc::create(this->dtype(), new_shape, new_strides);
tensor->_offset = this->_offset;
return tensor;
}
void Tensor::debug(const std::string &filename) const { void Tensor::debug(const std::string &filename) const {
RUN_INFINI(infinirtDeviceSynchronize()); RUN_INFINI(infinirtDeviceSynchronize());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment