Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
a51e1d56
Unverified
Commit
a51e1d56
authored
Mar 17, 2025
by
PanZezhong1725
Committed by
GitHub
Mar 17, 2025
Browse files
Merge pull request #96 from PanZezhong1725/issue/87/ascend
issur/87/feat:修改昇腾平台的handle
parents
89ebdac8
018e2546
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
167 additions
and
250 deletions
+167
-250
src/infiniop/devices/ascend/ascend_handle.cc
src/infiniop/devices/ascend/ascend_handle.cc
+8
-9
src/infiniop/devices/ascend/ascend_handle.h
src/infiniop/devices/ascend/ascend_handle.h
+8
-5
src/infiniop/devices/ascend/common_ascend.cc
src/infiniop/devices/ascend/common_ascend.cc
+102
-10
src/infiniop/devices/ascend/common_ascend.h
src/infiniop/devices/ascend/common_ascend.h
+28
-7
src/infiniop/devices/ascend/tensor_aclnn.cc
src/infiniop/devices/ascend/tensor_aclnn.cc
+0
-136
src/infiniop/devices/ascend/tensor_aclnn.h
src/infiniop/devices/ascend/tensor_aclnn.h
+0
-39
src/infiniop/devices/handle.cc
src/infiniop/devices/handle.cc
+2
-6
src/infiniop/ops/matmul/ascend/matmul_ascend.cc
src/infiniop/ops/matmul/ascend/matmul_ascend.cc
+19
-38
No files found.
src/infiniop/devices/ascend/ascend_handle.cc
View file @
a51e1d56
#include "
common_ascend
.h"
#include "
ascend_handle
.h"
infiniStatus_t
createAscendHandle
(
infiniopAscendHandle_t
*
handle_ptr
)
{
namespace
device
::
ascend
{
int
device_id
=
0
;
CHECK_ACL
(
aclrtGetDevice
(
&
device_id
));
*
handle_ptr
=
new
InfiniopAscendHandle
{
INFINI_DEVICE_ASCEND
,
device_id
};
Handle
::
Handle
(
int
device_id
)
:
InfiniopHandle
{
INFINI_DEVICE_ASCEND
,
device_id
}
{}
return
INFINI_STATUS_SUCCESS
;
infiniStatus_t
Handle
::
create
(
InfiniopHandle
**
Handle_ptr
,
int
device_id
)
{
}
*
Handle_ptr
=
new
Handle
(
device_id
);
infiniStatus_t
destroyAscendHandle
(
infiniopAscendHandle_t
handle_ptr
)
{
delete
handle_ptr
;
return
INFINI_STATUS_SUCCESS
;
return
INFINI_STATUS_SUCCESS
;
}
}
}
// namespace device::ascend
src/infiniop/devices/ascend/ascend_handle.h
View file @
a51e1d56
...
@@ -2,13 +2,16 @@
...
@@ -2,13 +2,16 @@
#define __INFINIOP_ASCEND_HANDLE_H__
#define __INFINIOP_ASCEND_HANDLE_H__
#include "../../handle.h"
#include "../../handle.h"
#include "infinicore.h"
struct
InfiniopAscendHandle
;
namespace
device
::
ascend
{
typedef
struct
InfiniopAscendHandle
*
infiniopAscendHandle_t
;
infiniStatus_t
createAscendHandle
(
infiniopAscendHandle_t
*
handle_ptr
);
class
Handle
:
public
InfiniopHandle
{
infiniStatus_t
destroyAscendHandle
(
infiniopAscendHandle_t
handle_ptr
);
Handle
(
int
device_id
);
public:
static
infiniStatus_t
create
(
InfiniopHandle
**
handle_ptr
,
int
device_id
);
};
}
// namespace device::ascend
#endif
#endif
src/infiniop/devices/ascend/common_ascend.cc
View file @
a51e1d56
#include "common_ascend.h"
#include "common_ascend.h"
infiniStatus_t
mallocWorkspace
(
void
**
workspaceAddr
,
size_t
workspaceSize
)
{
std
::
vector
<
int64_t
>
inferStorageShape
(
std
::
vector
<
int64_t
>
shape
,
std
::
vector
<
int64_t
>
strides
)
{
*
workspaceAddr
=
nullptr
;
auto
index
=
std
::
max_element
(
strides
.
begin
(),
strides
.
end
());
if
(
workspaceSize
>
0
)
{
uint64_t
max_stride_index
=
std
::
distance
(
strides
.
begin
(),
index
);
CHECK_ACL
(
aclrtMalloc
(
workspaceAddr
,
workspaceSize
,
auto
storageShape
=
std
::
vector
<
int64_t
>
({
shape
[
max_stride_index
]
*
strides
[
max_stride_index
]});
ACL_MEM_MALLOC_HUGE_FIRST
));
return
storageShape
;
}
aclnnTensorDescriptor
::
aclnnTensorDescriptor
(
infiniopTensorDescriptor_t
desc
,
void
*
data
)
{
this
->
ndim
=
desc
->
ndim
();
this
->
shape
=
std
::
vector
<
int64_t
>
(
ndim
);
this
->
strides
=
std
::
vector
<
int64_t
>
(
ndim
);
for
(
uint64_t
i
=
0
;
i
<
ndim
;
++
i
)
{
this
->
shape
[
i
]
=
static_cast
<
int64_t
>
(
desc
->
dim
(
i
));
this
->
strides
[
i
]
=
desc
->
stride
(
i
);
}
}
return
INFINI_STATUS_SUCCESS
;
this
->
storageShape
=
inferStorageShape
(
this
->
shape
,
this
->
strides
);
this
->
dataType
=
toAclDataType
(
desc
->
dtype
());
// TODO: support other formats
this
->
format
=
aclFormat
::
ACL_FORMAT_ND
;
this
->
tensor
=
aclCreateTensor
(
this
->
shape
.
data
(),
this
->
ndim
,
this
->
dataType
,
this
->
strides
.
data
(),
this
->
offset
,
this
->
format
,
this
->
storageShape
.
data
(),
this
->
storageNdim
,
data
);
}
aclnnTensorDescriptor
::
aclnnTensorDescriptor
(
aclDataType
dtype
,
const
std
::
vector
<
int64_t
>
&
shape
,
const
std
::
vector
<
int64_t
>
&
strides
,
void
*
data
)
{
this
->
ndim
=
shape
.
size
();
this
->
shape
=
shape
;
this
->
strides
=
strides
;
this
->
dataType
=
dtype
;
this
->
format
=
aclFormat
::
ACL_FORMAT_ND
;
this
->
storageShape
=
inferStorageShape
(
this
->
shape
,
this
->
strides
);
this
->
tensor
=
aclCreateTensor
(
this
->
shape
.
data
(),
this
->
ndim
,
this
->
dataType
,
this
->
strides
.
data
(),
this
->
offset
,
this
->
format
,
this
->
storageShape
.
data
(),
this
->
storageNdim
,
data
);
}
}
infiniStatus_t
freeWorkspace
(
void
*
workspaceAddr
)
{
aclnnTensorDescriptor
::~
aclnnTensorDescriptor
()
{
if
(
workspaceAddr
!=
nullptr
)
{
if
(
this
->
tensor
)
{
CHECK_ACL
(
aclrtFree
(
workspaceAddr
));
aclDestroyTensor
(
this
->
tensor
);
this
->
tensor
=
nullptr
;
}
}
return
INFINI_STATUS_SUCCESS
;
}
}
aclDataType
toAclDataType
(
infiniDtype_t
dt
)
{
aclDataType
toAclDataType
(
infiniDtype_t
dt
)
{
...
@@ -129,3 +169,55 @@ const char *formatToString(aclFormat format) {
...
@@ -129,3 +169,55 @@ const char *formatToString(aclFormat format) {
return
"UNKNOWN"
;
return
"UNKNOWN"
;
}
}
}
}
std
::
string
aclnnTensorDescriptor
::
toString
()
{
std
::
ostringstream
oss
;
// 写入 ndim
oss
<<
"ndim: "
<<
this
->
ndim
<<
"
\n
"
;
// 写入 shape
oss
<<
"shape: ["
;
for
(
uint64_t
i
=
0
;
i
<
this
->
ndim
;
++
i
)
{
oss
<<
this
->
shape
[
i
];
if
(
i
<
this
->
ndim
-
1
)
{
oss
<<
", "
;
}
}
oss
<<
"]
\n
"
;
// 写入 stride
oss
<<
"stride: ["
;
for
(
uint64_t
i
=
0
;
i
<
this
->
ndim
;
++
i
)
{
oss
<<
this
->
strides
[
i
];
if
(
i
<
this
->
ndim
-
1
)
{
oss
<<
", "
;
}
}
oss
<<
"]
\n
"
;
// 写入 offset
oss
<<
"offset: "
<<
this
->
offset
<<
"
\n
"
;
// 写入 dataType
oss
<<
"dataType: "
<<
dataTypeToString
(
this
->
dataType
)
<<
"
\n
"
;
// 写入 format
oss
<<
"format: "
<<
formatToString
(
this
->
format
)
<<
"
\n
"
;
// 写入 storageShape
oss
<<
"storageShape: ["
;
for
(
int64_t
i
=
0
;
i
<
this
->
storageNdim
;
++
i
)
{
oss
<<
this
->
storageShape
[
i
];
if
(
i
<
this
->
storageNdim
-
1
)
{
oss
<<
", "
;
}
}
oss
<<
"]
\n
"
;
// 写入 storageNdim
oss
<<
"storageNdim: "
<<
this
->
storageNdim
<<
"
\n
"
;
// 返回构建的字符串
return
oss
.
str
();
}
src/infiniop/devices/ascend/common_ascend.h
View file @
a51e1d56
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#define __INFINIOP_COMMON_ASCEND_H__
#define __INFINIOP_COMMON_ASCEND_H__
#include "../../../utils.h"
#include "../../../utils.h"
#include "../../tensor.h"
#include "ascend_handle.h"
#include "ascend_handle.h"
#include <acl/acl.h>
#include <acl/acl.h>
#include <acl/acl_base.h>
#include <acl/acl_base.h>
...
@@ -11,6 +12,7 @@
...
@@ -11,6 +12,7 @@
#include <functional>
#include <functional>
#include <inttypes.h>
#include <inttypes.h>
#include <numeric>
#include <numeric>
#include <sstream>
#include <vector>
#include <vector>
#ifdef __cplusplus
#ifdef __cplusplus
...
@@ -21,15 +23,34 @@ extern "C" {
...
@@ -21,15 +23,34 @@ extern "C" {
};
};
#endif
#endif
struct
InfiniopAscendHandle
{
struct
aclnnTensorDescriptor
{
infiniDevice_t
device
;
uint64_t
ndim
;
int
device_id
;
std
::
vector
<
int64_t
>
shape
;
std
::
vector
<
int64_t
>
strides
;
int64_t
offset
=
0
;
aclDataType
dataType
;
aclFormat
format
;
std
::
vector
<
int64_t
>
storageShape
;
int64_t
storageNdim
=
1
;
aclTensor
*
tensor
;
// aclnnGemmGetWorkspaceSize only support 2D matrix multiply, so we need to convert 3D tensor to 2D tensor
aclnnTensorDescriptor
(
aclDataType
dtype
,
const
std
::
vector
<
int64_t
>
&
shape
,
const
std
::
vector
<
int64_t
>
&
strides
,
void
*
data
=
nullptr
);
aclnnTensorDescriptor
(
infiniopTensorDescriptor_t
y_desc
,
void
*
data
=
nullptr
);
~
aclnnTensorDescriptor
();
std
::
string
toString
();
};
};
typedef
aclnnTensorDescriptor
*
aclnnTensorDescriptor_t
;
const
char
*
dataTypeToString
(
aclDataType
dtype
);
const
char
*
formatToString
(
aclFormat
format
);
infiniStatus_t
mallocWorkspace
(
void
**
workspaceAddr
,
size_t
workspaceSize
);
infiniStatus_t
freeWorkspace
(
void
*
workspaceAddr
);
aclDataType
toAclDataType
(
infiniDtype_t
dt
);
aclDataType
toAclDataType
(
infiniDtype_t
dt
);
#define GetRecentErrMsg() \
{ \
auto tmp_err_msg = aclGetRecentErrMsg(); \
if (tmp_err_msg != NULL) { \
printf(" ERROR Message : %s \n ", tmp_err_msg); \
} \
}
#endif
#endif
src/infiniop/devices/ascend/tensor_aclnn.cc
deleted
100644 → 0
View file @
89ebdac8
#include "tensor_aclnn.h"
#include "../../../utils.h"
#include "../../tensor.h"
#include <algorithm>
infiniStatus_t
aclnnTensorDescriptor
::
setDescriptor
(
aclDataType
dtype
,
const
std
::
vector
<
int64_t
>
&
shape
,
const
std
::
vector
<
int64_t
>
&
strides
)
{
if
(
shape
.
size
()
!=
strides
.
size
())
{
return
INFINI_STATUS_BAD_TENSOR_STRIDES
;
}
this
->
ndim
=
shape
.
size
();
this
->
shape
=
std
::
vector
<
int64_t
>
(
shape
);
this
->
strides
=
std
::
vector
<
int64_t
>
(
strides
);
this
->
dataType
=
dtype
;
// Set format
// TODO: Support other format
aclFormat
format
=
aclFormat
::
ACL_FORMAT_ND
;
this
->
format
=
format
;
CHECK_STATUS
(
this
->
inferStorageShape
());
return
INFINI_STATUS_SUCCESS
;
}
/// @brief Infer storage shape. For now this ruturns a 1D shape of the total tensor storage size.
/// We don't see why higher dimensional storage shape is ever needed. To change if necesary.
infiniStatus_t
aclnnTensorDescriptor
::
inferStorageShape
()
{
auto
index
=
std
::
max_element
(
this
->
strides
.
begin
(),
this
->
strides
.
end
());
uint64_t
max_stride_index
=
std
::
distance
(
this
->
strides
.
begin
(),
index
);
this
->
storageNdim
=
1
;
this
->
storageShape
=
std
::
vector
<
int64_t
>
({
this
->
shape
[
max_stride_index
]
*
this
->
strides
[
max_stride_index
]});
return
INFINI_STATUS_SUCCESS
;
}
/// @brief Set aclnnTensorDescriptor from infiniopTensorDescriptor
/// @param y infiniopTensorDescriptor
/// @return infiniopStatus_t
infiniStatus_t
aclnnTensorDescriptor
::
fromInfiniOpTensorDescriptor
(
infiniopTensorDescriptor_t
y
)
{
uint64_t
ndim
=
y
->
ndim
();
// Cast shape type
auto
shape
=
std
::
vector
<
int64_t
>
(
ndim
);
auto
strides
=
std
::
vector
<
int64_t
>
(
ndim
);
for
(
uint64_t
i
=
0
;
i
<
ndim
;
++
i
)
{
shape
[
i
]
=
static_cast
<
int64_t
>
(
y
->
dim
(
i
));
strides
[
i
]
=
y
->
stride
(
i
);
}
return
setDescriptor
(
toAclDataType
(
y
->
dtype
()),
shape
,
strides
);
}
/// @brief Wrapper of aclCreateTensor. Create aclTensor.
/// See https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC3alpha001/apiref/appdevgapi/aclcppdevg_03_0168.html
/// @param desc Alias of aclnnTensorDescriptor*.
/// @param data Data ptr on device global mem.
/// @param tensor Pointer of pointer of aclTensor.
/// @return
infiniStatus_t
aclnnTensorDescriptor
::
createTensor
(
void
*
data
)
{
if
(
this
->
t
)
{
return
INFINI_STATUS_SUCCESS
;
}
this
->
t
=
aclCreateTensor
(
this
->
shape
.
data
(),
this
->
ndim
,
this
->
dataType
,
this
->
strides
.
data
(),
this
->
offset
,
this
->
format
,
this
->
storageShape
.
data
(),
this
->
storageNdim
,
data
);
return
INFINI_STATUS_SUCCESS
;
}
infiniStatus_t
aclnnTensorDescriptor
::
destroyTensor
()
{
CHECK_ACL
(
aclDestroyTensor
(
this
->
t
));
t
=
nullptr
;
return
INFINI_STATUS_SUCCESS
;
}
aclnnTensorDescriptor
::~
aclnnTensorDescriptor
()
{
if
(
this
->
t
)
{
destroyTensor
();
}
}
/// @brief TensorDescriptor's string info
/// @param desc Alias of aclnnTensorDescriptor*.
/// @return String of aclnnTensorDescriptor.
char
*
aclnnTensorDescriptor
::
toString
()
{
// Assume bufferSize
size_t
bufferSize
=
1024
+
this
->
ndim
*
40
+
this
->
storageNdim
*
40
;
char
*
buffer
=
(
char
*
)
malloc
(
bufferSize
);
if
(
!
buffer
)
{
return
NULL
;
}
// Write info into buffer
char
*
ptr
=
buffer
;
ptr
+=
sprintf
(
ptr
,
"ndim: %"
PRId64
"
\n
"
,
this
->
ndim
);
ptr
+=
sprintf
(
ptr
,
"shape: ["
);
for
(
uint64_t
i
=
0
;
i
<
this
->
ndim
;
++
i
)
{
ptr
+=
sprintf
(
ptr
,
"%"
PRId64
,
this
->
shape
[
i
]);
if
(
i
<
this
->
ndim
-
1
)
{
ptr
+=
sprintf
(
ptr
,
", "
);
}
}
ptr
+=
sprintf
(
ptr
,
"]
\n
"
);
ptr
+=
sprintf
(
ptr
,
"stride: ["
);
for
(
uint64_t
i
=
0
;
i
<
this
->
ndim
;
++
i
)
{
ptr
+=
sprintf
(
ptr
,
"%"
PRId64
,
this
->
strides
[
i
]);
if
(
i
<
this
->
ndim
-
1
)
{
ptr
+=
sprintf
(
ptr
,
", "
);
}
}
ptr
+=
sprintf
(
ptr
,
"]
\n
"
);
ptr
+=
sprintf
(
ptr
,
"offset: %"
PRId64
"
\n
"
,
this
->
offset
);
ptr
+=
sprintf
(
ptr
,
"dataType: %s
\n
"
,
dataTypeToString
(
this
->
dataType
));
ptr
+=
sprintf
(
ptr
,
"format: %s
\n
"
,
formatToString
(
this
->
format
));
ptr
+=
sprintf
(
ptr
,
"storageShape: ["
);
for
(
int64_t
i
=
0
;
i
<
this
->
storageNdim
;
++
i
)
{
ptr
+=
sprintf
(
ptr
,
"%"
PRId64
,
this
->
storageShape
[
i
]);
if
(
i
<
this
->
storageNdim
-
1
)
{
ptr
+=
sprintf
(
ptr
,
", "
);
}
}
ptr
+=
sprintf
(
ptr
,
"]
\n
"
);
ptr
+=
sprintf
(
ptr
,
"storageNdim: %"
PRId64
"
\n
"
,
this
->
storageNdim
);
return
buffer
;
}
src/infiniop/devices/ascend/tensor_aclnn.h
deleted
100644 → 0
View file @
89ebdac8
#ifndef __ACLNN_TENSOR__
#define __ACLNN_TENSOR__
#include "../../operator.h"
#include "common_ascend.h"
#include <acl/acl.h>
#include <acl/acl_base.h>
#include <aclnn/acl_meta.h>
#include <vector>
// Aclnn tensor descriptor,
// used to build aclTensor
struct
aclnnTensorDescriptor
{
uint64_t
ndim
;
std
::
vector
<
int64_t
>
shape
;
std
::
vector
<
int64_t
>
strides
;
int64_t
offset
;
aclDataType
dataType
;
aclFormat
format
;
std
::
vector
<
int64_t
>
storageShape
;
int64_t
storageNdim
;
aclTensor
*
t
;
// Transfer from infiniOp DT to aclDataType
infiniStatus_t
setDescriptor
(
aclDataType
dtype
,
const
std
::
vector
<
int64_t
>
&
shape
,
const
std
::
vector
<
int64_t
>
&
strides
);
infiniStatus_t
inferStorageShape
();
// Convert form InfiniOpTensorDescriptor
infiniStatus_t
fromInfiniOpTensorDescriptor
(
infiniopTensorDescriptor_t
y_desc
);
infiniStatus_t
createTensor
(
void
*
data
=
nullptr
);
infiniStatus_t
destroyTensor
();
~
aclnnTensorDescriptor
();
char
*
toString
();
};
typedef
aclnnTensorDescriptor
*
aclnnTensorDescriptor_t
;
#endif
src/infiniop/devices/handle.cc
View file @
a51e1d56
...
@@ -45,9 +45,7 @@ __C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr) {
...
@@ -45,9 +45,7 @@ __C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr) {
CREATE
(
INFINI_DEVICE_CAMBRICON
,
bang
::
cambricon
);
CREATE
(
INFINI_DEVICE_CAMBRICON
,
bang
::
cambricon
);
#endif
#endif
#ifdef ENABLE_ASCEND_API
#ifdef ENABLE_ASCEND_API
case
INFINI_DEVICE_ASCEND
:
{
CREATE
(
INFINI_DEVICE_ASCEND
,
ascend
);
return
createAscendHandle
((
infiniopAscendHandle_t
*
)
handle_ptr
);
}
#endif
#endif
#ifdef ENABLE_KUNLUN_API
#ifdef ENABLE_KUNLUN_API
case
INFINI_DEVICE_KUNLUN
:
{
case
INFINI_DEVICE_KUNLUN
:
{
...
@@ -83,9 +81,7 @@ __C infiniStatus_t infiniopDestroyHandle(infiniopHandle_t handle) {
...
@@ -83,9 +81,7 @@ __C infiniStatus_t infiniopDestroyHandle(infiniopHandle_t handle) {
DELETE
(
INFINI_DEVICE_CAMBRICON
,
bang
::
cambricon
);
DELETE
(
INFINI_DEVICE_CAMBRICON
,
bang
::
cambricon
);
#endif
#endif
#ifdef ENABLE_ASCEND_API
#ifdef ENABLE_ASCEND_API
case
INFINI_DEVICE_ASCEND
:
{
DELETE
(
INFINI_DEVICE_ASCEND
,
ascend
);
return
destroyAscendHandle
((
infiniopAscendHandle_t
)
handle
);
}
#endif
#endif
#ifdef ENABLE_KUNLUN_API
#ifdef ENABLE_KUNLUN_API
case
INFINI_DEVICE_KUNLUN
:
{
case
INFINI_DEVICE_KUNLUN
:
{
...
...
src/infiniop/ops/matmul/ascend/matmul_ascend.cc
View file @
a51e1d56
#include "matmul_ascend.h"
#include "matmul_ascend.h"
#include "../../../devices/ascend/ascend_handle.h"
#include "../../../devices/ascend/common_ascend.h"
#include "../../../devices/ascend/tensor_aclnn.h"
#include <acl/acl_base.h>
#include <aclnn/acl_meta.h>
#include <aclnnop/aclnn_matmul.h>
#include <aclnnop/aclnn_matmul.h>
#include <aclnnop/level2/aclnn_gemm.h>
#include <aclnnop/level2/aclnn_gemm.h>
...
@@ -34,7 +31,7 @@ infiniStatus_t Descriptor::create(
...
@@ -34,7 +31,7 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t
c_desc
,
infiniopTensorDescriptor_t
c_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
)
{
infiniopTensorDescriptor_t
b_desc
)
{
auto
handle
=
reinterpret_cast
<
infiniopA
scendHandle
_t
>
(
handle_
);
auto
handle
=
reinterpret_cast
<
device
::
a
scend
::
Handle
*
>
(
handle_
);
auto
dtype
=
c_desc
->
dtype
();
auto
dtype
=
c_desc
->
dtype
();
if
(
dtype
!=
INFINI_DTYPE_F16
&&
dtype
!=
INFINI_DTYPE_F32
)
{
if
(
dtype
!=
INFINI_DTYPE_F16
&&
dtype
!=
INFINI_DTYPE_F32
)
{
...
@@ -47,35 +44,20 @@ infiniStatus_t Descriptor::create(
...
@@ -47,35 +44,20 @@ infiniStatus_t Descriptor::create(
return
status
;
return
status
;
}
}
auto
c
=
new
aclnnTensorDescriptor
(),
auto
c
=
new
aclnnTensorDescriptor
(
toAclDataType
(
c_desc
->
dtype
()),
a
=
new
aclnnTensorDescriptor
(),
{
static_cast
<
int64_t
>
(
info
.
c_matrix
.
rows
),
static_cast
<
int64_t
>
(
info
.
c_matrix
.
cols
)},
b
=
new
aclnnTensorDescriptor
();
{
info
.
c_matrix
.
row_stride
,
info
.
c_matrix
.
col_stride
});
auto
a
=
new
aclnnTensorDescriptor
(
toAclDataType
(
a_desc
->
dtype
()),
// Treat A, B, C as 2D matrix, reuse aclnnTensorDescriptor for batched
{
static_cast
<
int64_t
>
(
info
.
a_matrix
.
rows
),
static_cast
<
int64_t
>
(
info
.
a_matrix
.
cols
)},
// operation
{
info
.
a_matrix
.
row_stride
,
info
.
a_matrix
.
col_stride
});
CHECK_STATUS
(
c
->
setDescriptor
(
auto
b
=
new
aclnnTensorDescriptor
(
toAclDataType
(
b_desc
->
dtype
()),
toAclDataType
(
c_desc
->
dtype
()),
{
static_cast
<
int64_t
>
(
info
.
b_matrix
.
rows
),
static_cast
<
int64_t
>
(
info
.
b_matrix
.
cols
)},
{
static_cast
<
int64_t
>
(
info
.
c_matrix
.
rows
),
{
info
.
b_matrix
.
row_stride
,
info
.
b_matrix
.
col_stride
});
static_cast
<
int64_t
>
(
info
.
c_matrix
.
cols
)},
{
info
.
c_matrix
.
row_stride
,
info
.
c_matrix
.
col_stride
}));
auto
tc
=
c
->
tensor
,
CHECK_STATUS
(
a
->
setDescriptor
(
ta
=
a
->
tensor
,
toAclDataType
(
a_desc
->
dtype
()),
tb
=
b
->
tensor
;
{
static_cast
<
int64_t
>
(
info
.
a_matrix
.
rows
),
static_cast
<
int64_t
>
(
info
.
a_matrix
.
cols
)},
{
info
.
a_matrix
.
row_stride
,
info
.
a_matrix
.
col_stride
}));
CHECK_STATUS
(
b
->
setDescriptor
(
toAclDataType
(
b_desc
->
dtype
()),
{
static_cast
<
int64_t
>
(
info
.
b_matrix
.
rows
),
static_cast
<
int64_t
>
(
info
.
b_matrix
.
cols
)},
{
info
.
b_matrix
.
row_stride
,
info
.
b_matrix
.
col_stride
}));
CHECK_STATUS
(
c
->
createTensor
());
CHECK_STATUS
(
a
->
createTensor
());
CHECK_STATUS
(
b
->
createTensor
());
auto
tc
=
c
->
t
,
ta
=
a
->
t
,
tb
=
b
->
t
;
aclOpExecutor
*
executor
;
aclOpExecutor
*
executor
;
size_t
workspace_size
;
size_t
workspace_size
;
// aclnnGemm support C = alpha * A @ B + beta * C
// aclnnGemm support C = alpha * A @ B + beta * C
...
@@ -85,7 +67,6 @@ infiniStatus_t Descriptor::create(
...
@@ -85,7 +67,6 @@ infiniStatus_t Descriptor::create(
int8_t
mt
=
1
;
int8_t
mt
=
1
;
CHECK_ACL
(
aclnnGemmGetWorkspaceSize
(
ta
,
tb
,
tc
,
.5
,
.5
,
0
,
0
,
tc
,
mt
,
&
workspace_size
,
&
executor
));
CHECK_ACL
(
aclnnGemmGetWorkspaceSize
(
ta
,
tb
,
tc
,
.5
,
.5
,
0
,
0
,
tc
,
mt
,
&
workspace_size
,
&
executor
));
aclSetAclOpExecutorRepeatable
(
executor
);
*
desc_ptr
=
new
Descriptor
(
*
desc_ptr
=
new
Descriptor
(
dtype
,
info
,
workspace_size
,
dtype
,
info
,
workspace_size
,
...
@@ -110,9 +91,9 @@ infiniStatus_t Descriptor::calculate(
...
@@ -110,9 +91,9 @@ infiniStatus_t Descriptor::calculate(
float
alpha
,
float
alpha
,
void
*
stream
)
const
{
void
*
stream
)
const
{
auto
tc
=
_opaque
->
c
->
t
,
auto
tc
=
_opaque
->
c
->
t
ensor
,
ta
=
_opaque
->
a
->
t
,
ta
=
_opaque
->
a
->
t
ensor
,
tb
=
_opaque
->
b
->
t
;
tb
=
_opaque
->
b
->
t
ensor
;
size_t
workspace_size
;
size_t
workspace_size
;
CHECK_ACL
(
aclnnGemmGetWorkspaceSize
(
CHECK_ACL
(
aclnnGemmGetWorkspaceSize
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment