Unverified Commit 07aa6990 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #22 from InfiniTensor/issue/21

Issue/21 - Inference Process Modualization
parents be0e66ef bfae3bbb
#ifndef CACHE_MANAGER_HPP
#define CACHE_MANAGER_HPP
#include <functional>
#include <memory>
#include <unordered_map>
#include <vector>
#include "../tensor.hpp"
#include "../utils.hpp"
#include "infinicore_infer.h"
class IDescriptorDestroyer {
public:
virtual ~IDescriptorDestroyer() = default;
virtual void destroy(void *descriptor) = 0;
};
template <typename DescriptorType>
class DescriptorDestroyer : public IDescriptorDestroyer {
using DestroyFunc = infiniStatus_t (*)(DescriptorType);
DestroyFunc destroyFunc;
public:
DescriptorDestroyer(DestroyFunc func) : destroyFunc(func) {}
void destroy(void *descriptor) override {
destroyFunc(*static_cast<DescriptorType *>(descriptor));
}
};
template <typename DescriptorType>
class LRUDescriptorCache {
private:
struct CacheNode {
size_t key;
DescriptorType desc;
CacheNode *prev;
CacheNode *next;
CacheNode() : key(0), desc(), prev(nullptr), next(nullptr) {}
CacheNode(size_t k, const DescriptorType &d) : key(k), desc(d), prev(nullptr), next(nullptr) {}
};
std::unordered_map<size_t, CacheNode *> cache;
CacheNode *head;
CacheNode *tail;
const size_t capacity;
size_t size;
std::unique_ptr<IDescriptorDestroyer> destroyer;
void removeNode(CacheNode *node) {
node->prev->next = node->next;
node->next->prev = node->prev;
if (destroyer) {
destroyer->destroy(&node->desc);
}
cache.erase(node->key);
delete node;
--size;
}
void addToTop(CacheNode *node) {
node->next = head->next;
node->next->prev = node;
node->prev = head;
head->next = node;
cache[node->key] = node;
if (++size > capacity) {
removeNode(tail->prev);
}
}
void moveToTop(CacheNode *node) {
node->prev->next = node->next;
node->next->prev = node->prev;
node->next = head->next;
node->next->prev = node;
node->prev = head;
head->next = node;
}
public:
template <typename DestroyFunc>
LRUDescriptorCache(size_t c, DestroyFunc destroyFunc)
: capacity(c), size(0), destroyer(std::make_unique<DescriptorDestroyer<DescriptorType>>(destroyFunc)) {
head = new CacheNode();
tail = new CacheNode();
head->next = tail;
tail->prev = head;
}
~LRUDescriptorCache() {
while (head->next != tail) {
removeNode(head->next);
}
delete head;
delete tail;
}
bool get(size_t key, DescriptorType &out_desc) {
auto it = cache.find(key);
if (it == cache.end()) {
return false;
}
CacheNode *node = it->second;
moveToTop(node);
out_desc = node->desc;
return true;
}
void put(size_t key, const DescriptorType &descriptor) {
auto it = cache.find(key);
if (it != cache.end()) {
// Key already exists, update the descriptor
CacheNode *node = it->second;
if (destroyer) {
destroyer->destroy(&node->desc);
}
node->desc = descriptor;
moveToTop(node);
return;
}
// Check if we need to evict
if (size >= capacity) {
removeNode(tail->prev);
}
// Create new node and add to top
CacheNode *node = new CacheNode(key, descriptor);
addToTop(node);
}
LRUDescriptorCache(const LRUDescriptorCache &) = delete;
LRUDescriptorCache &operator=(const LRUDescriptorCache &) = delete;
};
class CacheManager {
private:
const size_t DEFAULT_CACHE_CAPACITY = 128;
LRUDescriptorCache<infiniopAddDescriptor_t> add_cache;
LRUDescriptorCache<infiniopRMSNormDescriptor_t> rms_norm_cache;
LRUDescriptorCache<infiniopGemmDescriptor_t> gemm_cache;
LRUDescriptorCache<infiniopRoPEDescriptor_t> rope_cache;
LRUDescriptorCache<infiniopRearrangeDescriptor_t> rearrange_cache;
LRUDescriptorCache<infiniopCausalSoftmaxDescriptor_t> causal_softmax_cache;
LRUDescriptorCache<infiniopSwiGLUDescriptor_t> swiglu_cache;
LRUDescriptorCache<infiniopRandomSampleDescriptor_t> random_sample_cache;
public:
CacheManager(size_t capacity = 100)
: add_cache(capacity, infiniopDestroyAddDescriptor),
rms_norm_cache(capacity, infiniopDestroyRMSNormDescriptor),
gemm_cache(capacity, infiniopDestroyGemmDescriptor),
rope_cache(capacity, infiniopDestroyRoPEDescriptor),
rearrange_cache(capacity, infiniopDestroyRearrangeDescriptor),
causal_softmax_cache(capacity, infiniopDestroyCausalSoftmaxDescriptor),
swiglu_cache(capacity, infiniopDestroySwiGLUDescriptor),
random_sample_cache(capacity, infiniopDestroyRandomSampleDescriptor) {}
// Add operations
bool getAddDescriptor(size_t key, infiniopAddDescriptor_t &desc) {
return add_cache.get(key, desc);
}
void putAddDescriptor(size_t key, const infiniopAddDescriptor_t &desc) {
add_cache.put(key, desc);
}
// RMSNorm operations
bool getRMSNormDescriptor(size_t key, infiniopRMSNormDescriptor_t &desc) {
return rms_norm_cache.get(key, desc);
}
void putRMSNormDescriptor(size_t key, const infiniopRMSNormDescriptor_t &desc) {
rms_norm_cache.put(key, desc);
}
// GEMM operations
bool getGemmDescriptor(size_t key, infiniopGemmDescriptor_t &desc) {
return gemm_cache.get(key, desc);
}
void putGemmDescriptor(size_t key, const infiniopGemmDescriptor_t &desc) {
gemm_cache.put(key, desc);
}
// RoPE operations
bool getRoPEDescriptor(size_t key, infiniopRoPEDescriptor_t &desc) {
return rope_cache.get(key, desc);
}
void putRoPEDescriptor(size_t key, const infiniopRoPEDescriptor_t &desc) {
rope_cache.put(key, desc);
}
// Rearrange operations
bool getRearrangeDescriptor(size_t key, infiniopRearrangeDescriptor_t &desc) {
return rearrange_cache.get(key, desc);
}
void putRearrangeDescriptor(size_t key, const infiniopRearrangeDescriptor_t &desc) {
rearrange_cache.put(key, desc);
}
// Softmax operations
bool getCausalSoftmaxDescriptor(size_t key, infiniopCausalSoftmaxDescriptor_t &desc) {
return causal_softmax_cache.get(key, desc);
}
void putCausalSoftmaxDescriptor(size_t key, const infiniopCausalSoftmaxDescriptor_t &desc) {
causal_softmax_cache.put(key, desc);
}
// SwiGLU operations
bool getSwiGLUDescriptor(size_t key, infiniopSwiGLUDescriptor_t &desc) {
return swiglu_cache.get(key, desc);
}
void putSwiGLUDescriptor(size_t key, const infiniopSwiGLUDescriptor_t &desc) {
swiglu_cache.put(key, desc);
}
// Random Sample operations
bool getRandomSampleDescriptor(size_t key, infiniopRandomSampleDescriptor_t &desc) {
return random_sample_cache.get(key, desc);
}
void putRandomSampleDescriptor(size_t key, const infiniopRandomSampleDescriptor_t &desc) {
random_sample_cache.put(key, desc);
}
template <typename... Tensors>
static size_t createDescriptorKey(Tensors... tensors) {
size_t seed = 0;
(..., (tensors ? hash_combine(seed, tensors->seed()) : (void)0));
return seed;
}
};
#endif // CACHE_MANAGER_HPP
#include "inference_context.hpp"
#include "../tensor.hpp"
#include "../utils.hpp"
InferenceContext::InferenceContext(DeviceResource *rsrc, CacheManager *cache_manager, infinirtStream_t stream)
: rsrc(rsrc), cache_manager(cache_manager), stream(stream) {}
void InferenceContext::ensure_workspace(size_t required_size) {
if (required_size > current_workspace_size || !workspace_storage) {
workspace_storage = Storage::createFromPool(required_size, rsrc->memory_pool);
current_workspace_size = required_size;
}
}
void InferenceContext::add(std::shared_ptr<Tensor> c,
std::shared_ptr<Tensor> a,
std::shared_ptr<Tensor> b) {
size_t key = CacheManager::createDescriptorKey(c, a, b);
infiniopAddDescriptor_t desc;
if (!cache_manager->getAddDescriptor(key, desc)) {
RUN_INFINI(infiniopCreateAddDescriptor(rsrc->handle, &desc, c->desc(), a->desc(), b->desc()));
cache_manager->putAddDescriptor(key, desc);
}
size_t workspace_size = 0;
RUN_INFINI(infiniopGetAddWorkspaceSize(desc, &workspace_size));
ensure_workspace(workspace_size);
void *workspace = workspace_storage->memory();
RUN_INFINI(infiniopAdd(
desc, workspace, workspace_size,
c->data(), a->data(), b->data(), stream));
}
void InferenceContext::rmsnorm(std::shared_ptr<Tensor> y,
std::shared_ptr<Tensor> x,
std::shared_ptr<Tensor> w,
float epsilon) {
size_t key = CacheManager::createDescriptorKey(y, x, w);
infiniopRMSNormDescriptor_t desc;
if (!cache_manager->getRMSNormDescriptor(key, desc)) {
RUN_INFINI(infiniopCreateRMSNormDescriptor(
rsrc->handle, &desc, y->desc(), x->desc(), w->desc(), epsilon));
cache_manager->putRMSNormDescriptor(key, desc);
}
size_t workspace_size = 0;
RUN_INFINI(infiniopGetRMSNormWorkspaceSize(desc, &workspace_size));
ensure_workspace(workspace_size);
void *workspace = workspace_storage->memory();
RUN_INFINI(infiniopRMSNorm(
desc, workspace, workspace_size,
y->data(), x->data(), w->data(), stream));
}
void InferenceContext::gemm(std::shared_ptr<Tensor> c,
std::shared_ptr<Tensor> a,
std::shared_ptr<Tensor> b,
float alpha, float beta) {
size_t key = CacheManager::createDescriptorKey(c, a, b);
infiniopGemmDescriptor_t desc;
if (!cache_manager->getGemmDescriptor(key, desc)) {
RUN_INFINI(infiniopCreateGemmDescriptor(rsrc->handle, &desc, c->desc(), a->desc(), b->desc()));
cache_manager->putGemmDescriptor(key, desc);
}
size_t workspace_size = 0;
RUN_INFINI(infiniopGetGemmWorkspaceSize(desc, &workspace_size));
ensure_workspace(workspace_size);
void *workspace = workspace_storage->memory();
RUN_INFINI(infiniopGemm(
desc, workspace, workspace_size,
c->data(), a->data(), b->data(), alpha, beta, stream));
}
void InferenceContext::rearrange(std::shared_ptr<Tensor> dst,
std::shared_ptr<Tensor> src) {
size_t key = CacheManager::createDescriptorKey(dst, src);
infiniopRearrangeDescriptor_t desc;
if (!cache_manager->getRearrangeDescriptor(key, desc)) {
RUN_INFINI(infiniopCreateRearrangeDescriptor(rsrc->handle, &desc, dst->desc(), src->desc()));
cache_manager->putRearrangeDescriptor(key, desc);
}
RUN_INFINI(infiniopRearrange(
desc,
dst->data(),
src->data(),
stream));
}
void InferenceContext::rope(std::shared_ptr<Tensor> q,
std::shared_ptr<Tensor> k,
std::shared_ptr<Tensor> pos,
std::shared_ptr<Tensor> sin,
std::shared_ptr<Tensor> cos) {
size_t key = CacheManager::createDescriptorKey(q, k, pos, sin, cos);
infiniopRoPEDescriptor_t desc;
if (!cache_manager->getRoPEDescriptor(key, desc)) {
RUN_INFINI(infiniopCreateRoPEDescriptor(
rsrc->handle, &desc, q->desc(), k->desc(),
pos->desc(), sin->desc(), cos->desc()));
cache_manager->putRoPEDescriptor(key, desc);
}
size_t workspace_size = 0;
RUN_INFINI(infiniopGetRoPEWorkspaceSize(desc, &workspace_size));
ensure_workspace(workspace_size);
void *workspace = workspace_storage->memory();
RUN_INFINI(infiniopRoPE(
desc, workspace, workspace_size,
q->data(), k->data(), pos->data(),
sin->data(), cos->data(), stream));
}
void InferenceContext::causalSoftmax(std::shared_ptr<Tensor> y,
std::shared_ptr<Tensor> x) {
size_t key = CacheManager::createDescriptorKey(y, x);
infiniopCausalSoftmaxDescriptor_t desc;
if (!cache_manager->getCausalSoftmaxDescriptor(key, desc)) {
RUN_INFINI(infiniopCreateCausalSoftmaxDescriptor(
rsrc->handle, &desc, y->desc(), x->desc()));
cache_manager->putCausalSoftmaxDescriptor(key, desc);
}
size_t workspace_size = 0;
RUN_INFINI(infiniopGetCausalSoftmaxWorkspaceSize(desc, &workspace_size));
ensure_workspace(workspace_size);
void *workspace = workspace_storage->memory();
RUN_INFINI(infiniopCausalSoftmax(desc, workspace, workspace_size,
y->data(), x->data(), stream));
}
void InferenceContext::swiglu(std::shared_ptr<Tensor> out,
std::shared_ptr<Tensor> up,
std::shared_ptr<Tensor> gate) {
size_t key = CacheManager::createDescriptorKey(out, up, gate);
infiniopSwiGLUDescriptor_t desc;
if (!cache_manager->getSwiGLUDescriptor(key, desc)) {
RUN_INFINI(infiniopCreateSwiGLUDescriptor(
rsrc->handle, &desc, out->desc(), up->desc(), gate->desc()));
cache_manager->putSwiGLUDescriptor(key, desc);
}
size_t workspace_size = 0;
RUN_INFINI(infiniopGetSwiGLUWorkspaceSize(desc, &workspace_size));
ensure_workspace(workspace_size);
void *workspace = workspace_storage->memory();
RUN_INFINI(infiniopSwiGLU(desc, workspace, workspace_size,
out->data(), up->data(), gate->data(), stream));
}
void InferenceContext::randomSample(std::shared_ptr<Tensor> out,
std::shared_ptr<Tensor> prob,
float random_val, float top_p, uint32_t top_k, float temperature) {
size_t key = CacheManager::createDescriptorKey(out, prob);
infiniopRandomSampleDescriptor_t desc;
if (!cache_manager->getRandomSampleDescriptor(key, desc)) {
RUN_INFINI(infiniopCreateRandomSampleDescriptor(
rsrc->handle, &desc, out->desc(), prob->desc()));
cache_manager->putRandomSampleDescriptor(key, desc);
}
size_t workspace_size = 0;
RUN_INFINI(infiniopGetRandomSampleWorkspaceSize(desc, &workspace_size));
ensure_workspace(workspace_size);
void *workspace = workspace_storage->memory();
RUN_INFINI(infiniopRandomSample(
desc, workspace, workspace_size,
out->data(), prob->data(),
random_val, top_p, top_k, temperature,
stream));
}
void InferenceContext::linear(std::shared_ptr<Tensor> c,
std::shared_ptr<Tensor> a,
std::shared_ptr<Tensor> b,
float alpha, float beta,
std::shared_ptr<Tensor> residual,
std::shared_ptr<Tensor> bias) {
bool residual_flag = residual != nullptr;
if (bias && !residual) {
int ndim_diff = c->ndim() - 1;
ASSERT_EQ(bias->ndim(), 1);
ASSERT_EQ(bias->shape()[0], c->shape()[ndim_diff]);
std::vector<ptrdiff_t> strides(ndim_diff, 0);
strides.push_back(bias->strides()[0]);
rearrange(c, bias->view_as(c->shape(), strides));
residual = c;
}
if (residual) {
if (residual->data() == c->data()) {
if (beta == 0.0) {
gemm(c, a, b, alpha, 1.0);
} else {
auto c_copy = Tensor::buffer(c->dtype(), c->shape(), rsrc->memory_pool);
c_copy->copyFrom(c, rsrc->handle, stream);
gemm(c, a, b, alpha, beta);
add(c, c, c_copy);
}
} else {
gemm(c, a, b, alpha, beta);
add(c, c, residual);
}
} else {
gemm(c, a, b, alpha, beta);
}
if (bias && residual_flag) {
int ndim_diff = c->ndim() - 1;
ASSERT_EQ(bias->ndim(), 1);
ASSERT_EQ(bias->shape()[0], c->shape()[ndim_diff]);
std::vector<ptrdiff_t> strides(ndim_diff, 0);
strides.push_back(bias->strides()[0]);
add(c, c, bias->view_as(c->shape(), strides));
}
}
#pragma once
#include "cache_manager.hpp"
#include "jiuge/jiuge_impl.hpp"
#include "jiuge/jiuge_weight.hpp"
#include <cassert>
struct InferenceContext {
DeviceResource *rsrc;
CacheManager *cache_manager;
infinirtStream_t stream;
std::shared_ptr<Storage> workspace_storage;
size_t current_workspace_size = 0;
InferenceContext(DeviceResource *rsrc, CacheManager *cache_manager, infinirtStream_t stream);
void ensure_workspace(size_t required_size);
void add(std::shared_ptr<Tensor> c,
std::shared_ptr<Tensor> a,
std::shared_ptr<Tensor> b);
void rmsnorm(std::shared_ptr<Tensor> y,
std::shared_ptr<Tensor> x,
std::shared_ptr<Tensor> w,
float epsilon);
void gemm(std::shared_ptr<Tensor> c,
std::shared_ptr<Tensor> a,
std::shared_ptr<Tensor> b,
float alpha, float beta);
void rearrange(std::shared_ptr<Tensor> dst,
std::shared_ptr<Tensor> src);
void rope(std::shared_ptr<Tensor> q,
std::shared_ptr<Tensor> k,
std::shared_ptr<Tensor> pos,
std::shared_ptr<Tensor> sin,
std::shared_ptr<Tensor> cos);
void causalSoftmax(std::shared_ptr<Tensor> y,
std::shared_ptr<Tensor> x);
void swiglu(std::shared_ptr<Tensor> out,
std::shared_ptr<Tensor> up,
std::shared_ptr<Tensor> gate);
void randomSample(std::shared_ptr<Tensor> out,
std::shared_ptr<Tensor> prob,
float random_val, float top_p, uint32_t top_k, float temperature);
void linear(std::shared_ptr<Tensor> c,
std::shared_ptr<Tensor> a,
std::shared_ptr<Tensor> b,
float alpha, float beta,
std::shared_ptr<Tensor> residual,
std::shared_ptr<Tensor> bias);
};
namespace {
thread_local InferenceContext *tls_inference_context = nullptr;
}
inline InferenceContext &getInferenceContext() {
assert(tls_inference_context != nullptr && "InferenceContext not set for this thread");
return *tls_inference_context;
}
inline void setInferenceContext(InferenceContext *ctx) {
tls_inference_context = ctx;
}
inline void add(std::shared_ptr<Tensor> c, std::shared_ptr<Tensor> a, std::shared_ptr<Tensor> b) {
getInferenceContext().add(c, a, b);
}
inline void rmsnorm(std::shared_ptr<Tensor> y, std::shared_ptr<Tensor> x,
std::shared_ptr<Tensor> w, float epsilon) {
getInferenceContext().rmsnorm(y, x, w, epsilon);
}
inline void gemm(std::shared_ptr<Tensor> c, std::shared_ptr<Tensor> a,
std::shared_ptr<Tensor> b, float alpha, float beta) {
getInferenceContext().gemm(c, a, b, alpha, beta);
}
inline void rearrange(std::shared_ptr<Tensor> dst, std::shared_ptr<Tensor> src) {
getInferenceContext().rearrange(dst, src);
}
inline void rope(std::shared_ptr<Tensor> q, std::shared_ptr<Tensor> k,
std::shared_ptr<Tensor> pos, std::shared_ptr<Tensor> sin,
std::shared_ptr<Tensor> cos) {
getInferenceContext().rope(q, k, pos, sin, cos);
}
inline void causalSoftmax(std::shared_ptr<Tensor> y, std::shared_ptr<Tensor> x) {
getInferenceContext().causalSoftmax(y, x);
}
inline void swiglu(std::shared_ptr<Tensor> out, std::shared_ptr<Tensor> up,
std::shared_ptr<Tensor> gate) {
getInferenceContext().swiglu(out, up, gate);
}
inline void randomSample(std::shared_ptr<Tensor> out, std::shared_ptr<Tensor> prob,
float random_val, float top_p, uint32_t top_k, float temperature) {
getInferenceContext().randomSample(out, prob, random_val, top_p, top_k, temperature);
}
inline void linear(std::shared_ptr<Tensor> c, std::shared_ptr<Tensor> a,
std::shared_ptr<Tensor> b, float alpha, float beta,
std::shared_ptr<Tensor> residual, std::shared_ptr<Tensor> bias) {
getInferenceContext().linear(c, a, b, alpha, beta, residual, bias);
}
......@@ -3,6 +3,7 @@
#include "../../tensor.hpp"
#include "../../utils.hpp"
#include "../inference_context.hpp"
#include "infinicore_infer.h"
#include <random>
......@@ -140,6 +141,8 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto result_buf = Tensor::buffer(INFINI_DTYPE_I64, {nreq}, rsrc.memory_pool);
auto result_cpu = std::vector<int64_t>(nreq);
auto qkv_rope = qkv_buf->view({ntok, nh + nkvh * 2, dh});
// Prepare inputs
auto batch_pos_ids = std::vector<uint32_t>(ntok);
size_t req_start = 0;
......@@ -164,239 +167,73 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
dsize(dt_logits) * d, INFINIRT_MEMCPY_D2D, stream));
}
// Prepare operators and workspace
size_t workspace_size = 0, temp_size = 0;
// attn & mlp rmsnorm
infiniopRMSNormDescriptor_t desc_norm;
RUN_INFINI(infiniopCreateRMSNormDescriptor(
rsrc.handle, &desc_norm, logits_in->desc(),
logits_out->desc(), rsrc.w_attn_norm[0]->desc(),
meta.epsilon));
RUN_INFINI(infiniopGetRMSNormWorkspaceSize(desc_norm, &workspace_size));
workspace_size = std::max(workspace_size, temp_size);
// Attention
infiniopGemmDescriptor_t desc_attn_qkv, desc_attn_o;
infiniopRearrangeDescriptor_t desc_qkv_bias;
if (has_qkv_bias) {
RUN_INFINI(infiniopCreateRearrangeDescriptor(
rsrc.handle, &desc_qkv_bias, qkv_buf->desc(),
TensorDesc::create(dt_logits, {ntok, (nh + nkvh * 2) * dh}, {0, 1})->desc()));
}
RUN_INFINI(infiniopCreateGemmDescriptor(
rsrc.handle, &desc_attn_qkv, qkv_buf->desc(),
logits_in->desc(), rsrc.w_attn_qkv[0]->desc()));
RUN_INFINI(infiniopCreateGemmDescriptor(
rsrc.handle, &desc_attn_o, logits_in->desc(),
o_buf->desc(), rsrc.w_attn_out[0]->desc()));
RUN_INFINI(infiniopGetGemmWorkspaceSize(desc_attn_qkv, &temp_size));
workspace_size = std::max(workspace_size, temp_size);
RUN_INFINI(infiniopGetGemmWorkspaceSize(desc_attn_o, &temp_size));
workspace_size = std::max(workspace_size, temp_size);
infiniopRoPEDescriptor_t desc_rope_q, desc_rope_k;
qkv_buf->dimSplit(1, {nh + nkvh * 2, dh}); // (ntok, nh + 2 * nkvh, dh)
auto qkv_buf_q = qkv_buf->slice(1, 0, nh);
auto qkv_buf_k = qkv_buf->slice(1, nh, nkvh);
RUN_INFINI(infiniopCreateRoPEDescriptor(
rsrc.handle, &desc_rope_q, qkv_buf_q->desc(), qkv_buf_q->desc(),
pos_ids_buf->desc(), rsrc.sin_table->desc(),
rsrc.cos_table->desc()));
RUN_INFINI(infiniopGetRoPEWorkspaceSize(desc_rope_q, &temp_size));
workspace_size = std::max(workspace_size, temp_size);
RUN_INFINI(infiniopCreateRoPEDescriptor(
rsrc.handle, &desc_rope_k, qkv_buf_k->desc(), qkv_buf_k->desc(),
pos_ids_buf->desc(), rsrc.sin_table->desc(),
rsrc.cos_table->desc()));
RUN_INFINI(infiniopGetRoPEWorkspaceSize(desc_rope_k, &temp_size));
workspace_size = std::max(workspace_size, temp_size);
// attention inner
auto desc_kv_rearranges = std::vector<infiniopRearrangeDescriptor_t>(nreq);
auto desc_q_rearranges = std::vector<infiniopRearrangeDescriptor_t>(nreq);
auto desc_qk_gemms = std::vector<infiniopGemmDescriptor_t>(nreq);
auto desc_qk_softmaxs = std::vector<infiniopCausalSoftmaxDescriptor_t>(nreq);
auto desc_attn_v_gemms = std::vector<infiniopGemmDescriptor_t>(nreq);
auto desc_attn_v_rearranges = std::vector<infiniopRearrangeDescriptor_t>(nreq);
size_t token_offset = 0;
size_t max_qk_size = 0;
size_t max_seq_len = 0;
o_buf->dimSplit(1, {nh, dh});
for (uint32_t req = 0; req < nreq; req++) {
auto past_len = req_pos[req];
auto seq_len = req_lens[req];
auto total_len = past_len + seq_len;
auto o = o_buf->slice({{0, token_offset, seq_len}});
auto q = qkv_buf->slice({{0, token_offset, seq_len}, {1, 0, nh}});
auto k = qkv_buf->slice({{0, token_offset, seq_len}, {1, nh, nkvh}});
// auto v = qkv_buf->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}});
// kv cache tensors can share the same descriptor
// [nkvh, dh, total_len]
auto full_kv = kv_caches[req]->k[idev][0]->slice(0, 0, total_len)->permute({1, 2, 0});
auto cache_kv = kv_caches[req]->k[idev][0]->slice(0, past_len, seq_len);
RUN_INFINI(infiniopCreateRearrangeDescriptor(rsrc.handle, &desc_kv_rearranges[req],
cache_kv->desc(), k->desc()));
// [nkvh, ngroup, seq_len, dh]
q->dimSplit(1, {nkvh, ngroup})->permute({1, 2, 0, 3});
auto q_t = TensorDesc::create(dt_logits, {nkvh, ngroup, seq_len, dh});
// [seq_len, nkvh, ngroup, dh] -> [nkvh, ngroup, seq_len, dh]
RUN_INFINI(infiniopCreateRearrangeDescriptor(rsrc.handle, &desc_q_rearranges[req],
q_t->desc(), q->desc()));
// [nkvh, ngroup, seq_len, dh] -> [seq_len, nkvh, ngroup, dh]
auto attn_v_t = q_t;
auto attn_v = TensorDesc::createWithOrder(dt_logits, {nkvh, ngroup, seq_len, dh}, {1, 2, 0, 3});
RUN_INFINI(infiniopCreateRearrangeDescriptor(rsrc.handle, &desc_attn_v_rearranges[req],
attn_v->desc(), attn_v_t->desc()));
q_t = TensorDesc::create(dt_logits, {nkvh, ngroup * seq_len, dh});
auto qk = TensorDesc::create(dt_logits, {nkvh, ngroup * seq_len, total_len});
max_qk_size = std::max(max_qk_size, size_t(seq_len * total_len));
max_seq_len = std::max(max_seq_len, size_t(seq_len));
RUN_INFINI(infiniopCreateGemmDescriptor(
rsrc.handle, &desc_qk_gemms[req], qk->desc(), q_t->desc(), full_kv->desc()));
RUN_INFINI(infiniopGetGemmWorkspaceSize(desc_qk_gemms[req], &temp_size));
workspace_size = std::max(workspace_size, temp_size);
// [nkvh, total_len, dh]
auto full_v = kv_caches[req]->v[idev][0]->slice(0, 0, total_len)->permute({1, 0, 2});
RUN_INFINI(infiniopCreateGemmDescriptor(
rsrc.handle, &desc_attn_v_gemms[req], q_t->desc(), qk->desc(), full_v->desc()));
RUN_INFINI(infiniopGetGemmWorkspaceSize(desc_attn_v_gemms[req], &temp_size));
workspace_size = std::max(workspace_size, temp_size);
qk = TensorDesc::create(dt_logits, {nkvh * ngroup, seq_len, total_len});
RUN_INFINI(infiniopCreateCausalSoftmaxDescriptor(
rsrc.handle, &desc_qk_softmaxs[req], qk->desc(), qk->desc()));
RUN_INFINI(infiniopGetCausalSoftmaxWorkspaceSize(desc_qk_softmaxs[req], &temp_size));
workspace_size = std::max(workspace_size, temp_size);
token_offset += seq_len;
}
auto qk_buf = Tensor::buffer(dt_logits, {nh, max_qk_size}, rsrc.memory_pool);
auto rearrange_q_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool);
auto attn_val_buf = Tensor::buffer(dt_logits, {nh, max_seq_len, dh}, rsrc.memory_pool);
// MLP descriptors
infiniopGemmDescriptor_t desc_ffn_gate_up, desc_ffn_down;
infiniopSwiGLUDescriptor_t desc_swiglu;
RUN_INFINI(infiniopCreateGemmDescriptor(
rsrc.handle, &desc_ffn_gate_up, gate_up_buf->desc(),
logits_out->desc(), rsrc.w_ffn_gate_up[0]->desc()));
RUN_INFINI(infiniopGetGemmWorkspaceSize(desc_ffn_gate_up, &temp_size));
workspace_size = std::max(workspace_size, temp_size);
auto q_rearrange = rearrange_q_buf->view({nkvh, ngroup, max_seq_len, dh});
auto attn_val_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool);
auto attn_val_gemm = attn_val_buf->view({nkvh, ngroup, max_seq_len, dh});
// MLP buffers
auto gate_buf = gate_up_buf->slice(1, 0, di);
auto up_buf = gate_up_buf->slice(1, di, di);
RUN_INFINI(infiniopCreateSwiGLUDescriptor(
rsrc.handle, &desc_swiglu, gate_buf->desc(), up_buf->desc(), gate_buf->desc()));
RUN_INFINI(infiniopGetSwiGLUWorkspaceSize(desc_swiglu, &temp_size));
workspace_size = std::max(workspace_size, temp_size);
RUN_INFINI(infiniopCreateGemmDescriptor(
rsrc.handle, &desc_ffn_down, logits_in->desc(),
gate_buf->desc(), rsrc.w_ffn_down[0]->desc()));
RUN_INFINI(infiniopGetGemmWorkspaceSize(desc_ffn_down, &temp_size));
workspace_size = std::max(workspace_size, temp_size);
// Output and sample
infiniopRMSNormDescriptor_t desc_norm_out;
RUN_INFINI(infiniopCreateRMSNormDescriptor(
rsrc.handle, &desc_norm_out, logits_out->slice(0, 0, 1)->desc(),
logits_out->slice(0, 0, 1)->desc(),
rsrc.w_out_norm->desc(), meta.epsilon));
RUN_INFINI(infiniopGetRMSNormWorkspaceSize(desc_norm_out, &temp_size));
workspace_size = std::max(workspace_size, temp_size);
infiniopGemmDescriptor_t desc_out_embd;
RUN_INFINI(infiniopCreateGemmDescriptor(
rsrc.handle, &desc_out_embd, prob_buf->desc(),
logits_out->slice(0, 0, nreq)->desc(),
rsrc.w_out_embd->desc()));
RUN_INFINI(infiniopGetGemmWorkspaceSize(desc_out_embd, &temp_size));
workspace_size = std::max(workspace_size, temp_size);
infiniopRandomSampleDescriptor_t desc_sample;
RUN_INFINI(infiniopCreateRandomSampleDescriptor(
rsrc.handle, &desc_sample,
TensorDesc::create(INFINI_DTYPE_I64, {}, {})->desc(),
TensorDesc::create(dt_logits, {dvoc}, {1})->desc()));
RUN_INFINI(infiniopGetRandomSampleWorkspaceSize(desc_sample, &temp_size));
workspace_size = std::max(workspace_size, temp_size);
// Allocate workspace
std::shared_ptr<Storage> workspace_storage = Storage::createFromPool(workspace_size, rsrc.memory_pool);
void *workspace = workspace_storage->memory();
// Compute
for (uint32_t layer = 0; layer < nlayer; layer++) {
// 1. Attention
// rms norm
RUN_INFINI(infiniopRMSNorm(
desc_norm, workspace, workspace_size,
logits_out->data(), logits_in->data(),
rsrc.w_attn_norm[layer]->data(), stream));
rmsnorm(logits_out, logits_in, rsrc.w_attn_norm[layer], meta.epsilon);
// qkv_proj
if (has_qkv_bias) {
RUN_INFINI(infiniopRearrange(
desc_qkv_bias,
qkv_buf->data(), rsrc.b_attn_qkv[layer]->data(), stream));
}
RUN_INFINI(infiniopGemm(
desc_attn_qkv, workspace, workspace_size,
qkv_buf->data(), logits_out->data(),
rsrc.w_attn_qkv[layer]->data(), 1.0, has_qkv_bias ? 1.0 : 0.0, stream));
linear(qkv_buf, logits_out, rsrc.w_attn_qkv[layer], 1.0, 0.0, nullptr, has_qkv_bias ? rsrc.b_attn_qkv[layer] : nullptr);
// rope
RUN_INFINI(infiniopRoPE(
desc_rope_q, workspace, workspace_size,
qkv_buf->data(), qkv_buf->data(),
pos_ids_buf->data(),
rsrc.sin_table->data(),
rsrc.cos_table->data(), stream));
RUN_INFINI(infiniopRoPE(
desc_rope_k, workspace, workspace_size,
qkv_buf->data(nh * dh), qkv_buf->data(nh * dh),
pos_ids_buf->data(),
rsrc.sin_table->data(),
rsrc.cos_table->data(),
stream));
rope(qkv_rope->slice(1, 0, nh), qkv_rope->slice(1, 0, nh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table);
rope(qkv_rope->slice(1, nh, nkvh), qkv_rope->slice(1, nh, nkvh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table);
size_t token_offset = 0;
for (uint32_t req = 0; req < nreq; req++) {
auto past_len = req_pos[req];
auto seq_len = req_lens[req];
auto o = o_buf->slice({{0, token_offset, seq_len}});
auto q = qkv_buf->slice({{0, token_offset, seq_len}, {1, 0, nh}});
auto k = qkv_buf->slice({{0, token_offset, seq_len}, {1, nh, nkvh}});
auto v = qkv_buf->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}});
auto total_len = past_len + seq_len;
auto o = o_buf->slice({{0, token_offset, seq_len}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3});
auto q = qkv_rope->slice({{0, token_offset, seq_len}, {1, 0, nh}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3});
auto k = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh, nkvh}});
auto v = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}});
// self attention
// concat
RUN_INFINI(infiniopRearrange(
desc_kv_rearranges[req],
kv_caches[req]->k[idev][layer]->data(past_len * nkvh * dh),
k->data(), stream));
RUN_INFINI(infiniopRearrange(
desc_kv_rearranges[req],
kv_caches[req]->v[idev][layer]->data(past_len * nkvh * dh),
v->data(), stream));
rearrange(kv_caches[req]->k[idev][layer]->slice(0, past_len, seq_len), k);
rearrange(kv_caches[req]->v[idev][layer]->slice(0, past_len, seq_len), v);
// qk
RUN_INFINI(infiniopRearrange(desc_q_rearranges[req], rearrange_q_buf->data(), q->data(), stream));
RUN_INFINI(infiniopGemm(
desc_qk_gemms[req], workspace, workspace_size,
qk_buf->data(), rearrange_q_buf->data(), kv_caches[req]->k[idev][layer]->data(), 1. / sqrt(dh), 0.0, stream));
rearrange(q_rearrange, q);
auto qk_gemm = qk_buf->view({nkvh, ngroup * seq_len, total_len});
auto k_gemm = kv_caches[req]->k[idev][layer]->slice(0, 0, total_len)->permute({1, 2, 0});
linear(qk_gemm, rearrange_q_buf, k_gemm, 1. / sqrt(dh), 0.0, nullptr, nullptr);
// softmax
RUN_INFINI(infiniopCausalSoftmax(
desc_qk_softmaxs[req], workspace, workspace_size,
qk_buf->data(), qk_buf->data(), stream));
// attn val
RUN_INFINI(infiniopGemm(
desc_attn_v_gemms[req], workspace, workspace_size,
attn_val_buf->data(), qk_buf->data(), kv_caches[req]->v[idev][layer]->data(), 1.0, 0.0, stream));
auto qk_softmax = qk_buf->view({nh, seq_len, total_len});
causalSoftmax(qk_softmax, qk_softmax);
auto v_gemm = kv_caches[req]->v[idev][layer]->slice(0, 0, total_len)->permute({1, 0, 2});
linear(attn_val_buf, qk_gemm, v_gemm, 1.0, 0.0, nullptr, nullptr);
// rearrange attn val
RUN_INFINI(infiniopRearrange(
desc_attn_v_rearranges[req],
o->data(),
attn_val_buf->data(), stream));
rearrange(o, attn_val_gemm);
token_offset += seq_len;
}
// o_proj
RUN_INFINI(infiniopGemm(
desc_attn_o, workspace, workspace_size,
logits_in->data(), o_buf->data(),
rsrc.w_attn_out[layer]->data(), 1.0, idev == 0 ? 1.0 : 0.0, stream)); // only rank 0 adds residual
linear(logits_in, o_buf, rsrc.w_attn_out[layer], 1.0, 0.0, idev == 0 ? logits_in : nullptr, nullptr); // only rank 0 adds residual
// All_reduce if distributed
if (rsrc.comm != nullptr) {
......@@ -406,22 +243,10 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
RUN_INFINI(infinirtStreamSynchronize(stream));
}
// 2. FFN
// rms_norm
RUN_INFINI(infiniopRMSNorm(
desc_norm, workspace, workspace_size,
logits_out->data(), logits_in->data(),
rsrc.w_ffn_norm[layer]->data(), stream));
RUN_INFINI(infiniopGemm(
desc_ffn_gate_up, workspace, workspace_size,
gate_up_buf->data(), logits_out->data(), rsrc.w_ffn_gate_up[layer]->data(),
1.0, 0.0, stream));
RUN_INFINI(infiniopSwiGLU(
desc_swiglu, workspace, workspace_size,
gate_buf->data(), up_buf->data(), gate_buf->data(), stream));
RUN_INFINI(infiniopGemm(
desc_ffn_down, workspace, workspace_size,
logits_in->data(), gate_buf->data(),
rsrc.w_ffn_down[layer]->data(), 1.0, idev == 0 ? 1.0 : 0.0, stream)); // only rank 0 adds residual
rmsnorm(logits_out, logits_in, rsrc.w_ffn_norm[layer], meta.epsilon);
linear(gate_up_buf, logits_out, rsrc.w_ffn_gate_up[layer], 1.0, 0.0, nullptr, nullptr);
swiglu(gate_buf, up_buf, gate_buf);
linear(logits_in, gate_buf, rsrc.w_ffn_down[layer], 1.0, 0.0, idev == 0 ? logits_in : nullptr, nullptr); // only rank 0 adds residual
// All_reduce if distributed
if (rsrc.comm != nullptr) {
......@@ -437,31 +262,21 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
for (uint32_t req = 0; req < nreq; req++) {
auto seq_len = req_lens[req];
token_offset += seq_len;
RUN_INFINI(infiniopRMSNorm(
desc_norm_out, workspace, workspace_size,
logits_out->data(req * d),
logits_in->data((token_offset - 1) * d),
rsrc.w_out_norm->data(), stream));
}
RUN_INFINI(infiniopGemm(
desc_out_embd, workspace, workspace_size,
prob_buf->data(), logits_out->data(),
rsrc.w_out_embd->data(), 1.0, 0.0, stream));
rmsnorm(logits_out->slice(0, req, 1),
logits_in->slice(0, token_offset - 1, 1),
rsrc.w_out_norm,
meta.epsilon);
}
linear(prob_buf, logits_out->slice(0, 0, nreq), rsrc.w_out_embd, 1.0, 0.0, nullptr, nullptr);
std::random_device _rd;
std::mt19937 gen(_rd());
token_offset = 0;
for (uint32_t req = 0; req < nreq; req++) {
auto seq_len = req_lens[req];
float random_val = std::uniform_real_distribution<float>(0, 1)(gen);
// prob_buf->debug();
RUN_INFINI(infiniopRandomSample(
desc_sample, workspace, workspace_size,
result_buf->data(req),
prob_buf->data(req * dvoc),
random_val,
topp[req], topk[req], temperature[req],
stream));
// result_buf->debug();
randomSample(result_buf->memShare({}, result_buf->dtype()),
prob_buf->view_as({dvoc}, {1}),
random_val, topp[req], topk[req], temperature[req]);
token_offset += seq_len;
}
RUN_INFINI(infinirtStreamSynchronize(stream));
......@@ -471,30 +286,6 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
output[req] = result_cpu[req];
}
}
// Clean up
infiniopDestroyRMSNormDescriptor(desc_norm);
if (has_qkv_bias) {
infiniopDestroyRearrangeDescriptor(desc_qkv_bias);
}
infiniopDestroyGemmDescriptor(desc_attn_qkv);
infiniopDestroyGemmDescriptor(desc_attn_o);
infiniopDestroyRoPEDescriptor(desc_rope_q);
infiniopDestroyRoPEDescriptor(desc_rope_k);
for (uint32_t req = 0; req < nreq; req++) {
infiniopDestroyRearrangeDescriptor(desc_kv_rearranges[req]);
infiniopDestroyRearrangeDescriptor(desc_q_rearranges[req]);
infiniopDestroyGemmDescriptor(desc_qk_gemms[req]);
infiniopDestroyCausalSoftmaxDescriptor(desc_qk_softmaxs[req]);
infiniopDestroyGemmDescriptor(desc_attn_v_gemms[req]);
infiniopDestroyRearrangeDescriptor(desc_attn_v_rearranges[req]);
}
infiniopDestroyGemmDescriptor(desc_ffn_gate_up);
infiniopDestroySwiGLUDescriptor(desc_swiglu);
infiniopDestroyGemmDescriptor(desc_ffn_down);
infiniopDestroyRMSNormDescriptor(desc_norm_out);
infiniopDestroyGemmDescriptor(desc_out_embd);
infiniopDestroyRandomSampleDescriptor(desc_sample);
}
__C void
......@@ -531,6 +322,12 @@ inferBatch(struct JiugeModel *model,
void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceResource *rsrc, InferState &state, InferRequest &req,
infiniDevice_t device, int idev, int ndev, int dev_id, infinicclComm_t comm) {
CacheManager cache_manager(100);
InferenceContext ctx(rsrc, &cache_manager, rsrc->stream);
// Set the inference context for this thread
setInferenceContext(&ctx);
// Create Device Resource
createDeviceResource(rsrc, &meta, weights, device, idev, ndev, dev_id, comm);
{
......@@ -549,7 +346,9 @@ void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceReso
break;
}
inferDeviceBatch(meta, *rsrc, idev, ndev, req.tokens, req.ntok, req.req_lens, req.nreq, req.req_pos, req.kv_caches, req.temperature, req.topk, req.topp, req.output);
inferDeviceBatch(meta, *rsrc, idev, ndev, req.tokens, req.ntok,
req.req_lens, req.nreq, req.req_pos, req.kv_caches,
req.temperature, req.topk, req.topp, req.output);
state.proceed = false;
lock.unlock();
......@@ -558,6 +357,7 @@ void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceReso
// Clean-Up
releaseDeviceResource(*rsrc);
setInferenceContext(nullptr); // Clear the context when done
}
JiugeModel::JiugeModel(const JiugeMeta *_meta, const JiugeWeights *weights, infiniDevice_t device_, std::vector<int> device_ids) : meta(*_meta) {
......
......@@ -51,10 +51,12 @@ private:
std::vector<size_t> _shape;
std::vector<ptrdiff_t> _strides;
infiniopTensorDescriptor_t _desc;
size_t _seed;
TensorDesc(infiniDtype_t dtype, const std::vector<size_t> &shape,
const std::vector<ptrdiff_t> &strides) : _dtype(dtype), _shape(shape), _strides(strides), _desc(nullptr) {}
const std::vector<ptrdiff_t> &strides) : _dtype(dtype), _shape(shape), _strides(strides), _desc(nullptr) { computeTensorDesHash(); }
void resetDesc();
void computeTensorDesHash();
public:
~TensorDesc();
......@@ -74,6 +76,7 @@ public:
infiniopTensorDescriptor_t desc() const;
bool isContigous() const;
std::string info() const;
size_t seed() const { return _seed; }
void dimMerge(size_t dim_start, size_t dim_end);
void dimSplit(size_t dim, const std::vector<size_t> &dims);
......@@ -83,7 +86,7 @@ public:
class Tensor : public std::enable_shared_from_this<Tensor> {
private:
std::shared_ptr<Storage> _storage;
std::shared_ptr<TensorDesc> _desc;
std::shared_ptr<const TensorDesc> _desc;
ptrdiff_t _offset;
......@@ -127,6 +130,11 @@ public:
void debug(const std::string &filename) const;
void debug() const;
std::string info() const;
size_t seed() const;
std::shared_ptr<Tensor> view(const std::vector<size_t> &new_shape) const;
std::shared_ptr<Tensor> view_as(const std::vector<size_t> &new_shape) const;
std::shared_ptr<Tensor> view_as(const std::vector<size_t> &new_shape, const std::vector<ptrdiff_t> &new_strides) const;
~Tensor();
};
......
......@@ -62,6 +62,16 @@ void TensorDesc::resetDesc() {
}
}
void TensorDesc::computeTensorDesHash() {
_seed = 0;
for (auto dim : this->shape()) {
hash_combine(_seed, dim);
}
for (auto stride : this->strides()) {
hash_combine(_seed, static_cast<size_t>(stride));
}
}
bool TensorDesc::isContigous() const {
auto ndim = this->ndim();
auto shape = this->shape();
......@@ -258,6 +268,86 @@ std::string Tensor::info() const {
return this->_desc->info();
}
size_t Tensor::seed() const {
return this->_desc->seed();
}
std::shared_ptr<Tensor> Tensor::view(const std::vector<size_t> &new_shape) const {
// Step 1: Validate total size
size_t numel = 1;
for (size_t dim : this->_desc->shape()) {
numel *= dim;
}
size_t new_numel = 1;
for (size_t dim : new_shape) {
new_numel *= dim;
}
ASSERT_EQ(numel, new_numel);
// Step 2: Get current shape and strides
const std::vector<size_t> &old_shape = this->_desc->shape();
const std::vector<ptrdiff_t> &old_strides = this->_desc->strides();
// Step 3: Create merged shape and strides
std::vector<size_t> merged_shape;
std::vector<ptrdiff_t> merged_strides;
if (!old_shape.empty()) {
merged_shape.push_back(old_shape[0]);
merged_strides.push_back(old_strides[0]);
for (size_t i = 1; i < old_shape.size(); ++i) {
if (old_strides[i] * static_cast<ptrdiff_t>(old_shape[i]) == merged_strides.back()) {
merged_shape.back() *= old_shape[i];
merged_strides.back() = old_strides[i];
} else {
merged_shape.push_back(old_shape[i]);
merged_strides.push_back(old_strides[i]);
}
}
}
// Step 4: Compute new strides by splitting merged dimensions
std::vector<ptrdiff_t> new_strides(new_shape.size());
size_t merged_idx = 0;
ptrdiff_t current_stride = merged_strides[0];
size_t remaining_size = merged_shape[0];
for (size_t i = 0; i < new_shape.size(); ++i) {
// Find which merged dimension contains this new dimension
while (new_shape[i] > remaining_size) {
ASSERT(++merged_idx < merged_shape.size());
current_stride = merged_strides[merged_idx];
remaining_size = merged_shape[merged_idx];
}
ASSERT_EQ(remaining_size % new_shape[i], 0);
new_strides[i] = current_stride * (remaining_size / new_shape[i]);
remaining_size /= new_shape[i];
}
return this->view_as(new_shape, new_strides);
}
std::shared_ptr<Tensor> Tensor::view_as(const std::vector<size_t> &new_shape) const {
std::shared_ptr<Tensor> tensor = std::make_shared<Tensor>();
tensor->_storage = this->_storage;
tensor->_desc = TensorDesc::create(this->dtype(), new_shape);
tensor->_offset = this->_offset;
return tensor;
}
std::shared_ptr<Tensor> Tensor::view_as(const std::vector<size_t> &new_shape, const std::vector<ptrdiff_t> &new_strides) const {
std::shared_ptr<Tensor> tensor = std::make_shared<Tensor>();
tensor->_storage = this->_storage;
tensor->_desc = TensorDesc::create(this->dtype(), new_shape, new_strides);
tensor->_offset = this->_offset;
return tensor;
}
void Tensor::debug(const std::string &filename) const {
RUN_INFINI(infinirtDeviceSynchronize());
......
......@@ -63,11 +63,18 @@ void TensorDesc::dimMerge(size_t dim_start, size_t dim_end) {
this->_shape = new_shape;
this->_strides = new_strides;
this->resetDesc();
this->computeTensorDesHash();
}
std::shared_ptr<Tensor> Tensor::dimMerge(size_t dim_start, size_t dim_end) {
this->_desc->dimMerge(dim_start, dim_end);
return shared_from_this();
auto new_desc = TensorDesc::create(_desc->dtype(), _desc->shape(), _desc->strides());
new_desc->dimMerge(dim_start, dim_end);
auto tensor = std::make_shared<Tensor>();
tensor->_storage = _storage;
tensor->_desc = new_desc;
tensor->_offset = _offset;
return tensor;
}
void TensorDesc::dimSplit(size_t dim, const std::vector<size_t> &dims) {
......@@ -89,11 +96,18 @@ void TensorDesc::dimSplit(size_t dim, const std::vector<size_t> &dims) {
this->_shape = new_shape;
this->_strides = new_strides;
this->resetDesc();
this->computeTensorDesHash();
}
std::shared_ptr<Tensor> Tensor::dimSplit(size_t dim, const std::vector<size_t> &dims) {
this->_desc->dimSplit(dim, dims);
return shared_from_this();
auto new_desc = TensorDesc::create(_desc->dtype(), _desc->shape(), _desc->strides());
new_desc->dimSplit(dim, dims);
auto tensor = std::make_shared<Tensor>();
tensor->_storage = _storage;
tensor->_desc = new_desc;
tensor->_offset = _offset;
return tensor;
}
void TensorDesc::permute(const std::vector<size_t> &order) {
......@@ -108,9 +122,16 @@ void TensorDesc::permute(const std::vector<size_t> &order) {
this->_shape = new_shape;
this->_strides = new_strides;
this->resetDesc();
this->computeTensorDesHash();
}
std::shared_ptr<Tensor> Tensor::permute(const std::vector<size_t> &order) {
this->_desc->permute(order);
return shared_from_this();
auto new_desc = TensorDesc::create(_desc->dtype(), _desc->shape(), _desc->strides());
new_desc->permute(order);
auto tensor = std::make_shared<Tensor>();
tensor->_storage = _storage;
tensor->_desc = new_desc;
tensor->_offset = _offset;
return tensor;
}
......@@ -119,4 +119,9 @@ inline uint16_t f32_to_bf16(float val) {
return bf16_bits;
}
// Hash combine utility (similar to boost::hash_combine)
inline void hash_combine(size_t &seed, size_t value) {
seed ^= value + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
#endif
......@@ -12,6 +12,7 @@ target("infinicore_infer")
set_languages("cxx17")
set_warnings("all", "error")
add_files("src/models/*.cpp")
add_files("src/models/*/*.cpp")
add_files("src/tensor/*.cpp")
add_files("src/allocator/*.cpp")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment