Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
1622c975
Unverified
Commit
1622c975
authored
Apr 21, 2025
by
PanZezhong1725
Committed by
GitHub
Apr 21, 2025
Browse files
Merge pull request #64 from InfiniTensor/issue/8
issue/8: causalsoftmax算子-昇腾
parents
cbd573f5
d5b310de
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
160 additions
and
20 deletions
+160
-20
src/infiniop/devices/ascend/common_ascend.cc
src/infiniop/devices/ascend/common_ascend.cc
+6
-1
src/infiniop/ops/causal_softmax/ascend/causal_softmax_aclnn.cc
...nfiniop/ops/causal_softmax/ascend/causal_softmax_aclnn.cc
+133
-0
src/infiniop/ops/causal_softmax/ascend/causal_softmax_aclnn.h
...infiniop/ops/causal_softmax/ascend/causal_softmax_aclnn.h
+7
-0
src/infiniop/ops/causal_softmax/operator.cc
src/infiniop/ops/causal_softmax/operator.cc
+11
-16
src/infiniop/ops/swiglu/operator.cc
src/infiniop/ops/swiglu/operator.cc
+1
-1
test/infiniop/causal_softmax.py
test/infiniop/causal_softmax.py
+2
-2
No files found.
src/infiniop/devices/ascend/common_ascend.cc
View file @
1622c975
...
...
@@ -41,7 +41,12 @@ aclnnTensorDescriptor::aclnnTensorDescriptor(aclDataType dtype, const std::vecto
this
->
strides
=
strides
;
this
->
dataType
=
dtype
;
this
->
format
=
aclFormat
::
ACL_FORMAT_ND
;
this
->
storageShape
=
inferStorageShape
(
this
->
shape
,
this
->
strides
);
if
(
this
->
ndim
!=
0
)
{
this
->
storageShape
=
inferStorageShape
(
this
->
shape
,
this
->
strides
);
}
else
{
this
->
storageShape
=
shape
;
this
->
storageNdim
=
0
;
}
this
->
tensor
=
aclCreateTensor
(
this
->
shape
.
data
(),
this
->
ndim
,
this
->
dataType
,
...
...
src/infiniop/ops/causal_softmax/ascend/causal_softmax_aclnn.cc
0 → 100644
View file @
1622c975
#include "causal_softmax_aclnn.h"
#include "../../../devices/ascend/common_ascend.h"
#include <aclnnop/aclnn_masked_fill_tensor.h>
#include <aclnnop/aclnn_softmax.h>
namespace
op
::
causal_softmax
::
ascend
{
struct
Descriptor
::
Opaque
{
mutable
aclOpExecutor
*
executor
;
mutable
aclOpExecutor
*
mask_executor
;
aclnnTensorDescriptor_t
x
;
aclnnTensorDescriptor_t
mask
;
aclnnTensorDescriptor_t
y
;
void
*
mask_addr
;
size_t
workspacesize_softmax
;
size_t
workspacesize_mask
;
~
Opaque
()
{
delete
x
;
delete
mask
;
delete
y
;
aclDestroyAclOpExecutor
(
executor
);
aclDestroyAclOpExecutor
(
mask_executor
);
}
};
Descriptor
::~
Descriptor
()
{
delete
_opaque
;
}
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
x_desc
)
{
auto
handle_ascend
=
reinterpret_cast
<
device
::
ascend
::
Handle
*>
(
handle
);
auto
result
=
CausalSoftmaxInfo
::
create
(
y_desc
,
x_desc
);
CHECK_RESULT
(
result
);
CausalSoftmaxInfo
info
=
result
.
take
();
aclOpExecutor
*
executor
=
nullptr
;
aclOpExecutor
*
mask_executor
=
nullptr
;
aclnnTensorDescriptor_t
y
=
nullptr
;
aclnnTensorDescriptor_t
mask
=
nullptr
;
aclnnTensorDescriptor_t
x
=
nullptr
;
aclnnTensorDescriptor_t
value
=
nullptr
;
void
*
mask_addr
=
nullptr
;
void
*
value_addr
=
nullptr
;
size_t
workspacesize_softmax
=
0
;
size_t
workspacesize_mask
=
0
;
// Create Aclnn Tensor Descriptors for input , mask and output
std
::
vector
<
int64_t
>
shape
=
{
static_cast
<
int64_t
>
(
info
.
batch_size
),
static_cast
<
int64_t
>
(
info
.
seq_len
),
static_cast
<
int64_t
>
(
info
.
total_seq_len
)};
std
::
vector
<
int64_t
>
x_strides
=
{
static_cast
<
int64_t
>
(
info
.
x_stride_b
),
static_cast
<
int64_t
>
(
info
.
x_stride_i
),
static_cast
<
int64_t
>
(
info
.
x_stride_j
)};
std
::
vector
<
int64_t
>
y_strides
=
{
static_cast
<
int64_t
>
(
info
.
y_stride_b
),
static_cast
<
int64_t
>
(
info
.
y_stride_i
),
static_cast
<
int64_t
>
(
info
.
y_stride_j
)};
y
=
new
aclnnTensorDescriptor
(
toAclDataType
(
info
.
dtype
),
shape
,
y_strides
);
x
=
new
aclnnTensorDescriptor
(
toAclDataType
(
info
.
dtype
),
shape
,
x_strides
);
mask
=
new
aclnnTensorDescriptor
(
aclDataType
::
ACL_BOOL
,
{
static_cast
<
int64_t
>
(
info
.
seq_len
),
static_cast
<
int64_t
>
(
info
.
total_seq_len
)},
{
static_cast
<
int64_t
>
(
info
.
total_seq_len
),
1
});
// Initialize the value tensor with -∞
if
(
info
.
dtype
==
INFINI_DTYPE_F16
)
{
uint16_t
mask_value
=
0xfc00
;
auto
size
=
aclDataTypeSize
(
aclDataType
::
ACL_FLOAT16
);
CHECK_ACL
(
aclrtMalloc
(
&
value_addr
,
size
,
ACL_MEM_MALLOC_HUGE_FIRST
));
CHECK_ACL
(
aclrtMemcpy
(
value_addr
,
size
,
&
mask_value
,
size
,
ACL_MEMCPY_HOST_TO_DEVICE
));
value
=
new
aclnnTensorDescriptor
(
aclDataType
::
ACL_FLOAT16
,
{},
{},
value_addr
);
}
else
{
uint32_t
mask_value
=
0xff800000
;
auto
size
=
aclDataTypeSize
(
aclDataType
::
ACL_FLOAT
);
CHECK_ACL
(
aclrtMalloc
(
&
value_addr
,
size
,
ACL_MEM_MALLOC_HUGE_FIRST
));
CHECK_ACL
(
aclrtMemcpy
(
value_addr
,
size
,
&
mask_value
,
size
,
ACL_MEMCPY_HOST_TO_DEVICE
));
value
=
new
aclnnTensorDescriptor
(
aclDataType
::
ACL_FLOAT
,
{},
{},
value_addr
);
}
// Fill Mask Tensor
std
::
vector
<
char
>
mask_matrix
(
mask
->
numel
(),
0
);
for
(
size_t
i
=
0
;
i
<
info
.
seq_len
;
++
i
)
{
for
(
size_t
j
=
info
.
total_seq_len
-
info
.
seq_len
+
i
+
1
;
j
<
info
.
total_seq_len
;
++
j
)
{
size_t
index
=
i
*
info
.
total_seq_len
+
j
;
mask_matrix
[
index
]
=
1
;
}
}
auto
size
=
mask
->
numel
()
*
aclDataTypeSize
(
aclDataType
::
ACL_BOOL
);
CHECK_ACL
(
aclrtMalloc
(
&
mask_addr
,
size
,
ACL_MEM_MALLOC_HUGE_FIRST
));
CHECK_ACL
(
aclrtMemcpy
(
mask_addr
,
size
,
mask_matrix
.
data
(),
size
,
ACL_MEMCPY_HOST_TO_DEVICE
));
// Get the workspace size for the op
aclTensor
*
tx
=
x
->
tensor
;
aclTensor
*
ty
=
y
->
tensor
;
aclTensor
*
tmask
=
mask
->
tensor
;
aclTensor
*
tvalue
=
value
->
tensor
;
CHECK_ACL
(
aclnnInplaceMaskedFillTensorGetWorkspaceSize
(
tx
,
tmask
,
tvalue
,
&
workspacesize_mask
,
&
mask_executor
));
aclSetAclOpExecutorRepeatable
(
mask_executor
);
int64_t
dim
=
2
;
CHECK_ACL
(
aclnnSoftmaxGetWorkspaceSize
(
tx
,
dim
,
ty
,
&
workspacesize_softmax
,
&
executor
));
aclSetAclOpExecutorRepeatable
(
executor
);
// Create the descriptor
size_t
all_workspacesize
=
workspacesize_softmax
+
workspacesize_mask
;
*
desc_ptr
=
new
Descriptor
(
new
Opaque
{
executor
,
mask_executor
,
x
,
mask
,
y
,
mask_addr
,
workspacesize_softmax
,
workspacesize_mask
},
std
::
move
(
info
),
all_workspacesize
,
handle_ascend
->
device
,
handle_ascend
->
device_id
);
return
INFINI_STATUS_SUCCESS
;
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
y
,
const
void
*
x
,
void
*
stream
)
const
{
if
(
workspace_size
<
workspaceSize
())
{
return
INFINI_STATUS_INSUFFICIENT_WORKSPACE
;
}
auto
tx
=
_opaque
->
x
->
tensor
;
auto
ty
=
_opaque
->
y
->
tensor
;
auto
tmask
=
_opaque
->
mask
->
tensor
;
auto
executor
=
_opaque
->
executor
;
auto
mask_executor
=
_opaque
->
mask_executor
;
auto
mask_addr
=
_opaque
->
mask_addr
;
AclSetTensorAddr
(
mask_executor
,
0
,
tx
,
(
void
*
)
x
);
AclSetTensorAddr
(
mask_executor
,
1
,
tmask
,
mask_addr
);
CHECK_ACL
(
aclnnInplaceMaskedFillTensor
(
workspace
,
_opaque
->
workspacesize_mask
,
mask_executor
,
stream
));
CHECK_ACL
(
aclrtSynchronizeStream
(
stream
));
AclSetTensorAddr
(
executor
,
0
,
tx
,
(
void
*
)
x
);
AclSetTensorAddr
(
executor
,
1
,
ty
,
y
);
CHECK_ACL
(
aclnnSoftmax
(
workspace
,
_opaque
->
workspacesize_softmax
,
executor
,
stream
));
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::causal_softmax::ascend
src/infiniop/ops/causal_softmax/ascend/causal_softmax_aclnn.h
0 → 100644
View file @
1622c975
#ifndef __CAUSAL_SOFTMAX_ASCEND_H__
#define __CAUSAL_SOFTMAX_ASCEND_H__
#include "../causal_softmax.h"
DESCRIPTOR
(
ascend
)
#endif
src/infiniop/ops/causal_softmax/operator.cc
View file @
1622c975
...
...
@@ -8,6 +8,9 @@
#ifdef ENABLE_CUDA_API
#include "cuda/causal_softmax_cuda.cuh"
#endif
#ifdef ENABLE_ASCEND_API
#include "ascend/causal_softmax_aclnn.h"
#endif
__C
infiniStatus_t
infiniopCreateCausalSoftmaxDescriptor
(
infiniopHandle_t
handle
,
...
...
@@ -36,10 +39,8 @@ __C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor(
// return cnnlCreateCausalSoftmaxDescriptor((BangHandle_t) handle, (CausalSoftmaxCnnlDescriptor_t *) desc_ptr, y_desc);
}
#endif
#ifdef ENABLE_ASCEND_NPU
case
DevAscendNpu
:
{
return
aclnnCreateCausalSoftmaxDescriptor
((
AscendHandle_t
)
handle
,
(
CausalSoftmaxAclnnDescriptor_t
*
)
desc_ptr
,
y_desc
);
}
#ifdef ENABLE_ASCEND_API
CREATE
(
INFINI_DEVICE_ASCEND
,
ascend
)
#endif
#ifdef ENABLE_METAX_GPU
case
DevMetaxGpu
:
{
...
...
@@ -76,10 +77,8 @@ __C infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDe
}
#endif
#ifdef ENABLE_ASCEND_NPU
case
DevAscendNpu
:
{
return
aclnnGetCausalSoftmaxWorkspaceSize
((
CausalSoftmaxAclnnDescriptor_t
)
desc
,
size
);
}
#ifdef ENABLE_ASCEND_API
GET
(
INFINI_DEVICE_ASCEND
,
ascend
)
#endif
#ifdef ENABLE_METAX_GPU
case
DevMetaxGpu
:
{
...
...
@@ -120,10 +119,8 @@ __C infiniStatus_t infiniopCausalSoftmax(
// return cnnlCausalSoftmax((CausalSoftmaxCnnlDescriptor_t) desc, workspace, workspace_size, data, stream);
}
#endif
#ifdef ENABLE_ASCEND_NPU
case
DevAscendNpu
:
{
return
aclnnCausalSoftmax
((
CausalSoftmaxAclnnDescriptor_t
)
desc
,
workspace
,
workspace_size
,
data
,
stream
);
}
#ifdef ENABLE_ASCEND_API
CALCULATE
(
INFINI_DEVICE_ASCEND
,
ascend
)
#endif
#ifdef ENABLE_METAX_GPU
case
DevMetaxGpu
:
{
...
...
@@ -159,10 +156,8 @@ __C infiniStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftmaxD
// return cnnlDestroyCausalSoftmaxDescriptor((CausalSoftmaxCnnlDescriptor_t) desc);
}
#endif
#ifdef ENABLE_ASCEND_NPU
case
DevAscendNpu
:
{
return
aclnnDestroyCausalSoftmaxDescriptor
((
CausalSoftmaxAclnnDescriptor_t
)
desc
);
}
#ifdef ENABLE_ASCEND_API
DESTROY
(
INFINI_DEVICE_ASCEND
,
ascend
)
#endif
#ifdef ENABLE_METAX_GPU
case
DevMetaxGpu
:
{
...
...
src/infiniop/ops/swiglu/operator.cc
View file @
1622c975
...
...
@@ -85,7 +85,7 @@ __C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t des
return
bangGetSwiGLUWorkspaceSize
((
SwiGLUBangDescriptor_t
)
desc
,
size
);
}
#endif
#ifdef ENABLE_ASCEND_
API
#ifdef ENABLE_ASCEND_
NPU
GET
(
INFINI_DEVICE_ASCEND
,
ascend
)
#endif
#ifdef ENABLE_METAX_GPU
...
...
test/infiniop/causal_softmax.py
View file @
1622c975
...
...
@@ -29,7 +29,7 @@ _TEST_CASES_ = [
((
32
,
512
),
(
1024
,
1
),
(
1024
,
1
)),
((
32
,
5
,
5
),
None
,
None
),
((
32
,
20
,
512
),
None
,
None
),
((
32
,
20
,
512
),
(
20480
,
512
,
1
),
None
),
# Ascend 暂不支持非连续
((
32
,
20
,
512
),
(
20480
,
512
,
1
),
None
),
]
# Data types used for testing
...
...
@@ -47,8 +47,8 @@ class Inplace(Enum):
_INPLACE
=
[
Inplace
.
OUT_OF_PLACE
,
Inplace
.
INPLACE_X
,
Inplace
.
OUT_OF_PLACE
,
]
_TEST_CASES
=
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment