Commit e641693d authored by wooway777's avatar wooway777
Browse files

issue/21 - improved view implementation

parent 6998a8f1
...@@ -192,6 +192,18 @@ void InferenceContext::linear(std::shared_ptr<Tensor> c, ...@@ -192,6 +192,18 @@ void InferenceContext::linear(std::shared_ptr<Tensor> c,
float alpha, float beta, float alpha, float beta,
std::shared_ptr<Tensor> residual, std::shared_ptr<Tensor> residual,
std::shared_ptr<Tensor> bias) { std::shared_ptr<Tensor> bias) {
bool residual_flag = residual != nullptr;
if (bias && !residual) {
int ndim_diff = c->ndim() - 1;
ASSERT_EQ(bias->ndim(), 1);
ASSERT_EQ(bias->shape()[0], c->shape()[ndim_diff]);
std::vector<ptrdiff_t> strides(ndim_diff, 0);
strides.push_back(bias->strides()[0]);
rearrange(c, bias->view_as(c->shape(), strides));
residual = c;
}
if (residual) { if (residual) {
if (residual->data() == c->data()) { if (residual->data() == c->data()) {
if (beta == 0.0) { if (beta == 0.0) {
...@@ -210,7 +222,7 @@ void InferenceContext::linear(std::shared_ptr<Tensor> c, ...@@ -210,7 +222,7 @@ void InferenceContext::linear(std::shared_ptr<Tensor> c,
gemm(c, a, b, alpha, beta); gemm(c, a, b, alpha, beta);
} }
if (bias) { if (bias && residual_flag) {
int ndim_diff = c->ndim() - 1; int ndim_diff = c->ndim() - 1;
ASSERT_EQ(bias->ndim(), 1); ASSERT_EQ(bias->ndim(), 1);
ASSERT_EQ(bias->shape()[0], c->shape()[ndim_diff]); ASSERT_EQ(bias->shape()[0], c->shape()[ndim_diff]);
......
...@@ -141,7 +141,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -141,7 +141,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto result_buf = Tensor::buffer(INFINI_DTYPE_I64, {nreq}, rsrc.memory_pool); auto result_buf = Tensor::buffer(INFINI_DTYPE_I64, {nreq}, rsrc.memory_pool);
auto result_cpu = std::vector<int64_t>(nreq); auto result_cpu = std::vector<int64_t>(nreq);
auto qkv_rope = qkv_buf->view({ntok, nh + nkvh * 2, dh}); auto qkv_rope = qkv_buf->view_as({ntok, nh + nkvh * 2, dh});
// Prepare inputs // Prepare inputs
auto batch_pos_ids = std::vector<uint32_t>(ntok); auto batch_pos_ids = std::vector<uint32_t>(ntok);
...@@ -183,9 +183,9 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -183,9 +183,9 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto qk_buf = Tensor::buffer(dt_logits, {nh, max_qk_size}, rsrc.memory_pool); auto qk_buf = Tensor::buffer(dt_logits, {nh, max_qk_size}, rsrc.memory_pool);
auto rearrange_q_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool); auto rearrange_q_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool);
auto q_rearrange = rearrange_q_buf->view({nkvh, ngroup, max_seq_len, dh}); auto q_rearrange = rearrange_q_buf->view_as({nkvh, ngroup, max_seq_len, dh});
auto attn_val_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool); auto attn_val_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool);
auto attn_val_gemm = attn_val_buf->view({nkvh, ngroup, max_seq_len, dh}); auto attn_val_gemm = attn_val_buf->view_as({nkvh, ngroup, max_seq_len, dh});
// MLP buffers // MLP buffers
auto gate_buf = gate_up_buf->slice(1, 0, di); auto gate_buf = gate_up_buf->slice(1, 0, di);
...@@ -207,8 +207,8 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -207,8 +207,8 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto past_len = req_pos[req]; auto past_len = req_pos[req];
auto seq_len = req_lens[req]; auto seq_len = req_lens[req];
auto total_len = past_len + seq_len; auto total_len = past_len + seq_len;
auto o = o_buf->view({ntok, nh, dh})->slice({{0, token_offset, seq_len}})->dimSplit(1, {nkvh, ngroup})->permute({1, 2, 0, 3}); auto o = o_buf->slice({{0, token_offset, seq_len}})->view_as({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3});
auto q = qkv_rope->slice({{0, token_offset, seq_len}, {1, 0, nh}})->dimSplit(1, {nkvh, ngroup})->permute({1, 2, 0, 3}); auto q = qkv_rope->slice({{0, token_offset, seq_len}, {1, 0, nh}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3});
auto k = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh, nkvh}}); auto k = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh, nkvh}});
auto v = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}}); auto v = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}});
...@@ -218,11 +218,11 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -218,11 +218,11 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
rearrange(kv_caches[req]->v[idev][layer]->slice(0, past_len, seq_len), v); rearrange(kv_caches[req]->v[idev][layer]->slice(0, past_len, seq_len), v);
// qk // qk
rearrange(q_rearrange, q); rearrange(q_rearrange, q);
auto qk_gemm = qk_buf->view({nkvh, ngroup * seq_len, total_len}); auto qk_gemm = qk_buf->view_as({nkvh, ngroup * seq_len, total_len});
auto k_gemm = kv_caches[req]->k[idev][layer]->slice(0, 0, total_len)->permute({1, 2, 0}); auto k_gemm = kv_caches[req]->k[idev][layer]->slice(0, 0, total_len)->permute({1, 2, 0});
linear(qk_gemm, rearrange_q_buf, k_gemm, 1. / sqrt(dh), 0.0, nullptr, nullptr); linear(qk_gemm, rearrange_q_buf, k_gemm, 1. / sqrt(dh), 0.0, nullptr, nullptr);
// softmax // softmax
auto qk_softmax = qk_buf->view({nh, seq_len, total_len}); auto qk_softmax = qk_buf->view_as({nh, seq_len, total_len});
causalSoftmax(qk_softmax, qk_softmax); causalSoftmax(qk_softmax, qk_softmax);
auto v_gemm = kv_caches[req]->v[idev][layer]->slice(0, 0, total_len)->permute({1, 0, 2}); auto v_gemm = kv_caches[req]->v[idev][layer]->slice(0, 0, total_len)->permute({1, 0, 2});
linear(attn_val_buf, qk_gemm, v_gemm, 1.0, 0.0, nullptr, nullptr); linear(attn_val_buf, qk_gemm, v_gemm, 1.0, 0.0, nullptr, nullptr);
...@@ -322,7 +322,7 @@ inferBatch(struct JiugeModel *model, ...@@ -322,7 +322,7 @@ inferBatch(struct JiugeModel *model,
void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceResource *rsrc, InferState &state, InferRequest &req, void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceResource *rsrc, InferState &state, InferRequest &req,
infiniDevice_t device, int idev, int ndev, int dev_id, infinicclComm_t comm) { infiniDevice_t device, int idev, int ndev, int dev_id, infinicclComm_t comm) {
CacheManager cache_manager(256); CacheManager cache_manager(100);
InferenceContext ctx(rsrc, &cache_manager, rsrc->stream); InferenceContext ctx(rsrc, &cache_manager, rsrc->stream);
// Set the inference context for this thread // Set the inference context for this thread
......
...@@ -130,6 +130,7 @@ public: ...@@ -130,6 +130,7 @@ public:
std::shared_ptr<Tensor> view(const std::vector<size_t> &new_shape) const; std::shared_ptr<Tensor> view(const std::vector<size_t> &new_shape) const;
std::shared_ptr<Tensor> view_as(const std::vector<size_t> &new_shape, const std::vector<ptrdiff_t> &new_strides) const; std::shared_ptr<Tensor> view_as(const std::vector<size_t> &new_shape, const std::vector<ptrdiff_t> &new_strides) const;
std::shared_ptr<Tensor> view_as(const std::vector<size_t> &new_shape) const;
~Tensor(); ~Tensor();
}; };
......
...@@ -259,69 +259,128 @@ std::string Tensor::info() const { ...@@ -259,69 +259,128 @@ std::string Tensor::info() const {
} }
std::shared_ptr<Tensor> Tensor::view(const std::vector<size_t> &new_shape) const { std::shared_ptr<Tensor> Tensor::view(const std::vector<size_t> &new_shape) const {
// Calculate total elements in current and new shape // Calculate total number of elements
size_t current_elements = std::accumulate( size_t numel = 1;
_desc->shape().begin(), _desc->shape().end(), for (auto s : shape()) {
1, std::multiplies<size_t>()); numel *= s;
size_t new_elements = std::accumulate(
new_shape.begin(), new_shape.end(),
1, std::multiplies<size_t>());
ASSERT_EQ(current_elements, new_elements);
const auto &old_shape = _desc->shape();
const auto &old_strides = _desc->strides();
// Special case: empty tensor
if (current_elements == 0) {
auto result = std::make_shared<Tensor>();
result->_storage = this->_storage;
result->_desc = TensorDesc::create(this->dtype(), new_shape, {});
result->_offset = this->_offset;
return result;
} }
// Special case: scalar to scalar size_t new_numel = 1;
if (old_shape.empty() && new_shape.empty()) { for (auto s : new_shape) {
auto result = std::make_shared<Tensor>(); new_numel *= s;
result->_storage = this->_storage;
result->_desc = this->_desc;
result->_offset = this->_offset;
return result;
} }
// Compute new strides ASSERT(numel == new_numel);
std::vector<ptrdiff_t> new_strides;
if (!new_shape.empty()) {
new_strides.resize(new_shape.size());
// Compute strides for the new shape while preserving memory layout // Handle empty tensors
// Start from the rightmost dimension if (numel == 0) {
new_strides.back() = old_strides.back(); return this->view_as(new_shape, {});
for (int i = new_shape.size() - 2; i >= 0; --i) { }
new_strides[i] = new_strides[i + 1] * new_shape[i + 1];
// Special case: view(-1) flattens the tensor
if (new_shape.size() == 1 && new_shape[0] == static_cast<size_t>(-1)) {
std::vector<size_t> flat_shape = {numel};
return this->view_as(flat_shape, {});
}
// Check for -1 in new_shape (infer dimension)
std::vector<size_t> inferred_shape = new_shape;
size_t infer_index = static_cast<size_t>(-1);
size_t known_elements = 1;
for (size_t i = 0; i < new_shape.size(); ++i) {
if (new_shape[i] == static_cast<size_t>(-1)) {
ASSERT(infer_index == static_cast<size_t>(-1)); // Only one -1 allowed
infer_index = i;
} else {
known_elements *= new_shape[i];
} }
}
// Verify the new strides are compatible with the old memory layout if (infer_index != static_cast<size_t>(-1)) {
size_t offset = 0; ASSERT(numel % known_elements == 0);
for (size_t i = 0; i < old_shape.size(); ++i) { inferred_shape[infer_index] = numel / known_elements;
offset += (old_shape[i] - 1) * old_strides[i]; }
// For contiguous tensors, compute standard row-major strides
if (this->isContigous()) {
std::vector<ptrdiff_t> new_strides(inferred_shape.size());
if (!inferred_shape.empty()) {
new_strides.back() = 1;
for (int i = static_cast<int>(inferred_shape.size()) - 2; i >= 0; --i) {
new_strides[i] = new_strides[i + 1] * static_cast<ptrdiff_t>(inferred_shape[i + 1]);
}
} }
return this->view_as(inferred_shape, new_strides);
}
// For non-contiguous tensors
std::vector<size_t> old_shape = shape();
std::vector<ptrdiff_t> old_strides = strides();
std::vector<ptrdiff_t> new_strides(inferred_shape.size(), 0);
size_t old_idx = old_shape.size() - 1;
size_t new_idx = inferred_shape.size() - 1;
size_t new_offset = 0; if (new_idx != static_cast<size_t>(-1)) {
for (size_t i = 0; i < new_shape.size(); ++i) { new_strides[new_idx] = 1;
new_offset += (new_shape[i] - 1) * new_strides[i]; }
while (old_idx != static_cast<size_t>(-1) && new_idx != static_cast<size_t>(-1)) {
size_t old_size = old_shape[old_idx];
size_t new_size = inferred_shape[new_idx];
if (old_size == 1) {
old_idx--;
} else if (new_size == 1) {
new_strides[new_idx] = (new_idx == inferred_shape.size() - 1) ? 1 : new_strides[new_idx + 1];
new_idx--;
} else if (old_size == new_size) {
new_strides[new_idx] = old_strides[old_idx];
old_idx--;
new_idx--;
} else if (old_size < new_size) {
size_t combined_size = old_size;
ptrdiff_t combined_stride = old_strides[old_idx];
old_idx--;
while (old_idx != static_cast<size_t>(-1) && combined_size < new_size) {
ASSERT(static_cast<size_t>(old_strides[old_idx]) == old_shape[old_idx + 1] * static_cast<size_t>(old_strides[old_idx + 1]));
combined_size *= old_shape[old_idx];
combined_stride = old_strides[old_idx];
old_idx--;
}
ASSERT(combined_size == new_size);
new_strides[new_idx] = combined_stride;
new_idx--;
} else {
size_t remaining_size = old_size / new_size;
ASSERT(old_size % new_size == 0);
new_strides[new_idx] = old_strides[old_idx] * static_cast<ptrdiff_t>(remaining_size);
new_idx--;
if (remaining_size != 1) {
if (new_idx != static_cast<size_t>(-1)) {
inferred_shape[new_idx] = remaining_size;
new_strides[new_idx] = old_strides[old_idx];
new_idx--;
} else {
ASSERT(false);
}
}
old_idx--;
} }
}
ASSERT_EQ(offset, new_offset); // Fill remaining dimensions (must be size 1)
while (new_idx != static_cast<size_t>(-1)) {
ASSERT(inferred_shape[new_idx] == 1);
new_strides[new_idx] = new_strides[new_idx + 1];
new_idx--;
} }
// Create and return the reshaped tensor return this->view_as(inferred_shape, new_strides);
auto result = std::make_shared<Tensor>();
result->_storage = this->_storage;
result->_desc = TensorDesc::create(this->dtype(), new_shape, new_strides);
result->_offset = this->_offset;
return result;
} }
std::shared_ptr<Tensor> Tensor::view_as(const std::vector<size_t> &new_shape, const std::vector<ptrdiff_t> &new_strides) const { std::shared_ptr<Tensor> Tensor::view_as(const std::vector<size_t> &new_shape, const std::vector<ptrdiff_t> &new_strides) const {
...@@ -332,6 +391,14 @@ std::shared_ptr<Tensor> Tensor::view_as(const std::vector<size_t> &new_shape, co ...@@ -332,6 +391,14 @@ std::shared_ptr<Tensor> Tensor::view_as(const std::vector<size_t> &new_shape, co
return tensor; return tensor;
} }
std::shared_ptr<Tensor> Tensor::view_as(const std::vector<size_t> &new_shape) const {
std::shared_ptr<Tensor> tensor = std::make_shared<Tensor>();
tensor->_storage = this->_storage;
tensor->_desc = TensorDesc::create(this->dtype(), new_shape);
tensor->_offset = this->_offset;
return tensor;
}
void Tensor::debug(const std::string &filename) const { void Tensor::debug(const std::string &filename) const {
RUN_INFINI(infinirtDeviceSynchronize()); RUN_INFINI(infinirtDeviceSynchronize());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment