Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
6e1491bd
Commit
6e1491bd
authored
Apr 02, 2025
by
zhangyunze
Committed by
zhangyue
May 14, 2025
Browse files
feat: 添加昇腾swiglu算子
parent
bd37042c
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
309 additions
and
18 deletions
+309
-18
src/infiniop/devices/ascend/CMakeLists.txt
src/infiniop/devices/ascend/CMakeLists.txt
+7
-3
src/infiniop/ops/swiglu/ascend/swiglu_aclnn.cc
src/infiniop/ops/swiglu/ascend/swiglu_aclnn.cc
+52
-0
src/infiniop/ops/swiglu/ascend/swiglu_aclnn.h
src/infiniop/ops/swiglu/ascend/swiglu_aclnn.h
+74
-0
src/infiniop/ops/swiglu/ascend/swiglu_kernel.cpp
src/infiniop/ops/swiglu/ascend/swiglu_kernel.cpp
+164
-0
src/infiniop/ops/swiglu/operator.cc
src/infiniop/ops/swiglu/operator.cc
+10
-12
xmake/ascend.lua
xmake/ascend.lua
+2
-3
No files found.
src/infiniop/devices/ascend/CMakeLists.txt
View file @
6e1491bd
...
@@ -3,7 +3,7 @@ cmake_minimum_required(VERSION 3.16.0)
...
@@ -3,7 +3,7 @@ cmake_minimum_required(VERSION 3.16.0)
# project information
# project information
project
(
Ascend_C
)
project
(
Ascend_C
)
set
(
SOC_VERSION
"Ascend910B3"
CACHE STRING
"system on chip type"
)
set
(
SOC_VERSION
"Ascend910B3"
CACHE STRING
"system on chip type"
)
set
(
ASCEND_CANN_PACKAGE_PATH $ENV{ASCEND_HOME} CACHE PATH
"ASCEND CANN package installation directory"
)
set
(
ASCEND_CANN_PACKAGE_PATH $ENV{ASCEND_
TOOLKIT_
HOME} CACHE PATH
"ASCEND CANN package installation directory"
)
set
(
RUN_MODE
"npu"
CACHE STRING
"run mode: npu"
)
set
(
RUN_MODE
"npu"
CACHE STRING
"run mode: npu"
)
set
(
CMAKE_BUILD_TYPE
"Release"
CACHE STRING
"Build type Release/Debug (default Debug)"
FORCE
)
set
(
CMAKE_BUILD_TYPE
"Release"
CACHE STRING
"Build type Release/Debug (default Debug)"
FORCE
)
set
(
CMAKE_INSTALL_PREFIX
"
${
CMAKE_CURRENT_LIST_DIR
}
/out"
CACHE STRING
"path for install()"
FORCE
)
set
(
CMAKE_INSTALL_PREFIX
"
${
CMAKE_CURRENT_LIST_DIR
}
/out"
CACHE STRING
"path for install()"
FORCE
)
...
@@ -19,10 +19,14 @@ else()
...
@@ -19,10 +19,14 @@ else()
endif
()
endif
()
include
(
${
ASCENDC_CMAKE_DIR
}
/ascendc.cmake
)
include
(
${
ASCENDC_CMAKE_DIR
}
/ascendc.cmake
)
include_directories
(
${
CMAKE_SOURCE_DIR
}
/../../../../include/infiniop/
)
ascendc_library
(
ascend_kernels STATIC
ascendc_library
(
ascend_kernels STATIC
../../ops/swiglu/ascend/swiglu_kernel.cpp
../../ops/swiglu/ascend/swiglu_kernel.cpp
../../ops/rotary_embedding/ascend/rotary_embedding_kernel.cpp
#
../../ops/rotary_embedding/ascend/rotary_embedding_kernel.cpp
../../ops/random_sample/ascend/random_sample_kernel.cpp
#
../../ops/random_sample/ascend/random_sample_kernel.cpp
)
)
src/infiniop/ops/swiglu/ascend/swiglu_aclnn.cc
0 → 100644
View file @
6e1491bd
#include "swiglu_aclnn.h"
#include "../../../devices/ascend/common_ascend.h"
namespace
op
::
swiglu
::
ascend
{
Descriptor
::~
Descriptor
()
=
default
;
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
c_desc
,
std
::
vector
<
infiniopTensorDescriptor_t
>
input_descs
)
{
auto
handle_ascend
=
reinterpret_cast
<
device
::
ascend
::
Handle
*>
(
handle
);
auto
dtype
=
c_desc
->
dtype
();
CHECK_DTYPE
(
dtype
,
INFINI_DTYPE_F16
,
INFINI_DTYPE_F32
);
const
auto
&
a_desc
=
input_descs
[
0
];
const
auto
&
b_desc
=
input_descs
[
1
];
auto
result
=
SwigluInfo
::
create
(
c_desc
,
a_desc
,
b_desc
);
CHECK_RESULT
(
result
);
SwigluInfo
info
=
result
.
take
();
// https://www.hiascend.com/document/detail/zh/canncommercial/800/apiref/ascendcopapi/atlasascendc_api_07_0777.html
size_t
workspace_size
=
0
;
*
desc_ptr
=
new
Descriptor
(
std
::
move
(
info
),
workspace_size
,
handle_ascend
->
device
,
handle_ascend
->
device_id
);
return
INFINI_STATUS_SUCCESS
;
}
extern
"C"
infiniStatus_t
swiglu_kernel_launch
(
void
*
c
,
void
*
a
,
void
*
b
,
int
dtype
,
int
batch
,
int
seq
,
int
hd
,
int
stride_batch_c
,
int
stride_batch_a
,
int
stride_batch_b
,
int
stride_seq_c
,
int
stride_seq_a
,
int
stride_seq_b
,
void
*
stream
);
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
c
,
std
::
vector
<
const
void
*>
inputs
,
void
*
stream
)
const
{
int
batch
=
_info
.
ndim
==
2
?
1
:
_info
.
shape
[
0
];
int
seq_len
=
_info
.
ndim
==
2
?
_info
.
shape
[
0
]
:
_info
.
shape
[
1
];
int
hidden_size
=
_info
.
shape
[
_info
.
ndim
-
1
];
int
stride_batch_c
=
_info
.
ndim
==
2
?
1
:
_info
.
c_strides
[
0
];
int
stride_batch_a
=
_info
.
ndim
==
2
?
1
:
_info
.
a_strides
[
0
];
int
stride_batch_b
=
_info
.
ndim
==
2
?
1
:
_info
.
b_strides
[
0
];
int
stride_seq_c
=
_info
.
ndim
==
2
?
_info
.
c_strides
[
0
]
:
_info
.
c_strides
[
1
];
int
stride_seq_a
=
_info
.
ndim
==
2
?
_info
.
a_strides
[
0
]
:
_info
.
a_strides
[
1
];
int
stride_seq_b
=
_info
.
ndim
==
2
?
_info
.
b_strides
[
0
]
:
_info
.
b_strides
[
1
];
auto
status
=
swiglu_kernel_launch
(
c
,
(
void
*
)
inputs
[
0
],
(
void
*
)
inputs
[
1
],
_info
.
dtype
,
batch
,
seq_len
,
hidden_size
,
stride_batch_c
,
stride_batch_a
,
stride_batch_b
,
stride_seq_c
,
stride_seq_a
,
stride_seq_b
,
stream
);
return
status
;
}
}
// namespace op::swiglu::ascend
src/infiniop/ops/swiglu/ascend/swiglu_aclnn.h
0 → 100644
View file @
6e1491bd
#ifndef __ACLNN_SWIGLU_H__
#define __ACLNN_SWIGLU_H__
#include "../../../../utils.h"
#include "../../../../utils/check.h"
#include "../../../operator.h"
#include "../../../tensor.h"
namespace
op
::
swiglu
::
ascend
{
class
SwigluInfo
{
SwigluInfo
()
=
default
;
public:
infiniDtype_t
dtype
;
std
::
vector
<
size_t
>
shape
;
int32_t
ndim
;
std
::
vector
<
ptrdiff_t
>
c_strides
;
std
::
vector
<
ptrdiff_t
>
a_strides
;
std
::
vector
<
ptrdiff_t
>
b_strides
;
static
utils
::
Result
<
SwigluInfo
>
create
(
infiniopTensorDescriptor_t
c_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
)
{
if
(
!
c_desc
||
!
a_desc
||
!
b_desc
)
{
return
INFINI_STATUS_BAD_PARAM
;
}
if
(
c_desc
->
hasBroadcastDim
())
{
return
INFINI_STATUS_BAD_TENSOR_STRIDES
;
}
if
(
c_desc
->
ndim
()
!=
a_desc
->
ndim
()
||
c_desc
->
ndim
()
!=
b_desc
->
ndim
()
||
(
c_desc
->
ndim
()
!=
2
&&
c_desc
->
ndim
()
!=
3
))
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
CHECK_SAME_SHAPE
(
c_desc
->
shape
(),
a_desc
->
shape
(),
b_desc
->
shape
());
int32_t
ndim
=
c_desc
->
ndim
();
if
(
c_desc
->
stride
(
ndim
-
1
)
!=
1
||
a_desc
->
stride
(
ndim
-
1
)
!=
1
||
b_desc
->
stride
(
ndim
-
1
)
!=
1
)
{
return
INFINI_STATUS_BAD_TENSOR_STRIDES
;
}
if
(
c_desc
->
dtype
()
!=
a_desc
->
dtype
()
||
c_desc
->
dtype
()
!=
b_desc
->
dtype
())
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
return
utils
::
Result
<
SwigluInfo
>
(
SwigluInfo
{
c_desc
->
dtype
(),
std
::
move
(
c_desc
->
shape
()),
ndim
,
std
::
move
(
c_desc
->
strides
()),
std
::
move
(
a_desc
->
strides
()),
std
::
move
(
b_desc
->
strides
()),
});
}
};
class
Descriptor
final
:
public
InfiniopDescriptor
{
SwigluInfo
_info
;
size_t
_workspace_size
;
Descriptor
(
SwigluInfo
info
,
size_t
workspace_size
,
infiniDevice_t
device_type
,
int
device_id
)
:
InfiniopDescriptor
{
device_type
,
device_id
},
_info
(
info
),
_workspace_size
(
workspace_size
)
{}
public:
~
Descriptor
();
static
infiniStatus_t
create
(
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
c_desc
,
std
::
vector
<
infiniopTensorDescriptor_t
>
input_descs
);
size_t
workspaceSize
()
const
{
return
_workspace_size
;
}
infiniStatus_t
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
c
,
std
::
vector
<
const
void
*>
inputs
,
void
*
stream
)
const
;
};
}
// namespace op::swiglu::ascend
#endif // __ACLNN_SWIGLU_H__
src/infiniop/ops/swiglu/ascend/swiglu_kernel.cpp
0 → 100644
View file @
6e1491bd
#include "../../../../../include/infinicore.h"
#include "kernel_operator.h"
constexpr
int32_t
BLOCK_NUM
=
8
;
constexpr
int32_t
BUFFER_NUM
=
2
;
constexpr
int32_t
BYTE_ALIGN
=
32
;
// ubsize = 196KB
using
namespace
AscendC
;
template
<
typename
T
>
class
SwigluKernel
{
public:
__aicore__
inline
SwigluKernel
()
{}
__aicore__
inline
void
Init
(
GM_ADDR
c
,
GM_ADDR
a
,
GM_ADDR
b
,
int
batch_
,
int
seq
,
int
hd
,
int
stride_batch_c
,
int
stride_batch_a
,
int
stride_batch_b
,
int
stride_seq_c
,
int
stride_seq_a
,
int
stride_seq_b
);
__aicore__
inline
void
Process
();
private:
__aicore__
inline
void
CopyIn
(
int32_t
i
);
__aicore__
inline
void
Compute
(
int32_t
i
);
__aicore__
inline
void
CopyOut
(
int32_t
i
);
private:
GlobalTensor
<
T
>
cGm
,
aGm
,
bGm
;
TQue
<
QuePosition
::
VECIN
,
BUFFER_NUM
>
inQueueA
,
inQueueB
;
TQue
<
QuePosition
::
VECOUT
,
BUFFER_NUM
>
outQueueC
;
TPipe
pipe
;
uint32_t
_data_size
=
0
;
float
_beta_value
=
1.0
f
;
uint32_t
_block_idx
,
_tile_len
,
_copy_len
;
uint32_t
batch
,
seq_len
,
hidden_size
;
int32_t
strideBatchA
=
1
,
strideBatchB
=
1
,
strideBatchC
=
1
;
int32_t
strideSeqA
,
strideSeqB
,
strideSeqC
;
};
template
<
typename
T
>
__aicore__
inline
void
SwigluKernel
<
T
>::
Init
(
GM_ADDR
c
,
GM_ADDR
a
,
GM_ADDR
b
,
int
batch_
,
int
seq
,
int
hd
,
int
stride_batch_c
,
int
stride_batch_a
,
int
stride_batch_b
,
int
stride_seq_c
,
int
stride_seq_a
,
int
stride_seq_b
)
{
// Init Shape & StrideVariables
batch
=
batch_
;
seq_len
=
seq
;
hidden_size
=
hd
;
strideBatchA
=
stride_batch_a
;
strideBatchB
=
stride_batch_b
;
strideBatchC
=
stride_batch_c
;
strideSeqA
=
stride_seq_a
;
strideSeqB
=
stride_seq_b
;
strideSeqC
=
stride_seq_c
;
_block_idx
=
GetBlockIdx
();
_tile_len
=
_block_idx
<
(
hidden_size
%
BLOCK_NUM
)
?
(
hidden_size
/
BLOCK_NUM
)
+
1
:
(
hidden_size
/
BLOCK_NUM
);
_copy_len
=
(
_tile_len
*
sizeof
(
T
))
%
BYTE_ALIGN
==
0
?
_tile_len
:
(
_tile_len
*
sizeof
(
T
)
+
(
BYTE_ALIGN
-
_tile_len
*
sizeof
(
T
)
%
BYTE_ALIGN
))
/
sizeof
(
T
);
// Set global tensor
aGm
.
SetGlobalBuffer
((
__gm__
T
*
)
a
);
bGm
.
SetGlobalBuffer
((
__gm__
T
*
)
b
);
cGm
.
SetGlobalBuffer
((
__gm__
T
*
)
c
);
// Pipe alloc memory to queue, the unit is bytes
pipe
.
InitBuffer
(
inQueueA
,
BUFFER_NUM
,
_copy_len
*
sizeof
(
T
));
pipe
.
InitBuffer
(
inQueueB
,
BUFFER_NUM
,
_copy_len
*
sizeof
(
T
));
pipe
.
InitBuffer
(
outQueueC
,
BUFFER_NUM
,
_copy_len
*
sizeof
(
T
));
}
template
<
typename
T
>
__aicore__
inline
void
SwigluKernel
<
T
>::
CopyIn
(
int32_t
i
)
{
// Alloc tensor from queue memory
LocalTensor
<
T
>
aLocal
=
inQueueA
.
AllocTensor
<
T
>
();
LocalTensor
<
T
>
bLocal
=
inQueueB
.
AllocTensor
<
T
>
();
// Get idx of current tile
auto
batchIdx
=
batch
==
1
?
0
:
i
/
seq_len
;
auto
seqIdx
=
batch
==
1
?
i
:
i
%
seq_len
;
int32_t
idxa
=
batchIdx
*
strideBatchA
+
seqIdx
*
strideSeqA
+
_block_idx
*
_tile_len
;
int32_t
idxb
=
batchIdx
*
strideBatchB
+
seqIdx
*
strideSeqB
+
_block_idx
*
_tile_len
;
// Copy process_th tile from global tensor to local tensor
DataCopy
(
aLocal
,
aGm
[
idxa
],
_copy_len
);
DataCopy
(
bLocal
,
bGm
[
idxb
],
_copy_len
);
// Enque input tensor to VECIN queue
inQueueA
.
EnQue
(
aLocal
);
inQueueB
.
EnQue
(
bLocal
);
}
template
<
typename
T
>
__aicore__
inline
void
SwigluKernel
<
T
>::
Compute
(
int32_t
i
)
{
// Deque input tensors from VECIN queue
LocalTensor
<
T
>
aLocal
=
inQueueA
.
DeQue
<
T
>
();
LocalTensor
<
T
>
bLocal
=
inQueueB
.
DeQue
<
T
>
();
LocalTensor
<
T
>
cLocal
=
outQueueC
.
AllocTensor
<
T
>
();
// Call SwiGLU ascend api
SwiGLU
<
T
,
false
>
(
cLocal
,
aLocal
,
bLocal
,
_beta_value
,
_copy_len
);
// Enque result and free input
outQueueC
.
EnQue
<
T
>
(
cLocal
);
inQueueA
.
FreeTensor
(
aLocal
);
inQueueB
.
FreeTensor
(
bLocal
);
}
template
<
typename
T
>
__aicore__
inline
void
SwigluKernel
<
T
>::
CopyOut
(
int32_t
i
)
{
// Deque output tensor from VECOUT queue
LocalTensor
<
T
>
cLocal
=
outQueueC
.
DeQue
<
T
>
();
auto
batchIdx
=
batch
==
1
?
0
:
i
/
seq_len
;
auto
seqIdx
=
batch
==
1
?
i
:
i
%
seq_len
;
int32_t
idxc
=
batchIdx
*
strideBatchC
+
seqIdx
*
strideSeqC
+
_block_idx
*
_tile_len
;
// Copy progress_th tile from local tensor to global tensor
if
(
_tile_len
*
sizeof
(
T
)
%
BYTE_ALIGN
!=
0
)
{
DataCopyExtParams
dcep
=
{
1
,
static_cast
<
uint32_t
>
(
_tile_len
*
sizeof
(
T
)),
0
,
0
,
0
};
DataCopyPad
(
cGm
[
idxc
],
cLocal
,
dcep
);
}
else
{
DataCopy
(
cGm
[
idxc
],
cLocal
,
_tile_len
);
}
// Free output Local tensor
outQueueC
.
FreeTensor
(
cLocal
);
}
template
<
typename
T
>
__aicore__
inline
void
SwigluKernel
<
T
>::
Process
()
{
for
(
int32_t
i
=
0
;
i
<
batch
*
seq_len
;
++
i
)
{
CopyIn
(
i
);
Compute
(
i
);
CopyOut
(
i
);
}
}
__global__
__aicore__
void
swiglu_kernel_half
(
GM_ADDR
c
,
GM_ADDR
a
,
GM_ADDR
b
,
int
batch
,
int
seq
,
int
hd
,
int
stride_batch_c
,
int
stride_batch_a
,
int
stride_batch_b
,
int
stride_seq_c
,
int
stride_seq_a
,
int
stride_seq_b
)
{
SwigluKernel
<
half
>
op
;
op
.
Init
(
c
,
a
,
b
,
batch
,
seq
,
hd
,
stride_batch_c
,
stride_batch_a
,
stride_batch_b
,
stride_seq_c
,
stride_seq_a
,
stride_seq_b
);
op
.
Process
();
}
__global__
__aicore__
void
swiglu_kernel_float
(
GM_ADDR
c
,
GM_ADDR
a
,
GM_ADDR
b
,
int
batch
,
int
seq
,
int
hd
,
int
stride_batch_c
,
int
stride_batch_a
,
int
stride_batch_b
,
int
stride_seq_c
,
int
stride_seq_a
,
int
stride_seq_b
)
{
SwigluKernel
<
float
>
op
;
op
.
Init
(
c
,
a
,
b
,
batch
,
seq
,
hd
,
stride_batch_c
,
stride_batch_a
,
stride_batch_b
,
stride_seq_c
,
stride_seq_a
,
stride_seq_b
);
op
.
Process
();
}
extern
"C"
infiniStatus_t
swiglu_kernel_launch
(
void
*
c
,
void
*
a
,
void
*
b
,
int
dtype
,
int
batch
,
int
seq
,
int
hd
,
int
stride_batch_c
,
int
stride_batch_a
,
int
stride_batch_b
,
int
stride_seq_c
,
int
stride_seq_a
,
int
stride_seq_b
,
void
*
stream
)
{
switch
(
dtype
)
{
case
12
:
swiglu_kernel_half
<<<
BLOCK_NUM
,
nullptr
,
stream
>>>
(
c
,
a
,
b
,
batch
,
seq
,
hd
,
stride_batch_c
,
stride_batch_a
,
stride_batch_b
,
stride_seq_c
,
stride_seq_a
,
stride_seq_b
);
return
INFINI_STATUS_SUCCESS
;
case
13
:
swiglu_kernel_float
<<<
BLOCK_NUM
,
nullptr
,
stream
>>>
(
c
,
a
,
b
,
batch
,
seq
,
hd
,
stride_batch_c
,
stride_batch_a
,
stride_batch_b
,
stride_seq_c
,
stride_seq_a
,
stride_seq_b
);
return
INFINI_STATUS_SUCCESS
;
default:
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
src/infiniop/ops/swiglu/operator.cc
View file @
6e1491bd
...
@@ -11,6 +11,9 @@
...
@@ -11,6 +11,9 @@
#ifdef ENABLE_KUNLUN_API
#ifdef ENABLE_KUNLUN_API
#include "kunlun/swiglu_kunlun.h"
#include "kunlun/swiglu_kunlun.h"
#endif
#endif
#ifdef ENABLE_ASCEND_API
#include "ascend/swiglu_aclnn.h"
#endif
__C
infiniStatus_t
infiniopCreateSwiGLUDescriptor
(
__C
infiniStatus_t
infiniopCreateSwiGLUDescriptor
(
infiniopHandle_t
handle
,
infiniopHandle_t
handle
,
...
@@ -46,11 +49,8 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor(
...
@@ -46,11 +49,8 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor(
c_desc
,
a_desc
,
b_desc
);
c_desc
,
a_desc
,
b_desc
);
}
}
#endif
#endif
#ifdef ENABLE_ASCEND_NPU
#ifdef ENABLE_ASCEND_API
case
DevAscendNpu
:
CREATE
(
INFINI_DEVICE_ASCEND
,
ascend
);
return
ascendCreateSwiGLUDescriptor
(
(
AscendHandle_t
)
handle
,
(
SwiGLUAscendDescriptor_t
*
)
desc_ptr
,
c_desc
,
a_desc
,
b_desc
);
#endif
#endif
#ifdef ENABLE_METAX_GPU
#ifdef ENABLE_METAX_GPU
case
DevMetaxGpu
:
{
case
DevMetaxGpu
:
{
...
@@ -95,7 +95,7 @@ __C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t des
...
@@ -95,7 +95,7 @@ __C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t des
}
}
#endif
#endif
#ifdef ENABLE_ASCEND_API
#ifdef ENABLE_ASCEND_API
//
GET(INFINI_DEVICE_ASCEND, ascend)
GET
(
INFINI_DEVICE_ASCEND
,
ascend
)
#endif
#endif
#ifdef ENABLE_METAX_GPU
#ifdef ENABLE_METAX_GPU
case
DevMetaxGpu
:
{
case
DevMetaxGpu
:
{
...
@@ -144,9 +144,8 @@ __C infiniStatus_t infiniopSwiGLU(
...
@@ -144,9 +144,8 @@ __C infiniStatus_t infiniopSwiGLU(
return
bangSwiGLU
((
SwiGLUBangDescriptor_t
)
desc
,
c
,
a
,
b
,
stream
);
return
bangSwiGLU
((
SwiGLUBangDescriptor_t
)
desc
,
c
,
a
,
b
,
stream
);
}
}
#endif
#endif
#ifdef ENABLE_ASCEND_NPU
#ifdef ENABLE_ASCEND_API
case
DevAscendNpu
:
CALCULATE
(
INFINI_DEVICE_ASCEND
,
ascend
);
return
ascendSwiGLU
((
SwiGLUAscendDescriptor_t
)
desc
,
c
,
a
,
b
,
stream
);
#endif
#endif
#ifdef ENABLE_METAX_GPU
#ifdef ENABLE_METAX_GPU
case
DevMetaxGpu
:
case
DevMetaxGpu
:
...
@@ -188,9 +187,8 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) {
...
@@ -188,9 +187,8 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) {
return
bangDestroySwiGLUDescriptor
((
SwiGLUBangDescriptor_t
)
desc
);
return
bangDestroySwiGLUDescriptor
((
SwiGLUBangDescriptor_t
)
desc
);
}
}
#endif
#endif
#ifdef ENABLE_ASCEND_NPU
#ifdef ENABLE_ASCEND_API
case
DevAscendNpu
:
DELETE
(
INFINI_DEVICE_ASCEND
,
ascend
)
return
ascendDestroySwiGLUDescriptor
((
SwiGLUAscendDescriptor_t
)
desc
);
#endif
#endif
#ifdef ENABLE_METAX_GPU
#ifdef ENABLE_METAX_GPU
case
DevMetaxGpu
:
case
DevMetaxGpu
:
...
...
xmake/ascend.lua
View file @
6e1491bd
...
@@ -50,9 +50,8 @@ target("infiniop-ascend")
...
@@ -50,9 +50,8 @@ target("infiniop-ascend")
add_files
(
"$(projectdir)/src/infiniop/devices/ascend/*.cc"
,
"$(projectdir)/src/infiniop/ops/*/ascend/*.cc"
)
add_files
(
"$(projectdir)/src/infiniop/devices/ascend/*.cc"
,
"$(projectdir)/src/infiniop/ops/*/ascend/*.cc"
)
-- Add operator
-- Add operator
-- TODO: add it back after ascend-kernels is fixed
add_rules
(
"ascend-kernels"
)
-- add_rules("ascend-kernels")
add_links
(
builddir
..
"/libascend_kernels.a"
)
-- add_links(builddir.."/libascend_kernels.a")
target_end
()
target_end
()
target
(
"infinirt-ascend"
)
target
(
"infinirt-ascend"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment