Commit 21ef8820 authored by wooway777's avatar wooway777
Browse files

issue/21 - fixed view() implementation

parent b3275d7c
...@@ -141,7 +141,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -141,7 +141,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto result_buf = Tensor::buffer(INFINI_DTYPE_I64, {nreq}, rsrc.memory_pool); auto result_buf = Tensor::buffer(INFINI_DTYPE_I64, {nreq}, rsrc.memory_pool);
auto result_cpu = std::vector<int64_t>(nreq); auto result_cpu = std::vector<int64_t>(nreq);
auto qkv_rope = qkv_buf->view_as({ntok, nh + nkvh * 2, dh}); auto qkv_rope = qkv_buf->view({ntok, nh + nkvh * 2, dh});
// Prepare inputs // Prepare inputs
auto batch_pos_ids = std::vector<uint32_t>(ntok); auto batch_pos_ids = std::vector<uint32_t>(ntok);
...@@ -183,9 +183,9 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -183,9 +183,9 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto qk_buf = Tensor::buffer(dt_logits, {nh, max_qk_size}, rsrc.memory_pool); auto qk_buf = Tensor::buffer(dt_logits, {nh, max_qk_size}, rsrc.memory_pool);
auto rearrange_q_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool); auto rearrange_q_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool);
auto q_rearrange = rearrange_q_buf->view_as({nkvh, ngroup, max_seq_len, dh}); auto q_rearrange = rearrange_q_buf->view({nkvh, ngroup, max_seq_len, dh});
auto attn_val_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool); auto attn_val_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool);
auto attn_val_gemm = attn_val_buf->view_as({nkvh, ngroup, max_seq_len, dh}); auto attn_val_gemm = attn_val_buf->view({nkvh, ngroup, max_seq_len, dh});
// MLP buffers // MLP buffers
auto gate_buf = gate_up_buf->slice(1, 0, di); auto gate_buf = gate_up_buf->slice(1, 0, di);
...@@ -207,7 +207,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -207,7 +207,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto past_len = req_pos[req]; auto past_len = req_pos[req];
auto seq_len = req_lens[req]; auto seq_len = req_lens[req];
auto total_len = past_len + seq_len; auto total_len = past_len + seq_len;
auto o = o_buf->slice({{0, token_offset, seq_len}})->view_as({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3}); auto o = o_buf->slice({{0, token_offset, seq_len}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3});
auto q = qkv_rope->slice({{0, token_offset, seq_len}, {1, 0, nh}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3}); auto q = qkv_rope->slice({{0, token_offset, seq_len}, {1, 0, nh}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3});
auto k = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh, nkvh}}); auto k = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh, nkvh}});
auto v = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}}); auto v = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}});
...@@ -218,11 +218,11 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -218,11 +218,11 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
rearrange(kv_caches[req]->v[idev][layer]->slice(0, past_len, seq_len), v); rearrange(kv_caches[req]->v[idev][layer]->slice(0, past_len, seq_len), v);
// qk // qk
rearrange(q_rearrange, q); rearrange(q_rearrange, q);
auto qk_gemm = qk_buf->view_as({nkvh, ngroup * seq_len, total_len}); auto qk_gemm = qk_buf->view({nkvh, ngroup * seq_len, total_len});
auto k_gemm = kv_caches[req]->k[idev][layer]->slice(0, 0, total_len)->permute({1, 2, 0}); auto k_gemm = kv_caches[req]->k[idev][layer]->slice(0, 0, total_len)->permute({1, 2, 0});
linear(qk_gemm, rearrange_q_buf, k_gemm, 1. / sqrt(dh), 0.0, nullptr, nullptr); linear(qk_gemm, rearrange_q_buf, k_gemm, 1. / sqrt(dh), 0.0, nullptr, nullptr);
// softmax // softmax
auto qk_softmax = qk_buf->view_as({nh, seq_len, total_len}); auto qk_softmax = qk_buf->view({nh, seq_len, total_len});
causalSoftmax(qk_softmax, qk_softmax); causalSoftmax(qk_softmax, qk_softmax);
auto v_gemm = kv_caches[req]->v[idev][layer]->slice(0, 0, total_len)->permute({1, 0, 2}); auto v_gemm = kv_caches[req]->v[idev][layer]->slice(0, 0, total_len)->permute({1, 0, 2});
linear(attn_val_buf, qk_gemm, v_gemm, 1.0, 0.0, nullptr, nullptr); linear(attn_val_buf, qk_gemm, v_gemm, 1.0, 0.0, nullptr, nullptr);
......
...@@ -273,128 +273,63 @@ size_t Tensor::seed() const { ...@@ -273,128 +273,63 @@ size_t Tensor::seed() const {
} }
std::shared_ptr<Tensor> Tensor::view(const std::vector<size_t> &new_shape) const { std::shared_ptr<Tensor> Tensor::view(const std::vector<size_t> &new_shape) const {
// Calculate total number of elements // Step 1: Validate total size
size_t numel = 1; size_t numel = 1;
for (auto s : shape()) { for (size_t dim : this->_desc->shape()) {
numel *= s; numel *= dim;
} }
size_t new_numel = 1; size_t new_numel = 1;
for (auto s : new_shape) { for (size_t dim : new_shape) {
new_numel *= s; new_numel *= dim;
} }
ASSERT(numel == new_numel); ASSERT_EQ(numel, new_numel);
// Handle empty tensors // Step 2: Get current shape and strides
if (numel == 0) { const std::vector<size_t> &old_shape = this->_desc->shape();
return this->view_as(new_shape, {}); const std::vector<ptrdiff_t> &old_strides = this->_desc->strides();
}
// Special case: view(-1) flattens the tensor // Step 3: Create merged shape and strides
if (new_shape.size() == 1 && new_shape[0] == static_cast<size_t>(-1)) { std::vector<size_t> merged_shape;
std::vector<size_t> flat_shape = {numel}; std::vector<ptrdiff_t> merged_strides;
return this->view_as(flat_shape, {});
}
// Check for -1 in new_shape (infer dimension) if (!old_shape.empty()) {
std::vector<size_t> inferred_shape = new_shape; merged_shape.push_back(old_shape[0]);
size_t infer_index = static_cast<size_t>(-1); merged_strides.push_back(old_strides[0]);
size_t known_elements = 1;
for (size_t i = 0; i < new_shape.size(); ++i) { for (size_t i = 1; i < old_shape.size(); ++i) {
if (new_shape[i] == static_cast<size_t>(-1)) { if (old_strides[i] * static_cast<ptrdiff_t>(old_shape[i]) == merged_strides.back()) {
ASSERT(infer_index == static_cast<size_t>(-1)); // Only one -1 allowed merged_shape.back() *= old_shape[i];
infer_index = i; merged_strides.back() = old_strides[i];
} else { } else {
known_elements *= new_shape[i]; merged_shape.push_back(old_shape[i]);
} merged_strides.push_back(old_strides[i]);
}
if (infer_index != static_cast<size_t>(-1)) {
ASSERT(numel % known_elements == 0);
inferred_shape[infer_index] = numel / known_elements;
}
// For contiguous tensors, compute standard row-major strides
if (this->isContigous()) {
std::vector<ptrdiff_t> new_strides(inferred_shape.size());
if (!inferred_shape.empty()) {
new_strides.back() = 1;
for (int i = static_cast<int>(inferred_shape.size()) - 2; i >= 0; --i) {
new_strides[i] = new_strides[i + 1] * static_cast<ptrdiff_t>(inferred_shape[i + 1]);
} }
} }
return this->view_as(inferred_shape, new_strides);
} }
// For non-contiguous tensors // Step 4: Compute new strides by splitting merged dimensions
std::vector<size_t> old_shape = shape(); std::vector<ptrdiff_t> new_strides(new_shape.size());
std::vector<ptrdiff_t> old_strides = strides(); size_t merged_idx = 0;
std::vector<ptrdiff_t> new_strides(inferred_shape.size(), 0); ptrdiff_t current_stride = merged_strides[0];
size_t remaining_size = merged_shape[0];
size_t old_idx = old_shape.size() - 1;
size_t new_idx = inferred_shape.size() - 1;
if (new_idx != static_cast<size_t>(-1)) { for (size_t i = 0; i < new_shape.size(); ++i) {
new_strides[new_idx] = 1; // Find which merged dimension contains this new dimension
} while (new_shape[i] > remaining_size) {
ASSERT(++merged_idx < merged_shape.size());
while (old_idx != static_cast<size_t>(-1) && new_idx != static_cast<size_t>(-1)) { current_stride = merged_strides[merged_idx];
size_t old_size = old_shape[old_idx]; remaining_size = merged_shape[merged_idx];
size_t new_size = inferred_shape[new_idx];
if (old_size == 1) {
old_idx--;
} else if (new_size == 1) {
new_strides[new_idx] = (new_idx == inferred_shape.size() - 1) ? 1 : new_strides[new_idx + 1];
new_idx--;
} else if (old_size == new_size) {
new_strides[new_idx] = old_strides[old_idx];
old_idx--;
new_idx--;
} else if (old_size < new_size) {
size_t combined_size = old_size;
ptrdiff_t combined_stride = old_strides[old_idx];
old_idx--;
while (old_idx != static_cast<size_t>(-1) && combined_size < new_size) {
ASSERT(static_cast<size_t>(old_strides[old_idx]) == old_shape[old_idx + 1] * static_cast<size_t>(old_strides[old_idx + 1]));
combined_size *= old_shape[old_idx];
combined_stride = old_strides[old_idx];
old_idx--;
}
ASSERT(combined_size == new_size);
new_strides[new_idx] = combined_stride;
new_idx--;
} else {
size_t remaining_size = old_size / new_size;
ASSERT(old_size % new_size == 0);
new_strides[new_idx] = old_strides[old_idx] * static_cast<ptrdiff_t>(remaining_size);
new_idx--;
if (remaining_size != 1) {
if (new_idx != static_cast<size_t>(-1)) {
inferred_shape[new_idx] = remaining_size;
new_strides[new_idx] = old_strides[old_idx];
new_idx--;
} else {
ASSERT(false);
}
}
old_idx--;
} }
}
// Fill remaining dimensions (must be size 1) ASSERT_EQ(remaining_size % new_shape[i], 0);
while (new_idx != static_cast<size_t>(-1)) {
ASSERT(inferred_shape[new_idx] == 1); new_strides[i] = current_stride * (remaining_size / new_shape[i]);
new_strides[new_idx] = new_strides[new_idx + 1]; remaining_size /= new_shape[i];
new_idx--;
} }
return this->view_as(inferred_shape, new_strides); return this->view_as(new_shape, new_strides);
} }
std::shared_ptr<Tensor> Tensor::view_as(const std::vector<size_t> &new_shape) const { std::shared_ptr<Tensor> Tensor::view_as(const std::vector<size_t> &new_shape) const {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment