Commit 21ef8820 authored by wooway777's avatar wooway777
Browse files

issue/21 - fixed view() implementation

parent b3275d7c
......@@ -141,7 +141,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto result_buf = Tensor::buffer(INFINI_DTYPE_I64, {nreq}, rsrc.memory_pool);
auto result_cpu = std::vector<int64_t>(nreq);
auto qkv_rope = qkv_buf->view_as({ntok, nh + nkvh * 2, dh});
auto qkv_rope = qkv_buf->view({ntok, nh + nkvh * 2, dh});
// Prepare inputs
auto batch_pos_ids = std::vector<uint32_t>(ntok);
......@@ -183,9 +183,9 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto qk_buf = Tensor::buffer(dt_logits, {nh, max_qk_size}, rsrc.memory_pool);
auto rearrange_q_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool);
auto q_rearrange = rearrange_q_buf->view_as({nkvh, ngroup, max_seq_len, dh});
auto q_rearrange = rearrange_q_buf->view({nkvh, ngroup, max_seq_len, dh});
auto attn_val_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool);
auto attn_val_gemm = attn_val_buf->view_as({nkvh, ngroup, max_seq_len, dh});
auto attn_val_gemm = attn_val_buf->view({nkvh, ngroup, max_seq_len, dh});
// MLP buffers
auto gate_buf = gate_up_buf->slice(1, 0, di);
......@@ -207,7 +207,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto past_len = req_pos[req];
auto seq_len = req_lens[req];
auto total_len = past_len + seq_len;
auto o = o_buf->slice({{0, token_offset, seq_len}})->view_as({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3});
auto o = o_buf->slice({{0, token_offset, seq_len}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3});
auto q = qkv_rope->slice({{0, token_offset, seq_len}, {1, 0, nh}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3});
auto k = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh, nkvh}});
auto v = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}});
......@@ -218,11 +218,11 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
rearrange(kv_caches[req]->v[idev][layer]->slice(0, past_len, seq_len), v);
// qk
rearrange(q_rearrange, q);
auto qk_gemm = qk_buf->view_as({nkvh, ngroup * seq_len, total_len});
auto qk_gemm = qk_buf->view({nkvh, ngroup * seq_len, total_len});
auto k_gemm = kv_caches[req]->k[idev][layer]->slice(0, 0, total_len)->permute({1, 2, 0});
linear(qk_gemm, rearrange_q_buf, k_gemm, 1. / sqrt(dh), 0.0, nullptr, nullptr);
// softmax
auto qk_softmax = qk_buf->view_as({nh, seq_len, total_len});
auto qk_softmax = qk_buf->view({nh, seq_len, total_len});
causalSoftmax(qk_softmax, qk_softmax);
auto v_gemm = kv_caches[req]->v[idev][layer]->slice(0, 0, total_len)->permute({1, 0, 2});
linear(attn_val_buf, qk_gemm, v_gemm, 1.0, 0.0, nullptr, nullptr);
......
......@@ -273,128 +273,63 @@ size_t Tensor::seed() const {
}
std::shared_ptr<Tensor> Tensor::view(const std::vector<size_t> &new_shape) const {
// Calculate total number of elements
// Step 1: Validate total size
size_t numel = 1;
for (auto s : shape()) {
numel *= s;
for (size_t dim : this->_desc->shape()) {
numel *= dim;
}
size_t new_numel = 1;
for (auto s : new_shape) {
new_numel *= s;
for (size_t dim : new_shape) {
new_numel *= dim;
}
ASSERT(numel == new_numel);
ASSERT_EQ(numel, new_numel);
// Handle empty tensors
if (numel == 0) {
return this->view_as(new_shape, {});
}
// Step 2: Get current shape and strides
const std::vector<size_t> &old_shape = this->_desc->shape();
const std::vector<ptrdiff_t> &old_strides = this->_desc->strides();
// Special case: view(-1) flattens the tensor
if (new_shape.size() == 1 && new_shape[0] == static_cast<size_t>(-1)) {
std::vector<size_t> flat_shape = {numel};
return this->view_as(flat_shape, {});
}
// Step 3: Create merged shape and strides
std::vector<size_t> merged_shape;
std::vector<ptrdiff_t> merged_strides;
// Check for -1 in new_shape (infer dimension)
std::vector<size_t> inferred_shape = new_shape;
size_t infer_index = static_cast<size_t>(-1);
size_t known_elements = 1;
if (!old_shape.empty()) {
merged_shape.push_back(old_shape[0]);
merged_strides.push_back(old_strides[0]);
for (size_t i = 0; i < new_shape.size(); ++i) {
if (new_shape[i] == static_cast<size_t>(-1)) {
ASSERT(infer_index == static_cast<size_t>(-1)); // Only one -1 allowed
infer_index = i;
for (size_t i = 1; i < old_shape.size(); ++i) {
if (old_strides[i] * static_cast<ptrdiff_t>(old_shape[i]) == merged_strides.back()) {
merged_shape.back() *= old_shape[i];
merged_strides.back() = old_strides[i];
} else {
known_elements *= new_shape[i];
}
}
if (infer_index != static_cast<size_t>(-1)) {
ASSERT(numel % known_elements == 0);
inferred_shape[infer_index] = numel / known_elements;
}
// For contiguous tensors, compute standard row-major strides
if (this->isContigous()) {
std::vector<ptrdiff_t> new_strides(inferred_shape.size());
if (!inferred_shape.empty()) {
new_strides.back() = 1;
for (int i = static_cast<int>(inferred_shape.size()) - 2; i >= 0; --i) {
new_strides[i] = new_strides[i + 1] * static_cast<ptrdiff_t>(inferred_shape[i + 1]);
merged_shape.push_back(old_shape[i]);
merged_strides.push_back(old_strides[i]);
}
}
return this->view_as(inferred_shape, new_strides);
}
// For non-contiguous tensors
std::vector<size_t> old_shape = shape();
std::vector<ptrdiff_t> old_strides = strides();
std::vector<ptrdiff_t> new_strides(inferred_shape.size(), 0);
size_t old_idx = old_shape.size() - 1;
size_t new_idx = inferred_shape.size() - 1;
if (new_idx != static_cast<size_t>(-1)) {
new_strides[new_idx] = 1;
}
while (old_idx != static_cast<size_t>(-1) && new_idx != static_cast<size_t>(-1)) {
size_t old_size = old_shape[old_idx];
size_t new_size = inferred_shape[new_idx];
if (old_size == 1) {
old_idx--;
} else if (new_size == 1) {
new_strides[new_idx] = (new_idx == inferred_shape.size() - 1) ? 1 : new_strides[new_idx + 1];
new_idx--;
} else if (old_size == new_size) {
new_strides[new_idx] = old_strides[old_idx];
old_idx--;
new_idx--;
} else if (old_size < new_size) {
size_t combined_size = old_size;
ptrdiff_t combined_stride = old_strides[old_idx];
old_idx--;
// Step 4: Compute new strides by splitting merged dimensions
std::vector<ptrdiff_t> new_strides(new_shape.size());
size_t merged_idx = 0;
ptrdiff_t current_stride = merged_strides[0];
size_t remaining_size = merged_shape[0];
while (old_idx != static_cast<size_t>(-1) && combined_size < new_size) {
ASSERT(static_cast<size_t>(old_strides[old_idx]) == old_shape[old_idx + 1] * static_cast<size_t>(old_strides[old_idx + 1]));
combined_size *= old_shape[old_idx];
combined_stride = old_strides[old_idx];
old_idx--;
for (size_t i = 0; i < new_shape.size(); ++i) {
// Find which merged dimension contains this new dimension
while (new_shape[i] > remaining_size) {
ASSERT(++merged_idx < merged_shape.size());
current_stride = merged_strides[merged_idx];
remaining_size = merged_shape[merged_idx];
}
ASSERT(combined_size == new_size);
new_strides[new_idx] = combined_stride;
new_idx--;
} else {
size_t remaining_size = old_size / new_size;
ASSERT(old_size % new_size == 0);
new_strides[new_idx] = old_strides[old_idx] * static_cast<ptrdiff_t>(remaining_size);
new_idx--;
if (remaining_size != 1) {
if (new_idx != static_cast<size_t>(-1)) {
inferred_shape[new_idx] = remaining_size;
new_strides[new_idx] = old_strides[old_idx];
new_idx--;
} else {
ASSERT(false);
}
}
old_idx--;
}
}
ASSERT_EQ(remaining_size % new_shape[i], 0);
// Fill remaining dimensions (must be size 1)
while (new_idx != static_cast<size_t>(-1)) {
ASSERT(inferred_shape[new_idx] == 1);
new_strides[new_idx] = new_strides[new_idx + 1];
new_idx--;
new_strides[i] = current_stride * (remaining_size / new_shape[i]);
remaining_size /= new_shape[i];
}
return this->view_as(inferred_shape, new_strides);
return this->view_as(new_shape, new_strides);
}
std::shared_ptr<Tensor> Tensor::view_as(const std::vector<size_t> &new_shape) const {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment