Unverified Commit 5330d5fa authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

issue/39 - Used macro on op cache (#40)

parents 22804eaa c219718e
......@@ -137,133 +137,45 @@ public:
LRUDescriptorCache &operator=(const LRUDescriptorCache &) = delete;
};
class CacheManager {
private:
const size_t DEFAULT_CACHE_CAPACITY = 128;
LRUDescriptorCache<infiniopAddDescriptor_t> add_cache;
LRUDescriptorCache<infiniopRMSNormDescriptor_t> rms_norm_cache;
LRUDescriptorCache<infiniopGemmDescriptor_t> gemm_cache;
LRUDescriptorCache<infiniopRoPEDescriptor_t> rope_cache;
LRUDescriptorCache<infiniopRoPEv2Descriptor_t> rope_v2_cache;
LRUDescriptorCache<infiniopRearrangeDescriptor_t> rearrange_cache;
LRUDescriptorCache<infiniopCausalSoftmaxDescriptor_t> causal_softmax_cache;
LRUDescriptorCache<infiniopTopkrouterDescriptor_t> causal_topkrouter_cache;
LRUDescriptorCache<infiniopSwiGLUDescriptor_t> swiglu_cache;
LRUDescriptorCache<infiniopRandomSampleDescriptor_t> random_sample_cache;
LRUDescriptorCache<infiniopDequantizeDescriptor_t> dequantize_cache;
public:
CacheManager(size_t capacity = 100)
: add_cache(capacity, infiniopDestroyAddDescriptor),
rms_norm_cache(capacity, infiniopDestroyRMSNormDescriptor),
gemm_cache(capacity, infiniopDestroyGemmDescriptor),
rope_cache(capacity, infiniopDestroyRoPEDescriptor),
rope_v2_cache(capacity, infiniopDestroyRoPEv2Descriptor),
rearrange_cache(capacity, infiniopDestroyRearrangeDescriptor),
causal_softmax_cache(capacity, infiniopDestroyCausalSoftmaxDescriptor),
causal_topkrouter_cache(capacity, infiniopDestroyTopkrouterDescriptor),
swiglu_cache(capacity, infiniopDestroySwiGLUDescriptor),
random_sample_cache(capacity, infiniopDestroyRandomSampleDescriptor),
dequantize_cache(capacity, infiniopDestroyDequantizeDescriptor) {}
// Add operations
bool getAddDescriptor(size_t key, infiniopAddDescriptor_t &desc) {
return add_cache.get(key, desc);
}
void putAddDescriptor(size_t key, const infiniopAddDescriptor_t &desc) {
add_cache.put(key, desc);
}
// RMSNorm operations
bool getRMSNormDescriptor(size_t key, infiniopRMSNormDescriptor_t &desc) {
return rms_norm_cache.get(key, desc);
}
void putRMSNormDescriptor(size_t key, const infiniopRMSNormDescriptor_t &desc) {
rms_norm_cache.put(key, desc);
}
// GEMM operations
bool getGemmDescriptor(size_t key, infiniopGemmDescriptor_t &desc) {
return gemm_cache.get(key, desc);
}
void putGemmDescriptor(size_t key, const infiniopGemmDescriptor_t &desc) {
gemm_cache.put(key, desc);
}
// RoPE operations
bool getRoPEDescriptor(size_t key, infiniopRoPEDescriptor_t &desc) {
return rope_cache.get(key, desc);
}
void putRoPEDescriptor(size_t key, const infiniopRoPEDescriptor_t &desc) {
rope_cache.put(key, desc);
}
bool getRoPEv2Descriptor(size_t key, infiniopRoPEv2Descriptor_t &desc) {
return rope_v2_cache.get(key, desc);
}
void putRoPEv2Descriptor(size_t key, const infiniopRoPEv2Descriptor_t &desc) {
rope_v2_cache.put(key, desc);
}
// Rearrange operations
bool getRearrangeDescriptor(size_t key, infiniopRearrangeDescriptor_t &desc) {
return rearrange_cache.get(key, desc);
}
void putRearrangeDescriptor(size_t key, const infiniopRearrangeDescriptor_t &desc) {
rearrange_cache.put(key, desc);
}
// Softmax operations
bool getCausalSoftmaxDescriptor(size_t key, infiniopCausalSoftmaxDescriptor_t &desc) {
return causal_softmax_cache.get(key, desc);
}
void putCausalSoftmaxDescriptor(size_t key, const infiniopCausalSoftmaxDescriptor_t &desc) {
causal_softmax_cache.put(key, desc);
}
// Topkrouter operations
bool getTopkrouterDescriptor(size_t key, infiniopTopkrouterDescriptor_t &desc) {
return causal_topkrouter_cache.get(key, desc);
}
void putTopkrouterDescriptor(size_t key, const infiniopTopkrouterDescriptor_t &desc) {
causal_topkrouter_cache.put(key, desc);
}
// Helper macro to generate the destroy function name
#define DESTROY_FUNC(OpType) infiniopDestroy##OpType##Descriptor
// SwiGLU operations
bool getSwiGLUDescriptor(size_t key, infiniopSwiGLUDescriptor_t &desc) {
return swiglu_cache.get(key, desc);
// Declare cache and access functions
#define DECLARE_OP_CACHE(OpType) \
LRUDescriptorCache<infiniop##OpType##Descriptor_t> OpType##_cache; \
bool get##OpType##Descriptor(size_t key, infiniop##OpType##Descriptor_t &desc) { \
return OpType##_cache.get(key, desc); \
} \
void put##OpType##Descriptor(size_t key, const infiniop##OpType##Descriptor_t &desc) { \
OpType##_cache.put(key, desc); \
}
void putSwiGLUDescriptor(size_t key, const infiniopSwiGLUDescriptor_t &desc) {
swiglu_cache.put(key, desc);
}
// Random Sample operations
bool getRandomSampleDescriptor(size_t key, infiniopRandomSampleDescriptor_t &desc) {
return random_sample_cache.get(key, desc);
}
void putRandomSampleDescriptor(size_t key, const infiniopRandomSampleDescriptor_t &desc) {
random_sample_cache.put(key, desc);
}
// Dequantize operations
bool getDequantizeDescriptor(size_t key, infiniopDequantizeDescriptor_t &desc) {
return dequantize_cache.get(key, desc);
}
class CacheManager {
public:
DECLARE_OP_CACHE(Add)
DECLARE_OP_CACHE(RMSNorm)
DECLARE_OP_CACHE(Gemm)
DECLARE_OP_CACHE(RoPE)
DECLARE_OP_CACHE(RoPEv2)
DECLARE_OP_CACHE(Rearrange)
DECLARE_OP_CACHE(CausalSoftmax)
DECLARE_OP_CACHE(Topkrouter)
DECLARE_OP_CACHE(SwiGLU)
DECLARE_OP_CACHE(RandomSample)
DECLARE_OP_CACHE(Dequantize)
void putDequantizeDescriptor(size_t key, const infiniopDequantizeDescriptor_t &desc) {
dequantize_cache.put(key, desc);
}
CacheManager(size_t capacity = 100)
: Add_cache(capacity, DESTROY_FUNC(Add)),
RMSNorm_cache(capacity, DESTROY_FUNC(RMSNorm)),
Gemm_cache(capacity, DESTROY_FUNC(Gemm)),
RoPE_cache(capacity, DESTROY_FUNC(RoPE)),
RoPEv2_cache(capacity, DESTROY_FUNC(RoPEv2)),
Rearrange_cache(capacity, DESTROY_FUNC(Rearrange)),
CausalSoftmax_cache(capacity, DESTROY_FUNC(CausalSoftmax)),
Topkrouter_cache(capacity, DESTROY_FUNC(Topkrouter)),
SwiGLU_cache(capacity, DESTROY_FUNC(SwiGLU)),
RandomSample_cache(capacity, DESTROY_FUNC(RandomSample)),
Dequantize_cache(capacity, DESTROY_FUNC(Dequantize)) {}
template <typename... Tensors>
static size_t createDescriptorKey(Tensors... tensors) {
......@@ -273,4 +185,7 @@ public:
}
};
#undef DESTROY_FUNC
#undef DECLARE_OP_CACHE
#endif // CACHE_MANAGER_HPP
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment