"git@developer.sourcefind.cn:yangql/googletest.git" did not exist on "88e97c822c988eaa9f8bcbaa1ea5d702ffd7d384"
Commit 366386d3 authored by wooway777's avatar wooway777
Browse files

issue/21 - removed descriptor overwrite

parent 726f444f
...@@ -21,23 +21,14 @@ inline void hash_combine(size_t &seed, T value, typename std::enable_if<std::is_ ...@@ -21,23 +21,14 @@ inline void hash_combine(size_t &seed, T value, typename std::enable_if<std::is_
hash_combine(seed, static_cast<size_t>(value)); hash_combine(seed, static_cast<size_t>(value));
} }
// Specialization for float to handle potential precision issues
inline void hash_combine(size_t &seed, float value) {
// Treat float bits as uint32_t for consistent hashing
uint32_t int_value;
static_assert(sizeof(value) == sizeof(int_value), "Size mismatch");
std::memcpy(&int_value, &value, sizeof(value));
hash_combine(seed, static_cast<size_t>(int_value));
}
// Helper function to compute hash for tensor descriptors // Helper function to compute hash for tensor descriptors
inline size_t computeTensorDescHash(std::shared_ptr<TensorDesc> desc) { inline size_t computeTensorDescHash(std::shared_ptr<Tensor> tensor) {
size_t seed = 0; size_t seed = 0;
hash_combine(seed, desc->dtype()); hash_combine(seed, tensor->dtype());
for (auto dim : desc->shape()) { for (auto dim : tensor->shape()) {
hash_combine(seed, dim); hash_combine(seed, dim);
} }
for (auto stride : desc->strides()) { for (auto stride : tensor->strides()) {
hash_combine(seed, static_cast<size_t>(stride)); hash_combine(seed, static_cast<size_t>(stride));
} }
return seed; return seed;
...@@ -185,7 +176,7 @@ public: ...@@ -185,7 +176,7 @@ public:
class CacheManager { class CacheManager {
private: private:
const size_t DEFAULT_CACHE_CAPACITY = 100; const size_t DEFAULT_CACHE_CAPACITY = 128;
LRUDescriptorCache<infiniopRMSNormDescriptor_t> rms_norm_cache; LRUDescriptorCache<infiniopRMSNormDescriptor_t> rms_norm_cache;
LRUDescriptorCache<infiniopGemmDescriptor_t> gemm_cache; LRUDescriptorCache<infiniopGemmDescriptor_t> gemm_cache;
...@@ -267,11 +258,11 @@ public: ...@@ -267,11 +258,11 @@ public:
random_sample_cache.put(key, desc); random_sample_cache.put(key, desc);
} }
static size_t createDescriptorKey(std::shared_ptr<TensorDesc> desc0, static size_t createDescriptorKey(std::shared_ptr<Tensor> desc0,
std::shared_ptr<TensorDesc> desc1, std::shared_ptr<Tensor> desc1,
std::shared_ptr<TensorDesc> desc2, std::shared_ptr<Tensor> desc2,
std::shared_ptr<TensorDesc> desc3, std::shared_ptr<Tensor> desc3,
std::shared_ptr<TensorDesc> desc4) { std::shared_ptr<Tensor> desc4) {
size_t seed = 0; size_t seed = 0;
if (desc0) { if (desc0) {
hash_combine(seed, computeTensorDescHash(desc0)); hash_combine(seed, computeTensorDescHash(desc0));
......
...@@ -16,7 +16,7 @@ void InferenceContext::rmsnorm(std::shared_ptr<Tensor> y, ...@@ -16,7 +16,7 @@ void InferenceContext::rmsnorm(std::shared_ptr<Tensor> y,
std::shared_ptr<Tensor> x, std::shared_ptr<Tensor> x,
std::shared_ptr<Tensor> w, std::shared_ptr<Tensor> w,
float epsilon) { float epsilon) {
size_t key = CacheManager::createDescriptorKey(y->tdesc(), x->tdesc(), w->tdesc(), nullptr, nullptr); size_t key = CacheManager::createDescriptorKey(y, x, w, nullptr, nullptr);
infiniopRMSNormDescriptor_t desc; infiniopRMSNormDescriptor_t desc;
if (!cache_manager->getRMSNormDescriptor(key, desc)) { if (!cache_manager->getRMSNormDescriptor(key, desc)) {
...@@ -35,23 +35,16 @@ void InferenceContext::rmsnorm(std::shared_ptr<Tensor> y, ...@@ -35,23 +35,16 @@ void InferenceContext::rmsnorm(std::shared_ptr<Tensor> y,
y->data(), x->data(), w->data(), stream)); y->data(), x->data(), w->data(), stream));
} }
void InferenceContext::gemm(std::shared_ptr<Tensor> c, std::shared_ptr<TensorDesc> c_desc_overwrite, void InferenceContext::gemm(std::shared_ptr<Tensor> c,
std::shared_ptr<Tensor> a, std::shared_ptr<TensorDesc> a_desc_overwrite, std::shared_ptr<Tensor> a,
std::shared_ptr<Tensor> b, std::shared_ptr<TensorDesc> b_desc_overwrite, std::shared_ptr<Tensor> b,
float alpha, float beta) { float alpha, float beta) {
size_t key = CacheManager::createDescriptorKey( size_t key = CacheManager::createDescriptorKey(c, a, b,
c_desc_overwrite ? c_desc_overwrite : c->tdesc(), nullptr, nullptr);
a_desc_overwrite ? a_desc_overwrite : a->tdesc(),
b_desc_overwrite ? b_desc_overwrite : b->tdesc(),
nullptr, nullptr);
infiniopGemmDescriptor_t desc; infiniopGemmDescriptor_t desc;
if (!cache_manager->getGemmDescriptor(key, desc)) { if (!cache_manager->getGemmDescriptor(key, desc)) {
RUN_INFINI(infiniopCreateGemmDescriptor( RUN_INFINI(infiniopCreateGemmDescriptor(rsrc->handle, &desc, c->desc(), a->desc(), b->desc()));
rsrc->handle, &desc,
c_desc_overwrite ? c_desc_overwrite->desc() : c->desc(),
a_desc_overwrite ? a_desc_overwrite->desc() : a->desc(),
b_desc_overwrite ? b_desc_overwrite->desc() : b->desc()));
cache_manager->putGemmDescriptor(key, desc); cache_manager->putGemmDescriptor(key, desc);
} }
...@@ -65,19 +58,13 @@ void InferenceContext::gemm(std::shared_ptr<Tensor> c, std::shared_ptr<TensorDes ...@@ -65,19 +58,13 @@ void InferenceContext::gemm(std::shared_ptr<Tensor> c, std::shared_ptr<TensorDes
c->data(), a->data(), b->data(), alpha, beta, stream)); c->data(), a->data(), b->data(), alpha, beta, stream));
} }
void InferenceContext::rearrange(std::shared_ptr<Tensor> dst, std::shared_ptr<TensorDesc> dst_desc_overwrite, void InferenceContext::rearrange(std::shared_ptr<Tensor> dst,
std::shared_ptr<Tensor> src, std::shared_ptr<TensorDesc> src_desc_overwrite) { std::shared_ptr<Tensor> src) {
size_t key = CacheManager::createDescriptorKey( size_t key = CacheManager::createDescriptorKey(dst, src, nullptr, nullptr, nullptr);
dst_desc_overwrite ? dst_desc_overwrite : dst->tdesc(),
src_desc_overwrite ? src_desc_overwrite : src->tdesc(),
nullptr, nullptr, nullptr);
infiniopRearrangeDescriptor_t desc; infiniopRearrangeDescriptor_t desc;
if (!cache_manager->getRearrangeDescriptor(key, desc)) { if (!cache_manager->getRearrangeDescriptor(key, desc)) {
RUN_INFINI(infiniopCreateRearrangeDescriptor( RUN_INFINI(infiniopCreateRearrangeDescriptor(rsrc->handle, &desc, dst->desc(), src->desc()));
rsrc->handle, &desc,
dst_desc_overwrite ? dst_desc_overwrite->desc() : dst->desc(),
src_desc_overwrite ? src_desc_overwrite->desc() : src->desc()));
cache_manager->putRearrangeDescriptor(key, desc); cache_manager->putRearrangeDescriptor(key, desc);
} }
...@@ -93,7 +80,7 @@ void InferenceContext::rope(std::shared_ptr<Tensor> q, ...@@ -93,7 +80,7 @@ void InferenceContext::rope(std::shared_ptr<Tensor> q,
std::shared_ptr<Tensor> pos, std::shared_ptr<Tensor> pos,
std::shared_ptr<Tensor> sin, std::shared_ptr<Tensor> sin,
std::shared_ptr<Tensor> cos) { std::shared_ptr<Tensor> cos) {
size_t key = CacheManager::createDescriptorKey(q->tdesc(), k->tdesc(), pos->tdesc(), sin->tdesc(), cos->tdesc()); size_t key = CacheManager::createDescriptorKey(q, k, pos, sin, cos);
infiniopRoPEDescriptor_t desc; infiniopRoPEDescriptor_t desc;
if (!cache_manager->getRoPEDescriptor(key, desc)) { if (!cache_manager->getRoPEDescriptor(key, desc)) {
...@@ -114,19 +101,14 @@ void InferenceContext::rope(std::shared_ptr<Tensor> q, ...@@ -114,19 +101,14 @@ void InferenceContext::rope(std::shared_ptr<Tensor> q,
sin->data(), cos->data(), stream)); sin->data(), cos->data(), stream));
} }
void InferenceContext::causalSoftmax(std::shared_ptr<Tensor> y, std::shared_ptr<TensorDesc> y_desc_overwrite, void InferenceContext::causalSoftmax(std::shared_ptr<Tensor> y,
std::shared_ptr<Tensor> x, std::shared_ptr<TensorDesc> x_desc_overwrite) { std::shared_ptr<Tensor> x) {
size_t key = CacheManager::createDescriptorKey( size_t key = CacheManager::createDescriptorKey(y, x, nullptr, nullptr, nullptr);
y_desc_overwrite ? y_desc_overwrite : y->tdesc(),
x_desc_overwrite ? x_desc_overwrite : x->tdesc(),
nullptr, nullptr, nullptr);
infiniopCausalSoftmaxDescriptor_t desc; infiniopCausalSoftmaxDescriptor_t desc;
if (!cache_manager->getCausalSoftmaxDescriptor(key, desc)) { if (!cache_manager->getCausalSoftmaxDescriptor(key, desc)) {
RUN_INFINI(infiniopCreateCausalSoftmaxDescriptor( RUN_INFINI(infiniopCreateCausalSoftmaxDescriptor(
rsrc->handle, &desc, rsrc->handle, &desc, y->desc(), x->desc()));
y_desc_overwrite ? y_desc_overwrite->desc() : y->desc(),
x_desc_overwrite ? x_desc_overwrite->desc() : x->desc()));
cache_manager->putCausalSoftmaxDescriptor(key, desc); cache_manager->putCausalSoftmaxDescriptor(key, desc);
} }
...@@ -139,8 +121,10 @@ void InferenceContext::causalSoftmax(std::shared_ptr<Tensor> y, std::shared_ptr< ...@@ -139,8 +121,10 @@ void InferenceContext::causalSoftmax(std::shared_ptr<Tensor> y, std::shared_ptr<
y->data(), x->data(), stream)); y->data(), x->data(), stream));
} }
void InferenceContext::swiglu(std::shared_ptr<Tensor> out, std::shared_ptr<Tensor> up, std::shared_ptr<Tensor> gate) { void InferenceContext::swiglu(std::shared_ptr<Tensor> out,
size_t key = CacheManager::createDescriptorKey(out->tdesc(), up->tdesc(), gate->tdesc(), nullptr, nullptr); std::shared_ptr<Tensor> up,
std::shared_ptr<Tensor> gate) {
size_t key = CacheManager::createDescriptorKey(out, up, gate, nullptr, nullptr);
infiniopSwiGLUDescriptor_t desc; infiniopSwiGLUDescriptor_t desc;
if (!cache_manager->getSwiGLUDescriptor(key, desc)) { if (!cache_manager->getSwiGLUDescriptor(key, desc)) {
...@@ -158,20 +142,15 @@ void InferenceContext::swiglu(std::shared_ptr<Tensor> out, std::shared_ptr<Tenso ...@@ -158,20 +142,15 @@ void InferenceContext::swiglu(std::shared_ptr<Tensor> out, std::shared_ptr<Tenso
out->data(), up->data(), gate->data(), stream)); out->data(), up->data(), gate->data(), stream));
} }
void InferenceContext::randomSample(std::shared_ptr<Tensor> out, std::shared_ptr<TensorDesc> out_desc_overwrite, void InferenceContext::randomSample(std::shared_ptr<Tensor> out,
std::shared_ptr<Tensor> prob, std::shared_ptr<TensorDesc> prob_desc_overwrite, std::shared_ptr<Tensor> prob,
float random_val, float top_p, uint32_t top_k, float temperature) { float random_val, float top_p, uint32_t top_k, float temperature) {
size_t key = CacheManager::createDescriptorKey( size_t key = CacheManager::createDescriptorKey(out, prob, nullptr, nullptr, nullptr);
out_desc_overwrite ? out_desc_overwrite : out->tdesc(),
prob_desc_overwrite ? prob_desc_overwrite : prob->tdesc(),
nullptr, nullptr, nullptr);
infiniopRandomSampleDescriptor_t desc; infiniopRandomSampleDescriptor_t desc;
if (!cache_manager->getRandomSampleDescriptor(key, desc)) { if (!cache_manager->getRandomSampleDescriptor(key, desc)) {
RUN_INFINI(infiniopCreateRandomSampleDescriptor( RUN_INFINI(infiniopCreateRandomSampleDescriptor(
rsrc->handle, &desc, rsrc->handle, &desc, out->desc(), prob->desc()));
out_desc_overwrite ? out_desc_overwrite->desc() : out->desc(),
prob_desc_overwrite ? prob_desc_overwrite->desc() : prob->desc()));
cache_manager->putRandomSampleDescriptor(key, desc); cache_manager->putRandomSampleDescriptor(key, desc);
} }
......
...@@ -15,25 +15,28 @@ struct InferenceContext { ...@@ -15,25 +15,28 @@ struct InferenceContext {
InferenceContext(DeviceResource *rsrc, CacheManager *cache_manager, infinirtStream_t stream); InferenceContext(DeviceResource *rsrc, CacheManager *cache_manager, infinirtStream_t stream);
void ensure_workspace(size_t required_size); void ensure_workspace(size_t required_size);
void rmsnorm(std::shared_ptr<Tensor> y, void rmsnorm(std::shared_ptr<Tensor> y,
std::shared_ptr<Tensor> x, std::shared_ptr<Tensor> x,
std::shared_ptr<Tensor> w, std::shared_ptr<Tensor> w,
float epsilon); float epsilon);
void gemm(std::shared_ptr<Tensor> c, std::shared_ptr<TensorDesc> c_desc_overwrite, void gemm(std::shared_ptr<Tensor> c,
std::shared_ptr<Tensor> a, std::shared_ptr<TensorDesc> a_desc_overwrite, std::shared_ptr<Tensor> a,
std::shared_ptr<Tensor> b, std::shared_ptr<TensorDesc> b_desc_overwrite, std::shared_ptr<Tensor> b,
float alpha, float beta); float alpha, float beta);
void rearrange(std::shared_ptr<Tensor> dst, std::shared_ptr<TensorDesc> dst_desc_overwrite, void rearrange(std::shared_ptr<Tensor> dst,
std::shared_ptr<Tensor> src, std::shared_ptr<TensorDesc> src_desc_overwrite); std::shared_ptr<Tensor> src);
void rope(std::shared_ptr<Tensor> q, void rope(std::shared_ptr<Tensor> q,
std::shared_ptr<Tensor> k, std::shared_ptr<Tensor> k,
std::shared_ptr<Tensor> pos, std::shared_ptr<Tensor> pos,
std::shared_ptr<Tensor> sin, std::shared_ptr<Tensor> sin,
std::shared_ptr<Tensor> cos); std::shared_ptr<Tensor> cos);
void causalSoftmax(std::shared_ptr<Tensor> y, std::shared_ptr<TensorDesc> y_desc_overwrite, void causalSoftmax(std::shared_ptr<Tensor> y,
std::shared_ptr<Tensor> x, std::shared_ptr<TensorDesc> x_desc_overwrite); std::shared_ptr<Tensor> x);
void swiglu(std::shared_ptr<Tensor> out, std::shared_ptr<Tensor> up, std::shared_ptr<Tensor> gate); void swiglu(std::shared_ptr<Tensor> out,
void randomSample(std::shared_ptr<Tensor> out, std::shared_ptr<TensorDesc> out_desc_overwrite, std::shared_ptr<Tensor> up,
std::shared_ptr<Tensor> prob, std::shared_ptr<TensorDesc> prob_desc_overwrite, std::shared_ptr<Tensor> gate);
void randomSample(std::shared_ptr<Tensor> out,
std::shared_ptr<Tensor> prob,
float random_val, float top_p, uint32_t top_k, float temperature); float random_val, float top_p, uint32_t top_k, float temperature);
}; };
...@@ -166,15 +166,10 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -166,15 +166,10 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
} }
// Attention // Attention
auto qkv_desc = TensorDesc::create(dt_logits, qkv_buf->shape(), qkv_buf->strides());
auto b_attn_qkv_desc = TensorDesc::create(dt_logits, {ntok, (nh + nkvh * 2) * dh}, {0, 1});
auto o_desc = TensorDesc::create(dt_logits, o_buf->shape(), o_buf->strides());
qkv_buf->dimSplit(1, {nh + nkvh * 2, dh}); // (ntok, nh + 2 * nkvh, dh) qkv_buf->dimSplit(1, {nh + nkvh * 2, dh}); // (ntok, nh + 2 * nkvh, dh)
// attention inner // attention inner
size_t max_qk_size = 0; size_t max_qk_size = 0;
size_t max_seq_len = 0; size_t max_seq_len = 0;
o_buf->dimSplit(1, {nh, dh});
for (uint32_t req = 0; req < nreq; req++) { for (uint32_t req = 0; req < nreq; req++) {
auto past_len = req_pos[req]; auto past_len = req_pos[req];
...@@ -193,24 +188,19 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -193,24 +188,19 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto gate_buf = gate_up_buf->slice(1, 0, di); auto gate_buf = gate_up_buf->slice(1, 0, di);
auto up_buf = gate_up_buf->slice(1, di, di); auto up_buf = gate_up_buf->slice(1, di, di);
// Output and sample
auto result_desc = TensorDesc::create(INFINI_DTYPE_I64, {}, {});
auto prob_desc = TensorDesc::create(dt_logits, {dvoc}, {1});
// Compute // Compute
for (uint32_t layer = 0; layer < nlayer; layer++) { for (uint32_t layer = 0; layer < nlayer; layer++) {
// 1. Attention // 1. Attention
// rms norm // rms norm
ctx.rmsnorm(logits_out, logits_in, rsrc.w_attn_norm[layer], meta.epsilon); ctx.rmsnorm(logits_out, logits_in, rsrc.w_attn_norm[layer], meta.epsilon);
// qkv_proj // qkv_proj
qkv_buf->dimMerge(1, 2);
if (has_qkv_bias) { if (has_qkv_bias) {
ctx.rearrange(qkv_buf, qkv_desc, rsrc.b_attn_qkv[layer], b_attn_qkv_desc); ctx.rearrange(qkv_buf, rsrc.b_attn_qkv[layer]->reDesc({ntok, (nh + nkvh * 2) * dh}, {0, 1}));
} }
ctx.gemm(qkv_buf, qkv_desc, ctx.gemm(qkv_buf, logits_out, rsrc.w_attn_qkv[layer], 1.0, has_qkv_bias ? 1.0 : 0.0);
logits_out, nullptr,
rsrc.w_attn_qkv[layer], nullptr,
1.0, has_qkv_bias ? 1.0 : 0.0);
// rope // rope
qkv_buf->dimSplit(1, {nh + nkvh * 2, dh});
ctx.rope(qkv_buf->slice(1, 0, nh), qkv_buf->slice(1, 0, nh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table); ctx.rope(qkv_buf->slice(1, 0, nh), qkv_buf->slice(1, 0, nh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table);
ctx.rope(qkv_buf->slice(1, nh, nkvh), qkv_buf->slice(1, nh, nkvh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table); ctx.rope(qkv_buf->slice(1, nh, nkvh), qkv_buf->slice(1, nh, nkvh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table);
...@@ -219,43 +209,41 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -219,43 +209,41 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto past_len = req_pos[req]; auto past_len = req_pos[req];
auto seq_len = req_lens[req]; auto seq_len = req_lens[req];
auto total_len = past_len + seq_len; auto total_len = past_len + seq_len;
auto o = o_buf->slice({{0, token_offset, seq_len}}); auto o = o_buf->dimSplit(1, {nh, dh})->slice({{0, token_offset, seq_len}});
auto q = qkv_buf->slice({{0, token_offset, seq_len}, {1, 0, nh}}); auto q = qkv_buf->slice({{0, token_offset, seq_len}, {1, 0, nh}});
auto k = qkv_buf->slice({{0, token_offset, seq_len}, {1, nh, nkvh}}); auto k = qkv_buf->slice({{0, token_offset, seq_len}, {1, nh, nkvh}});
auto v = qkv_buf->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}}); auto v = qkv_buf->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}});
auto qt_rearrange_desc = TensorDesc::create(dt_logits, {nkvh, ngroup, seq_len, dh});
auto qt_gemm_desc = TensorDesc::create(dt_logits, {nkvh, ngroup * seq_len, dh});
auto qk_gemm_desc = TensorDesc::create(dt_logits, {nkvh, ngroup * seq_len, total_len});
// self attention // self attention
// concat // concat
ctx.rearrange(kv_caches[req]->k[idev][layer]->slice(0, past_len, seq_len), nullptr, k, nullptr); ctx.rearrange(kv_caches[req]->k[idev][layer]->slice(0, past_len, seq_len), k);
ctx.rearrange(kv_caches[req]->v[idev][layer]->slice(0, past_len, seq_len), nullptr, v, nullptr); ctx.rearrange(kv_caches[req]->v[idev][layer]->slice(0, past_len, seq_len), v);
// qk // qk
ctx.rearrange(rearrange_q_buf, qt_rearrange_desc, ctx.rearrange(rearrange_q_buf->dimSplit(1, {ngroup, seq_len}),
q->dimSplit(1, {nkvh, ngroup})->permute({1, 2, 0, 3}), nullptr); q->dimSplit(1, {nkvh, ngroup})->permute({1, 2, 0, 3}));
ctx.gemm(qk_buf, qk_gemm_desc, qk_buf->dimSplit(1, {seq_len, total_len});
rearrange_q_buf, qt_gemm_desc, qk_buf->dimSplit(0, {nkvh, ngroup});
kv_caches[req]->k[idev][layer]->slice(0, 0, total_len)->permute({1, 2, 0}), nullptr, qk_buf->dimMerge(1, 2);
1. / sqrt(dh), 0.0); ctx.gemm(qk_buf, rearrange_q_buf->dimMerge(1, 2), kv_caches[req]->k[idev][layer]->slice(0, 0, total_len)->permute({1, 2, 0}), 1. / sqrt(dh), 0.0);
// softmax // softmax
auto qk_desc = TensorDesc::create(dt_logits, {nkvh * ngroup, seq_len, total_len}); qk_buf->dimSplit(1, {ngroup, seq_len});
ctx.causalSoftmax(qk_buf, qk_desc, qk_buf, qk_desc); qk_buf->dimMerge(0, 1);
ctx.gemm(attn_val_buf, qt_gemm_desc, ctx.causalSoftmax(qk_buf, qk_buf);
qk_buf, qk_gemm_desc, qk_buf->dimSplit(0, {nkvh, ngroup});
kv_caches[req]->v[idev][layer]->slice(0, 0, total_len)->permute({1, 0, 2}), nullptr, qk_buf->dimMerge(1, 2);
1.0, 0.0); ctx.gemm(attn_val_buf, qk_buf, kv_caches[req]->v[idev][layer]->slice(0, 0, total_len)->permute({1, 0, 2}), 1.0, 0.0);
qk_buf->dimSplit(1, {ngroup, seq_len});
qk_buf->dimMerge(2, 3);
qk_buf->dimMerge(0, 1);
// rearrange attn val // rearrange attn val
ctx.rearrange(o, TensorDesc::createWithOrder(dt_logits, {nkvh, ngroup, seq_len, dh}, {1, 2, 0, 3}), attn_val_buf->dimSplit(0, {nkvh, ngroup});
attn_val_buf, qt_rearrange_desc); ctx.rearrange(o->dimSplit(1, {nkvh, ngroup})->permute({1, 2, 0, 3}), attn_val_buf);
attn_val_buf->dimMerge(0, 1);
token_offset += seq_len; token_offset += seq_len;
} }
// o_proj // o_proj
ctx.gemm(logits_in, nullptr, ctx.gemm(logits_in, o_buf->dimMerge(1, 2), rsrc.w_attn_out[layer], 1.0, idev == 0 ? 1.0 : 0.0); // only rank 0 adds residual
o_buf, o_desc,
rsrc.w_attn_out[layer], nullptr,
1.0, idev == 0 ? 1.0 : 0.0); // only rank 0 adds residual
// All_reduce if distributed // All_reduce if distributed
if (rsrc.comm != nullptr) { if (rsrc.comm != nullptr) {
...@@ -267,15 +255,9 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -267,15 +255,9 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
// 2. FFN // 2. FFN
// rms_norm // rms_norm
ctx.rmsnorm(logits_out, logits_in, rsrc.w_ffn_norm[layer], meta.epsilon); ctx.rmsnorm(logits_out, logits_in, rsrc.w_ffn_norm[layer], meta.epsilon);
ctx.gemm(gate_up_buf, nullptr, ctx.gemm(gate_up_buf, logits_out, rsrc.w_ffn_gate_up[layer], 1.0, 0.0);
logits_out, nullptr,
rsrc.w_ffn_gate_up[layer], nullptr,
1.0, 0.0);
ctx.swiglu(gate_buf, up_buf, gate_buf); ctx.swiglu(gate_buf, up_buf, gate_buf);
ctx.gemm(logits_in, nullptr, ctx.gemm(logits_in, gate_buf, rsrc.w_ffn_down[layer], 1.0, idev == 0 ? 1.0 : 0.0); // only rank 0 adds residual
gate_buf, nullptr,
rsrc.w_ffn_down[layer], nullptr,
1.0, idev == 0 ? 1.0 : 0.0); // only rank 0 adds residual
// All_reduce if distributed // All_reduce if distributed
if (rsrc.comm != nullptr) { if (rsrc.comm != nullptr) {
...@@ -296,18 +278,15 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -296,18 +278,15 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
rsrc.w_out_norm, rsrc.w_out_norm,
meta.epsilon); meta.epsilon);
} }
ctx.gemm(prob_buf, nullptr, ctx.gemm(prob_buf, logits_out->slice(0, 0, nreq), rsrc.w_out_embd, 1.0, 0.0);
logits_out->slice(0, 0, nreq), nullptr,
rsrc.w_out_embd, nullptr,
1.0, 0.0);
std::random_device _rd; std::random_device _rd;
std::mt19937 gen(_rd()); std::mt19937 gen(_rd());
token_offset = 0; token_offset = 0;
for (uint32_t req = 0; req < nreq; req++) { for (uint32_t req = 0; req < nreq; req++) {
auto seq_len = req_lens[req]; auto seq_len = req_lens[req];
float random_val = std::uniform_real_distribution<float>(0, 1)(gen); float random_val = std::uniform_real_distribution<float>(0, 1)(gen);
ctx.randomSample(result_buf->slice(0, req, 1), result_desc, ctx.randomSample(result_buf->reDesc({}, {}),
prob_buf->slice(0, req, 1), prob_desc, prob_buf->reDesc({dvoc}, {1}),
random_val, topp[req], topk[req], temperature[req]); random_val, topp[req], topk[req], temperature[req]);
token_offset += seq_len; token_offset += seq_len;
} }
...@@ -354,7 +333,7 @@ inferBatch(struct JiugeModel *model, ...@@ -354,7 +333,7 @@ inferBatch(struct JiugeModel *model,
void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceResource *rsrc, InferState &state, InferRequest &req, void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceResource *rsrc, InferState &state, InferRequest &req,
infiniDevice_t device, int idev, int ndev, int dev_id, infinicclComm_t comm) { infiniDevice_t device, int idev, int ndev, int dev_id, infinicclComm_t comm) {
CacheManager cache_manager(100); CacheManager cache_manager(256);
InferenceContext ctx(rsrc, &cache_manager, rsrc->stream); InferenceContext ctx(rsrc, &cache_manager, rsrc->stream);
// Create Device Resource // Create Device Resource
......
...@@ -78,6 +78,7 @@ public: ...@@ -78,6 +78,7 @@ public:
void dimMerge(size_t dim_start, size_t dim_end); void dimMerge(size_t dim_start, size_t dim_end);
void dimSplit(size_t dim, const std::vector<size_t> &dims); void dimSplit(size_t dim, const std::vector<size_t> &dims);
void permute(const std::vector<size_t> &order); void permute(const std::vector<size_t> &order);
void reDesc(const std::vector<size_t> new_shape, const std::vector<ptrdiff_t> new_strides);
}; };
class Tensor : public std::enable_shared_from_this<Tensor> { class Tensor : public std::enable_shared_from_this<Tensor> {
...@@ -110,6 +111,7 @@ public: ...@@ -110,6 +111,7 @@ public:
std::shared_ptr<Tensor> dimSplit(size_t dim, std::shared_ptr<Tensor> dimSplit(size_t dim,
const std::vector<size_t> &dims); const std::vector<size_t> &dims);
std::shared_ptr<Tensor> permute(const std::vector<size_t> &order); std::shared_ptr<Tensor> permute(const std::vector<size_t> &order);
std::shared_ptr<Tensor> reDesc(const std::vector<size_t> new_shape, const std::vector<ptrdiff_t> new_strides);
void *data(ptrdiff_t offset = 0); void *data(ptrdiff_t offset = 0);
void const *data(ptrdiff_t offset = 0) const; void const *data(ptrdiff_t offset = 0) const;
void copyFrom(std::shared_ptr<Tensor const> src, infiniopHandle_t handle, void copyFrom(std::shared_ptr<Tensor const> src, infiniopHandle_t handle,
...@@ -120,7 +122,6 @@ public: ...@@ -120,7 +122,6 @@ public:
infiniDtype_t dtype() const; infiniDtype_t dtype() const;
bool isContigous() const; bool isContigous() const;
infiniopTensorDescriptor_t desc() const; infiniopTensorDescriptor_t desc() const;
std::shared_ptr<TensorDesc> tdesc() const;
ptrdiff_t dataOffset() const; ptrdiff_t dataOffset() const;
infiniDevice_t deviceType() const; infiniDevice_t deviceType() const;
int deviceId() const; int deviceId() const;
......
...@@ -108,7 +108,6 @@ ptrdiff_t Tensor::dataOffset() const { ...@@ -108,7 +108,6 @@ ptrdiff_t Tensor::dataOffset() const {
} }
infiniopTensorDescriptor_t Tensor::desc() const { return _desc->desc(); } infiniopTensorDescriptor_t Tensor::desc() const { return _desc->desc(); }
std::shared_ptr<TensorDesc> Tensor::tdesc() const { return _desc; }
std::shared_ptr<Tensor> Tensor::buffer(infiniDtype_t dtype, std::shared_ptr<Tensor> Tensor::buffer(infiniDtype_t dtype,
const std::vector<size_t> &shape, const std::vector<size_t> &shape,
......
...@@ -114,3 +114,14 @@ std::shared_ptr<Tensor> Tensor::permute(const std::vector<size_t> &order) { ...@@ -114,3 +114,14 @@ std::shared_ptr<Tensor> Tensor::permute(const std::vector<size_t> &order) {
this->_desc->permute(order); this->_desc->permute(order);
return shared_from_this(); return shared_from_this();
} }
void TensorDesc::reDesc(const std::vector<size_t> new_shape, const std::vector<ptrdiff_t> new_strides) {
this->_shape = new_shape;
this->_strides = new_strides;
this->resetDesc();
}
std::shared_ptr<Tensor> Tensor::reDesc(const std::vector<size_t> new_shape, const std::vector<ptrdiff_t> new_strides) {
this->_desc->reDesc(new_shape, new_strides);
return shared_from_this();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment