#include "jiuge_impl.hpp" #include "jiuge_weight.hpp" #include "../../tensor.hpp" #include "../../utils.hpp" #include "../inference_context.hpp" #include "infinicore_infer.h" #include #include #include void createDeviceResource(DeviceResource *rsrc, const JiugeMeta *meta, const JiugeWeights *weights, infiniDevice_t device, int idev, int ndev, int dev_id, infinicclComm_t comm) { RUN_INFINI(infinirtSetDevice(device, dev_id)); infiniopHandle_t handle; infiniopCreateHandle(&handle); infinirtStream_t stream; infinirtStreamCreate(&stream); std::vector> w_attn_norm, w_attn_qkv, b_attn_qkv, w_attn_out, w_ffn_norm, w_ffn_gate_up, w_ffn_down; for (size_t layer = 0; layer < meta->nlayer; layer++) { w_attn_norm.push_back( getAttnNorm(meta, weights, layer)); w_attn_qkv.push_back( getAttnQKV(meta, weights, layer, idev, ndev)); if (weights->attn_qkv_b != nullptr) { b_attn_qkv.push_back( getAttnQKVBias(meta, weights, layer, idev, ndev)); } w_attn_out.push_back( getAttnO(meta, weights, layer, idev, ndev)); w_ffn_norm.push_back( getFFNNorm(meta, weights, layer)); w_ffn_gate_up.push_back( getFFNGateUp(meta, weights, layer, idev, ndev)); w_ffn_down.push_back( getFFNDown(meta, weights, layer, idev, ndev)); } auto memory_pool = std::make_shared(128 * 1024 * 1024); *rsrc = DeviceResource{ device, dev_id, handle, getInEmbd(meta, weights), getOutNorm(meta, weights), getOutEmbd(meta, weights), getSinTable(meta), getCosTable(meta), w_attn_norm, w_attn_qkv, b_attn_qkv, w_attn_out, w_ffn_norm, w_ffn_gate_up, w_ffn_down, stream, comm, memory_pool, }; RUN_INFINI(infinirtDeviceSynchronize()); } void releaseDeviceResource(DeviceResource &res) { infinirtDeviceSynchronize(); // Release individual Tensors res.w_in_embd.reset(); res.w_out_norm.reset(); res.w_out_embd.reset(); res.sin_table.reset(); res.cos_table.reset(); for (auto &t : res.w_attn_norm) { t.reset(); } res.w_attn_norm.clear(); for (auto &t : res.w_attn_qkv) { t.reset(); } res.w_attn_qkv.clear(); for (auto &t : res.b_attn_qkv) { t.reset(); } res.b_attn_qkv.clear(); for (auto &t : res.w_attn_out) { t.reset(); } res.w_attn_out.clear(); for (auto &t : res.w_ffn_norm) { t.reset(); } res.w_ffn_norm.clear(); for (auto &t : res.w_ffn_gate_up) { t.reset(); } res.w_ffn_gate_up.clear(); for (auto &t : res.w_ffn_down) { t.reset(); } res.w_ffn_down.clear(); infiniopDestroyHandle(res.handle); res.handle = nullptr; infinirtStreamDestroy(res.stream); res.stream = nullptr; infinicclCommDestroy(res.comm); res.comm = nullptr; } void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, uint32_t idev, uint32_t ndev, const uint32_t *tokens, uint32_t ntok, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, struct KVCache **kv_caches, const float *temperature, const uint32_t *topk, const float *topp, uint32_t *output) { auto nlayer = meta.nlayer; auto nkvh = meta.nkvh / ndev; auto nh = meta.nh / ndev; auto ngroup = nh / nkvh; // auto dctx = meta.dctx; auto dh = meta.dh; auto d = meta.d; auto dt_logits = meta.dt_logits; auto di = meta.di / ndev; auto dvoc = meta.dvoc; auto stream = rsrc.stream; bool has_qkv_bias = rsrc.b_attn_qkv.size() > 0; // Allocate buffers auto logits_in = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool); auto logits_out = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool); auto qkv_buf = Tensor::buffer(dt_logits, {ntok, (nh + nkvh * 2) * dh}, rsrc.memory_pool); auto gate_up_buf = Tensor::buffer(dt_logits, {ntok, 2 * di}, rsrc.memory_pool); auto o_buf = Tensor::buffer(dt_logits, {ntok, nh * dh}, rsrc.memory_pool); auto prob_buf = Tensor::buffer(dt_logits, {nreq, dvoc}, rsrc.memory_pool); auto result_buf = Tensor::buffer(INFINI_DTYPE_I64, {nreq}, rsrc.memory_pool); auto result_cpu = std::vector(nreq); auto qkv_rope = qkv_buf->view({ntok, nh + nkvh * 2, dh}); // Prepare inputs auto batch_pos_ids = std::vector(ntok); size_t req_start = 0; for (uint32_t req = 0; req < nreq; req++) { for (uint32_t i = 0; i < req_lens[req]; i++) { batch_pos_ids[req_start + i] = req_pos[req] + i; } req_start += req_lens[req]; } std::shared_ptr pos_ids_buf; if (rsrc.device == INFINI_DEVICE_CPU) { pos_ids_buf = Tensor::weight(batch_pos_ids.data(), INFINI_DTYPE_U32, {ntok}); } else { pos_ids_buf = Tensor::buffer(INFINI_DTYPE_U32, {ntok}, rsrc.memory_pool); RUN_INFINI(infinirtMemcpyAsync(pos_ids_buf->data(), batch_pos_ids.data(), sizeof(uint32_t) * ntok, INFINIRT_MEMCPY_H2D, stream)); } for (uint32_t i = 0; i < ntok; i++) { RUN_INFINI(infinirtMemcpyAsync(logits_in->data(i * d), rsrc.w_in_embd->data(tokens[i] * d), dsize(dt_logits) * d, INFINIRT_MEMCPY_D2D, stream)); } // Attention // attention inner size_t max_qk_size = 0; size_t max_seq_len = 0; for (uint32_t req = 0; req < nreq; req++) { auto past_len = req_pos[req]; auto seq_len = req_lens[req]; auto total_len = past_len + seq_len; max_qk_size = std::max(max_qk_size, size_t(seq_len * total_len)); max_seq_len = std::max(max_seq_len, size_t(seq_len)); } auto qk_buf = Tensor::buffer(dt_logits, {nh, max_qk_size}, rsrc.memory_pool); auto rearrange_q_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool); auto q_rearrange = rearrange_q_buf->view({nkvh, ngroup, max_seq_len, dh}); auto attn_val_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool); auto attn_val_gemm = attn_val_buf->view({nkvh, ngroup, max_seq_len, dh}); // MLP buffers auto gate_buf = gate_up_buf->slice(1, 0, di); auto up_buf = gate_up_buf->slice(1, di, di); // Compute for (uint32_t layer = 0; layer < nlayer; layer++) { // 1. Attention // rms norm rmsnorm(logits_out, logits_in, rsrc.w_attn_norm[layer], meta.epsilon); // qkv_proj linear(qkv_buf, logits_out, rsrc.w_attn_qkv[layer], 1.0, 0.0, nullptr, has_qkv_bias ? rsrc.b_attn_qkv[layer] : nullptr); // rope rope(qkv_rope->slice(1, 0, nh), qkv_rope->slice(1, 0, nh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table); rope(qkv_rope->slice(1, nh, nkvh), qkv_rope->slice(1, nh, nkvh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table); size_t token_offset = 0; for (uint32_t req = 0; req < nreq; req++) { auto past_len = req_pos[req]; auto seq_len = req_lens[req]; auto total_len = past_len + seq_len; auto o = o_buf->view({ntok, nh, dh})->slice({{0, token_offset, seq_len}})->dimSplit(1, {nkvh, ngroup})->permute({1, 2, 0, 3}); auto q = qkv_rope->slice({{0, token_offset, seq_len}, {1, 0, nh}})->dimSplit(1, {nkvh, ngroup})->permute({1, 2, 0, 3}); auto k = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh, nkvh}}); auto v = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}}); // self attention // concat rearrange(kv_caches[req]->k[idev][layer]->slice(0, past_len, seq_len), k); rearrange(kv_caches[req]->v[idev][layer]->slice(0, past_len, seq_len), v); // qk rearrange(q_rearrange, q); auto qk_gemm = qk_buf->view({nkvh, ngroup * seq_len, total_len}); auto k_gemm = kv_caches[req]->k[idev][layer]->slice(0, 0, total_len)->permute({1, 2, 0}); linear(qk_gemm, rearrange_q_buf, k_gemm, 1. / sqrt(dh), 0.0, nullptr, nullptr); // softmax auto qk_softmax = qk_buf->view({nh, seq_len, total_len}); causalSoftmax(qk_softmax, qk_softmax); auto v_gemm = kv_caches[req]->v[idev][layer]->slice(0, 0, total_len)->permute({1, 0, 2}); linear(attn_val_buf, qk_gemm, v_gemm, 1.0, 0.0, nullptr, nullptr); // rearrange attn val rearrange(o, attn_val_gemm); token_offset += seq_len; } // o_proj linear(logits_in, o_buf, rsrc.w_attn_out[layer], 1.0, 0.0, idev == 0 ? logits_in : nullptr, nullptr); // only rank 0 adds residual // All_reduce if distributed if (rsrc.comm != nullptr) { RUN_INFINI(infinicclAllReduce( logits_in->data(), logits_in->data(), ntok * d, dt_logits, INFINICCL_SUM, rsrc.comm, stream)); RUN_INFINI(infinirtStreamSynchronize(stream)); } // 2. FFN rmsnorm(logits_out, logits_in, rsrc.w_ffn_norm[layer], meta.epsilon); linear(gate_up_buf, logits_out, rsrc.w_ffn_gate_up[layer], 1.0, 0.0, nullptr, nullptr); swiglu(gate_buf, up_buf, gate_buf); linear(logits_in, gate_buf, rsrc.w_ffn_down[layer], 1.0, 0.0, idev == 0 ? logits_in : nullptr, nullptr); // only rank 0 adds residual // All_reduce if distributed if (rsrc.comm != nullptr) { RUN_INFINI(infinicclAllReduce( logits_in->data(), logits_in->data(), ntok * d, dt_logits, INFINICCL_SUM, rsrc.comm, stream)); RUN_INFINI(infinirtStreamSynchronize(stream)); } } // Sample and Output if (idev == 0) { size_t token_offset = 0; for (uint32_t req = 0; req < nreq; req++) { auto seq_len = req_lens[req]; token_offset += seq_len; rmsnorm(logits_out->slice(0, req, 1), logits_in->slice(0, token_offset - 1, 1), rsrc.w_out_norm, meta.epsilon); } linear(prob_buf, logits_out->slice(0, 0, nreq), rsrc.w_out_embd, 1.0, 0.0, nullptr, nullptr); std::random_device _rd; std::mt19937 gen(_rd()); token_offset = 0; for (uint32_t req = 0; req < nreq; req++) { auto seq_len = req_lens[req]; float random_val = std::uniform_real_distribution(0, 1)(gen); randomSample(result_buf->memShare({}, result_buf->dtype()), prob_buf->view_as({dvoc}, {1}), random_val, topp[req], topk[req], temperature[req]); token_offset += seq_len; } RUN_INFINI(infinirtStreamSynchronize(stream)); RUN_INFINI(infinirtMemcpy(result_cpu.data(), result_buf->data(), sizeof(int64_t) * nreq, INFINIRT_MEMCPY_D2H)); for (uint32_t req = 0; req < nreq; req++) { output[req] = result_cpu[req]; } } } __C void inferBatch(struct JiugeModel *model, const uint32_t *tokens, uint32_t ntok, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, struct KVCache **kv_caches, const float *temperature, const uint32_t *topk, const float *topp, uint32_t *output) { model->req.tokens = tokens; model->req.ntok = ntok; model->req.req_lens = req_lens; model->req.nreq = nreq; model->req.req_pos = req_pos; model->req.kv_caches = kv_caches; model->req.output = output; model->req.temperature = temperature; model->req.topk = topk; model->req.topp = topp; for (size_t idev = 0; idev < model->dev_ids.size(); idev++) { std::unique_lock lock(model->states[idev].mtx); model->states[idev].proceed = true; lock.unlock(); model->states[idev].cv_start.notify_one(); } for (size_t i = model->dev_ids.size(); i > 0; i--) { auto idev = i - 1; std::unique_lock lock(model->states[idev].mtx); model->states[idev].cv_done.wait(lock, [&] { return !(model->states[idev].proceed); }); lock.unlock(); } } void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceResource *rsrc, InferState &state, InferRequest &req, infiniDevice_t device, int idev, int ndev, int dev_id, infinicclComm_t comm) { CacheManager cache_manager(256); InferenceContext ctx(rsrc, &cache_manager, rsrc->stream); // Set the inference context for this thread setInferenceContext(&ctx); // Create Device Resource createDeviceResource(rsrc, &meta, weights, device, idev, ndev, dev_id, comm); { std::unique_lock lock(state.mtx); state.loaded = true; lock.unlock(); state.cv_load.notify_one(); } // Infer Loop while (true) { std::unique_lock lock(state.mtx); state.cv_start.wait(lock, [&] { return state.proceed || state.exit_flag; }); // quit if exit_flag is set if (state.exit_flag) { break; } inferDeviceBatch(meta, *rsrc, idev, ndev, req.tokens, req.ntok, req.req_lens, req.nreq, req.req_pos, req.kv_caches, req.temperature, req.topk, req.topp, req.output); state.proceed = false; lock.unlock(); state.cv_done.notify_one(); } // Clean-Up releaseDeviceResource(*rsrc); setInferenceContext(nullptr); // Clear the context when done } JiugeModel::JiugeModel(const JiugeMeta *_meta, const JiugeWeights *weights, infiniDevice_t device_, std::vector device_ids) : meta(*_meta) { int ndev = int(device_ids.size()); device = device_; dev_ids = device_ids; dev_resources = std::vector(ndev); states = std::vector(ndev); threads.resize(ndev); RUN_INFINI(infinirtInit()); auto comms = std::vector(ndev, nullptr); if (ndev > 1) { RUN_INFINI(infinicclCommInitAll(device, comms.data(), ndev, dev_ids.data())); } for (int i = 0; i < ndev; i++) { threads[i] = std::thread(launchDevice, std::cref(meta), weights, &dev_resources[i], std::ref(states[i]), std::ref(req), device, i, ndev, dev_ids[i], comms[i]); } for (int i = 0; i < ndev; i++) { std::unique_lock lock(states[i].mtx); states[i].cv_load.wait(lock, [&] { return states[i].loaded; }); lock.unlock(); } } __C struct JiugeModel * createJiugeModel(const JiugeMeta *meta, const JiugeWeights *weights, infiniDevice_t device, int ndev, const int *dev_ids) { std::vector device_ids(ndev); std::copy(dev_ids, dev_ids + ndev, device_ids.begin()); JiugeModel *model = new JiugeModel(meta, weights, device, device_ids); return model; } __C void destroyJiugeModel(struct JiugeModel *model) { auto ndev = model->dev_resources.size(); for (size_t idev = 0; idev < ndev; idev++) { std::unique_lock lock(model->states[idev].mtx); model->states[idev].exit_flag = true; lock.unlock(); model->states[idev].cv_start.notify_one(); } for (size_t idev = 0; idev < ndev; idev++) { model->threads[idev].join(); } delete model; }