Commit 79bc0438 authored by wooway777's avatar wooway777
Browse files

issue/21 - replaced most dim operations and added a linear layer

parent 366386d3
......@@ -35,6 +35,7 @@ inline size_t computeTensorDescHash(std::shared_ptr<Tensor> tensor) {
}
enum class OperatorType {
ADD,
RMS_NORM,
GEMM,
ROPE,
......@@ -66,6 +67,9 @@ private:
void destroyDescriptor(DescriptorType &desc) {
switch (opType) {
case OperatorType::ADD:
infiniopDestroyAddDescriptor(desc);
break;
case OperatorType::RMS_NORM:
infiniopDestroyRMSNormDescriptor(desc);
break;
......@@ -178,6 +182,7 @@ class CacheManager {
private:
const size_t DEFAULT_CACHE_CAPACITY = 128;
LRUDescriptorCache<infiniopAddDescriptor_t> add_cache;
LRUDescriptorCache<infiniopRMSNormDescriptor_t> rms_norm_cache;
LRUDescriptorCache<infiniopGemmDescriptor_t> gemm_cache;
LRUDescriptorCache<infiniopRoPEDescriptor_t> rope_cache;
......@@ -187,7 +192,8 @@ private:
LRUDescriptorCache<infiniopRandomSampleDescriptor_t> random_sample_cache;
public:
CacheManager(size_t capacity = 100) : rms_norm_cache(capacity, OperatorType::RMS_NORM),
CacheManager(size_t capacity = 100) : add_cache(capacity, OperatorType::ADD),
rms_norm_cache(capacity, OperatorType::RMS_NORM),
gemm_cache(capacity, OperatorType::GEMM),
rope_cache(capacity, OperatorType::ROPE),
rearrange_cache(capacity, OperatorType::REARRANGE),
......@@ -195,6 +201,15 @@ public:
swiglu_cache(capacity, OperatorType::SWIGLU),
random_sample_cache(capacity, OperatorType::RANDOM_SAMPLE) {}
// Add operations
bool getAddDescriptor(size_t key, infiniopAddDescriptor_t &desc) {
return add_cache.get(key, desc);
}
void putAddDescriptor(size_t key, const infiniopAddDescriptor_t &desc) {
add_cache.put(key, desc);
}
// RMSNorm operations
bool getRMSNormDescriptor(size_t key, infiniopRMSNormDescriptor_t &desc) {
return rms_norm_cache.get(key, desc);
......
......@@ -12,6 +12,28 @@ void InferenceContext::ensure_workspace(size_t required_size) {
}
}
void InferenceContext::add(std::shared_ptr<Tensor> c,
std::shared_ptr<Tensor> a,
std::shared_ptr<Tensor> b) {
size_t key = CacheManager::createDescriptorKey(c, a, b,
nullptr, nullptr);
infiniopAddDescriptor_t desc;
if (!cache_manager->getAddDescriptor(key, desc)) {
RUN_INFINI(infiniopCreateAddDescriptor(rsrc->handle, &desc, c->desc(), a->desc(), b->desc()));
cache_manager->putAddDescriptor(key, desc);
}
size_t workspace_size = 0;
RUN_INFINI(infiniopGetAddWorkspaceSize(desc, &workspace_size));
ensure_workspace(workspace_size);
void *workspace = workspace_storage->memory();
RUN_INFINI(infiniopAdd(
desc, workspace, workspace_size,
c->data(), a->data(), b->data(), stream));
}
void InferenceContext::rmsnorm(std::shared_ptr<Tensor> y,
std::shared_ptr<Tensor> x,
std::shared_ptr<Tensor> w,
......@@ -165,3 +187,27 @@ void InferenceContext::randomSample(std::shared_ptr<Tensor> out,
random_val, top_p, top_k, temperature,
stream));
}
void InferenceContext::linear(std::shared_ptr<Tensor> c,
std::shared_ptr<Tensor> a,
std::shared_ptr<Tensor> b,
float alpha, float beta,
std::shared_ptr<Tensor> residual) {
if (residual) {
if (residual->data() == c->data()) {
if (beta == 0.0) {
gemm(c, a, b, alpha, 1.0);
} else {
auto c_copy = Tensor::buffer(c->dtype(), c->shape(), rsrc->memory_pool);
c_copy->copyFrom(c, rsrc->handle, stream);
gemm(c, a, b, alpha, beta);
add(c, c, c_copy);
}
} else {
gemm(c, a, b, alpha, beta);
add(c, c, residual);
}
} else {
gemm(c, a, b, alpha, beta);
}
}
......@@ -16,6 +16,9 @@ struct InferenceContext {
void ensure_workspace(size_t required_size);
void add(std::shared_ptr<Tensor> c,
std::shared_ptr<Tensor> a,
std::shared_ptr<Tensor> b);
void rmsnorm(std::shared_ptr<Tensor> y,
std::shared_ptr<Tensor> x,
std::shared_ptr<Tensor> w,
......@@ -39,4 +42,10 @@ struct InferenceContext {
void randomSample(std::shared_ptr<Tensor> out,
std::shared_ptr<Tensor> prob,
float random_val, float top_p, uint32_t top_k, float temperature);
void linear(std::shared_ptr<Tensor> c,
std::shared_ptr<Tensor> a,
std::shared_ptr<Tensor> b,
float alpha, float beta,
std::shared_ptr<Tensor> residual);
};
......@@ -166,7 +166,6 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
}
// Attention
qkv_buf->dimSplit(1, {nh + nkvh * 2, dh}); // (ntok, nh + 2 * nkvh, dh)
// attention inner
size_t max_qk_size = 0;
size_t max_seq_len = 0;
......@@ -194,56 +193,49 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
// rms norm
ctx.rmsnorm(logits_out, logits_in, rsrc.w_attn_norm[layer], meta.epsilon);
// qkv_proj
qkv_buf->dimMerge(1, 2);
if (has_qkv_bias) {
ctx.rearrange(qkv_buf, rsrc.b_attn_qkv[layer]->reDesc({ntok, (nh + nkvh * 2) * dh}, {0, 1}));
ctx.rearrange(qkv_buf, rsrc.b_attn_qkv[layer]->view({ntok, (nh + nkvh * 2) * dh}, {0, 1}));
}
ctx.gemm(qkv_buf, logits_out, rsrc.w_attn_qkv[layer], 1.0, has_qkv_bias ? 1.0 : 0.0);
ctx.linear(qkv_buf, logits_out, rsrc.w_attn_qkv[layer], 1.0, 0.0, has_qkv_bias ? qkv_buf : nullptr);
// rope
qkv_buf->dimSplit(1, {nh + nkvh * 2, dh});
ctx.rope(qkv_buf->slice(1, 0, nh), qkv_buf->slice(1, 0, nh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table);
ctx.rope(qkv_buf->slice(1, nh, nkvh), qkv_buf->slice(1, nh, nkvh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table);
auto qkv_rope = qkv_buf->viewReshaped({ntok, nh + nkvh * 2, dh});
ctx.rope(qkv_rope->slice(1, 0, nh), qkv_rope->slice(1, 0, nh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table);
ctx.rope(qkv_rope->slice(1, nh, nkvh), qkv_rope->slice(1, nh, nkvh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table);
size_t token_offset = 0;
for (uint32_t req = 0; req < nreq; req++) {
auto past_len = req_pos[req];
auto seq_len = req_lens[req];
auto total_len = past_len + seq_len;
auto o = o_buf->dimSplit(1, {nh, dh})->slice({{0, token_offset, seq_len}});
auto q = qkv_buf->slice({{0, token_offset, seq_len}, {1, 0, nh}});
auto k = qkv_buf->slice({{0, token_offset, seq_len}, {1, nh, nkvh}});
auto v = qkv_buf->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}});
auto o = o_buf->viewReshaped({ntok, nh, dh})->slice({{0, token_offset, seq_len}})->dimSplit(1, {nkvh, ngroup})->permute({1, 2, 0, 3});
auto q = qkv_rope->slice({{0, token_offset, seq_len}, {1, 0, nh}})->dimSplit(1, {nkvh, ngroup})->permute({1, 2, 0, 3});
auto k = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh, nkvh}});
auto v = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}});
// self attention
// concat
ctx.rearrange(kv_caches[req]->k[idev][layer]->slice(0, past_len, seq_len), k);
ctx.rearrange(kv_caches[req]->v[idev][layer]->slice(0, past_len, seq_len), v);
// qk
ctx.rearrange(rearrange_q_buf->dimSplit(1, {ngroup, seq_len}),
q->dimSplit(1, {nkvh, ngroup})->permute({1, 2, 0, 3}));
qk_buf->dimSplit(1, {seq_len, total_len});
qk_buf->dimSplit(0, {nkvh, ngroup});
qk_buf->dimMerge(1, 2);
ctx.gemm(qk_buf, rearrange_q_buf->dimMerge(1, 2), kv_caches[req]->k[idev][layer]->slice(0, 0, total_len)->permute({1, 2, 0}), 1. / sqrt(dh), 0.0);
auto q_rearrange = rearrange_q_buf->viewReshaped({nkvh, ngroup, seq_len, dh});
ctx.rearrange(q_rearrange, q);
auto qk_gemm = qk_buf->viewReshaped({nkvh, ngroup * seq_len, total_len});
auto k_gemm = kv_caches[req]->k[idev][layer]->slice(0, 0, total_len)->permute({1, 2, 0});
ctx.linear(qk_gemm, rearrange_q_buf, k_gemm, 1. / sqrt(dh), 0.0, nullptr);
// softmax
qk_buf->dimSplit(1, {ngroup, seq_len});
qk_buf->dimMerge(0, 1);
ctx.causalSoftmax(qk_buf, qk_buf);
qk_buf->dimSplit(0, {nkvh, ngroup});
qk_buf->dimMerge(1, 2);
ctx.gemm(attn_val_buf, qk_buf, kv_caches[req]->v[idev][layer]->slice(0, 0, total_len)->permute({1, 0, 2}), 1.0, 0.0);
qk_buf->dimSplit(1, {ngroup, seq_len});
qk_buf->dimMerge(2, 3);
qk_buf->dimMerge(0, 1);
auto qk_softmax = qk_buf->viewReshaped({nh, seq_len, total_len});
ctx.causalSoftmax(qk_softmax, qk_softmax);
auto v_gemm = kv_caches[req]->v[idev][layer]->slice(0, 0, total_len)->permute({1, 0, 2});
ctx.linear(attn_val_buf, qk_gemm, v_gemm, 1.0, 0.0, nullptr);
// rearrange attn val
attn_val_buf->dimSplit(0, {nkvh, ngroup});
ctx.rearrange(o->dimSplit(1, {nkvh, ngroup})->permute({1, 2, 0, 3}), attn_val_buf);
attn_val_buf->dimMerge(0, 1);
auto attn_val_gemm = attn_val_buf->viewReshaped({nkvh, ngroup, max_seq_len, dh});
ctx.rearrange(o, attn_val_gemm);
token_offset += seq_len;
}
// o_proj
ctx.gemm(logits_in, o_buf->dimMerge(1, 2), rsrc.w_attn_out[layer], 1.0, idev == 0 ? 1.0 : 0.0); // only rank 0 adds residual
ctx.linear(logits_in, o_buf, rsrc.w_attn_out[layer], 1.0, 0.0, idev == 0 ? logits_in : nullptr); // only rank 0 adds residual
// All_reduce if distributed
if (rsrc.comm != nullptr) {
......@@ -253,11 +245,10 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
RUN_INFINI(infinirtStreamSynchronize(stream));
}
// 2. FFN
// rms_norm
ctx.rmsnorm(logits_out, logits_in, rsrc.w_ffn_norm[layer], meta.epsilon);
ctx.gemm(gate_up_buf, logits_out, rsrc.w_ffn_gate_up[layer], 1.0, 0.0);
ctx.linear(gate_up_buf, logits_out, rsrc.w_ffn_gate_up[layer], 1.0, 0.0, nullptr);
ctx.swiglu(gate_buf, up_buf, gate_buf);
ctx.gemm(logits_in, gate_buf, rsrc.w_ffn_down[layer], 1.0, idev == 0 ? 1.0 : 0.0); // only rank 0 adds residual
ctx.linear(logits_in, gate_buf, rsrc.w_ffn_down[layer], 1.0, 0.0, idev == 0 ? logits_in : nullptr); // only rank 0 adds residual
// All_reduce if distributed
if (rsrc.comm != nullptr) {
......@@ -278,15 +269,15 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
rsrc.w_out_norm,
meta.epsilon);
}
ctx.gemm(prob_buf, logits_out->slice(0, 0, nreq), rsrc.w_out_embd, 1.0, 0.0);
ctx.linear(prob_buf, logits_out->slice(0, 0, nreq), rsrc.w_out_embd, 1.0, 0.0, nullptr);
std::random_device _rd;
std::mt19937 gen(_rd());
token_offset = 0;
for (uint32_t req = 0; req < nreq; req++) {
auto seq_len = req_lens[req];
float random_val = std::uniform_real_distribution<float>(0, 1)(gen);
ctx.randomSample(result_buf->reDesc({}, {}),
prob_buf->reDesc({dvoc}, {1}),
ctx.randomSample(result_buf->view({}, {}),
prob_buf->view({dvoc}, {1}),
random_val, topp[req], topk[req], temperature[req]);
token_offset += seq_len;
}
......
......@@ -78,7 +78,6 @@ public:
void dimMerge(size_t dim_start, size_t dim_end);
void dimSplit(size_t dim, const std::vector<size_t> &dims);
void permute(const std::vector<size_t> &order);
void reDesc(const std::vector<size_t> new_shape, const std::vector<ptrdiff_t> new_strides);
};
class Tensor : public std::enable_shared_from_this<Tensor> {
......@@ -111,7 +110,6 @@ public:
std::shared_ptr<Tensor> dimSplit(size_t dim,
const std::vector<size_t> &dims);
std::shared_ptr<Tensor> permute(const std::vector<size_t> &order);
std::shared_ptr<Tensor> reDesc(const std::vector<size_t> new_shape, const std::vector<ptrdiff_t> new_strides);
void *data(ptrdiff_t offset = 0);
void const *data(ptrdiff_t offset = 0) const;
void copyFrom(std::shared_ptr<Tensor const> src, infiniopHandle_t handle,
......@@ -130,6 +128,10 @@ public:
void debug() const;
std::string info() const;
std::shared_ptr<Tensor> view() const;
std::shared_ptr<Tensor> view(const std::vector<size_t> new_shape, const std::vector<ptrdiff_t> new_strides) const;
std::shared_ptr<Tensor> viewReshaped(const std::vector<size_t> new_shape) const;
~Tensor();
};
......
......@@ -258,6 +258,49 @@ std::string Tensor::info() const {
return this->_desc->info();
}
std::shared_ptr<Tensor> Tensor::view() const {
std::shared_ptr<Tensor> tensor = std::make_shared<Tensor>();
tensor->_storage = this->_storage;
tensor->_desc = TensorDesc::create(this->dtype(), this->shape(), this->strides());
tensor->_offset = this->_offset;
return tensor;
}
std::shared_ptr<Tensor> Tensor::view(const std::vector<size_t> new_shape, const std::vector<ptrdiff_t> new_strides) const {
std::shared_ptr<Tensor> tensor = std::make_shared<Tensor>();
tensor->_storage = this->_storage;
tensor->_desc = TensorDesc::create(this->dtype(), new_shape, new_strides);
tensor->_offset = this->_offset;
return tensor;
}
std::shared_ptr<Tensor> Tensor::viewReshaped(const std::vector<size_t> new_shape) const {
// First validate that the total number of elements matches
size_t current_elements = std::accumulate(_desc->shape().begin(), _desc->shape().end(),
1, std::multiplies<size_t>());
size_t new_elements = std::accumulate(new_shape.begin(), new_shape.end(),
1, std::multiplies<size_t>());
ASSERT_EQ(current_elements, new_elements);
// Create a copy of the current shape and strides
auto current_shape = _desc->shape();
// Start with the current tensor
auto result = this->view();
// Step 1: Merge all dimensions (if there are more than 1)
if (current_shape.size() > 1) {
result = result->dimMerge(0, current_shape.size() - 1);
}
// Step 2: Split into the new shape
if (new_shape.size() > 1) {
result = result->dimSplit(0, new_shape);
}
return result;
}
void Tensor::debug(const std::string &filename) const {
RUN_INFINI(infinirtDeviceSynchronize());
......
......@@ -114,14 +114,3 @@ std::shared_ptr<Tensor> Tensor::permute(const std::vector<size_t> &order) {
this->_desc->permute(order);
return shared_from_this();
}
void TensorDesc::reDesc(const std::vector<size_t> new_shape, const std::vector<ptrdiff_t> new_strides) {
this->_shape = new_shape;
this->_strides = new_strides;
this->resetDesc();
}
std::shared_ptr<Tensor> Tensor::reDesc(const std::vector<size_t> new_shape, const std::vector<ptrdiff_t> new_strides) {
this->_desc->reDesc(new_shape, new_strides);
return shared_from_this();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment