Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
fafb22db
Commit
fafb22db
authored
May 15, 2025
by
zhangyue
Browse files
issue/9: 根据review 修改
parent
120b4348
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
83 additions
and
84 deletions
+83
-84
src/infiniop/devices/ascend/CMakeLists.txt
src/infiniop/devices/ascend/CMakeLists.txt
+1
-1
src/infiniop/ops/causal_softmax/ascend/causal_softmax_ascend.cc
...finiop/ops/causal_softmax/ascend/causal_softmax_ascend.cc
+1
-1
src/infiniop/ops/causal_softmax/ascend/causal_softmax_ascend.h
...nfiniop/ops/causal_softmax/ascend/causal_softmax_ascend.h
+0
-0
src/infiniop/ops/causal_softmax/operator.cc
src/infiniop/ops/causal_softmax/operator.cc
+1
-1
src/infiniop/ops/swiglu/ascend/swiglu_ascend.cc
src/infiniop/ops/swiglu/ascend/swiglu_ascend.cc
+1
-1
src/infiniop/ops/swiglu/ascend/swiglu_ascend.h
src/infiniop/ops/swiglu/ascend/swiglu_ascend.h
+14
-15
src/infiniop/ops/swiglu/ascend/swiglu_ascend_kernel.cpp
src/infiniop/ops/swiglu/ascend/swiglu_ascend_kernel.cpp
+64
-64
src/infiniop/ops/swiglu/operator.cc
src/infiniop/ops/swiglu/operator.cc
+1
-1
No files found.
src/infiniop/devices/ascend/CMakeLists.txt
View file @
fafb22db
...
@@ -25,7 +25,7 @@ include_directories(
...
@@ -25,7 +25,7 @@ include_directories(
ascendc_library
(
ascend_kernels STATIC
ascendc_library
(
ascend_kernels STATIC
../../ops/swiglu/ascend/swiglu_kernel.cpp
../../ops/swiglu/ascend/swiglu_
ascend_
kernel.cpp
# ../../ops/rotary_embedding/ascend/rotary_embedding_kernel.cpp
# ../../ops/rotary_embedding/ascend/rotary_embedding_kernel.cpp
# ../../ops/random_sample/ascend/random_sample_kernel.cpp
# ../../ops/random_sample/ascend/random_sample_kernel.cpp
)
)
...
...
src/infiniop/ops/causal_softmax/ascend/causal_softmax_a
clnn
.cc
→
src/infiniop/ops/causal_softmax/ascend/causal_softmax_a
scend
.cc
View file @
fafb22db
#include "causal_softmax_a
clnn
.h"
#include "causal_softmax_a
scend
.h"
#include "../../../devices/ascend/common_ascend.h"
#include "../../../devices/ascend/common_ascend.h"
#include <aclnnop/aclnn_masked_fill_tensor.h>
#include <aclnnop/aclnn_masked_fill_tensor.h>
#include <aclnnop/aclnn_softmax.h>
#include <aclnnop/aclnn_softmax.h>
...
...
src/infiniop/ops/causal_softmax/ascend/causal_softmax_a
clnn
.h
→
src/infiniop/ops/causal_softmax/ascend/causal_softmax_a
scend
.h
View file @
fafb22db
File moved
src/infiniop/ops/causal_softmax/operator.cc
View file @
fafb22db
...
@@ -9,7 +9,7 @@
...
@@ -9,7 +9,7 @@
#include "cuda/causal_softmax_cuda.cuh"
#include "cuda/causal_softmax_cuda.cuh"
#endif
#endif
#ifdef ENABLE_ASCEND_API
#ifdef ENABLE_ASCEND_API
#include "ascend/causal_softmax_a
clnn
.h"
#include "ascend/causal_softmax_a
scend
.h"
#endif
#endif
__C
infiniStatus_t
infiniopCreateCausalSoftmaxDescriptor
(
__C
infiniStatus_t
infiniopCreateCausalSoftmaxDescriptor
(
...
...
src/infiniop/ops/swiglu/ascend/swiglu_a
clnn
.cc
→
src/infiniop/ops/swiglu/ascend/swiglu_a
scend
.cc
View file @
fafb22db
#include "swiglu_a
clnn
.h"
#include "swiglu_a
scend
.h"
#include "../../../devices/ascend/common_ascend.h"
#include "../../../devices/ascend/common_ascend.h"
namespace
op
::
swiglu
::
ascend
{
namespace
op
::
swiglu
::
ascend
{
...
...
src/infiniop/ops/swiglu/ascend/swiglu_a
clnn
.h
→
src/infiniop/ops/swiglu/ascend/swiglu_a
scend
.h
View file @
fafb22db
...
@@ -20,23 +20,22 @@ public:
...
@@ -20,23 +20,22 @@ public:
std
::
vector
<
ptrdiff_t
>
b_strides
;
std
::
vector
<
ptrdiff_t
>
b_strides
;
static
utils
::
Result
<
SwigluInfo
>
create
(
infiniopTensorDescriptor_t
c_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
)
{
static
utils
::
Result
<
SwigluInfo
>
create
(
infiniopTensorDescriptor_t
c_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
)
{
if
(
!
c_desc
||
!
a_desc
||
!
b_desc
)
{
CHECK_OR_RETURN
(
c_desc
&&
a_desc
&&
b_desc
,
INFINI_STATUS_BAD_PARAM
);
return
INFINI_STATUS_BAD_PARAM
;
CHECK_OR_RETURN
(
!
c_desc
->
hasBroadcastDim
(),
INFINI_STATUS_BAD_TENSOR_STRIDES
);
}
CHECK_OR_RETURN
(
c_desc
->
ndim
()
==
a_desc
->
ndim
()
if
(
c_desc
->
hasBroadcastDim
())
{
&&
c_desc
->
ndim
()
==
b_desc
->
ndim
()
return
INFINI_STATUS_BAD_TENSOR_STRIDES
;
&&
(
c_desc
->
ndim
()
==
2
||
c_desc
->
ndim
()
==
3
),
}
INFINI_STATUS_BAD_TENSOR_SHAPE
);
if
(
c_desc
->
ndim
()
!=
a_desc
->
ndim
()
||
c_desc
->
ndim
()
!=
b_desc
->
ndim
()
||
(
c_desc
->
ndim
()
!=
2
&&
c_desc
->
ndim
()
!=
3
))
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
CHECK_SAME_SHAPE
(
c_desc
->
shape
(),
a_desc
->
shape
(),
b_desc
->
shape
());
CHECK_SAME_SHAPE
(
c_desc
->
shape
(),
a_desc
->
shape
(),
b_desc
->
shape
());
int32_t
ndim
=
c_desc
->
ndim
();
int32_t
ndim
=
c_desc
->
ndim
();
if
(
c_desc
->
stride
(
ndim
-
1
)
!=
1
||
a_desc
->
stride
(
ndim
-
1
)
!=
1
||
b_desc
->
stride
(
ndim
-
1
)
!=
1
)
{
CHECK_OR_RETURN
(
c_desc
->
stride
(
ndim
-
1
)
==
1
return
INFINI_STATUS_BAD_TENSOR_STRIDES
;
&&
a_desc
->
stride
(
ndim
-
1
)
==
1
}
&&
b_desc
->
stride
(
ndim
-
1
)
==
1
,
if
(
c_desc
->
dtype
()
!=
a_desc
->
dtype
()
||
c_desc
->
dtype
()
!=
b_desc
->
dtype
())
{
INFINI_STATUS_BAD_TENSOR_STRIDES
);
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
CHECK_OR_RETURN
(
c_desc
->
dtype
()
==
a_desc
->
dtype
()
}
&&
c_desc
->
dtype
()
==
b_desc
->
dtype
(),
INFINI_STATUS_BAD_TENSOR_DTYPE
);
return
utils
::
Result
<
SwigluInfo
>
(
SwigluInfo
{
return
utils
::
Result
<
SwigluInfo
>
(
SwigluInfo
{
c_desc
->
dtype
(),
c_desc
->
dtype
(),
c_desc
->
shape
(),
c_desc
->
shape
(),
...
...
src/infiniop/ops/swiglu/ascend/swiglu_kernel.cpp
→
src/infiniop/ops/swiglu/ascend/swiglu_
ascend_
kernel.cpp
View file @
fafb22db
...
@@ -6,117 +6,117 @@ template <typename T>
...
@@ -6,117 +6,117 @@ template <typename T>
class
SwigluKernel
{
class
SwigluKernel
{
public:
public:
__aicore__
inline
SwigluKernel
()
{}
__aicore__
inline
SwigluKernel
()
{}
__aicore__
inline
void
I
nit
(
GM_ADDR
c
,
GM_ADDR
a
,
GM_ADDR
b
,
int64_t
batch_
,
int64_t
seq
,
int64_t
hd
,
__aicore__
inline
void
i
nit
(
GM_ADDR
c
,
GM_ADDR
a
,
GM_ADDR
b
,
int64_t
batch_
,
int64_t
seq
,
int64_t
hd
,
int64_t
stride_batch_c
,
int64_t
stride_batch_a
,
int64_t
stride_batch_b
,
int64_t
stride_batch_c
,
int64_t
stride_batch_a
,
int64_t
stride_batch_b
,
int64_t
stride_seq_c
,
int64_t
stride_seq_a
,
int64_t
stride_seq_b
);
int64_t
stride_seq_c
,
int64_t
stride_seq_a
,
int64_t
stride_seq_b
);
__aicore__
inline
void
P
rocess
();
__aicore__
inline
void
p
rocess
();
private:
private:
__aicore__
inline
void
C
opyIn
(
int64_t
i
);
__aicore__
inline
void
c
opyIn
(
int64_t
i
);
__aicore__
inline
void
C
ompute
(
int64_t
i
);
__aicore__
inline
void
c
ompute
(
int64_t
i
);
__aicore__
inline
void
C
opyOut
(
int64_t
i
);
__aicore__
inline
void
c
opyOut
(
int64_t
i
);
private:
private:
GlobalTensor
<
T
>
cGm
,
aGm
,
bG
m
;
GlobalTensor
<
T
>
_c_gm
,
_a_gm
,
_b_g
m
;
TQue
<
QuePosition
::
VECIN
,
BUFFER_NUM
>
in
Q
ueue
A
,
in
Q
ueue
B
;
TQue
<
QuePosition
::
VECIN
,
BUFFER_NUM
>
_
in
_q
ueue
_a
,
_
in
_q
ueue
_b
;
TQue
<
QuePosition
::
VECOUT
,
BUFFER_NUM
>
out
Q
ueue
C
;
TQue
<
QuePosition
::
VECOUT
,
BUFFER_NUM
>
_
out
_q
ueue
_c
;
TPipe
pipe
;
TPipe
_
pipe
;
float
_beta_value
=
1.0
f
;
float
_beta_value
=
1.0
f
;
int64_t
_block_idx
,
_tile_len
,
_copy_len
,
int64_t
_block_idx
,
_tile_len
,
_copy_len
,
batch
,
seq_len
,
hidden_size
,
_
batch
,
_
seq_len
,
_
hidden_size
,
stride
SeqA
,
stride
SeqB
,
stride
SeqC
;
_
stride
_seq_a
,
_
stride
_seq_b
,
_
stride
_seq_c
;
int64_t
stride
B
atch
A
=
1
,
stride
B
atch
B
=
1
,
stride
B
atch
C
=
1
;
int64_t
_
stride
_b
atch
_a
=
1
,
_
stride
_b
atch
_b
=
1
,
_
stride
_b
atch
_c
=
1
;
};
};
template
<
typename
T
>
template
<
typename
T
>
__aicore__
inline
void
SwigluKernel
<
T
>::
I
nit
(
GM_ADDR
c
,
GM_ADDR
a
,
GM_ADDR
b
,
int64_t
batch_
,
int64_t
seq
,
int64_t
hd
,
__aicore__
inline
void
SwigluKernel
<
T
>::
i
nit
(
GM_ADDR
c
,
GM_ADDR
a
,
GM_ADDR
b
,
int64_t
batch_
,
int64_t
seq
,
int64_t
hd
,
int64_t
stride_batch_c
,
int64_t
stride_batch_a
,
int64_t
stride_batch_b
,
int64_t
stride_batch_c
,
int64_t
stride_batch_a
,
int64_t
stride_batch_b
,
int64_t
stride_seq_c
,
int64_t
stride_seq_a
,
int64_t
stride_seq_b
)
{
int64_t
stride_seq_c
,
int64_t
stride_seq_a
,
int64_t
stride_seq_b
)
{
// Init Shape & StrideVariables
// Init Shape & StrideVariables
batch
=
batch_
;
_
batch
=
batch_
;
seq_len
=
seq
;
_
seq_len
=
seq
;
hidden_size
=
hd
;
_
hidden_size
=
hd
;
stride
B
atch
A
=
stride_batch_a
;
_
stride
_b
atch
_a
=
stride_batch_a
;
stride
B
atch
B
=
stride_batch_b
;
_
stride
_b
atch
_b
=
stride_batch_b
;
stride
B
atch
C
=
stride_batch_c
;
_
stride
_b
atch
_c
=
stride_batch_c
;
stride
SeqA
=
stride_seq_a
;
_
stride
_seq_a
=
stride_seq_a
;
stride
SeqB
=
stride_seq_b
;
_
stride
_seq_b
=
stride_seq_b
;
stride
SeqC
=
stride_seq_c
;
_
stride
_seq_c
=
stride_seq_c
;
_block_idx
=
GetBlockIdx
();
_block_idx
=
GetBlockIdx
();
_tile_len
=
_block_idx
<
(
hidden_size
%
BLOCK_NUM
)
?
(
hidden_size
/
BLOCK_NUM
)
+
1
:
(
hidden_size
/
BLOCK_NUM
);
_tile_len
=
_block_idx
<
(
_
hidden_size
%
BLOCK_NUM
)
?
(
_
hidden_size
/
BLOCK_NUM
)
+
1
:
(
_
hidden_size
/
BLOCK_NUM
);
_copy_len
=
(
_tile_len
*
sizeof
(
T
))
%
BYTE_ALIGN
==
0
?
_tile_len
:
(
_tile_len
*
sizeof
(
T
)
+
(
BYTE_ALIGN
-
_tile_len
*
sizeof
(
T
)
%
BYTE_ALIGN
))
/
sizeof
(
T
);
_copy_len
=
(
_tile_len
*
sizeof
(
T
))
%
BYTE_ALIGN
==
0
?
_tile_len
:
(
_tile_len
*
sizeof
(
T
)
+
(
BYTE_ALIGN
-
_tile_len
*
sizeof
(
T
)
%
BYTE_ALIGN
))
/
sizeof
(
T
);
// Set global tensor
// Set global tensor
aG
m
.
SetGlobalBuffer
((
__gm__
T
*
)
a
);
_a_g
m
.
SetGlobalBuffer
((
__gm__
T
*
)
a
);
bG
m
.
SetGlobalBuffer
((
__gm__
T
*
)
b
);
_b_g
m
.
SetGlobalBuffer
((
__gm__
T
*
)
b
);
cG
m
.
SetGlobalBuffer
((
__gm__
T
*
)
c
);
_c_g
m
.
SetGlobalBuffer
((
__gm__
T
*
)
c
);
//
P
ipe alloc memory to queue, the unit is bytes
//
_p
ipe alloc memory to queue, the unit is bytes
pipe
.
InitBuffer
(
in
Q
ueue
A
,
BUFFER_NUM
,
_copy_len
*
sizeof
(
T
));
_
pipe
.
InitBuffer
(
_
in
_q
ueue
_a
,
BUFFER_NUM
,
_copy_len
*
sizeof
(
T
));
pipe
.
InitBuffer
(
in
Q
ueue
B
,
BUFFER_NUM
,
_copy_len
*
sizeof
(
T
));
_
pipe
.
InitBuffer
(
_
in
_q
ueue
_b
,
BUFFER_NUM
,
_copy_len
*
sizeof
(
T
));
pipe
.
InitBuffer
(
out
Q
ueue
C
,
BUFFER_NUM
,
_copy_len
*
sizeof
(
T
));
_
pipe
.
InitBuffer
(
_
out
_q
ueue
_c
,
BUFFER_NUM
,
_copy_len
*
sizeof
(
T
));
}
}
template
<
typename
T
>
template
<
typename
T
>
__aicore__
inline
void
SwigluKernel
<
T
>::
C
opyIn
(
int64_t
i
)
{
__aicore__
inline
void
SwigluKernel
<
T
>::
c
opyIn
(
int64_t
i
)
{
// Alloc tensor from queue memory
// Alloc tensor from queue memory
LocalTensor
<
T
>
aLocal
=
in
Q
ueue
A
.
AllocTensor
<
T
>
();
LocalTensor
<
T
>
aLocal
=
_
in
_q
ueue
_a
.
AllocTensor
<
T
>
();
LocalTensor
<
T
>
bLocal
=
in
Q
ueue
B
.
AllocTensor
<
T
>
();
LocalTensor
<
T
>
bLocal
=
_
in
_q
ueue
_b
.
AllocTensor
<
T
>
();
// Get idx of current tile
// Get idx of current tile
auto
batch
I
dx
=
batch
==
1
?
0
:
i
/
seq_len
;
auto
batch
_i
dx
=
_
batch
==
1
?
0
:
i
/
_
seq_len
;
auto
seq
I
dx
=
batch
==
1
?
i
:
i
%
seq_len
;
auto
seq
_i
dx
=
_
batch
==
1
?
i
:
i
%
_
seq_len
;
int64_t
idxa
=
batch
I
dx
*
stride
B
atch
A
+
seq
I
dx
*
stride
SeqA
+
_block_idx
*
_tile_len
;
int64_t
idxa
=
batch
_i
dx
*
_
stride
_b
atch
_a
+
seq
_i
dx
*
_
stride
_seq_a
+
_block_idx
*
_tile_len
;
int64_t
idxb
=
batch
I
dx
*
stride
B
atch
B
+
seq
I
dx
*
stride
SeqB
+
_block_idx
*
_tile_len
;
int64_t
idxb
=
batch
_i
dx
*
_
stride
_b
atch
_b
+
seq
_i
dx
*
_
stride
_seq_b
+
_block_idx
*
_tile_len
;
// Copy process_th tile from global tensor to local tensor
// Copy process_th tile from global tensor to local tensor
DataCopy
(
aLocal
,
aG
m
[
idxa
],
_copy_len
);
DataCopy
(
aLocal
,
_a_g
m
[
idxa
],
_copy_len
);
DataCopy
(
bLocal
,
bG
m
[
idxb
],
_copy_len
);
DataCopy
(
bLocal
,
_b_g
m
[
idxb
],
_copy_len
);
// Enque input tensor to VECIN queue
// Enque input tensor to VECIN queue
in
Q
ueue
A
.
EnQue
(
aLocal
);
_
in
_q
ueue
_a
.
EnQue
(
aLocal
);
in
Q
ueue
B
.
EnQue
(
bLocal
);
_
in
_q
ueue
_b
.
EnQue
(
bLocal
);
}
}
template
<
typename
T
>
template
<
typename
T
>
__aicore__
inline
void
SwigluKernel
<
T
>::
C
ompute
(
int64_t
i
)
{
__aicore__
inline
void
SwigluKernel
<
T
>::
c
ompute
(
int64_t
i
)
{
// Deque input tensors from VECIN queue
// Deque input tensors from VECIN queue
LocalTensor
<
T
>
aLocal
=
in
Q
ueue
A
.
DeQue
<
T
>
();
LocalTensor
<
T
>
aLocal
=
_
in
_q
ueue
_a
.
DeQue
<
T
>
();
LocalTensor
<
T
>
bLocal
=
in
Q
ueue
B
.
DeQue
<
T
>
();
LocalTensor
<
T
>
bLocal
=
_
in
_q
ueue
_b
.
DeQue
<
T
>
();
LocalTensor
<
T
>
cLocal
=
out
Q
ueue
C
.
AllocTensor
<
T
>
();
LocalTensor
<
T
>
cLocal
=
_
out
_q
ueue
_c
.
AllocTensor
<
T
>
();
// Call SwiGLU ascend api
// Call SwiGLU ascend api
SwiGLU
<
T
,
false
>
(
cLocal
,
aLocal
,
bLocal
,
_beta_value
,
_copy_len
);
SwiGLU
<
T
,
false
>
(
cLocal
,
aLocal
,
bLocal
,
_beta_value
,
_copy_len
);
// Enque result and free input
// Enque result and free input
out
Q
ueue
C
.
EnQue
<
T
>
(
cLocal
);
_
out
_q
ueue
_c
.
EnQue
<
T
>
(
cLocal
);
in
Q
ueue
A
.
FreeTensor
(
aLocal
);
_
in
_q
ueue
_a
.
FreeTensor
(
aLocal
);
in
Q
ueue
B
.
FreeTensor
(
bLocal
);
_
in
_q
ueue
_b
.
FreeTensor
(
bLocal
);
}
}
template
<
typename
T
>
template
<
typename
T
>
__aicore__
inline
void
SwigluKernel
<
T
>::
C
opyOut
(
int64_t
i
)
{
__aicore__
inline
void
SwigluKernel
<
T
>::
c
opyOut
(
int64_t
i
)
{
// Deque output tensor from VECOUT queue
// Deque output tensor from VECOUT queue
LocalTensor
<
T
>
cLocal
=
out
Q
ueue
C
.
DeQue
<
T
>
();
LocalTensor
<
T
>
cLocal
=
_
out
_q
ueue
_c
.
DeQue
<
T
>
();
auto
batch
I
dx
=
batch
==
1
?
0
:
i
/
seq_len
;
auto
batch
_i
dx
=
_
batch
==
1
?
0
:
i
/
_
seq_len
;
auto
seq
I
dx
=
batch
==
1
?
i
:
i
%
seq_len
;
auto
seq
_i
dx
=
_
batch
==
1
?
i
:
i
%
_
seq_len
;
int64_t
idxc
=
batch
I
dx
*
stride
B
atch
C
+
seq
I
dx
*
stride
SeqC
+
_block_idx
*
_tile_len
;
int64_t
idxc
=
batch
_i
dx
*
_
stride
_b
atch
_c
+
seq
_i
dx
*
_
stride
_seq_c
+
_block_idx
*
_tile_len
;
// Copy progress_th tile from local tensor to global tensor
// Copy progress_th tile from local tensor to global tensor
if
(
_tile_len
*
sizeof
(
T
)
%
BYTE_ALIGN
!=
0
)
{
if
(
_tile_len
*
sizeof
(
T
)
%
BYTE_ALIGN
!=
0
)
{
DataCopyExtParams
dcep
=
{
1
,
static_cast
<
uint32_t
>
(
_tile_len
*
sizeof
(
T
)),
0
,
0
,
0
};
DataCopyExtParams
dcep
=
{
1
,
static_cast
<
uint32_t
>
(
_tile_len
*
sizeof
(
T
)),
0
,
0
,
0
};
DataCopyPad
(
cG
m
[
idxc
],
cLocal
,
dcep
);
DataCopyPad
(
_c_g
m
[
idxc
],
cLocal
,
dcep
);
}
else
{
}
else
{
DataCopy
(
cG
m
[
idxc
],
cLocal
,
_tile_len
);
DataCopy
(
_c_g
m
[
idxc
],
cLocal
,
_tile_len
);
}
}
// Free output Local tensor
// Free output Local tensor
out
Q
ueue
C
.
FreeTensor
(
cLocal
);
_
out
_q
ueue
_c
.
FreeTensor
(
cLocal
);
}
}
template
<
typename
T
>
template
<
typename
T
>
__aicore__
inline
void
SwigluKernel
<
T
>::
P
rocess
()
{
__aicore__
inline
void
SwigluKernel
<
T
>::
p
rocess
()
{
for
(
int64_t
i
=
0
;
i
<
batch
*
seq_len
;
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
_
batch
*
_
seq_len
;
++
i
)
{
C
opyIn
(
i
);
c
opyIn
(
i
);
C
ompute
(
i
);
c
ompute
(
i
);
C
opyOut
(
i
);
c
opyOut
(
i
);
}
}
}
}
...
@@ -130,11 +130,11 @@ __aicore__ inline void SwigluKernel<T>::Process() {
...
@@ -130,11 +130,11 @@ __aicore__ inline void SwigluKernel<T>::Process() {
int64_t stride_seq_a, \
int64_t stride_seq_a, \
int64_t stride_seq_b) { \
int64_t stride_seq_b) { \
SwigluKernel<TYPE> op; \
SwigluKernel<TYPE> op; \
op.
I
nit(c, a, b, \
op.
i
nit(c, a, b, \
batch, seq, hd, \
batch, seq, hd, \
stride_batch_c, stride_batch_a, stride_batch_b, \
stride_batch_c, stride_batch_a, stride_batch_b, \
stride_seq_c, stride_seq_a, stride_seq_b); \
stride_seq_c, stride_seq_a, stride_seq_b); \
op.
P
rocess(); \
op.
p
rocess(); \
}
}
DEFINE_SWIGLU_KERNEL
(
swiglu_kernel_half
,
half
)
DEFINE_SWIGLU_KERNEL
(
swiglu_kernel_half
,
half
)
...
...
src/infiniop/ops/swiglu/operator.cc
View file @
fafb22db
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
#include "kunlun/swiglu_kunlun.h"
#include "kunlun/swiglu_kunlun.h"
#endif
#endif
#ifdef ENABLE_ASCEND_API
#ifdef ENABLE_ASCEND_API
#include "ascend/swiglu_a
clnn
.h"
#include "ascend/swiglu_a
scend
.h"
#endif
#endif
__C
infiniStatus_t
infiniopCreateSwiGLUDescriptor
(
__C
infiniStatus_t
infiniopCreateSwiGLUDescriptor
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment