Commit 8853663e authored by wooway777's avatar wooway777
Browse files

issue/21 - Improved viewReshaped implementation and calls

parent dd5dec97
...@@ -141,6 +141,8 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -141,6 +141,8 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto result_buf = Tensor::buffer(INFINI_DTYPE_I64, {nreq}, rsrc.memory_pool); auto result_buf = Tensor::buffer(INFINI_DTYPE_I64, {nreq}, rsrc.memory_pool);
auto result_cpu = std::vector<int64_t>(nreq); auto result_cpu = std::vector<int64_t>(nreq);
auto qkv_rope = qkv_buf->viewReshaped({ntok, nh + nkvh * 2, dh});
// Prepare inputs // Prepare inputs
auto batch_pos_ids = std::vector<uint32_t>(ntok); auto batch_pos_ids = std::vector<uint32_t>(ntok);
size_t req_start = 0; size_t req_start = 0;
...@@ -181,7 +183,9 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -181,7 +183,9 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto qk_buf = Tensor::buffer(dt_logits, {nh, max_qk_size}, rsrc.memory_pool); auto qk_buf = Tensor::buffer(dt_logits, {nh, max_qk_size}, rsrc.memory_pool);
auto rearrange_q_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool); auto rearrange_q_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool);
auto q_rearrange = rearrange_q_buf->viewReshaped({nkvh, ngroup, max_seq_len, dh});
auto attn_val_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool); auto attn_val_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool);
auto attn_val_gemm = attn_val_buf->viewReshaped({nkvh, ngroup, max_seq_len, dh});
// MLP buffers // MLP buffers
auto gate_buf = gate_up_buf->slice(1, 0, di); auto gate_buf = gate_up_buf->slice(1, 0, di);
...@@ -198,7 +202,6 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -198,7 +202,6 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
} }
linear(qkv_buf, logits_out, rsrc.w_attn_qkv[layer], 1.0, 0.0, has_qkv_bias ? qkv_buf : nullptr); linear(qkv_buf, logits_out, rsrc.w_attn_qkv[layer], 1.0, 0.0, has_qkv_bias ? qkv_buf : nullptr);
// rope // rope
auto qkv_rope = qkv_buf->viewReshaped({ntok, nh + nkvh * 2, dh});
rope(qkv_rope->slice(1, 0, nh), qkv_rope->slice(1, 0, nh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table); rope(qkv_rope->slice(1, 0, nh), qkv_rope->slice(1, 0, nh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table);
rope(qkv_rope->slice(1, nh, nkvh), qkv_rope->slice(1, nh, nkvh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table); rope(qkv_rope->slice(1, nh, nkvh), qkv_rope->slice(1, nh, nkvh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table);
...@@ -217,7 +220,6 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -217,7 +220,6 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
rearrange(kv_caches[req]->k[idev][layer]->slice(0, past_len, seq_len), k); rearrange(kv_caches[req]->k[idev][layer]->slice(0, past_len, seq_len), k);
rearrange(kv_caches[req]->v[idev][layer]->slice(0, past_len, seq_len), v); rearrange(kv_caches[req]->v[idev][layer]->slice(0, past_len, seq_len), v);
// qk // qk
auto q_rearrange = rearrange_q_buf->viewReshaped({nkvh, ngroup, seq_len, dh});
rearrange(q_rearrange, q); rearrange(q_rearrange, q);
auto qk_gemm = qk_buf->viewReshaped({nkvh, ngroup * seq_len, total_len}); auto qk_gemm = qk_buf->viewReshaped({nkvh, ngroup * seq_len, total_len});
auto k_gemm = kv_caches[req]->k[idev][layer]->slice(0, 0, total_len)->permute({1, 2, 0}); auto k_gemm = kv_caches[req]->k[idev][layer]->slice(0, 0, total_len)->permute({1, 2, 0});
...@@ -228,7 +230,6 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -228,7 +230,6 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto v_gemm = kv_caches[req]->v[idev][layer]->slice(0, 0, total_len)->permute({1, 0, 2}); auto v_gemm = kv_caches[req]->v[idev][layer]->slice(0, 0, total_len)->permute({1, 0, 2});
linear(attn_val_buf, qk_gemm, v_gemm, 1.0, 0.0, nullptr); linear(attn_val_buf, qk_gemm, v_gemm, 1.0, 0.0, nullptr);
// rearrange attn val // rearrange attn val
auto attn_val_gemm = attn_val_buf->viewReshaped({nkvh, ngroup, max_seq_len, dh});
rearrange(o, attn_val_gemm); rearrange(o, attn_val_gemm);
token_offset += seq_len; token_offset += seq_len;
......
...@@ -130,7 +130,7 @@ public: ...@@ -130,7 +130,7 @@ public:
std::shared_ptr<Tensor> view() const; std::shared_ptr<Tensor> view() const;
std::shared_ptr<Tensor> view(const std::vector<size_t> new_shape, const std::vector<ptrdiff_t> new_strides) const; std::shared_ptr<Tensor> view(const std::vector<size_t> new_shape, const std::vector<ptrdiff_t> new_strides) const;
std::shared_ptr<Tensor> viewReshaped(const std::vector<size_t> new_shape) const; std::shared_ptr<Tensor> viewReshaped(const std::vector<size_t> &new_shape) const;
~Tensor(); ~Tensor();
}; };
......
...@@ -274,23 +274,69 @@ std::shared_ptr<Tensor> Tensor::view(const std::vector<size_t> new_shape, const ...@@ -274,23 +274,69 @@ std::shared_ptr<Tensor> Tensor::view(const std::vector<size_t> new_shape, const
return tensor; return tensor;
} }
std::shared_ptr<Tensor> Tensor::viewReshaped(const std::vector<size_t> new_shape) const { std::shared_ptr<Tensor> Tensor::viewReshaped(const std::vector<size_t> &new_shape) const {
// Create a copy of the current shape and strides // Calculate total elements in current and new shape
auto current_shape = _desc->shape(); size_t current_elements = std::accumulate(
_desc->shape().begin(), _desc->shape().end(),
// Start with the current tensor 1, std::multiplies<size_t>());
auto result = this->view(); size_t new_elements = std::accumulate(
new_shape.begin(), new_shape.end(),
1, std::multiplies<size_t>());
ASSERT_EQ(current_elements, new_elements);
const auto &old_shape = _desc->shape();
const auto &old_strides = _desc->strides();
// Special case: empty tensor
if (current_elements == 0) {
auto result = std::make_shared<Tensor>();
result->_storage = this->_storage;
result->_desc = TensorDesc::create(this->dtype(), new_shape, {});
result->_offset = this->_offset;
return result;
}
// Step 1: Merge all dimensions (if there are more than 1) // Special case: scalar to scalar
if (current_shape.size() > 1) { if (old_shape.empty() && new_shape.empty()) {
result = result->dimMerge(0, current_shape.size() - 1); auto result = std::make_shared<Tensor>();
result->_storage = this->_storage;
result->_desc = this->_desc;
result->_offset = this->_offset;
return result;
} }
// Step 2: Split into the new shape // Compute new strides
if (new_shape.size() > 1) { std::vector<ptrdiff_t> new_strides;
result = result->dimSplit(0, new_shape); if (!new_shape.empty()) {
new_strides.resize(new_shape.size());
// Compute strides for the new shape while preserving memory layout
// Start from the rightmost dimension
new_strides.back() = old_strides.back();
for (int i = new_shape.size() - 2; i >= 0; --i) {
new_strides[i] = new_strides[i + 1] * new_shape[i + 1];
}
// Verify the new strides are compatible with the old memory layout
size_t offset = 0;
for (size_t i = 0; i < old_shape.size(); ++i) {
offset += (old_shape[i] - 1) * old_strides[i];
}
size_t new_offset = 0;
for (size_t i = 0; i < new_shape.size(); ++i) {
new_offset += (new_shape[i] - 1) * new_strides[i];
}
ASSERT_EQ(offset, new_offset);
} }
// Create and return the reshaped tensor
auto result = std::make_shared<Tensor>();
result->_storage = this->_storage;
result->_desc = TensorDesc::create(this->dtype(), new_shape, new_strides);
result->_offset = this->_offset;
return result; return result;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment