Commit 2a2ddc57 authored by wooway777's avatar wooway777
Browse files

issue/21 - Improved Descriptor Destruction Logic

parent ebd35e7e
...@@ -34,15 +34,23 @@ inline size_t computeTensorDescHash(std::shared_ptr<Tensor> tensor) { ...@@ -34,15 +34,23 @@ inline size_t computeTensorDescHash(std::shared_ptr<Tensor> tensor) {
return seed; return seed;
} }
enum class OperatorType { class IDescriptorDestroyer {
ADD, public:
RMS_NORM, virtual ~IDescriptorDestroyer() = default;
GEMM, virtual void destroy(void *descriptor) = 0;
ROPE, };
REARRANGE,
CAUSAL_SOFTMAX, template <typename DescriptorType>
SWIGLU, class DescriptorDestroyer : public IDescriptorDestroyer {
RANDOM_SAMPLE using DestroyFunc = infiniStatus_t (*)(DescriptorType);
DestroyFunc destroyFunc;
public:
DescriptorDestroyer(DestroyFunc func) : destroyFunc(func) {}
void destroy(void *descriptor) override {
destroyFunc(*static_cast<DescriptorType *>(descriptor));
}
}; };
template <typename DescriptorType> template <typename DescriptorType>
...@@ -63,43 +71,14 @@ private: ...@@ -63,43 +71,14 @@ private:
CacheNode *tail; CacheNode *tail;
const size_t capacity; const size_t capacity;
size_t size; size_t size;
const OperatorType opType; std::unique_ptr<IDescriptorDestroyer> destroyer;
void destroyDescriptor(DescriptorType &desc) {
switch (opType) {
case OperatorType::ADD:
infiniopDestroyAddDescriptor(desc);
break;
case OperatorType::RMS_NORM:
infiniopDestroyRMSNormDescriptor(desc);
break;
case OperatorType::GEMM:
infiniopDestroyGemmDescriptor(desc);
break;
case OperatorType::ROPE:
infiniopDestroyRoPEDescriptor(desc);
break;
case OperatorType::REARRANGE:
infiniopDestroyRearrangeDescriptor(desc);
break;
case OperatorType::CAUSAL_SOFTMAX:
infiniopDestroyCausalSoftmaxDescriptor(desc);
break;
case OperatorType::SWIGLU:
infiniopDestroySwiGLUDescriptor(desc);
break;
case OperatorType::RANDOM_SAMPLE:
infiniopDestroyRandomSampleDescriptor(desc);
break;
default:
throw std::runtime_error("Unknown descriptor type");
}
}
void removeNode(CacheNode *node) { void removeNode(CacheNode *node) {
node->prev->next = node->next; node->prev->next = node->next;
node->next->prev = node->prev; node->next->prev = node->prev;
destroyDescriptor(node->desc); if (destroyer) {
destroyer->destroy(&node->desc);
}
cache.erase(node->key); cache.erase(node->key);
delete node; delete node;
--size; --size;
...@@ -126,7 +105,9 @@ private: ...@@ -126,7 +105,9 @@ private:
} }
public: public:
LRUDescriptorCache(size_t c, OperatorType t) : capacity(c), size(0), opType(t) { template <typename DestroyFunc>
LRUDescriptorCache(size_t c, DestroyFunc destroyFunc)
: capacity(c), size(0), destroyer(std::make_unique<DescriptorDestroyer<DescriptorType>>(destroyFunc)) {
head = new CacheNode(); head = new CacheNode();
tail = new CacheNode(); tail = new CacheNode();
head->next = tail; head->next = tail;
...@@ -158,7 +139,9 @@ public: ...@@ -158,7 +139,9 @@ public:
if (it != cache.end()) { if (it != cache.end()) {
// Key already exists, update the descriptor // Key already exists, update the descriptor
CacheNode *node = it->second; CacheNode *node = it->second;
destroyDescriptor(node->desc); if (destroyer) {
destroyer->destroy(&node->desc);
}
node->desc = descriptor; node->desc = descriptor;
moveToTop(node); moveToTop(node);
return; return;
...@@ -192,14 +175,15 @@ private: ...@@ -192,14 +175,15 @@ private:
LRUDescriptorCache<infiniopRandomSampleDescriptor_t> random_sample_cache; LRUDescriptorCache<infiniopRandomSampleDescriptor_t> random_sample_cache;
public: public:
CacheManager(size_t capacity = 100) : add_cache(capacity, OperatorType::ADD), CacheManager(size_t capacity = 100)
rms_norm_cache(capacity, OperatorType::RMS_NORM), : add_cache(capacity, infiniopDestroyAddDescriptor),
gemm_cache(capacity, OperatorType::GEMM), rms_norm_cache(capacity, infiniopDestroyRMSNormDescriptor),
rope_cache(capacity, OperatorType::ROPE), gemm_cache(capacity, infiniopDestroyGemmDescriptor),
rearrange_cache(capacity, OperatorType::REARRANGE), rope_cache(capacity, infiniopDestroyRoPEDescriptor),
causal_softmax_cache(capacity, OperatorType::CAUSAL_SOFTMAX), rearrange_cache(capacity, infiniopDestroyRearrangeDescriptor),
swiglu_cache(capacity, OperatorType::SWIGLU), causal_softmax_cache(capacity, infiniopDestroyCausalSoftmaxDescriptor),
random_sample_cache(capacity, OperatorType::RANDOM_SAMPLE) {} swiglu_cache(capacity, infiniopDestroySwiGLUDescriptor),
random_sample_cache(capacity, infiniopDestroyRandomSampleDescriptor) {}
// Add operations // Add operations
bool getAddDescriptor(size_t key, infiniopAddDescriptor_t &desc) { bool getAddDescriptor(size_t key, infiniopAddDescriptor_t &desc) {
......
...@@ -409,4 +409,4 @@ __C void destroyJiugeModel(struct JiugeModel *model) { ...@@ -409,4 +409,4 @@ __C void destroyJiugeModel(struct JiugeModel *model) {
} }
delete model; delete model;
} }
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment