Commit 176b0b71 authored by PanZezhong's avatar PanZezhong
Browse files

refactor attention, use BSND layout

parent 51b1aade
...@@ -119,6 +119,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -119,6 +119,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto nlayer = meta.nlayer; auto nlayer = meta.nlayer;
auto nkvh = meta.nkvh / ndev; auto nkvh = meta.nkvh / ndev;
auto nh = meta.nh / ndev; auto nh = meta.nh / ndev;
auto ngroup = nh / nkvh;
// auto dctx = meta.dctx; // auto dctx = meta.dctx;
auto dh = meta.dh; auto dh = meta.dh;
auto d = meta.d; auto d = meta.d;
...@@ -208,30 +209,70 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -208,30 +209,70 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
RUN_INFINI(infiniopGetRoPEWorkspaceSize(desc_rope_k, &temp_size)); RUN_INFINI(infiniopGetRoPEWorkspaceSize(desc_rope_k, &temp_size));
workspace_size = std::max(workspace_size, temp_size); workspace_size = std::max(workspace_size, temp_size);
// attention inner // attention inner
auto desc_attns = std::vector<infiniopAttentionDescriptor_t>(nreq); auto desc_kv_rearranges = std::vector<infiniopRearrangeDescriptor_t>(nreq);
auto desc_q_rearranges = std::vector<infiniopRearrangeDescriptor_t>(nreq);
auto desc_qk_gemms = std::vector<infiniopGemmDescriptor_t>(nreq);
auto desc_qk_softmaxs = std::vector<infiniopCausalSoftmaxDescriptor_t>(nreq);
auto desc_attn_v_gemms = std::vector<infiniopGemmDescriptor_t>(nreq);
auto desc_attn_v_rearranges = std::vector<infiniopRearrangeDescriptor_t>(nreq);
size_t token_offset = 0; size_t token_offset = 0;
size_t max_qk_size = 0;
size_t max_seq_len = 0;
o_buf->dimSplit(1, {nh, dh}); o_buf->dimSplit(1, {nh, dh});
for (uint32_t req = 0; req < nreq; req++) { for (uint32_t req = 0; req < nreq; req++) {
auto past_len = req_pos[req]; auto past_len = req_pos[req];
auto seq_len = req_lens[req]; auto seq_len = req_lens[req];
auto total_len = past_len + seq_len;
auto o = o_buf->slice({{0, token_offset, seq_len}}); auto o = o_buf->slice({{0, token_offset, seq_len}});
auto q = qkv_buf->slice({{0, token_offset, seq_len}, {1, 0, nh}}) auto q = qkv_buf->slice({{0, token_offset, seq_len}, {1, 0, nh}});
->permute({1, 0, 2}); auto k = qkv_buf->slice({{0, token_offset, seq_len}, {1, nh, nkvh}});
auto k = qkv_buf->slice({{0, token_offset, seq_len}, {1, nh, nkvh}}) // auto v = qkv_buf->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}});
->permute({1, 0, 2}); // kv cache tensors can share the same descriptor
auto v = qkv_buf->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}}) // [nkvh, dh, total_len]
->permute({1, 0, 2}); auto full_kv = kv_caches[req]->k[idev][0]->slice(0, 0, total_len)->permute({1, 2, 0});
auto k_cache = kv_caches[req]->k[idev][0]; auto cache_kv = kv_caches[req]->k[idev][0]->slice(0, past_len, seq_len);
auto v_cache = kv_caches[req]->v[idev][0];
RUN_INFINI(infiniopCreateAttentionDescriptor( RUN_INFINI(infiniopCreateRearrangeDescriptor(rsrc.handle, &desc_kv_rearranges[req],
rsrc.handle, &desc_attns[req], o->desc()->get(), q->desc()->get(), cache_kv->desc()->get(), k->desc()->get()));
k->desc()->get(), v->desc()->get(), k_cache->desc()->get(),
v_cache->desc()->get(), past_len)); // [nkvh, ngroup, seq_len, dh]
RUN_INFINI( q->dimSplit(1, {nkvh, ngroup})->permute({1, 2, 0, 3});
infiniopGetAttentionWorkspaceSize(desc_attns[req], &temp_size)); auto q_t = TensorDesc::create(dt_logits, {nkvh, ngroup, seq_len, dh});
// [seq_len, nkvh, ngroup, dh] -> [nkvh, ngroup, seq_len, dh]
RUN_INFINI(infiniopCreateRearrangeDescriptor(rsrc.handle, &desc_q_rearranges[req],
q_t->get(), q->desc()->get()));
// [nkvh, ngroup, seq_len, dh] -> [seq_len, nkvh, ngroup, dh]
auto attn_v_t = q_t;
auto attn_v = TensorDesc::createWithOrder(dt_logits, {nkvh, ngroup, seq_len, dh}, {1, 2, 0, 3});
RUN_INFINI(infiniopCreateRearrangeDescriptor(rsrc.handle, &desc_attn_v_rearranges[req],
attn_v->get(), attn_v_t->get()));
q_t = TensorDesc::create(dt_logits, {nkvh, ngroup * seq_len, dh});
auto qk = TensorDesc::create(dt_logits, {nkvh, ngroup * seq_len, total_len});
max_qk_size = std::max(max_qk_size, size_t(seq_len * total_len));
max_seq_len = std::max(max_seq_len, size_t(seq_len));
RUN_INFINI(infiniopCreateGemmDescriptor(
rsrc.handle, &desc_qk_gemms[req], qk->get(), q_t->get(), full_kv->desc()->get()));
RUN_INFINI(infiniopGetGemmWorkspaceSize(desc_qk_gemms[req], &temp_size));
workspace_size = std::max(workspace_size, temp_size); workspace_size = std::max(workspace_size, temp_size);
// [nkvh, total_len, dh]
auto full_v = kv_caches[req]->v[idev][0]->slice(0, 0, total_len)->permute({1, 0, 2});
RUN_INFINI(infiniopCreateGemmDescriptor(
rsrc.handle, &desc_attn_v_gemms[req], q_t->get(), qk->get(), full_v->desc()->get()));
RUN_INFINI(infiniopGetGemmWorkspaceSize(desc_attn_v_gemms[req], &temp_size));
workspace_size = std::max(workspace_size, temp_size);
qk = TensorDesc::create(dt_logits, {nkvh * ngroup, seq_len, total_len});
RUN_INFINI(infiniopCreateCausalSoftmaxDescriptor(
rsrc.handle, &desc_qk_softmaxs[req], qk->get(), qk->get()));
RUN_INFINI(infiniopGetCausalSoftmaxWorkspaceSize(desc_qk_softmaxs[req], &temp_size));
workspace_size = std::max(workspace_size, temp_size);
token_offset += seq_len; token_offset += seq_len;
} }
auto qk_buf = Tensor::buffer(dt_logits, {nh, max_qk_size}, stream);
auto rearrange_q_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, stream);
auto attn_val_buf = Tensor::buffer(dt_logits, {nh, max_seq_len, dh}, stream);
// MLP descriptors // MLP descriptors
infiniopGemmDescriptor_t desc_ffn_gate_up, desc_ffn_down; infiniopGemmDescriptor_t desc_ffn_gate_up, desc_ffn_down;
...@@ -313,17 +354,40 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -313,17 +354,40 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
size_t token_offset = 0; size_t token_offset = 0;
for (uint32_t req = 0; req < nreq; req++) { for (uint32_t req = 0; req < nreq; req++) {
auto past_len = req_pos[req];
auto seq_len = req_lens[req]; auto seq_len = req_lens[req];
auto o = o_buf->slice({{0, token_offset, seq_len}});
auto q = qkv_buf->slice({{0, token_offset, seq_len}, {1, 0, nh}});
auto k = qkv_buf->slice({{0, token_offset, seq_len}, {1, nh, nkvh}});
auto v = qkv_buf->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}});
// self attention // self attention
RUN_INFINI(infiniopAttention( // concat
desc_attns[req], workspace, workspace_size, RUN_INFINI(infiniopRearrange(
o_buf->data(token_offset * nh * dh), desc_kv_rearranges[req],
qkv_buf->data(token_offset * (nh + nkvh * 2) * dh), kv_caches[req]->k[idev][layer]->data(past_len * nkvh * dh),
qkv_buf->data(token_offset * (nh + nkvh * 2) * dh + nh * dh), k->data(), stream));
qkv_buf->data(token_offset * (nh + nkvh * 2) * dh + (nh + nkvh) * dh), RUN_INFINI(infiniopRearrange(
kv_caches[req]->k[idev][layer]->data(), desc_kv_rearranges[req],
kv_caches[req]->v[idev][layer]->data(), kv_caches[req]->v[idev][layer]->data(past_len * nkvh * dh),
stream)); v->data(), stream));
// qk
RUN_INFINI(infiniopRearrange(desc_q_rearranges[req], rearrange_q_buf->data(), q->data(), stream));
RUN_INFINI(infiniopGemm(
desc_qk_gemms[req], workspace, workspace_size,
qk_buf->data(), rearrange_q_buf->data(), kv_caches[req]->k[idev][layer]->data(), 1. / sqrt(dh), 0.0, stream));
// softmax
RUN_INFINI(infiniopCausalSoftmax(
desc_qk_softmaxs[req], workspace, workspace_size,
qk_buf->data(), qk_buf->data(), stream));
// attn val
RUN_INFINI(infiniopGemm(
desc_attn_v_gemms[req], workspace, workspace_size,
attn_val_buf->data(), qk_buf->data(), kv_caches[req]->v[idev][layer]->data(), 1.0, 0.0, stream));
// rearrange attn val
RUN_INFINI(infiniopRearrange(
desc_attn_v_rearranges[req],
o->data(token_offset * nh * dh),
attn_val_buf->data(), stream));
token_offset += seq_len; token_offset += seq_len;
} }
...@@ -407,13 +471,20 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -407,13 +471,20 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
// Clean up // Clean up
infiniopDestroyRMSNormDescriptor(desc_norm); infiniopDestroyRMSNormDescriptor(desc_norm);
infiniopDestroyRearrangeDescriptor(desc_qkv_bias); if (has_qkv_bias) {
infiniopDestroyRearrangeDescriptor(desc_qkv_bias);
}
infiniopDestroyGemmDescriptor(desc_attn_qkv); infiniopDestroyGemmDescriptor(desc_attn_qkv);
infiniopDestroyGemmDescriptor(desc_attn_o); infiniopDestroyGemmDescriptor(desc_attn_o);
infiniopDestroyRoPEDescriptor(desc_rope_q); infiniopDestroyRoPEDescriptor(desc_rope_q);
infiniopDestroyRoPEDescriptor(desc_rope_k); infiniopDestroyRoPEDescriptor(desc_rope_k);
for (uint32_t req = 0; req < nreq; req++) { for (uint32_t req = 0; req < nreq; req++) {
infiniopDestroyAttentionDescriptor(desc_attns[req]); infiniopDestroyRearrangeDescriptor(desc_kv_rearranges[req]);
infiniopDestroyRearrangeDescriptor(desc_q_rearranges[req]);
infiniopDestroyGemmDescriptor(desc_qk_gemms[req]);
infiniopDestroyCausalSoftmaxDescriptor(desc_qk_softmaxs[req]);
infiniopDestroyGemmDescriptor(desc_attn_v_gemms[req]);
infiniopDestroyRearrangeDescriptor(desc_attn_v_rearranges[req]);
} }
infiniopDestroyGemmDescriptor(desc_ffn_gate_up); infiniopDestroyGemmDescriptor(desc_ffn_gate_up);
infiniopDestroySwiGLUDescriptor(desc_swiglu); infiniopDestroySwiGLUDescriptor(desc_swiglu);
......
...@@ -6,7 +6,7 @@ __C struct KVCache *createKVCache(const JiugeModel *model) { ...@@ -6,7 +6,7 @@ __C struct KVCache *createKVCache(const JiugeModel *model) {
auto nkvh = model->meta.nkvh / ndev; auto nkvh = model->meta.nkvh / ndev;
auto max_len = model->meta.dctx; auto max_len = model->meta.dctx;
auto dh = model->meta.dh; auto dh = model->meta.dh;
auto shape = std::vector<size_t>{nkvh, max_len, dh}; auto shape = std::vector<size_t>{max_len, nkvh, dh};
for (unsigned int idev = 0; idev < ndev; idev++) { for (unsigned int idev = 0; idev < ndev; idev++) {
RUN_INFINI(infinirtSetDevice(model->device, model->dev_ids[idev])); RUN_INFINI(infinirtSetDevice(model->device, model->dev_ids[idev]));
auto kcache = std::vector<std::shared_ptr<Tensor>>(); auto kcache = std::vector<std::shared_ptr<Tensor>>();
...@@ -27,18 +27,20 @@ __C struct KVCache *duplicateKVCache(const JiugeModel *model, ...@@ -27,18 +27,20 @@ __C struct KVCache *duplicateKVCache(const JiugeModel *model,
unsigned int seq_len) { unsigned int seq_len) {
auto new_kv_cache = createKVCache(model); auto new_kv_cache = createKVCache(model);
auto ndev = model->dev_resources.size(); auto ndev = model->dev_resources.size();
auto nkvh = model->meta.nkvh / ndev;
auto dh = model->meta.dh;
auto dt_size = dsize(model->meta.dt_logits);
for (unsigned int idev = 0; idev < ndev; idev++) { for (unsigned int idev = 0; idev < ndev; idev++) {
RUN_INFINI(infinirtSetDevice(model->device, model->dev_ids[idev])); RUN_INFINI(infinirtSetDevice(model->device, model->dev_ids[idev]));
for (unsigned int layer = 0; layer < model->meta.nlayer; layer++) { for (unsigned int layer = 0; layer < model->meta.nlayer; layer++) {
new_kv_cache->k[idev][layer] RUN_INFINI(infinirtMemcpy(new_kv_cache->k[idev][layer]->data(),
->slice(1, 0, seq_len) kv_cache->k[idev][layer]->data(),
->copyFrom(kv_cache->k[idev][layer]->slice(1, 0, seq_len), seq_len * nkvh * dh * dt_size,
model->dev_resources[idev].handle); INFINIRT_MEMCPY_D2D));
RUN_INFINI(infinirtMemcpy(new_kv_cache->v[idev][layer]->data(),
new_kv_cache->v[idev][layer] kv_cache->v[idev][layer]->data(),
->slice(1, 0, seq_len) seq_len * nkvh * dh * dt_size,
->copyFrom(kv_cache->v[idev][layer]->slice(1, 0, seq_len), INFINIRT_MEMCPY_D2D));
model->dev_resources[idev].handle);
} }
} }
return new_kv_cache; return new_kv_cache;
......
...@@ -25,6 +25,15 @@ struct SliceParams { ...@@ -25,6 +25,15 @@ struct SliceParams {
size_t len; size_t len;
}; };
template <typename... Args>
std::vector<size_t> __shape(Args... args) {
return std::vector<size_t>{static_cast<size_t>(args)...};
}
template <typename... Args>
std::vector<ptrdiff_t> __strides(Args... args) {
return std::vector<ptrdiff_t>{static_cast<ptrdiff_t>(args)...};
}
class TensorDesc { class TensorDesc {
private: private:
infiniopTensorDescriptor_t _desc; infiniopTensorDescriptor_t _desc;
...@@ -33,6 +42,11 @@ public: ...@@ -33,6 +42,11 @@ public:
static std::shared_ptr<TensorDesc> static std::shared_ptr<TensorDesc>
create(infiniDtype_t dtype, const std::vector<size_t> &shape, create(infiniDtype_t dtype, const std::vector<size_t> &shape,
const std::vector<ptrdiff_t> &strides); const std::vector<ptrdiff_t> &strides);
static std::shared_ptr<TensorDesc>
create(infiniDtype_t dtype, const std::vector<size_t> &shape);
static std::shared_ptr<TensorDesc>
createWithOrder(infiniDtype_t dtype, const std::vector<size_t> &shape,
const std::vector<size_t> &order);
infiniopTensorDescriptor_t get() const { return _desc; }; infiniopTensorDescriptor_t get() const { return _desc; };
~TensorDesc(); ~TensorDesc();
}; };
...@@ -58,6 +72,8 @@ public: ...@@ -58,6 +72,8 @@ public:
static std::shared_ptr<Tensor> weight(void *host_data, static std::shared_ptr<Tensor> weight(void *host_data,
infiniDtype_t dtype, infiniDtype_t dtype,
const std::vector<size_t> &shape); const std::vector<size_t> &shape);
std::shared_ptr<Tensor> memShare(const std::vector<size_t> &shape,
infiniDtype_t dtype = INFINI_DTYPE_INVALID) const;
std::shared_ptr<Tensor> slice(size_t dim, size_t start, size_t len); std::shared_ptr<Tensor> slice(size_t dim, size_t start, size_t len);
std::shared_ptr<Tensor const> slice(size_t dim, size_t start, std::shared_ptr<Tensor const> slice(size_t dim, size_t start,
size_t len) const; size_t len) const;
......
#include "../tensor.hpp" #include "../tensor.hpp"
#include "../utils.hpp" #include "../utils.hpp"
#include <algorithm>
#include <fstream> #include <fstream>
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -14,6 +15,39 @@ TensorDesc::create(infiniDtype_t dtype, const std::vector<size_t> &shape, ...@@ -14,6 +15,39 @@ TensorDesc::create(infiniDtype_t dtype, const std::vector<size_t> &shape,
return desc; return desc;
} }
std::shared_ptr<TensorDesc>
TensorDesc::create(infiniDtype_t dtype, const std::vector<size_t> &shape) {
auto ndim = shape.size();
auto strides = std::vector<ptrdiff_t>(ndim);
if (ndim > 0) {
strides[ndim - 1] = 1;
for (int i = ndim - 2; i >= 0; i--) {
strides[i] = strides[i + 1] * shape[i + 1];
}
}
return create(dtype, shape, strides);
}
std::shared_ptr<TensorDesc>
TensorDesc::createWithOrder(infiniDtype_t dtype, const std::vector<size_t> &shape,
const std::vector<size_t> &order) {
ASSERT_EQ(shape.size(), order.size());
auto ndim = shape.size();
if (ndim == 0) {
return create(dtype, shape);
}
auto strides = std::vector<ptrdiff_t>(order.size());
auto idx = std::find(order.begin(), order.end(), size_t(ndim - 1));
strides[std::distance(order.begin(), idx)] = 1;
for (int i = ndim - 2; i >= 0; i--) {
auto prev_dim = shape[std::distance(order.begin(), idx)];
auto prev_stride = strides[std::distance(order.begin(), idx)];
idx = std::find(order.begin(), order.end(), size_t(i));
strides[std::distance(order.begin(), idx)] = prev_stride * prev_dim;
}
return create(dtype, shape, strides);
}
TensorDesc::~TensorDesc() { TensorDesc::~TensorDesc() {
infiniopDestroyTensorDescriptor(this->_desc); infiniopDestroyTensorDescriptor(this->_desc);
} }
...@@ -60,7 +94,6 @@ std::shared_ptr<Tensor> Tensor::buffer(infiniDtype_t dtype, ...@@ -60,7 +94,6 @@ std::shared_ptr<Tensor> Tensor::buffer(infiniDtype_t dtype,
std::shared_ptr<Tensor> Tensor::weight(void *data, infiniDtype_t dtype, std::shared_ptr<Tensor> Tensor::weight(void *data, infiniDtype_t dtype,
const std::vector<size_t> &shape) { const std::vector<size_t> &shape) {
std::shared_ptr<Tensor> tensor = std::make_shared<Tensor>(); std::shared_ptr<Tensor> tensor = std::make_shared<Tensor>();
;
tensor->_dtype = dtype; tensor->_dtype = dtype;
auto ndim = shape.size(); auto ndim = shape.size();
tensor->_shape = std::vector<size_t>(shape); tensor->_shape = std::vector<size_t>(shape);
...@@ -83,6 +116,29 @@ std::shared_ptr<Tensor> Tensor::weight(void *data, infiniDtype_t dtype, ...@@ -83,6 +116,29 @@ std::shared_ptr<Tensor> Tensor::weight(void *data, infiniDtype_t dtype,
return tensor; return tensor;
} }
std::shared_ptr<Tensor> Tensor::memShare(const std::vector<size_t> &shape, infiniDtype_t dtype) const {
size_t size = std::accumulate(shape.begin(), shape.end(), dsize(dtype), std::multiplies<size_t>());
ASSERT(size <= this->_storage->size);
std::shared_ptr<Tensor> tensor = std::make_shared<Tensor>();
tensor->_dtype = dtype == INFINI_DTYPE_INVALID ? this->_dtype : dtype;
tensor->_shape = std::vector<size_t>(shape);
auto ndim = shape.size();
auto strides = std::vector<ptrdiff_t>(ndim);
if (ndim > 0) {
strides[ndim - 1] = 1;
for (int i = ndim - 2; i >= 0; i--) {
strides[i] = strides[i + 1] * shape[i + 1];
}
}
tensor->_strides = strides;
tensor->_storage = this->_storage;
infiniopCreateTensorDescriptor(&tensor->_desc, ndim, tensor->_shape.data(),
tensor->_strides.data(), tensor->_dtype);
tensor->_offset = 0;
return tensor;
}
void *Tensor::dataImpl(ptrdiff_t offset) const { void *Tensor::dataImpl(ptrdiff_t offset) const {
return (char *)(this->_data) + offset * dsize(this->dtype()); return (char *)(this->_data) + offset * dsize(this->dtype());
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment