Commit ebd35e7e authored by wooway777's avatar wooway777
Browse files

issue/21 - Improved Key Generation Logic

parent 79bc0438
......@@ -273,27 +273,10 @@ public:
random_sample_cache.put(key, desc);
}
static size_t createDescriptorKey(std::shared_ptr<Tensor> desc0,
std::shared_ptr<Tensor> desc1,
std::shared_ptr<Tensor> desc2,
std::shared_ptr<Tensor> desc3,
std::shared_ptr<Tensor> desc4) {
template <typename... Tensors>
static size_t createDescriptorKey(Tensors... tensors) {
size_t seed = 0;
if (desc0) {
hash_combine(seed, computeTensorDescHash(desc0));
}
if (desc1) {
hash_combine(seed, computeTensorDescHash(desc1));
}
if (desc2) {
hash_combine(seed, computeTensorDescHash(desc2));
}
if (desc3) {
hash_combine(seed, computeTensorDescHash(desc3));
}
if (desc4) {
hash_combine(seed, computeTensorDescHash(desc4));
}
(..., (tensors ? hash_combine(seed, computeTensorDescHash(tensors)) : (void)0));
return seed;
}
};
......
......@@ -15,8 +15,7 @@ void InferenceContext::ensure_workspace(size_t required_size) {
void InferenceContext::add(std::shared_ptr<Tensor> c,
std::shared_ptr<Tensor> a,
std::shared_ptr<Tensor> b) {
size_t key = CacheManager::createDescriptorKey(c, a, b,
nullptr, nullptr);
size_t key = CacheManager::createDescriptorKey(c, a, b);
infiniopAddDescriptor_t desc;
if (!cache_manager->getAddDescriptor(key, desc)) {
......@@ -38,7 +37,7 @@ void InferenceContext::rmsnorm(std::shared_ptr<Tensor> y,
std::shared_ptr<Tensor> x,
std::shared_ptr<Tensor> w,
float epsilon) {
size_t key = CacheManager::createDescriptorKey(y, x, w, nullptr, nullptr);
size_t key = CacheManager::createDescriptorKey(y, x, w);
infiniopRMSNormDescriptor_t desc;
if (!cache_manager->getRMSNormDescriptor(key, desc)) {
......@@ -61,8 +60,7 @@ void InferenceContext::gemm(std::shared_ptr<Tensor> c,
std::shared_ptr<Tensor> a,
std::shared_ptr<Tensor> b,
float alpha, float beta) {
size_t key = CacheManager::createDescriptorKey(c, a, b,
nullptr, nullptr);
size_t key = CacheManager::createDescriptorKey(c, a, b);
infiniopGemmDescriptor_t desc;
if (!cache_manager->getGemmDescriptor(key, desc)) {
......@@ -82,7 +80,7 @@ void InferenceContext::gemm(std::shared_ptr<Tensor> c,
void InferenceContext::rearrange(std::shared_ptr<Tensor> dst,
std::shared_ptr<Tensor> src) {
size_t key = CacheManager::createDescriptorKey(dst, src, nullptr, nullptr, nullptr);
size_t key = CacheManager::createDescriptorKey(dst, src);
infiniopRearrangeDescriptor_t desc;
if (!cache_manager->getRearrangeDescriptor(key, desc)) {
......@@ -125,7 +123,7 @@ void InferenceContext::rope(std::shared_ptr<Tensor> q,
void InferenceContext::causalSoftmax(std::shared_ptr<Tensor> y,
std::shared_ptr<Tensor> x) {
size_t key = CacheManager::createDescriptorKey(y, x, nullptr, nullptr, nullptr);
size_t key = CacheManager::createDescriptorKey(y, x);
infiniopCausalSoftmaxDescriptor_t desc;
if (!cache_manager->getCausalSoftmaxDescriptor(key, desc)) {
......@@ -146,7 +144,7 @@ void InferenceContext::causalSoftmax(std::shared_ptr<Tensor> y,
void InferenceContext::swiglu(std::shared_ptr<Tensor> out,
std::shared_ptr<Tensor> up,
std::shared_ptr<Tensor> gate) {
size_t key = CacheManager::createDescriptorKey(out, up, gate, nullptr, nullptr);
size_t key = CacheManager::createDescriptorKey(out, up, gate);
infiniopSwiGLUDescriptor_t desc;
if (!cache_manager->getSwiGLUDescriptor(key, desc)) {
......@@ -167,7 +165,7 @@ void InferenceContext::swiglu(std::shared_ptr<Tensor> out,
void InferenceContext::randomSample(std::shared_ptr<Tensor> out,
std::shared_ptr<Tensor> prob,
float random_val, float top_p, uint32_t top_k, float temperature) {
size_t key = CacheManager::createDescriptorKey(out, prob, nullptr, nullptr, nullptr);
size_t key = CacheManager::createDescriptorKey(out, prob);
infiniopRandomSampleDescriptor_t desc;
if (!cache_manager->getRandomSampleDescriptor(key, desc)) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment