Commit 2a2ddc57 authored by wooway777's avatar wooway777
Browse files

issue/21 - Improved Descriptor Destruction Logic

parent ebd35e7e
......@@ -34,15 +34,23 @@ inline size_t computeTensorDescHash(std::shared_ptr<Tensor> tensor) {
return seed;
}
enum class OperatorType {
ADD,
RMS_NORM,
GEMM,
ROPE,
REARRANGE,
CAUSAL_SOFTMAX,
SWIGLU,
RANDOM_SAMPLE
class IDescriptorDestroyer {
public:
virtual ~IDescriptorDestroyer() = default;
virtual void destroy(void *descriptor) = 0;
};
template <typename DescriptorType>
class DescriptorDestroyer : public IDescriptorDestroyer {
using DestroyFunc = infiniStatus_t (*)(DescriptorType);
DestroyFunc destroyFunc;
public:
DescriptorDestroyer(DestroyFunc func) : destroyFunc(func) {}
void destroy(void *descriptor) override {
destroyFunc(*static_cast<DescriptorType *>(descriptor));
}
};
template <typename DescriptorType>
......@@ -63,43 +71,14 @@ private:
CacheNode *tail;
const size_t capacity;
size_t size;
const OperatorType opType;
void destroyDescriptor(DescriptorType &desc) {
switch (opType) {
case OperatorType::ADD:
infiniopDestroyAddDescriptor(desc);
break;
case OperatorType::RMS_NORM:
infiniopDestroyRMSNormDescriptor(desc);
break;
case OperatorType::GEMM:
infiniopDestroyGemmDescriptor(desc);
break;
case OperatorType::ROPE:
infiniopDestroyRoPEDescriptor(desc);
break;
case OperatorType::REARRANGE:
infiniopDestroyRearrangeDescriptor(desc);
break;
case OperatorType::CAUSAL_SOFTMAX:
infiniopDestroyCausalSoftmaxDescriptor(desc);
break;
case OperatorType::SWIGLU:
infiniopDestroySwiGLUDescriptor(desc);
break;
case OperatorType::RANDOM_SAMPLE:
infiniopDestroyRandomSampleDescriptor(desc);
break;
default:
throw std::runtime_error("Unknown descriptor type");
}
}
std::unique_ptr<IDescriptorDestroyer> destroyer;
void removeNode(CacheNode *node) {
node->prev->next = node->next;
node->next->prev = node->prev;
destroyDescriptor(node->desc);
if (destroyer) {
destroyer->destroy(&node->desc);
}
cache.erase(node->key);
delete node;
--size;
......@@ -126,7 +105,9 @@ private:
}
public:
LRUDescriptorCache(size_t c, OperatorType t) : capacity(c), size(0), opType(t) {
template <typename DestroyFunc>
LRUDescriptorCache(size_t c, DestroyFunc destroyFunc)
: capacity(c), size(0), destroyer(std::make_unique<DescriptorDestroyer<DescriptorType>>(destroyFunc)) {
head = new CacheNode();
tail = new CacheNode();
head->next = tail;
......@@ -158,7 +139,9 @@ public:
if (it != cache.end()) {
// Key already exists, update the descriptor
CacheNode *node = it->second;
destroyDescriptor(node->desc);
if (destroyer) {
destroyer->destroy(&node->desc);
}
node->desc = descriptor;
moveToTop(node);
return;
......@@ -192,14 +175,15 @@ private:
LRUDescriptorCache<infiniopRandomSampleDescriptor_t> random_sample_cache;
public:
CacheManager(size_t capacity = 100) : add_cache(capacity, OperatorType::ADD),
rms_norm_cache(capacity, OperatorType::RMS_NORM),
gemm_cache(capacity, OperatorType::GEMM),
rope_cache(capacity, OperatorType::ROPE),
rearrange_cache(capacity, OperatorType::REARRANGE),
causal_softmax_cache(capacity, OperatorType::CAUSAL_SOFTMAX),
swiglu_cache(capacity, OperatorType::SWIGLU),
random_sample_cache(capacity, OperatorType::RANDOM_SAMPLE) {}
CacheManager(size_t capacity = 100)
: add_cache(capacity, infiniopDestroyAddDescriptor),
rms_norm_cache(capacity, infiniopDestroyRMSNormDescriptor),
gemm_cache(capacity, infiniopDestroyGemmDescriptor),
rope_cache(capacity, infiniopDestroyRoPEDescriptor),
rearrange_cache(capacity, infiniopDestroyRearrangeDescriptor),
causal_softmax_cache(capacity, infiniopDestroyCausalSoftmaxDescriptor),
swiglu_cache(capacity, infiniopDestroySwiGLUDescriptor),
random_sample_cache(capacity, infiniopDestroyRandomSampleDescriptor) {}
// Add operations
bool getAddDescriptor(size_t key, infiniopAddDescriptor_t &desc) {
......
......@@ -409,4 +409,4 @@ __C void destroyJiugeModel(struct JiugeModel *model) {
}
delete model;
}
\ No newline at end of file
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment