Unverified Commit 6ca0e313 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #235 from InfiniTensor/issue/234

issue/234 昇腾gemm缓存executor
parents 3e5842c3 676a52a7
......@@ -3,6 +3,26 @@
#include <aclnnop/aclnn_matmul.h>
#include <aclnnop/level2/aclnn_gemm.h>
#include <cstring>
#include <unordered_map>
// Custom hash function for alpha beta pair<float, float>
struct FloatPairHash {
size_t operator()(const std::pair<float, float> &p) const {
uint64_t combined;
std::memcpy(reinterpret_cast<char *>(&combined), &p.first, sizeof(float));
std::memcpy(reinterpret_cast<char *>(&combined) + sizeof(float), &p.second, sizeof(float));
return std::hash<uint64_t>()(combined);
}
};
struct FloatPairEqual {
bool operator()(const std::pair<float, float> &a, const std::pair<float, float> &b) const {
return a.first == b.first && a.second == b.second;
}
};
namespace op::gemm::ascend {
struct Descriptor::Opaque {
......@@ -11,11 +31,17 @@ struct Descriptor::Opaque {
// see doc:
// https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC3alpha002/apiref/appdevgapi/context/aclnnBatchMatMul.md
int8_t mt;
// alpha&beta hashmap
std::unordered_map<std::pair<float, float>, aclOpExecutor *, FloatPairHash, FloatPairEqual> lookup;
~Opaque() {
delete c;
delete a;
delete b;
for (auto &item : lookup) {
aclDestroyAclOpExecutor(item.second);
}
lookup.clear();
}
};
......@@ -54,15 +80,16 @@ infiniStatus_t Descriptor::create(
ta = a->tensor,
tb = b->tensor;
std::unordered_map<std::pair<float, float>, aclOpExecutor *, FloatPairHash, FloatPairEqual> lookup;
aclOpExecutor *executor = nullptr;
size_t workspace_size = 0;
// aclnnGemm support C = alpha * A @ B + beta * C
// see
// https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC3alpha003/apiref/aolapi/context/aclnnGemm.md
// use alpha = 0.5, beta = 0.5 temporarily
int8_t mt = 1;
CHECK_ACL(aclnnGemmGetWorkspaceSize(ta, tb, tc, .5, .5, 0, 0, tc, mt, &workspace_size, &executor));
CHECK_ACL(aclnnGemmGetWorkspaceSize(ta, tb, tc, 1., 0., 0, 0, tc, mt, &workspace_size, &executor));
CHECK_ACL(aclSetAclOpExecutorRepeatable(executor));
lookup[std::make_pair(1.0f, 0.0f)] = executor;
CHECK_ACL(aclnnGemmGetWorkspaceSize(ta, tb, tc, 1., 1., 0, 0, tc, mt, &workspace_size, &executor));
CHECK_ACL(aclSetAclOpExecutorRepeatable(executor));
lookup[std::make_pair(1.0f, 1.0f)] = executor;
*desc_ptr = new Descriptor(
dtype, info, workspace_size,
......@@ -71,11 +98,9 @@ infiniStatus_t Descriptor::create(
a,
b,
mt,
},
std::move(lookup)},
handle->device, handle->device_id);
aclDestroyAclOpExecutor(executor);
return INFINI_STATUS_SUCCESS;
}
......@@ -93,16 +118,22 @@ infiniStatus_t Descriptor::calculate(
ta = _opaque->a->tensor,
tb = _opaque->b->tensor;
size_t workspace_size = 0;
aclOpExecutor *executor = nullptr;
size_t workspace_size = _workspace_size;
aclOpExecutor *executor;
auto key = std::make_pair(alpha, beta);
if (_opaque->lookup.find(key) != _opaque->lookup.end()) {
executor = _opaque->lookup[key];
} else {
CHECK_ACL(aclnnGemmGetWorkspaceSize(
ta, tb, tc, alpha, beta, 0, 0, tc, _opaque->mt,
&workspace_size, &executor));
CHECK_ACL(aclSetAclOpExecutorRepeatable(executor));
_opaque->lookup[key] = executor;
}
if (workspaceSize_ < workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
CHECK_ACL(aclSetAclOpExecutorRepeatable(executor));
auto unit = infiniSizeOf(_dtype);
for (size_t i = 0; i < _info.batch; ++i) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment