Unverified Commit 1ab1e668 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

fix: use new rope (#47)

parents a179dcc3 36890e40
...@@ -156,7 +156,6 @@ public: ...@@ -156,7 +156,6 @@ public:
DECLARE_OP_CACHE(RMSNorm) DECLARE_OP_CACHE(RMSNorm)
DECLARE_OP_CACHE(Gemm) DECLARE_OP_CACHE(Gemm)
DECLARE_OP_CACHE(RoPE) DECLARE_OP_CACHE(RoPE)
DECLARE_OP_CACHE(RoPEv2)
DECLARE_OP_CACHE(Rearrange) DECLARE_OP_CACHE(Rearrange)
DECLARE_OP_CACHE(CausalSoftmax) DECLARE_OP_CACHE(CausalSoftmax)
DECLARE_OP_CACHE(Topkrouter) DECLARE_OP_CACHE(Topkrouter)
...@@ -169,7 +168,6 @@ public: ...@@ -169,7 +168,6 @@ public:
RMSNorm_cache(capacity, DESTROY_FUNC(RMSNorm)), RMSNorm_cache(capacity, DESTROY_FUNC(RMSNorm)),
Gemm_cache(capacity, DESTROY_FUNC(Gemm)), Gemm_cache(capacity, DESTROY_FUNC(Gemm)),
RoPE_cache(capacity, DESTROY_FUNC(RoPE)), RoPE_cache(capacity, DESTROY_FUNC(RoPE)),
RoPEv2_cache(capacity, DESTROY_FUNC(RoPEv2)),
Rearrange_cache(capacity, DESTROY_FUNC(Rearrange)), Rearrange_cache(capacity, DESTROY_FUNC(Rearrange)),
CausalSoftmax_cache(capacity, DESTROY_FUNC(CausalSoftmax)), CausalSoftmax_cache(capacity, DESTROY_FUNC(CausalSoftmax)),
Topkrouter_cache(capacity, DESTROY_FUNC(Topkrouter)), Topkrouter_cache(capacity, DESTROY_FUNC(Topkrouter)),
......
...@@ -99,14 +99,16 @@ void InferenceContext::rope(std::shared_ptr<Tensor> q, ...@@ -99,14 +99,16 @@ void InferenceContext::rope(std::shared_ptr<Tensor> q,
std::shared_ptr<Tensor> k, std::shared_ptr<Tensor> k,
std::shared_ptr<Tensor> pos, std::shared_ptr<Tensor> pos,
std::shared_ptr<Tensor> sin, std::shared_ptr<Tensor> sin,
std::shared_ptr<Tensor> cos) { std::shared_ptr<Tensor> cos,
infiniopRoPEAlgo_t algo) {
size_t key = CacheManager::createDescriptorKey(q, k, pos, sin, cos); size_t key = CacheManager::createDescriptorKey(q, k, pos, sin, cos);
hash_combine(key, std::hash<int>()(algo));
infiniopRoPEDescriptor_t desc; infiniopRoPEDescriptor_t desc;
if (!cache_manager->getRoPEDescriptor(key, desc)) { if (!cache_manager->getRoPEDescriptor(key, desc)) {
RUN_INFINI(infiniopCreateRoPEDescriptor( RUN_INFINI(infiniopCreateRoPEDescriptor(
op_handle, &desc, q->desc(), k->desc(), op_handle, &desc, q->desc(), k->desc(),
pos->desc(), sin->desc(), cos->desc())); pos->desc(), sin->desc(), cos->desc(), algo));
cache_manager->putRoPEDescriptor(key, desc); cache_manager->putRoPEDescriptor(key, desc);
} }
...@@ -121,32 +123,6 @@ void InferenceContext::rope(std::shared_ptr<Tensor> q, ...@@ -121,32 +123,6 @@ void InferenceContext::rope(std::shared_ptr<Tensor> q,
sin->data(), cos->data(), stream)); sin->data(), cos->data(), stream));
} }
void InferenceContext::rope_v2(std::shared_ptr<Tensor> q,
std::shared_ptr<Tensor> k,
std::shared_ptr<Tensor> pos,
std::shared_ptr<Tensor> sin,
std::shared_ptr<Tensor> cos) {
size_t key = CacheManager::createDescriptorKey(q, k, pos, sin, cos);
infiniopRoPEv2Descriptor_t desc;
if (!cache_manager->getRoPEv2Descriptor(key, desc)) {
RUN_INFINI(infiniopCreateRoPEv2Descriptor(
op_handle, &desc, q->desc(), k->desc(),
pos->desc(), sin->desc(), cos->desc()));
cache_manager->putRoPEv2Descriptor(key, desc);
}
size_t workspace_size = 0;
RUN_INFINI(infiniopGetRoPEv2WorkspaceSize(desc, &workspace_size));
ensure_workspace(workspace_size);
void *workspace = workspace_storage->memory();
RUN_INFINI(infiniopRoPEv2(
desc, workspace, workspace_size,
q->data(), k->data(), pos->data(),
sin->data(), cos->data(), stream));
}
void InferenceContext::causalSoftmax(std::shared_ptr<Tensor> y, void InferenceContext::causalSoftmax(std::shared_ptr<Tensor> y,
std::shared_ptr<Tensor> x) { std::shared_ptr<Tensor> x) {
size_t key = CacheManager::createDescriptorKey(y, x); size_t key = CacheManager::createDescriptorKey(y, x);
......
...@@ -33,12 +33,8 @@ struct InferenceContext { ...@@ -33,12 +33,8 @@ struct InferenceContext {
std::shared_ptr<Tensor> k, std::shared_ptr<Tensor> k,
std::shared_ptr<Tensor> pos, std::shared_ptr<Tensor> pos,
std::shared_ptr<Tensor> sin, std::shared_ptr<Tensor> sin,
std::shared_ptr<Tensor> cos); std::shared_ptr<Tensor> cos,
void rope_v2(std::shared_ptr<Tensor> q, infiniopRoPEAlgo_t algo);
std::shared_ptr<Tensor> k,
std::shared_ptr<Tensor> pos,
std::shared_ptr<Tensor> sin,
std::shared_ptr<Tensor> cos);
void causalSoftmax(std::shared_ptr<Tensor> y, void causalSoftmax(std::shared_ptr<Tensor> y,
std::shared_ptr<Tensor> x); std::shared_ptr<Tensor> x);
...@@ -102,13 +98,13 @@ inline void rearrange(std::shared_ptr<Tensor> dst, std::shared_ptr<Tensor> src) ...@@ -102,13 +98,13 @@ inline void rearrange(std::shared_ptr<Tensor> dst, std::shared_ptr<Tensor> src)
inline void rope(std::shared_ptr<Tensor> q, std::shared_ptr<Tensor> k, inline void rope(std::shared_ptr<Tensor> q, std::shared_ptr<Tensor> k,
std::shared_ptr<Tensor> pos, std::shared_ptr<Tensor> sin, std::shared_ptr<Tensor> pos, std::shared_ptr<Tensor> sin,
std::shared_ptr<Tensor> cos) { std::shared_ptr<Tensor> cos) {
getInferenceContext().rope(q, k, pos, sin, cos); getInferenceContext().rope(q, k, pos, sin, cos, INFINIOP_ROPE_ALGO_GPT_J);
} }
inline void rope_v2(std::shared_ptr<Tensor> q, std::shared_ptr<Tensor> k, inline void rope_v2(std::shared_ptr<Tensor> q, std::shared_ptr<Tensor> k,
std::shared_ptr<Tensor> pos, std::shared_ptr<Tensor> sin, std::shared_ptr<Tensor> pos, std::shared_ptr<Tensor> sin,
std::shared_ptr<Tensor> cos) { std::shared_ptr<Tensor> cos) {
getInferenceContext().rope_v2(q, k, pos, sin, cos); getInferenceContext().rope(q, k, pos, sin, cos, INFINIOP_ROPE_ALGO_GPT_NEOX);
} }
inline void causalSoftmax(std::shared_ptr<Tensor> y, std::shared_ptr<Tensor> x) { inline void causalSoftmax(std::shared_ptr<Tensor> y, std::shared_ptr<Tensor> x) {
......
#ifndef INFINICORE_INFER_UTILS_H #ifndef INFINICORE_INFER_UTILS_H
#define INFINICORE_INFER_UTILS_H #define INFINICORE_INFER_UTILS_H
#include <infinicore.h> #include <infinirt.h>
#include <cstring> #include <cstring>
#include <iostream> #include <iostream>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment