Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
5336978c
Commit
5336978c
authored
Feb 13, 2025
by
PanZezhong
Browse files
fix: 修改函数命名
parent
9405d54e
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
81 additions
and
62 deletions
+81
-62
include/infinicore.h
include/infinicore.h
+50
-31
src/infiniop/devices/cuda/cuda_handle.cu
src/infiniop/devices/cuda/cuda_handle.cu
+1
-1
src/infiniop/devices/cuda/cuda_handle.h
src/infiniop/devices/cuda/cuda_handle.h
+1
-1
src/infiniop/devices/handle.cc
src/infiniop/devices/handle.cc
+1
-1
src/infiniop/ops/matmul/ascend/matmul_aclnn.cc
src/infiniop/ops/matmul/ascend/matmul_aclnn.cc
+4
-4
src/infiniop/ops/matmul/bang/matmul_cnnl.cc
src/infiniop/ops/matmul/bang/matmul_cnnl.cc
+2
-2
src/infiniop/ops/matmul/cpu/matmul_cpu.cc
src/infiniop/ops/matmul/cpu/matmul_cpu.cc
+3
-3
src/infiniop/ops/matmul/cuda/matmul_cuda_kernel.cu
src/infiniop/ops/matmul/cuda/matmul_cuda_kernel.cu
+3
-3
src/infiniop/ops/utils.h
src/infiniop/ops/utils.h
+16
-16
No files found.
include/infinicore.h
View file @
5336978c
...
@@ -6,7 +6,8 @@
...
@@ -6,7 +6,8 @@
#define __INFINICORE_EXPORT_C__
#define __INFINICORE_EXPORT_C__
#if defined(_WIN32)
#if defined(_WIN32)
#define __export __declspec(dllexport)
#define __export __declspec(dllexport)
#elif defined(__GNUC__) && ((__GNUC__ >= 4) || (__GNUC__ == 3 && __GNUC_MINOR__ >= 3))
#elif defined(__GNUC__) && \
((__GNUC__ >= 4) || (__GNUC__ == 3 && __GNUC_MINOR__ >= 3))
#define __export __attribute__((visibility("default")))
#define __export __attribute__((visibility("default")))
#else
#else
#define __export
#define __export
...
@@ -19,13 +20,11 @@
...
@@ -19,13 +20,11 @@
#define __C
#define __C
#include <stddef>
#include <stddef>
#endif
#endif
#endif// __INFINICORE_EXPORT_C__
#endif // __INFINICORE_EXPORT_C__
#ifndef __INFINI_DEVICE__
#ifndef __INFINI_DEVICE__
#define __INFINI_DEVICE__
#define __INFINI_DEVICE__
typedef
enum
typedef
enum
{
{
INFINI_DEVICE_CPU
=
0
,
INFINI_DEVICE_CPU
=
0
,
INFINI_DEVICE_NVIDIA
=
1
,
INFINI_DEVICE_NVIDIA
=
1
,
INFINI_DEVICE_CAMBRICON
=
2
,
INFINI_DEVICE_CAMBRICON
=
2
,
...
@@ -36,8 +35,7 @@ typedef enum
...
@@ -36,8 +35,7 @@ typedef enum
INFINI_DEVICE_KUNLUN
=
7
,
INFINI_DEVICE_KUNLUN
=
7
,
INFINI_DEVICE_SUGON
=
8
,
INFINI_DEVICE_SUGON
=
8
,
}
infiniDevice_t
;
}
infiniDevice_t
;
#endif// __INFINI_DEVICE__
#endif // __INFINI_DEVICE__
#ifndef __INFINI_DTYPE__
#ifndef __INFINI_DTYPE__
#define __INFINI_DTYPE__
#define __INFINI_DTYPE__
...
@@ -64,31 +62,52 @@ typedef enum {
...
@@ -64,31 +62,52 @@ typedef enum {
INFINI_DTYPE_BF16
=
19
,
INFINI_DTYPE_BF16
=
19
,
}
infiniDtype_t
;
}
infiniDtype_t
;
inline
size_t
infini
_s
izeof
(
infiniDtype_t
dtype
)
{
inline
size_t
infini
S
izeof
(
infiniDtype_t
dtype
)
{
switch
(
dtype
)
{
switch
(
dtype
)
{
case
INFINI_DTYPE_INVALID
:
return
0
;
case
INFINI_DTYPE_INVALID
:
case
INFINI_DTYPE_BYTE
:
return
1
;
return
0
;
case
INFINI_DTYPE_BOOL
:
return
1
;
case
INFINI_DTYPE_BYTE
:
case
INFINI_DTYPE_I8
:
return
1
;
return
1
;
case
INFINI_DTYPE_I16
:
return
2
;
case
INFINI_DTYPE_BOOL
:
case
INFINI_DTYPE_I32
:
return
4
;
return
1
;
case
INFINI_DTYPE_I64
:
return
8
;
case
INFINI_DTYPE_I8
:
case
INFINI_DTYPE_U8
:
return
1
;
return
1
;
case
INFINI_DTYPE_U16
:
return
2
;
case
INFINI_DTYPE_I16
:
case
INFINI_DTYPE_U32
:
return
4
;
return
2
;
case
INFINI_DTYPE_U64
:
return
8
;
case
INFINI_DTYPE_I32
:
case
INFINI_DTYPE_F8
:
return
1
;
return
4
;
case
INFINI_DTYPE_F16
:
return
2
;
case
INFINI_DTYPE_I64
:
case
INFINI_DTYPE_F32
:
return
4
;
return
8
;
case
INFINI_DTYPE_F64
:
return
8
;
case
INFINI_DTYPE_U8
:
case
INFINI_DTYPE_C8
:
return
2
;
return
1
;
case
INFINI_DTYPE_C16
:
return
4
;
case
INFINI_DTYPE_U16
:
case
INFINI_DTYPE_C32
:
return
8
;
return
2
;
case
INFINI_DTYPE_C64
:
return
16
;
case
INFINI_DTYPE_U32
:
case
INFINI_DTYPE_BF16
:
return
2
;
return
4
;
default:
return
0
;
case
INFINI_DTYPE_U64
:
return
8
;
case
INFINI_DTYPE_F8
:
return
1
;
case
INFINI_DTYPE_F16
:
return
2
;
case
INFINI_DTYPE_F32
:
return
4
;
case
INFINI_DTYPE_F64
:
return
8
;
case
INFINI_DTYPE_C8
:
return
2
;
case
INFINI_DTYPE_C16
:
return
4
;
case
INFINI_DTYPE_C32
:
return
8
;
case
INFINI_DTYPE_C64
:
return
16
;
case
INFINI_DTYPE_BF16
:
return
2
;
default:
return
0
;
}
}
}
}
#endif// __INFINI_DTYPE__
#endif
// __INFINI_DTYPE__
#endif// __INFINICORE_H__
#endif
// __INFINICORE_H__
src/infiniop/devices/cuda/cuda_handle.cu
View file @
5336978c
...
@@ -46,7 +46,7 @@ infiniopStatus_t createCudaHandle(infiniopCudaHandle_t *handle_ptr, int device_i
...
@@ -46,7 +46,7 @@ infiniopStatus_t createCudaHandle(infiniopCudaHandle_t *handle_ptr, int device_i
return
INFINIOP_STATUS_SUCCESS
;
return
INFINIOP_STATUS_SUCCESS
;
}
}
infiniopStatus_t
de
lete
CudaHandle
(
infiniopCudaHandle_t
handle_ptr
)
{
infiniopStatus_t
de
stroy
CudaHandle
(
infiniopCudaHandle_t
handle_ptr
)
{
handle_ptr
->
cublas_handles_t
=
nullptr
;
handle_ptr
->
cublas_handles_t
=
nullptr
;
handle_ptr
->
cudnn_handles_t
=
nullptr
;
handle_ptr
->
cudnn_handles_t
=
nullptr
;
delete
handle_ptr
;
delete
handle_ptr
;
...
...
src/infiniop/devices/cuda/cuda_handle.h
View file @
5336978c
...
@@ -8,6 +8,6 @@ typedef struct InfiniopCudaHandle *infiniopCudaHandle_t;
...
@@ -8,6 +8,6 @@ typedef struct InfiniopCudaHandle *infiniopCudaHandle_t;
infiniopStatus_t
createCudaHandle
(
infiniopCudaHandle_t
*
handle_ptr
,
int
device_id
,
infiniDevice_t
cuda_device_type
);
infiniopStatus_t
createCudaHandle
(
infiniopCudaHandle_t
*
handle_ptr
,
int
device_id
,
infiniDevice_t
cuda_device_type
);
infiniopStatus_t
de
lete
CudaHandle
(
infiniopCudaHandle_t
handle_ptr
);
infiniopStatus_t
de
stroy
CudaHandle
(
infiniopCudaHandle_t
handle_ptr
);
#endif
#endif
src/infiniop/devices/handle.cc
View file @
5336978c
...
@@ -56,7 +56,7 @@ __C infiniopStatus_t infiniopDestroyHandle(infiniopHandle_t handle) {
...
@@ -56,7 +56,7 @@ __C infiniopStatus_t infiniopDestroyHandle(infiniopHandle_t handle) {
#endif
#endif
#ifdef ENABLE_CUDA_API
#ifdef ENABLE_CUDA_API
case
INFINI_DEVICE_NVIDIA
:
{
case
INFINI_DEVICE_NVIDIA
:
{
return
de
lete
CudaHandle
((
infiniopCudaHandle_t
)
handle
);
return
de
stroy
CudaHandle
((
infiniopCudaHandle_t
)
handle
);
}
}
#endif
#endif
#ifdef ENABLE_CAMBRICON_API
#ifdef ENABLE_CAMBRICON_API
...
...
src/infiniop/ops/matmul/ascend/matmul_aclnn.cc
View file @
5336978c
...
@@ -124,16 +124,16 @@ infiniopStatus_t aclnnMatmul(MatmulAclnnDescriptor_t desc, void *workspace,
...
@@ -124,16 +124,16 @@ infiniopStatus_t aclnnMatmul(MatmulAclnnDescriptor_t desc, void *workspace,
for
(
size_t
i
=
0
;
i
<
batch
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
batch
;
i
++
)
{
AclSetTensorAddr
(
desc
->
executor
,
0
,
ta
,
AclSetTensorAddr
(
desc
->
executor
,
0
,
ta
,
(
char
*
)(
a
)
+
i
*
desc
->
info
->
a_matrix
.
stride
*
(
char
*
)(
a
)
+
i
*
desc
->
info
->
a_matrix
.
stride
*
infini
_s
izeof
(
desc
->
dtype
));
infini
S
izeof
(
desc
->
dtype
));
AclSetTensorAddr
(
desc
->
executor
,
1
,
tb
,
AclSetTensorAddr
(
desc
->
executor
,
1
,
tb
,
(
char
*
)(
b
)
+
i
*
desc
->
info
->
b_matrix
.
stride
*
(
char
*
)(
b
)
+
i
*
desc
->
info
->
b_matrix
.
stride
*
infini
_s
izeof
(
desc
->
dtype
));
infini
S
izeof
(
desc
->
dtype
));
AclSetTensorAddr
(
desc
->
executor
,
2
,
tc
,
AclSetTensorAddr
(
desc
->
executor
,
2
,
tc
,
(
char
*
)(
c
)
+
i
*
desc
->
info
->
c_matrix
.
stride
*
(
char
*
)(
c
)
+
i
*
desc
->
info
->
c_matrix
.
stride
*
infini
_s
izeof
(
desc
->
dtype
));
infini
S
izeof
(
desc
->
dtype
));
AclSetTensorAddr
(
desc
->
executor
,
3
,
tc
,
AclSetTensorAddr
(
desc
->
executor
,
3
,
tc
,
(
char
*
)(
c
)
+
i
*
desc
->
info
->
c_matrix
.
stride
*
(
char
*
)(
c
)
+
i
*
desc
->
info
->
c_matrix
.
stride
*
infini
_s
izeof
(
desc
->
dtype
));
infini
S
izeof
(
desc
->
dtype
));
ret
=
aclnnGemm
(
workspace
,
workspaceSize
,
desc
->
executor
,
stream
);
ret
=
aclnnGemm
(
workspace
,
workspaceSize
,
desc
->
executor
,
stream
);
CHECK_RET
(
ret
==
ACL_SUCCESS
,
CHECK_RET
(
ret
==
ACL_SUCCESS
,
LOG_PRINT
(
"aclnnGemm failed. ERROR: %d
\n
"
,
ret
);
LOG_PRINT
(
"aclnnGemm failed. ERROR: %d
\n
"
,
ret
);
...
...
src/infiniop/ops/matmul/bang/matmul_cnnl.cc
View file @
5336978c
...
@@ -72,7 +72,7 @@ bangDestroyMatmulDescriptor(infiniopMatmulBangDescriptor_t desc) {
...
@@ -72,7 +72,7 @@ bangDestroyMatmulDescriptor(infiniopMatmulBangDescriptor_t desc) {
return
INFINIOP_STATUS_SUCCESS
;
return
INFINIOP_STATUS_SUCCESS
;
}
}
void
m
atmul
_c
nnl
(
infiniopMatmulBangDescriptor_t
desc
,
void
*
workspace
,
void
*
c
,
void
bangM
atmul
C
nnl
(
infiniopMatmulBangDescriptor_t
desc
,
void
*
workspace
,
void
*
c
,
float
beta
,
void
const
*
a
,
void
const
*
b
,
float
alpha
,
float
beta
,
void
const
*
a
,
void
const
*
b
,
float
alpha
,
void
*
stream
)
{
void
*
stream
)
{
auto
info
=
desc
->
info
;
auto
info
=
desc
->
info
;
...
@@ -92,7 +92,7 @@ infiniopStatus_t bangMatmul(infiniopMatmulBangDescriptor_t desc,
...
@@ -92,7 +92,7 @@ infiniopStatus_t bangMatmul(infiniopMatmulBangDescriptor_t desc,
void
const
*
a
,
void
const
*
b
,
float
alpha
,
void
const
*
a
,
void
const
*
b
,
float
alpha
,
float
beta
,
void
*
stream
)
{
float
beta
,
void
*
stream
)
{
if
(
desc
->
dtype
==
INFINI_DTYPE_F16
||
desc
->
dtype
==
INFINI_DTYPE_F32
)
{
if
(
desc
->
dtype
==
INFINI_DTYPE_F16
||
desc
->
dtype
==
INFINI_DTYPE_F32
)
{
m
atmul
_c
nnl
(
desc
,
workspace
,
c
,
beta
,
a
,
b
,
alpha
,
stream
);
bangM
atmul
C
nnl
(
desc
,
workspace
,
c
,
beta
,
a
,
b
,
alpha
,
stream
);
cnrtQueueSync
((
cnrtQueue_t
)
stream
);
cnrtQueueSync
((
cnrtQueue_t
)
stream
);
return
INFINIOP_STATUS_SUCCESS
;
return
INFINIOP_STATUS_SUCCESS
;
}
}
...
...
src/infiniop/ops/matmul/cpu/matmul_cpu.cc
View file @
5336978c
...
@@ -37,7 +37,7 @@ cpuDestroyMatmulDescriptor(infiniopMatmulCpuDescriptor_t desc) {
...
@@ -37,7 +37,7 @@ cpuDestroyMatmulDescriptor(infiniopMatmulCpuDescriptor_t desc) {
}
}
template
<
typename
Tdata
>
template
<
typename
Tdata
>
infiniopStatus_t
m
atmul
_cpu
(
infiniopMatmulCpuDescriptor_t
desc
,
void
*
c
,
infiniopStatus_t
cpuCalculateM
atmul
(
infiniopMatmulCpuDescriptor_t
desc
,
void
*
c
,
float
beta
,
void
const
*
a
,
void
const
*
b
,
float
beta
,
void
const
*
a
,
void
const
*
b
,
float
alpha
)
{
float
alpha
)
{
auto
info
=
desc
->
info
;
auto
info
=
desc
->
info
;
...
@@ -88,10 +88,10 @@ infiniopStatus_t cpuMatmul(infiniopMatmulCpuDescriptor_t desc, void *workspace,
...
@@ -88,10 +88,10 @@ infiniopStatus_t cpuMatmul(infiniopMatmulCpuDescriptor_t desc, void *workspace,
uint64_t
workspace_size
,
void
*
c
,
void
const
*
a
,
uint64_t
workspace_size
,
void
*
c
,
void
const
*
a
,
void
const
*
b
,
float
alpha
,
float
beta
)
{
void
const
*
b
,
float
alpha
,
float
beta
)
{
if
(
desc
->
dtype
==
INFINI_DTYPE_F16
)
{
if
(
desc
->
dtype
==
INFINI_DTYPE_F16
)
{
return
m
atmul
_cpu
<
uint16_t
>
(
desc
,
c
,
beta
,
a
,
b
,
alpha
);
return
cpuCalculateM
atmul
<
uint16_t
>
(
desc
,
c
,
beta
,
a
,
b
,
alpha
);
}
}
if
(
desc
->
dtype
==
INFINI_DTYPE_F32
)
{
if
(
desc
->
dtype
==
INFINI_DTYPE_F32
)
{
return
m
atmul
_cpu
<
float
>
(
desc
,
c
,
beta
,
a
,
b
,
alpha
);
return
cpuCalculateM
atmul
<
float
>
(
desc
,
c
,
beta
,
a
,
b
,
alpha
);
}
}
return
INFINIOP_STATUS_BAD_TENSOR_DTYPE
;
return
INFINIOP_STATUS_BAD_TENSOR_DTYPE
;
}
}
src/infiniop/ops/matmul/cuda/matmul_cuda_kernel.cu
View file @
5336978c
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
#include "./matmul_cuda.cuh"
#include "./matmul_cuda.cuh"
template
<
typename
Tdata
>
template
<
typename
Tdata
>
infiniopStatus_t
matmul_cuda
(
infiniopMatmulCudaDescriptor_t
desc
,
void
*
c
,
float
beta
,
void
const
*
a
,
void
const
*
b
,
float
alpha
,
void
*
stream
)
{
infiniopStatus_t
cudaMatmulCublas
(
infiniopMatmulCudaDescriptor_t
desc
,
void
*
c
,
float
beta
,
void
const
*
a
,
void
const
*
b
,
float
alpha
,
void
*
stream
)
{
auto
info
=
desc
->
info
;
auto
info
=
desc
->
info
;
if
(
info
.
is_transed
)
{
if
(
info
.
is_transed
)
{
...
@@ -64,10 +64,10 @@ infiniopStatus_t cudaMatmul(infiniopMatmulCudaDescriptor_t desc,
...
@@ -64,10 +64,10 @@ infiniopStatus_t cudaMatmul(infiniopMatmulCudaDescriptor_t desc,
float
beta
,
float
beta
,
void
*
stream
)
{
void
*
stream
)
{
if
(
desc
->
dtype
==
INFINI_DTYPE_F16
)
{
if
(
desc
->
dtype
==
INFINI_DTYPE_F16
)
{
return
matmul_cuda
<
half
>
(
desc
,
c
,
beta
,
a
,
b
,
alpha
,
stream
);
return
cudaMatmulCublas
<
half
>
(
desc
,
c
,
beta
,
a
,
b
,
alpha
,
stream
);
}
}
if
(
desc
->
dtype
==
INFINI_DTYPE_F32
)
{
if
(
desc
->
dtype
==
INFINI_DTYPE_F32
)
{
return
matmul_cuda
<
float
>
(
desc
,
c
,
beta
,
a
,
b
,
alpha
,
stream
);
return
cudaMatmulCublas
<
float
>
(
desc
,
c
,
beta
,
a
,
b
,
alpha
,
stream
);
}
}
return
INFINIOP_STATUS_BAD_TENSOR_DTYPE
;
return
INFINIOP_STATUS_BAD_TENSOR_DTYPE
;
}
}
src/infiniop/ops/utils.h
View file @
5336978c
...
@@ -37,15 +37,22 @@
...
@@ -37,15 +37,22 @@
} \
} \
} while (0)
} while (0)
inline
std
::
vector
<
int64_t
>
get
_b
yte
_s
trides
(
infiniopTensorDescriptor_t
desc
)
{
inline
std
::
vector
<
int64_t
>
get
B
yte
S
trides
(
infiniopTensorDescriptor_t
desc
)
{
std
::
vector
<
int64_t
>
strides
(
desc
->
ndim
);
std
::
vector
<
int64_t
>
strides
(
desc
->
ndim
);
for
(
uint64_t
i
=
0
;
i
<
desc
->
ndim
;
i
++
)
{
for
(
uint64_t
i
=
0
;
i
<
desc
->
ndim
;
i
++
)
{
strides
[
i
]
=
desc
->
strides
[
i
]
*
infini
_s
izeof
(
desc
->
dtype
);
strides
[
i
]
=
desc
->
strides
[
i
]
*
infini
S
izeof
(
desc
->
dtype
);
}
}
return
strides
;
return
strides
;
}
}
inline
size_t
getByteSize
(
infiniopTensorDescriptor_t
desc
)
{
size_t
size
=
1
;
for
(
size_t
i
=
0
;
i
<
desc
->
ndim
;
i
++
)
{
size
*=
desc
->
shape
[
i
];
}
return
size
*
infiniSizeof
(
desc
->
dtype
);
}
// calculate the broadcasted shape for two tensors
// calculate the broadcasted shape for two tensors
inline
bool
getBroadcastShape
(
const
uint64_t
*
shape1
,
uint64_t
ndim1
,
inline
bool
getBroadcastShape
(
const
uint64_t
*
shape1
,
uint64_t
ndim1
,
const
uint64_t
*
shape2
,
uint64_t
ndim2
,
const
uint64_t
*
shape2
,
uint64_t
ndim2
,
...
@@ -119,13 +126,6 @@ inline bool isValidBroadcastShape(infiniopTensorDescriptor_t a,
...
@@ -119,13 +126,6 @@ inline bool isValidBroadcastShape(infiniopTensorDescriptor_t a,
return
isValidBroadcastShape
(
a
,
b
,
c
,
std
::
max
(
a
->
ndim
,
b
->
ndim
));
return
isValidBroadcastShape
(
a
,
b
,
c
,
std
::
max
(
a
->
ndim
,
b
->
ndim
));
}
}
inline
size_t
get_byte_size
(
infiniopTensorDescriptor_t
desc
)
{
size_t
size
=
1
;
for
(
size_t
i
=
0
;
i
<
desc
->
ndim
;
i
++
)
{
size
*=
desc
->
shape
[
i
];
}
return
size
*
infini_sizeof
(
desc
->
dtype
);
}
// permute the dimensions of a tensor descriptor
// permute the dimensions of a tensor descriptor
inline
infiniopTensorDescriptor_t
permute
(
infiniopTensorDescriptor_t
desc
,
inline
infiniopTensorDescriptor_t
permute
(
infiniopTensorDescriptor_t
desc
,
...
@@ -148,7 +148,7 @@ inline infiniopTensorDescriptor_t permute(infiniopTensorDescriptor_t desc,
...
@@ -148,7 +148,7 @@ inline infiniopTensorDescriptor_t permute(infiniopTensorDescriptor_t desc,
// check if the dimensions [dim_start, dim_end] of a tensor descriptor are
// check if the dimensions [dim_start, dim_end] of a tensor descriptor are
// contiguous
// contiguous
inline
bool
is
_c
ontiguous
(
const
infiniopTensorDescriptor_t
&
desc
,
inline
bool
is
C
ontiguous
(
const
infiniopTensorDescriptor_t
&
desc
,
size_t
dim_start
,
size_t
dim_end
)
{
size_t
dim_start
,
size_t
dim_end
)
{
for
(
size_t
i
=
dim_start
+
1
;
i
<=
dim_end
;
i
++
)
{
for
(
size_t
i
=
dim_start
+
1
;
i
<=
dim_end
;
i
++
)
{
if
(
desc
->
strides
[
i
-
1
]
!=
if
(
desc
->
strides
[
i
-
1
]
!=
...
@@ -159,15 +159,15 @@ inline bool is_contiguous(const infiniopTensorDescriptor_t &desc,
...
@@ -159,15 +159,15 @@ inline bool is_contiguous(const infiniopTensorDescriptor_t &desc,
return
true
;
return
true
;
}
}
inline
bool
is
_c
ontiguous
(
const
infiniopTensorDescriptor_t
&
desc
)
{
inline
bool
is
C
ontiguous
(
const
infiniopTensorDescriptor_t
&
desc
)
{
if
(
desc
->
ndim
==
0
)
{
if
(
desc
->
ndim
==
0
)
{
return
true
;
return
true
;
}
}
return
is
_c
ontiguous
(
desc
,
0
,
desc
->
ndim
-
1
);
return
is
C
ontiguous
(
desc
,
0
,
desc
->
ndim
-
1
);
}
}
// merge the dimensions [dim_start, dim_end] of a tensor descriptor
// merge the dimensions [dim_start, dim_end] of a tensor descriptor
inline
infiniopTensorDescriptor_t
dim
_m
erge
(
infiniopTensorDescriptor_t
desc
,
inline
infiniopTensorDescriptor_t
dim
M
erge
(
infiniopTensorDescriptor_t
desc
,
size_t
dim_start
,
size_t
dim_end
)
{
size_t
dim_start
,
size_t
dim_end
)
{
size_t
ndim
=
desc
->
ndim
;
size_t
ndim
=
desc
->
ndim
;
if
(
dim_start
>
dim_end
||
dim_end
>=
ndim
)
{
if
(
dim_start
>
dim_end
||
dim_end
>=
ndim
)
{
...
@@ -183,7 +183,7 @@ inline infiniopTensorDescriptor_t dim_merge(infiniopTensorDescriptor_t desc,
...
@@ -183,7 +183,7 @@ inline infiniopTensorDescriptor_t dim_merge(infiniopTensorDescriptor_t desc,
new_strides
[
index
]
=
desc
->
strides
[
i
];
new_strides
[
index
]
=
desc
->
strides
[
i
];
index
++
;
index
++
;
}
}
if
(
!
is
_c
ontiguous
(
desc
,
dim_start
,
dim_end
))
{
if
(
!
is
C
ontiguous
(
desc
,
dim_start
,
dim_end
))
{
return
nullptr
;
return
nullptr
;
}
}
new_shape
[
index
]
=
1
;
new_shape
[
index
]
=
1
;
...
@@ -202,7 +202,7 @@ inline infiniopTensorDescriptor_t dim_merge(infiniopTensorDescriptor_t desc,
...
@@ -202,7 +202,7 @@ inline infiniopTensorDescriptor_t dim_merge(infiniopTensorDescriptor_t desc,
}
}
// split the dimension dim of a tensor descriptor into multiple dimensions
// split the dimension dim of a tensor descriptor into multiple dimensions
inline
infiniopTensorDescriptor_t
dim
_s
plit
(
infiniopTensorDescriptor_t
desc
,
inline
infiniopTensorDescriptor_t
dim
S
plit
(
infiniopTensorDescriptor_t
desc
,
size_t
dim
,
size_t
dim
,
const
std
::
vector
<
size_t
>
&
dims
)
{
const
std
::
vector
<
size_t
>
&
dims
)
{
size_t
ndim
=
desc
->
ndim
;
size_t
ndim
=
desc
->
ndim
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment