Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
6ca0e313
Unverified
Commit
6ca0e313
authored
May 29, 2025
by
PanZezhong1725
Committed by
GitHub
May 29, 2025
Browse files
Merge pull request #235 from InfiniTensor/issue/234
issue/234 昇腾gemm缓存executor
parents
3e5842c3
676a52a7
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
46 additions
and
15 deletions
+46
-15
src/infiniop/ops/gemm/ascend/gemm_ascend.cc
src/infiniop/ops/gemm/ascend/gemm_ascend.cc
+46
-15
No files found.
src/infiniop/ops/gemm/ascend/gemm_ascend.cc
View file @
6ca0e313
...
@@ -3,6 +3,26 @@
...
@@ -3,6 +3,26 @@
#include <aclnnop/aclnn_matmul.h>
#include <aclnnop/aclnn_matmul.h>
#include <aclnnop/level2/aclnn_gemm.h>
#include <aclnnop/level2/aclnn_gemm.h>
#include <cstring>
#include <unordered_map>
// Custom hash function for alpha beta pair<float, float>
struct
FloatPairHash
{
size_t
operator
()(
const
std
::
pair
<
float
,
float
>
&
p
)
const
{
uint64_t
combined
;
std
::
memcpy
(
reinterpret_cast
<
char
*>
(
&
combined
),
&
p
.
first
,
sizeof
(
float
));
std
::
memcpy
(
reinterpret_cast
<
char
*>
(
&
combined
)
+
sizeof
(
float
),
&
p
.
second
,
sizeof
(
float
));
return
std
::
hash
<
uint64_t
>
()(
combined
);
}
};
struct
FloatPairEqual
{
bool
operator
()(
const
std
::
pair
<
float
,
float
>
&
a
,
const
std
::
pair
<
float
,
float
>
&
b
)
const
{
return
a
.
first
==
b
.
first
&&
a
.
second
==
b
.
second
;
}
};
namespace
op
::
gemm
::
ascend
{
namespace
op
::
gemm
::
ascend
{
struct
Descriptor
::
Opaque
{
struct
Descriptor
::
Opaque
{
...
@@ -11,11 +31,17 @@ struct Descriptor::Opaque {
...
@@ -11,11 +31,17 @@ struct Descriptor::Opaque {
// see doc:
// see doc:
// https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC3alpha002/apiref/appdevgapi/context/aclnnBatchMatMul.md
// https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC3alpha002/apiref/appdevgapi/context/aclnnBatchMatMul.md
int8_t
mt
;
int8_t
mt
;
// alpha&beta hashmap
std
::
unordered_map
<
std
::
pair
<
float
,
float
>
,
aclOpExecutor
*
,
FloatPairHash
,
FloatPairEqual
>
lookup
;
~
Opaque
()
{
~
Opaque
()
{
delete
c
;
delete
c
;
delete
a
;
delete
a
;
delete
b
;
delete
b
;
for
(
auto
&
item
:
lookup
)
{
aclDestroyAclOpExecutor
(
item
.
second
);
}
lookup
.
clear
();
}
}
};
};
...
@@ -54,15 +80,16 @@ infiniStatus_t Descriptor::create(
...
@@ -54,15 +80,16 @@ infiniStatus_t Descriptor::create(
ta
=
a
->
tensor
,
ta
=
a
->
tensor
,
tb
=
b
->
tensor
;
tb
=
b
->
tensor
;
std
::
unordered_map
<
std
::
pair
<
float
,
float
>
,
aclOpExecutor
*
,
FloatPairHash
,
FloatPairEqual
>
lookup
;
aclOpExecutor
*
executor
=
nullptr
;
aclOpExecutor
*
executor
=
nullptr
;
size_t
workspace_size
=
0
;
size_t
workspace_size
=
0
;
// aclnnGemm support C = alpha * A @ B + beta * C
// see
// https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC3alpha003/apiref/aolapi/context/aclnnGemm.md
// use alpha = 0.5, beta = 0.5 temporarily
int8_t
mt
=
1
;
int8_t
mt
=
1
;
CHECK_ACL
(
aclnnGemmGetWorkspaceSize
(
ta
,
tb
,
tc
,
.5
,
.5
,
0
,
0
,
tc
,
mt
,
&
workspace_size
,
&
executor
));
CHECK_ACL
(
aclnnGemmGetWorkspaceSize
(
ta
,
tb
,
tc
,
1.
,
0.
,
0
,
0
,
tc
,
mt
,
&
workspace_size
,
&
executor
));
CHECK_ACL
(
aclSetAclOpExecutorRepeatable
(
executor
));
lookup
[
std
::
make_pair
(
1.0
f
,
0.0
f
)]
=
executor
;
CHECK_ACL
(
aclnnGemmGetWorkspaceSize
(
ta
,
tb
,
tc
,
1.
,
1.
,
0
,
0
,
tc
,
mt
,
&
workspace_size
,
&
executor
));
CHECK_ACL
(
aclSetAclOpExecutorRepeatable
(
executor
));
lookup
[
std
::
make_pair
(
1.0
f
,
1.0
f
)]
=
executor
;
*
desc_ptr
=
new
Descriptor
(
*
desc_ptr
=
new
Descriptor
(
dtype
,
info
,
workspace_size
,
dtype
,
info
,
workspace_size
,
...
@@ -71,11 +98,9 @@ infiniStatus_t Descriptor::create(
...
@@ -71,11 +98,9 @@ infiniStatus_t Descriptor::create(
a
,
a
,
b
,
b
,
mt
,
mt
,
},
std
::
move
(
lookup
)
},
handle
->
device
,
handle
->
device_id
);
handle
->
device
,
handle
->
device_id
);
aclDestroyAclOpExecutor
(
executor
);
return
INFINI_STATUS_SUCCESS
;
return
INFINI_STATUS_SUCCESS
;
}
}
...
@@ -93,16 +118,22 @@ infiniStatus_t Descriptor::calculate(
...
@@ -93,16 +118,22 @@ infiniStatus_t Descriptor::calculate(
ta
=
_opaque
->
a
->
tensor
,
ta
=
_opaque
->
a
->
tensor
,
tb
=
_opaque
->
b
->
tensor
;
tb
=
_opaque
->
b
->
tensor
;
size_t
workspace_size
=
0
;
size_t
workspace_size
=
_workspace_size
;
aclOpExecutor
*
executor
=
nullptr
;
aclOpExecutor
*
executor
;
auto
key
=
std
::
make_pair
(
alpha
,
beta
);
if
(
_opaque
->
lookup
.
find
(
key
)
!=
_opaque
->
lookup
.
end
())
{
executor
=
_opaque
->
lookup
[
key
];
}
else
{
CHECK_ACL
(
aclnnGemmGetWorkspaceSize
(
CHECK_ACL
(
aclnnGemmGetWorkspaceSize
(
ta
,
tb
,
tc
,
alpha
,
beta
,
0
,
0
,
tc
,
_opaque
->
mt
,
ta
,
tb
,
tc
,
alpha
,
beta
,
0
,
0
,
tc
,
_opaque
->
mt
,
&
workspace_size
,
&
executor
));
&
workspace_size
,
&
executor
));
CHECK_ACL
(
aclSetAclOpExecutorRepeatable
(
executor
));
_opaque
->
lookup
[
key
]
=
executor
;
}
if
(
workspaceSize_
<
workspace_size
)
{
if
(
workspaceSize_
<
workspace_size
)
{
return
INFINI_STATUS_INSUFFICIENT_WORKSPACE
;
return
INFINI_STATUS_INSUFFICIENT_WORKSPACE
;
}
}
CHECK_ACL
(
aclSetAclOpExecutorRepeatable
(
executor
));
auto
unit
=
infiniSizeOf
(
_dtype
);
auto
unit
=
infiniSizeOf
(
_dtype
);
for
(
size_t
i
=
0
;
i
<
_info
.
batch
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
_info
.
batch
;
++
i
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment