Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
c2e87202
Commit
c2e87202
authored
Jun 04, 2025
by
Catheriany
Browse files
Merge remote-tracking branch 'origin/main' into issue/142
parents
41818f84
c203635b
Changes
175
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
814 additions
and
134 deletions
+814
-134
src/infiniop/ops/causal_softmax/ascend/causal_softmax_ascend.cc
...finiop/ops/causal_softmax/ascend/causal_softmax_ascend.cc
+140
-0
src/infiniop/ops/causal_softmax/ascend/causal_softmax_ascend.h
...nfiniop/ops/causal_softmax/ascend/causal_softmax_ascend.h
+7
-0
src/infiniop/ops/causal_softmax/causal_softmax.h
src/infiniop/ops/causal_softmax/causal_softmax.h
+4
-2
src/infiniop/ops/causal_softmax/cpu/causal_softmax_cpu.cc
src/infiniop/ops/causal_softmax/cpu/causal_softmax_cpu.cc
+23
-20
src/infiniop/ops/causal_softmax/cuda/causal_softmax_cuda.cu
src/infiniop/ops/causal_softmax/cuda/causal_softmax_cuda.cu
+72
-0
src/infiniop/ops/causal_softmax/cuda/causal_softmax_cuda.cuh
src/infiniop/ops/causal_softmax/cuda/causal_softmax_cuda.cuh
+8
-0
src/infiniop/ops/causal_softmax/cuda/causal_softmax_kernel.cuh
...nfiniop/ops/causal_softmax/cuda/causal_softmax_kernel.cuh
+60
-0
src/infiniop/ops/causal_softmax/info.h
src/infiniop/ops/causal_softmax/info.h
+38
-20
src/infiniop/ops/causal_softmax/operator.cc
src/infiniop/ops/causal_softmax/operator.cc
+33
-40
src/infiniop/ops/clip/cpu/clip_cpu.cc
src/infiniop/ops/clip/cpu/clip_cpu.cc
+54
-0
src/infiniop/ops/clip/cpu/clip_cpu.h
src/infiniop/ops/clip/cpu/clip_cpu.h
+23
-0
src/infiniop/ops/clip/cuda/clip_cuda.cu
src/infiniop/ops/clip/cuda/clip_cuda.cu
+59
-0
src/infiniop/ops/clip/cuda/clip_cuda.cuh
src/infiniop/ops/clip/cuda/clip_cuda.cuh
+9
-0
src/infiniop/ops/clip/cuda/clip_cuda_internal.cuh
src/infiniop/ops/clip/cuda/clip_cuda_internal.cuh
+30
-0
src/infiniop/ops/clip/operator.cc
src/infiniop/ops/clip/operator.cc
+118
-0
src/infiniop/ops/gemm/ascend/gemm_ascend.cc
src/infiniop/ops/gemm/ascend/gemm_ascend.cc
+55
-22
src/infiniop/ops/gemm/cpu/gemm_cpu.cc
src/infiniop/ops/gemm/cpu/gemm_cpu.cc
+26
-23
src/infiniop/ops/gemm/kunlun/gemm_kunlun.cc
src/infiniop/ops/gemm/kunlun/gemm_kunlun.cc
+1
-1
src/infiniop/ops/gemm/musa/gemm_musa.mu
src/infiniop/ops/gemm/musa/gemm_musa.mu
+3
-6
src/infiniop/ops/mul/cpu/mul_cpu.cc
src/infiniop/ops/mul/cpu/mul_cpu.cc
+51
-0
No files found.
src/infiniop/ops/causal_softmax/ascend/causal_softmax_ascend.cc
0 → 100644
View file @
c2e87202
#include "causal_softmax_ascend.h"
#include "../../../devices/ascend/common_ascend.h"
#include <aclnnop/aclnn_masked_fill_tensor.h>
#include <aclnnop/aclnn_softmax.h>
namespace
op
::
causal_softmax
::
ascend
{
struct
Descriptor
::
Opaque
{
aclnnTensorDescriptor_t
x
;
aclnnTensorDescriptor_t
mask
;
aclnnTensorDescriptor_t
y
;
aclnnTensorDescriptor_t
value
;
void
*
mask_addr
;
void
*
value_addr
;
uint64_t
workspacesize
;
aclOpExecutor
*
executor
;
~
Opaque
()
{
delete
x
;
delete
mask
;
delete
y
;
delete
value
;
aclrtFree
(
mask_addr
);
aclrtFree
(
value_addr
);
// Delete useless executor
aclDestroyAclOpExecutor
(
executor
);
}
};
Descriptor
::~
Descriptor
()
{
delete
_opaque
;
}
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
x_desc
)
{
auto
handle_ascend
=
reinterpret_cast
<
device
::
ascend
::
Handle
*>
(
handle
);
auto
result
=
CausalSoftmaxInfo
::
create
(
y_desc
,
x_desc
);
CHECK_RESULT
(
result
);
CausalSoftmaxInfo
info
=
result
.
take
();
aclOpExecutor
*
executor
=
nullptr
;
aclOpExecutor
*
mask_executor
=
nullptr
;
aclnnTensorDescriptor_t
y
=
nullptr
;
aclnnTensorDescriptor_t
mask
=
nullptr
;
aclnnTensorDescriptor_t
x
=
nullptr
;
aclnnTensorDescriptor_t
value
=
nullptr
;
void
*
mask_addr
=
nullptr
;
void
*
value_addr
=
nullptr
;
size_t
workspacesize_softmax
=
0
;
size_t
workspacesize_mask
=
0
;
// Create Aclnn Tensor Descriptors for input , mask and output
std
::
vector
<
int64_t
>
shape
=
{
static_cast
<
int64_t
>
(
info
.
batch_size
),
static_cast
<
int64_t
>
(
info
.
seq_len
),
static_cast
<
int64_t
>
(
info
.
total_seq_len
)};
std
::
vector
<
int64_t
>
x_strides
=
{
static_cast
<
int64_t
>
(
info
.
x_stride_b
),
static_cast
<
int64_t
>
(
info
.
x_stride_i
),
static_cast
<
int64_t
>
(
info
.
x_stride_j
)};
std
::
vector
<
int64_t
>
y_strides
=
{
static_cast
<
int64_t
>
(
info
.
y_stride_b
),
static_cast
<
int64_t
>
(
info
.
y_stride_i
),
static_cast
<
int64_t
>
(
info
.
y_stride_j
)};
y
=
new
aclnnTensorDescriptor
(
toAclDataType
(
info
.
dtype
),
shape
,
y_strides
);
x
=
new
aclnnTensorDescriptor
(
toAclDataType
(
info
.
dtype
),
shape
,
x_strides
);
mask
=
new
aclnnTensorDescriptor
(
aclDataType
::
ACL_BOOL
,
{
static_cast
<
int64_t
>
(
info
.
seq_len
),
static_cast
<
int64_t
>
(
info
.
total_seq_len
)},
{
static_cast
<
int64_t
>
(
info
.
total_seq_len
),
1
});
// Initialize the value tensor with -∞
if
(
info
.
dtype
==
INFINI_DTYPE_F16
)
{
uint16_t
mask_value
=
0xfc00
;
auto
size
=
aclDataTypeSize
(
aclDataType
::
ACL_FLOAT16
);
CHECK_ACL
(
aclrtMalloc
(
&
value_addr
,
size
,
ACL_MEM_MALLOC_HUGE_FIRST
));
CHECK_ACL
(
aclrtMemcpy
(
value_addr
,
size
,
&
mask_value
,
size
,
ACL_MEMCPY_HOST_TO_DEVICE
));
value
=
new
aclnnTensorDescriptor
(
aclDataType
::
ACL_FLOAT16
,
{},
{});
}
else
{
uint32_t
mask_value
=
0xff800000
;
auto
size
=
aclDataTypeSize
(
aclDataType
::
ACL_FLOAT
);
CHECK_ACL
(
aclrtMalloc
(
&
value_addr
,
size
,
ACL_MEM_MALLOC_HUGE_FIRST
));
CHECK_ACL
(
aclrtMemcpy
(
value_addr
,
size
,
&
mask_value
,
size
,
ACL_MEMCPY_HOST_TO_DEVICE
));
value
=
new
aclnnTensorDescriptor
(
aclDataType
::
ACL_FLOAT
,
{},
{});
}
// Fill Mask Tensor
std
::
vector
<
char
>
mask_matrix
(
mask
->
numel
(),
0
);
for
(
size_t
i
=
0
;
i
<
info
.
seq_len
;
++
i
)
{
for
(
size_t
j
=
info
.
total_seq_len
-
info
.
seq_len
+
i
+
1
;
j
<
info
.
total_seq_len
;
++
j
)
{
size_t
index
=
i
*
info
.
total_seq_len
+
j
;
mask_matrix
[
index
]
=
1
;
}
}
auto
size
=
mask
->
numel
()
*
aclDataTypeSize
(
aclDataType
::
ACL_BOOL
);
CHECK_ACL
(
aclrtMalloc
(
&
mask_addr
,
size
,
ACL_MEM_MALLOC_HUGE_FIRST
));
CHECK_ACL
(
aclrtMemcpy
(
mask_addr
,
size
,
mask_matrix
.
data
(),
size
,
ACL_MEMCPY_HOST_TO_DEVICE
));
// Get the workspace size for the op
aclTensor
*
tx
=
x
->
tensor
;
aclTensor
*
ty
=
y
->
tensor
;
aclTensor
*
tmask
=
mask
->
tensor
;
aclTensor
*
tvalue
=
value
->
tensor
;
CHECK_ACL
(
aclnnInplaceMaskedFillTensorGetWorkspaceSize
(
tx
,
tmask
,
tvalue
,
&
workspacesize_mask
,
&
mask_executor
));
int64_t
dim
=
2
;
CHECK_ACL
(
aclnnSoftmaxGetWorkspaceSize
(
tx
,
dim
,
ty
,
&
workspacesize_softmax
,
&
executor
));
// set executor reusable
aclSetAclOpExecutorRepeatable
(
executor
);
// Create the descripto
size_t
all_workspacesize
=
std
::
max
(
workspacesize_softmax
,
workspacesize_mask
);
*
desc_ptr
=
new
Descriptor
(
new
Opaque
{
x
,
mask
,
y
,
value
,
mask_addr
,
value_addr
,
workspacesize_softmax
,
executor
},
std
::
move
(
info
),
all_workspacesize
,
handle_ascend
->
device
,
handle_ascend
->
device_id
);
return
INFINI_STATUS_SUCCESS
;
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
y
,
const
void
*
x
,
void
*
stream
)
const
{
if
(
workspace_size
<
workspaceSize
())
{
return
INFINI_STATUS_INSUFFICIENT_WORKSPACE
;
}
auto
tx
=
_opaque
->
x
->
tensor
;
auto
ty
=
_opaque
->
y
->
tensor
;
auto
tmask
=
_opaque
->
mask
->
tensor
;
auto
tvalue
=
_opaque
->
value
->
tensor
;
aclOpExecutor
*
mask_executor
=
nullptr
;
size_t
workspacesize_mask
=
0
;
AclSetTensorAddr
(
mask_executor
,
0
,
tx
,
(
void
*
)
x
);
AclSetTensorAddr
(
mask_executor
,
1
,
tmask
,
_opaque
->
mask_addr
);
AclSetTensorAddr
(
mask_executor
,
2
,
tvalue
,
_opaque
->
value_addr
);
CHECK_ACL
(
aclnnInplaceMaskedFillTensorGetWorkspaceSize
(
tx
,
tmask
,
tvalue
,
&
workspacesize_mask
,
&
mask_executor
));
CHECK_ACL
(
aclnnInplaceMaskedFillTensor
(
workspace
,
workspacesize_mask
,
mask_executor
,
stream
));
AclSetTensorAddr
(
_opaque
->
executor
,
0
,
tx
,
(
void
*
)
x
);
AclSetTensorAddr
(
_opaque
->
executor
,
1
,
ty
,
y
);
CHECK_ACL
(
aclnnSoftmax
(
workspace
,
_opaque
->
workspacesize
,
_opaque
->
executor
,
stream
));
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::causal_softmax::ascend
src/infiniop/ops/causal_softmax/ascend/causal_softmax_ascend.h
0 → 100644
View file @
c2e87202
#ifndef __CAUSAL_SOFTMAX_ASCEND_H__
#define __CAUSAL_SOFTMAX_ASCEND_H__
#include "../causal_softmax.h"
DESCRIPTOR
(
ascend
)
#endif
src/infiniop/ops/causal_softmax/causal_softmax.h
View file @
c2e87202
...
@@ -32,11 +32,13 @@
...
@@ -32,11 +32,13 @@
static infiniStatus_t create( \
static infiniStatus_t create( \
infiniopHandle_t handle, \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc); \
infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t x_desc); \
\
\
infiniStatus_t calculate( \
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *workspace, size_t workspace_size, \
void *data, \
void *y, \
const void *x, \
void *stream) const; \
void *stream) const; \
}; \
}; \
}
}
...
...
src/infiniop/ops/causal_softmax/cpu/causal_softmax_cpu.cc
View file @
c2e87202
...
@@ -9,44 +9,46 @@ Descriptor::~Descriptor() {}
...
@@ -9,44 +9,46 @@ Descriptor::~Descriptor() {}
infiniStatus_t
Descriptor
::
create
(
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle
,
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
y_desc
)
{
infiniopTensorDescriptor_t
y_desc
,
auto
result
=
CausalSoftmaxInfo
::
create
(
y_desc
);
infiniopTensorDescriptor_t
x_desc
)
{
auto
result
=
CausalSoftmaxInfo
::
create
(
y_desc
,
x_desc
);
CHECK_RESULT
(
result
);
CHECK_RESULT
(
result
);
*
desc_ptr
=
new
Descriptor
(
nullptr
,
result
.
take
(),
0
,
handle
->
device
,
handle
->
device_id
);
*
desc_ptr
=
new
Descriptor
(
nullptr
,
result
.
take
(),
0
,
handle
->
device
,
handle
->
device_id
);
return
INFINI_STATUS_SUCCESS
;
return
INFINI_STATUS_SUCCESS
;
}
}
template
<
typename
T
>
template
<
typename
T
>
infiniStatus_t
causal_softmax
(
const
CausalSoftmaxInfo
*
info
,
T
*
data
)
{
infiniStatus_t
causal_softmax
(
const
CausalSoftmaxInfo
*
info
,
T
*
y
,
const
T
*
x
)
{
#pragma omp parallel for
#pragma omp parallel for
for
(
ptrdiff_t
index
=
0
;
index
<
ptrdiff_t
(
info
->
batch_size
*
info
->
seq_len
);
index
++
)
{
for
(
ptrdiff_t
index
=
0
;
index
<
ptrdiff_t
(
info
->
batch_size
*
info
->
seq_len
);
index
++
)
{
size_t
ind
=
index
;
size_t
batch
=
index
/
info
->
seq_len
;
size_t
offset
=
0
;
size_t
i
=
(
index
%
info
->
seq_len
);
size_t
i
=
(
ind
%
info
->
seq_len
);
ptrdiff_t
y_offset
=
batch
*
info
->
y_stride_b
+
i
*
info
->
y_stride_i
;
offset
+=
(
ind
%
info
->
seq_len
)
*
info
->
stride_i
;
ptrdiff_t
x_offset
=
batch
*
info
->
x_stride_b
+
i
*
info
->
x_stride_i
;
ind
/=
info
->
seq_len
;
T
*
y_
=
y
+
y_offset
;
offset
+=
(
ind
%
info
->
batch_size
)
*
info
->
stride_b
;
const
T
*
x_
=
x
+
x_offset
;
for
(
size_t
j
=
info
->
total_seq_len
-
info
->
seq_len
+
i
+
1
;
j
<
info
->
total_seq_len
;
j
++
)
{
for
(
size_t
j
=
info
->
total_seq_len
-
info
->
seq_len
+
i
+
1
;
j
<
info
->
total_seq_len
;
j
++
)
{
if
constexpr
(
std
::
is_same
<
T
,
fp16_t
>::
value
)
{
if
constexpr
(
std
::
is_same
<
T
,
fp16_t
>::
value
)
{
data
[
offset
+
j
*
info
->
stride_j
]
=
utils
::
cast
<
fp16_t
>
(
0.0
f
);
y_
[
j
*
info
->
y_
stride_j
]
=
utils
::
cast
<
fp16_t
>
(
0.0
f
);
}
else
{
}
else
{
data
[
offset
+
j
*
info
->
stride_j
]
=
0.0
f
;
y_
[
j
*
info
->
y_
stride_j
]
=
0.0
f
;
}
}
}
}
float
val
=
op
::
common_cpu
::
reduce_op
::
max
(
&
data
[
offset
]
,
info
->
total_seq_len
-
info
->
seq_len
+
i
+
1
,
info
->
stride_j
);
float
val
=
op
::
common_cpu
::
reduce_op
::
max
(
x_
,
info
->
total_seq_len
-
info
->
seq_len
+
i
+
1
,
info
->
x_
stride_j
);
for
(
size_t
j
=
0
;
j
<=
info
->
total_seq_len
-
info
->
seq_len
+
i
;
j
++
)
{
for
(
size_t
j
=
0
;
j
<=
info
->
total_seq_len
-
info
->
seq_len
+
i
;
j
++
)
{
if
constexpr
(
std
::
is_same
<
T
,
fp16_t
>::
value
)
{
if
constexpr
(
std
::
is_same
<
T
,
fp16_t
>::
value
)
{
data
[
offset
+
j
*
info
->
stride_j
]
=
utils
::
cast
<
fp16_t
>
(
std
::
exp
(
utils
::
cast
<
float
>
(
data
[
offset
+
j
*
info
->
stride_j
])
-
val
));
y_
[
j
*
info
->
y_
stride_j
]
=
utils
::
cast
<
fp16_t
>
(
std
::
exp
(
utils
::
cast
<
float
>
(
x_
[
j
*
info
->
x_
stride_j
])
-
val
));
}
else
{
}
else
{
data
[
offset
+
j
*
info
->
stride_j
]
=
std
::
exp
(
data
[
offset
+
j
*
info
->
stride_j
]
-
val
);
y_
[
j
*
info
->
y_
stride_j
]
=
std
::
exp
(
x_
[
j
*
info
->
x_
stride_j
]
-
val
);
}
}
}
}
float
sum
=
op
::
common_cpu
::
reduce_op
::
sum
(
&
data
[
offset
]
,
info
->
total_seq_len
-
info
->
seq_len
+
i
+
1
,
info
->
stride_j
);
float
sum
=
op
::
common_cpu
::
reduce_op
::
sum
(
y_
,
info
->
total_seq_len
-
info
->
seq_len
+
i
+
1
,
info
->
y_
stride_j
);
for
(
size_t
j
=
0
;
j
<=
info
->
total_seq_len
-
info
->
seq_len
+
i
;
j
++
)
{
for
(
size_t
j
=
0
;
j
<=
info
->
total_seq_len
-
info
->
seq_len
+
i
;
j
++
)
{
if
constexpr
(
std
::
is_same
<
T
,
fp16_t
>::
value
)
{
if
constexpr
(
std
::
is_same
<
T
,
fp16_t
>::
value
)
{
data
[
offset
+
j
*
info
->
stride_j
]
=
utils
::
cast
<
fp16_t
>
(
utils
::
cast
<
float
>
(
data
[
offset
+
j
*
info
->
stride_j
])
/
sum
);
y_
[
j
*
info
->
y_
stride_j
]
=
utils
::
cast
<
fp16_t
>
(
utils
::
cast
<
float
>
(
y_
[
j
*
info
->
y_
stride_j
])
/
sum
);
}
else
{
}
else
{
data
[
offset
+
j
*
info
->
stride_j
]
=
data
[
offset
+
j
*
info
->
stride_j
]
/
sum
;
y_
[
j
*
info
->
y_
stride_j
]
=
y_
[
j
*
info
->
y_
stride_j
]
/
sum
;
}
}
}
}
}
}
...
@@ -56,13 +58,14 @@ infiniStatus_t causal_softmax(const CausalSoftmaxInfo *info, T *data) {
...
@@ -56,13 +58,14 @@ infiniStatus_t causal_softmax(const CausalSoftmaxInfo *info, T *data) {
infiniStatus_t
Descriptor
::
calculate
(
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
workspace
,
size_t
workspace_size
,
void
*
data
,
void
*
y
,
const
void
*
x
,
void
*
stream
)
const
{
void
*
stream
)
const
{
if
(
_info
.
dtype
==
INFINI_DTYPE_F16
)
{
if
(
_info
.
dtype
==
INFINI_DTYPE_F16
)
{
CHECK_STATUS
(
causal_softmax
<
fp16_t
>
(
&
_info
,
(
fp16_t
*
)
data
));
CHECK_STATUS
(
causal_softmax
<
fp16_t
>
(
&
_info
,
(
fp16_t
*
)
y
,
(
const
fp16_t
*
)
x
));
}
else
if
(
_info
.
dtype
==
INFINI_DTYPE_F32
)
{
}
else
if
(
_info
.
dtype
==
INFINI_DTYPE_F32
)
{
CHECK_STATUS
(
causal_softmax
<
float
>
(
&
_info
,
(
float
*
)
data
));
CHECK_STATUS
(
causal_softmax
<
float
>
(
&
_info
,
(
float
*
)
y
,
(
const
float
*
)
x
));
}
else
{
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
...
...
src/infiniop/ops/causal_softmax/cuda/causal_softmax_cuda.cu
0 → 100644
View file @
c2e87202
#include "../../../devices/cuda/cuda_common.cuh"
#include "causal_softmax_cuda.cuh"
#include "causal_softmax_kernel.cuh"
namespace
op
::
causal_softmax
::
cuda
{
struct
Descriptor
::
Opaque
{
std
::
shared_ptr
<
device
::
cuda
::
Handle
::
Internal
>
internal
;
};
Descriptor
::~
Descriptor
()
{
delete
_opaque
;
}
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
x_desc
)
{
auto
info
=
CausalSoftmaxInfo
::
create
(
y_desc
,
x_desc
);
CHECK_RESULT
(
info
);
*
desc_ptr
=
new
Descriptor
(
new
Opaque
{
reinterpret_cast
<
device
::
cuda
::
Handle
*>
(
handle
)
->
internal
()},
info
.
take
(),
0
,
handle
->
device
,
handle
->
device_id
);
return
INFINI_STATUS_SUCCESS
;
}
template
<
unsigned
int
BLOCK_SIZE
>
infiniStatus_t
launchKernel
(
void
*
y
,
const
void
*
x
,
infiniDtype_t
dtype
,
size_t
batch_size
,
size_t
seq_len
,
size_t
total_seq_len
,
ptrdiff_t
y_stride_b
,
ptrdiff_t
y_stride_i
,
ptrdiff_t
x_stride_b
,
ptrdiff_t
x_stride_i
,
cudaStream_t
stream
)
{
dim3
grid
(
uint32_t
(
seq_len
),
uint32_t
(
batch_size
),
1
);
if
(
dtype
==
INFINI_DTYPE_F16
)
{
causalSoftmax
<
BLOCK_SIZE
,
half
,
float
>
<<<
grid
,
BLOCK_SIZE
,
0
,
stream
>>>
((
half
*
)
y
,
(
const
half
*
)
x
,
batch_size
,
seq_len
,
total_seq_len
,
y_stride_b
,
y_stride_i
,
x_stride_b
,
x_stride_i
);
}
else
if
(
dtype
==
INFINI_DTYPE_F32
)
{
causalSoftmax
<
BLOCK_SIZE
,
float
,
float
>
<<<
grid
,
BLOCK_SIZE
,
0
,
stream
>>>
((
float
*
)
y
,
(
const
float
*
)
x
,
batch_size
,
seq_len
,
total_seq_len
,
y_stride_b
,
y_stride_i
,
x_stride_b
,
x_stride_i
);
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
return
INFINI_STATUS_SUCCESS
;
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
y
,
const
void
*
x
,
void
*
stream_
)
const
{
cudaStream_t
stream
=
(
cudaStream_t
)
stream_
;
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_1024
)
{
CHECK_STATUS
(
launchKernel
<
CUDA_BLOCK_SIZE_1024
>
(
y
,
x
,
_info
.
dtype
,
_info
.
batch_size
,
_info
.
seq_len
,
_info
.
total_seq_len
,
_info
.
y_stride_b
,
_info
.
y_stride_i
,
_info
.
x_stride_b
,
_info
.
x_stride_i
,
stream
));
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_512
)
{
CHECK_STATUS
(
launchKernel
<
CUDA_BLOCK_SIZE_512
>
(
y
,
x
,
_info
.
dtype
,
_info
.
batch_size
,
_info
.
seq_len
,
_info
.
total_seq_len
,
_info
.
y_stride_b
,
_info
.
y_stride_i
,
_info
.
x_stride_b
,
_info
.
x_stride_i
,
stream
));
}
else
{
return
INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED
;
}
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::causal_softmax::cuda
src/infiniop/ops/causal_softmax/cuda/causal_softmax_cuda.cuh
0 → 100644
View file @
c2e87202
#ifndef __CAUSAL_SOFTMAX_CUDA_H__
#define __CAUSAL_SOFTMAX_CUDA_H__
#include "../causal_softmax.h"
DESCRIPTOR
(
cuda
)
#endif
src/infiniop/ops/causal_softmax/cuda/causal_softmax_kernel.cuh
0 → 100644
View file @
c2e87202
#ifndef __CAUSAL_SOFTMAX_KERNEL_CUH__
#define __CAUSAL_SOFTMAX_KERNEL_CUH__
#include "../../../devices/cuda/cuda_kernel_common.cuh"
#include "../../../reduce/cuda/reduce.cuh"
template
<
unsigned
int
BLOCK_SIZE
,
typename
Tdata
,
typename
Tcompute
>
INFINIOP_CUDA_KERNEL
causalSoftmax
(
Tdata
*
y_
,
const
Tdata
*
x_
,
size_t
batch
,
size_t
height
,
size_t
width
,
ptrdiff_t
y_stride_b
,
ptrdiff_t
y_stride_h
,
ptrdiff_t
x_stride_b
,
ptrdiff_t
x_stride_h
)
{
Tdata
*
y
=
y_
// threadIdx.x for col_id
+
blockIdx
.
y
*
y_stride_b
// gridDim.y for batch_id
+
blockIdx
.
x
*
y_stride_h
;
// gridDim.x for row_id
const
Tdata
*
x
=
x_
+
blockIdx
.
y
*
x_stride_b
+
blockIdx
.
x
*
x_stride_h
;
// [Reduce] Find max value in each row and store in shared memory
__shared__
Tdata
max_
;
Tdata
max_0
=
op
::
common_cuda
::
reduce_op
::
max
<
BLOCK_SIZE
,
Tdata
>
(
x
,
width
-
height
+
1
+
blockIdx
.
x
);
if
(
threadIdx
.
x
==
0
)
{
max_
=
max_0
;
}
__syncthreads
();
// [Elementwise] Subtract max value from each element and apply causal mask
for
(
size_t
col
=
threadIdx
.
x
;
col
<
width
;
col
+=
BLOCK_SIZE
)
{
// row_id ↓ |<- width ->|
// 0 | * * * ... * |
// 1 | * * * ... * * |
// 2 | * * * ... * * * |
// height: 3 col_id->
if
(
width
+
blockIdx
.
x
>=
threadIdx
.
x
+
height
)
{
#ifdef ENABLE_CUDA_API
y
[
col
]
=
exp_
(
x
[
col
]
-
max_
);
#else
y
[
col
]
=
exp
(
x
[
col
]
-
max_
);
#endif
}
else
{
y
[
col
]
=
Tdata
(
0
);
}
}
__syncthreads
();
// [Reduce] Find the sum of each updated row and store in shared memory
__shared__
Tcompute
sum_
;
Tcompute
sum_0
=
op
::
common_cuda
::
reduce_op
::
sum
<
BLOCK_SIZE
,
Tdata
,
Tcompute
>
(
y
,
width
);
if
(
threadIdx
.
x
==
0
)
{
sum_
=
sum_0
;
}
__syncthreads
();
// [Elementwise] Divide each element by the sum and store in shared memory
for
(
size_t
col
=
threadIdx
.
x
;
col
<
width
;
col
+=
BLOCK_SIZE
)
{
y
[
col
]
/=
Tdata
(
sum_
);
}
}
#endif // __CAUSAL_SOFTMAX_KERNEL_CUH__
src/infiniop/ops/causal_softmax/info.h
View file @
c2e87202
...
@@ -13,45 +13,63 @@ class CausalSoftmaxInfo {
...
@@ -13,45 +13,63 @@ class CausalSoftmaxInfo {
public:
public:
infiniDtype_t
dtype
;
infiniDtype_t
dtype
;
size_t
batch_size
;
size_t
batch_size
;
ptrdiff_t
stride_b
;
size_t
seq_len
;
size_t
seq_len
;
ptrdiff_t
stride_i
;
size_t
total_seq_len
;
size_t
total_seq_len
;
ptrdiff_t
stride_j
;
static
utils
::
Result
<
CausalSoftmaxInfo
>
create
(
infiniopTensorDescriptor_t
y_desc
)
{
ptrdiff_t
y_stride_b
;
ptrdiff_t
y_stride_i
;
ptrdiff_t
y_stride_j
;
ptrdiff_t
x_stride_b
;
ptrdiff_t
x_stride_i
;
ptrdiff_t
x_stride_j
;
static
utils
::
Result
<
CausalSoftmaxInfo
>
create
(
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
x_desc
)
{
auto
dtype
=
y_desc
->
dtype
();
auto
dtype
=
y_desc
->
dtype
();
if
(
y_desc
->
dtype
()
!=
INFINI_DTYPE_F16
&&
y_desc
->
dtype
()
!=
INFINI_DTYPE_F32
)
{
if
(
dtype
!=
x_desc
->
dtype
()
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
CHECK_DTYPE
(
dtype
,
INFINI_DTYPE_F16
,
INFINI_DTYPE_F32
);
if
(
y_desc
->
ndim
()
!=
2
&&
y_desc
->
ndim
()
!=
3
)
{
auto
shape
=
y_desc
->
shape
();
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
CHECK_SAME_SHAPE
(
shape
,
x_desc
->
shape
());
auto
ndim
=
y_desc
->
ndim
();
if
(
ndim
!=
2
&&
ndim
!=
3
)
{
CHECK_STATUS
(
INFINI_STATUS_BAD_TENSOR_SHAPE
);
}
}
if
(
y_desc
->
shape
()[
y_desc
->
ndim
()
-
1
]
<
y_desc
->
shape
()[
y_desc
->
ndim
()
-
2
])
{
if
(
shape
[
ndim
-
1
]
<
shape
[
ndim
-
2
])
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
CHECK_STATUS
(
INFINI_STATUS_BAD_TENSOR_SHAPE
)
;
}
}
size_t
batch_size
=
1
;
size_t
batch_size
=
1
;
ptrdiff_t
stride_b
=
0
;
size_t
seq_len
=
shape
[
ndim
-
2
];
size_t
seq_len
=
y_desc
->
shape
()[
y_desc
->
ndim
()
-
2
];
size_t
total_seq_len
=
shape
[
ndim
-
1
];
ptrdiff_t
stride_i
=
y_desc
->
strides
()[
y_desc
->
ndim
()
-
2
];
ptrdiff_t
y_stride_b
=
0
,
size_t
total_seq_len
=
y_desc
->
shape
()[
y_desc
->
ndim
()
-
1
];
y_stride_i
=
y_desc
->
stride
(
ndim
-
2
),
ptrdiff_t
stride_j
=
y_desc
->
strides
()[
y_desc
->
ndim
()
-
1
];
y_stride_j
=
y_desc
->
stride
(
ndim
-
1
);
if
(
y_desc
->
ndim
()
==
3
)
{
ptrdiff_t
x_stride_b
=
0
,
stride_b
=
y_desc
->
strides
()[
0
];
x_stride_i
=
x_desc
->
stride
(
ndim
-
2
),
batch_size
=
y_desc
->
shape
()[
0
];
x_stride_j
=
x_desc
->
stride
(
ndim
-
1
);
if
(
ndim
==
3
)
{
y_stride_b
=
y_desc
->
stride
(
0
);
x_stride_b
=
x_desc
->
stride
(
0
);
batch_size
=
shape
[
0
];
}
}
return
utils
::
Result
<
CausalSoftmaxInfo
>
(
CausalSoftmaxInfo
{
return
utils
::
Result
<
CausalSoftmaxInfo
>
(
CausalSoftmaxInfo
{
dtype
,
dtype
,
batch_size
,
batch_size
,
stride_b
,
seq_len
,
seq_len
,
stride_i
,
total_seq_len
,
total_seq_len
,
stride_j
});
y_stride_b
,
y_stride_i
,
y_stride_j
,
x_stride_b
,
x_stride_i
,
x_stride_j
});
}
}
};
};
...
...
src/infiniop/ops/causal_softmax/operator.cc
View file @
c2e87202
...
@@ -5,28 +5,33 @@
...
@@ -5,28 +5,33 @@
#ifdef ENABLE_CPU_API
#ifdef ENABLE_CPU_API
#include "cpu/causal_softmax_cpu.h"
#include "cpu/causal_softmax_cpu.h"
#endif
#endif
#ifdef ENABLE_CUDA_API
#include "cuda/causal_softmax_cuda.cuh"
#endif
#ifdef ENABLE_ASCEND_API
#include "ascend/causal_softmax_ascend.h"
#endif
__C
infiniStatus_t
infiniopCreateCausalSoftmaxDescriptor
(
__C
infiniStatus_t
infiniopCreateCausalSoftmaxDescriptor
(
infiniopHandle_t
handle
,
infiniopHandle_t
handle
,
infiniopCausalSoftmaxDescriptor_t
*
desc_ptr
,
infiniopCausalSoftmaxDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
y_desc
)
{
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
x_desc
)
{
#define CREATE(CASE, NAMESPACE) \
#define CREATE(CASE, NAMESPACE) \
case CASE: \
case CASE: \
return op::causal_softmax::NAMESPACE::Descriptor::create( \
return op::causal_softmax::NAMESPACE::Descriptor::create( \
handle, \
handle, \
reinterpret_cast<op::causal_softmax::NAMESPACE::Descriptor **>(desc_ptr), \
reinterpret_cast<op::causal_softmax::NAMESPACE::Descriptor **>(desc_ptr), \
y_desc);
y_desc, \
x_desc);
switch
(
handle
->
device
)
{
switch
(
handle
->
device
)
{
#ifdef ENABLE_CPU_API
#ifdef ENABLE_CPU_API
CREATE
(
INFINI_DEVICE_CPU
,
cpu
)
CREATE
(
INFINI_DEVICE_CPU
,
cpu
)
#endif
#endif
#ifdef ENABLE_NV_GPU
#ifdef ENABLE_CUDA_API
case
DevNvGpu
:
{
CREATE
(
INFINI_DEVICE_NVIDIA
,
cuda
)
return
cudaCreateCausalSoftmaxDescriptor
((
CudaHandle_t
)
handle
,
(
CausalSoftmaxCudaDescriptor_t
*
)
desc_ptr
,
y_desc
);
}
#endif
#endif
#ifdef ENABLE_CAMBRICON_MLU
#ifdef ENABLE_CAMBRICON_MLU
case
DevCambriconMlu
:
{
case
DevCambriconMlu
:
{
...
@@ -34,10 +39,8 @@ __C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor(
...
@@ -34,10 +39,8 @@ __C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor(
// return cnnlCreateCausalSoftmaxDescriptor((BangHandle_t) handle, (CausalSoftmaxCnnlDescriptor_t *) desc_ptr, y_desc);
// return cnnlCreateCausalSoftmaxDescriptor((BangHandle_t) handle, (CausalSoftmaxCnnlDescriptor_t *) desc_ptr, y_desc);
}
}
#endif
#endif
#ifdef ENABLE_ASCEND_NPU
#ifdef ENABLE_ASCEND_API
case
DevAscendNpu
:
{
CREATE
(
INFINI_DEVICE_ASCEND
,
ascend
)
return
aclnnCreateCausalSoftmaxDescriptor
((
AscendHandle_t
)
handle
,
(
CausalSoftmaxAclnnDescriptor_t
*
)
desc_ptr
,
y_desc
);
}
#endif
#endif
#ifdef ENABLE_METAX_GPU
#ifdef ENABLE_METAX_GPU
case
DevMetaxGpu
:
{
case
DevMetaxGpu
:
{
...
@@ -64,11 +67,8 @@ __C infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDe
...
@@ -64,11 +67,8 @@ __C infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDe
#ifdef ENABLE_CPU_API
#ifdef ENABLE_CPU_API
GET
(
INFINI_DEVICE_CPU
,
cpu
)
GET
(
INFINI_DEVICE_CPU
,
cpu
)
#endif
#endif
#ifdef ENABLE_NV_GPU
#ifdef ENABLE_CUDA_API
case
DevNvGpu
:
{
GET
(
INFINI_DEVICE_NVIDIA
,
cuda
)
return
cudaGetCausalSoftmaxWorkspaceSize
((
CausalSoftmaxCudaDescriptor_t
)
desc
,
size
);
}
#endif
#endif
#ifdef ENABLE_CAMBRICON_MLU
#ifdef ENABLE_CAMBRICON_MLU
case
DevCambriconMlu
:
{
case
DevCambriconMlu
:
{
...
@@ -77,10 +77,8 @@ __C infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDe
...
@@ -77,10 +77,8 @@ __C infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDe
}
}
#endif
#endif
#ifdef ENABLE_ASCEND_NPU
#ifdef ENABLE_ASCEND_API
case
DevAscendNpu
:
{
GET
(
INFINI_DEVICE_ASCEND
,
ascend
)
return
aclnnGetCausalSoftmaxWorkspaceSize
((
CausalSoftmaxAclnnDescriptor_t
)
desc
,
size
);
}
#endif
#endif
#ifdef ENABLE_METAX_GPU
#ifdef ENABLE_METAX_GPU
case
DevMetaxGpu
:
{
case
DevMetaxGpu
:
{
...
@@ -96,22 +94,24 @@ __C infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDe
...
@@ -96,22 +94,24 @@ __C infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDe
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
__C
infiniStatus_t
infiniopCausalSoftmax
(
infiniopCausalSoftmaxDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
data
,
void
*
stream
)
{
__C
infiniStatus_t
infiniopCausalSoftmax
(
infiniopCausalSoftmaxDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
y
,
const
void
*
x
,
void
*
stream
)
{
#define CALCULATE(CASE, NAMESPACE) \
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
case CASE: \
return reinterpret_cast<op::causal_softmax::NAMESPACE::Descriptor *>(desc)->calculate( \
return reinterpret_cast<op::causal_softmax::NAMESPACE::Descriptor *>(desc)->calculate( \
workspace, workspace_size,
data
, stream);
workspace, workspace_size,
y, x
, stream);
switch
(
desc
->
device_type
)
{
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU_API
#ifdef ENABLE_CPU_API
CALCULATE
(
INFINI_DEVICE_CPU
,
cpu
)
CALCULATE
(
INFINI_DEVICE_CPU
,
cpu
)
#endif
#endif
#ifdef ENABLE_NV_GPU
#ifdef ENABLE_CUDA_API
case
DevNvGpu
:
{
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
cuda
)
return
cudaCausalSoftmax
((
CausalSoftmaxCudaDescriptor_t
)
desc
,
workspace
,
workspace_size
,
data
,
stream
);
}
#endif
#endif
#ifdef ENABLE_CAMBRICON_MLU
#ifdef ENABLE_CAMBRICON_MLU
case
DevCambriconMlu
:
{
case
DevCambriconMlu
:
{
...
@@ -119,10 +119,8 @@ __C infiniStatus_t infiniopCausalSoftmax(infiniopCausalSoftmaxDescriptor_t desc,
...
@@ -119,10 +119,8 @@ __C infiniStatus_t infiniopCausalSoftmax(infiniopCausalSoftmaxDescriptor_t desc,
// return cnnlCausalSoftmax((CausalSoftmaxCnnlDescriptor_t) desc, workspace, workspace_size, data, stream);
// return cnnlCausalSoftmax((CausalSoftmaxCnnlDescriptor_t) desc, workspace, workspace_size, data, stream);
}
}
#endif
#endif
#ifdef ENABLE_ASCEND_NPU
#ifdef ENABLE_ASCEND_API
case
DevAscendNpu
:
{
CALCULATE
(
INFINI_DEVICE_ASCEND
,
ascend
)
return
aclnnCausalSoftmax
((
CausalSoftmaxAclnnDescriptor_t
)
desc
,
workspace
,
workspace_size
,
data
,
stream
);
}
#endif
#endif
#ifdef ENABLE_METAX_GPU
#ifdef ENABLE_METAX_GPU
case
DevMetaxGpu
:
{
case
DevMetaxGpu
:
{
...
@@ -149,11 +147,8 @@ __C infiniStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftmaxD
...
@@ -149,11 +147,8 @@ __C infiniStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftmaxD
#ifdef ENABLE_CPU_API
#ifdef ENABLE_CPU_API
DESTROY
(
INFINI_DEVICE_CPU
,
cpu
)
DESTROY
(
INFINI_DEVICE_CPU
,
cpu
)
#endif
#endif
#ifdef ENABLE_NV_GPU
#ifdef ENABLE_CUDA_API
case
DevNvGpu
:
{
DESTROY
(
INFINI_DEVICE_NVIDIA
,
cuda
)
return
cudaDestroyCausalSoftmaxDescriptor
((
CausalSoftmaxCudaDescriptor_t
)
desc
);
}
#endif
#endif
#ifdef ENABLE_CAMBRICON_MLU
#ifdef ENABLE_CAMBRICON_MLU
case
DevCambriconMlu
:
{
case
DevCambriconMlu
:
{
...
@@ -161,10 +156,8 @@ __C infiniStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftmaxD
...
@@ -161,10 +156,8 @@ __C infiniStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftmaxD
// return cnnlDestroyCausalSoftmaxDescriptor((CausalSoftmaxCnnlDescriptor_t) desc);
// return cnnlDestroyCausalSoftmaxDescriptor((CausalSoftmaxCnnlDescriptor_t) desc);
}
}
#endif
#endif
#ifdef ENABLE_ASCEND_NPU
#ifdef ENABLE_ASCEND_API
case
DevAscendNpu
:
{
DESTROY
(
INFINI_DEVICE_ASCEND
,
ascend
)
return
aclnnDestroyCausalSoftmaxDescriptor
((
CausalSoftmaxAclnnDescriptor_t
)
desc
);
}
#endif
#endif
#ifdef ENABLE_METAX_GPU
#ifdef ENABLE_METAX_GPU
case
DevMetaxGpu
:
{
case
DevMetaxGpu
:
{
...
...
src/infiniop/ops/clip/cpu/clip_cpu.cc
0 → 100644
View file @
c2e87202
#include "clip_cpu.h"
namespace
op
::
clip
::
cpu
{
Descriptor
::~
Descriptor
()
=
default
;
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle_
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
out_desc
,
std
::
vector
<
infiniopTensorDescriptor_t
>
input_desc_vec
)
{
auto
handle
=
reinterpret_cast
<
device
::
cpu
::
Handle
*>
(
handle_
);
auto
dtype
=
out_desc
->
dtype
();
const
auto
&
in_desc
=
input_desc_vec
.
at
(
0
);
const
auto
&
min_desc
=
input_desc_vec
.
at
(
1
);
const
auto
&
max_desc
=
input_desc_vec
.
at
(
2
);
const
auto
&
out_shape
=
out_desc
->
shape
();
const
auto
&
in_shape
=
in_desc
->
shape
();
const
auto
&
min_shape
=
min_desc
->
shape
();
const
auto
&
max_shape
=
max_desc
->
shape
();
CHECK_DTYPE
(
dtype
,
INFINI_DTYPE_F16
,
INFINI_DTYPE_F32
,
INFINI_DTYPE_F64
);
CHECK_SAME_SHAPE
(
out_shape
,
in_shape
);
CHECK_SAME_SHAPE
(
out_shape
,
min_shape
);
CHECK_SAME_SHAPE
(
out_shape
,
max_shape
);
CREATE_ELEMENTWISE_CPU_DESCRIPTOR
(
handle
,
dtype
,
out_desc
,
input_desc_vec
);
return
INFINI_STATUS_SUCCESS
;
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
output
,
std
::
vector
<
const
void
*>
inputs
,
void
*
stream
)
const
{
switch
(
_dtype
)
{
case
INFINI_DTYPE_F16
:
return
_device_info
->
calculate
<
ClipOp
,
fp16_t
>
(
_info
,
output
,
inputs
,
stream
);
case
INFINI_DTYPE_F32
:
return
_device_info
->
calculate
<
ClipOp
,
float
>
(
_info
,
output
,
inputs
,
stream
);
case
INFINI_DTYPE_F64
:
return
_device_info
->
calculate
<
ClipOp
,
double
>
(
_info
,
output
,
inputs
,
stream
);
default:
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::clip::cpu
src/infiniop/ops/clip/cpu/clip_cpu.h
0 → 100644
View file @
c2e87202
#ifndef __CLIP_CPU_H__
#define __CLIP_CPU_H__
#include "../../../elementwise/cpu/elementwise_cpu.h"
#include "infiniop/ops/clip.h"
ELEMENTWISE_DESCRIPTOR
(
clip
,
cpu
)
namespace
op
::
clip
::
cpu
{
typedef
struct
ClipOp
{
public:
static
constexpr
size_t
num_inputs
=
3
;
template
<
typename
T
>
T
operator
()(
const
T
&
x
,
const
T
&
min_val
,
const
T
&
max_val
)
const
{
return
std
::
max
(
std
::
min
(
x
,
max_val
),
min_val
);
}
}
ClipOp
;
}
// namespace op::clip::cpu
#endif // __CLIP_CPU_H__
src/infiniop/ops/clip/cuda/clip_cuda.cu
0 → 100644
View file @
c2e87202
#include "clip_cuda.cuh"
#include "clip_cuda_internal.cuh"
namespace
op
::
clip
::
cuda
{
Descriptor
::~
Descriptor
()
=
default
;
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle_
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
out_desc
,
std
::
vector
<
infiniopTensorDescriptor_t
>
input_desc_vec
)
{
auto
handle
=
reinterpret_cast
<
device
::
cuda
::
Handle
*>
(
handle_
);
auto
dtype
=
out_desc
->
dtype
();
const
auto
&
in_desc
=
input_desc_vec
.
at
(
0
);
const
auto
&
min_desc
=
input_desc_vec
.
at
(
1
);
const
auto
&
max_desc
=
input_desc_vec
.
at
(
2
);
const
auto
&
out_shape
=
out_desc
->
shape
();
const
auto
&
in_shape
=
in_desc
->
shape
();
const
auto
&
min_shape
=
min_desc
->
shape
();
const
auto
&
max_shape
=
max_desc
->
shape
();
CHECK_DTYPE
(
dtype
,
INFINI_DTYPE_F16
,
INFINI_DTYPE_F32
,
INFINI_DTYPE_F64
);
CHECK_SAME_SHAPE
(
out_shape
,
in_shape
);
CHECK_SAME_SHAPE
(
out_shape
,
min_shape
);
CHECK_SAME_SHAPE
(
out_shape
,
max_shape
);
CREATE_ELEMENTWISE_CUDA_DESCRIPTOR
(
handle
,
dtype
,
out_desc
,
input_desc_vec
);
return
INFINI_STATUS_SUCCESS
;
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
output
,
std
::
vector
<
const
void
*>
inputs
,
void
*
stream
)
const
{
if
(
workspace_size
<
_workspace_size
)
{
return
INFINI_STATUS_INSUFFICIENT_WORKSPACE
;
}
switch
(
_dtype
)
{
case
INFINI_DTYPE_F16
:
return
_device_info
->
calculate
<
256
,
ClipOp
,
half
>
(
_info
,
workspace
,
output
,
inputs
,
stream
);
case
INFINI_DTYPE_F32
:
return
_device_info
->
calculate
<
256
,
ClipOp
,
float
>
(
_info
,
workspace
,
output
,
inputs
,
stream
);
case
INFINI_DTYPE_F64
:
return
_device_info
->
calculate
<
256
,
ClipOp
,
double
>
(
_info
,
workspace
,
output
,
inputs
,
stream
);
default:
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::clip::cuda
src/infiniop/ops/clip/cuda/clip_cuda.cuh
0 → 100644
View file @
c2e87202
#ifndef __CLIP_CUDA_API_H__
#define __CLIP_CUDA_API_H__
#include "../../../elementwise/cuda/elementwise_cuda_api.cuh"
#include "infiniop/ops/clip.h"
ELEMENTWISE_DESCRIPTOR
(
clip
,
cuda
)
#endif // __CLIP_CUDA_API_H__
src/infiniop/ops/clip/cuda/clip_cuda_internal.cuh
0 → 100644
View file @
c2e87202
#ifndef __CLIP_CUDA_H__
#define __CLIP_CUDA_H__
#include "../../../elementwise/cuda/elementwise_cuda.cuh"
#include <cuda_fp16.h>
namespace
op
::
clip
::
cuda
{
typedef
struct
ClipOp
{
public:
static
constexpr
size_t
num_inputs
=
3
;
template
<
typename
T
>
__device__
__forceinline__
T
operator
()(
const
T
&
x
,
const
T
&
min_val
,
const
T
&
max_val
)
const
{
if
constexpr
(
std
::
is_same_v
<
T
,
half2
>
)
{
return
__hmax2
(
__hmin2
(
x
,
max_val
),
min_val
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
half
>
)
{
return
__hmax
(
__hmin
(
x
,
max_val
),
min_val
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
float
>
)
{
return
fmaxf
(
fminf
(
x
,
max_val
),
min_val
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
double
>
)
{
return
fmax
(
fmin
(
x
,
max_val
),
min_val
);
}
else
{
return
std
::
max
(
std
::
min
(
x
,
max_val
),
min_val
);
}
}
}
ClipOp
;
}
// namespace op::clip::cuda
#endif // __CLIP_CUDA_H__
src/infiniop/ops/clip/operator.cc
0 → 100644
View file @
c2e87202
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/clip.h"
#ifdef ENABLE_CPU_API
#include "cpu/clip_cpu.h"
#endif
#ifdef ENABLE_CUDA_API
#include "cuda/clip_cuda.cuh"
#endif
__C
infiniStatus_t
infiniopCreateClipDescriptor
(
infiniopHandle_t
handle
,
infiniopClipDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
y
,
infiniopTensorDescriptor_t
x
,
infiniopTensorDescriptor_t
min_val
,
infiniopTensorDescriptor_t
max_val
)
{
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::clip::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::clip::NAMESPACE::Descriptor **>(desc_ptr), \
y, \
{x, min_val, max_val})
switch
(
handle
->
device
)
{
#ifdef ENABLE_CPU_API
CREATE
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_CUDA_API
CREATE
(
INFINI_DEVICE_NVIDIA
,
cuda
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CREATE
}
__C
infiniStatus_t
infiniopGetClipWorkspaceSize
(
infiniopClipDescriptor_t
desc
,
size_t
*
size
)
{
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::clip::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS;
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU_API
GET
(
INFINI_DEVICE_CPU
,
cpu
)
#endif
#ifdef ENABLE_CUDA_API
GET
(
INFINI_DEVICE_NVIDIA
,
cuda
)
#endif
}
#undef GET
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopClip
(
infiniopClipDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
y
,
const
void
*
x
,
const
void
*
min_val
,
const
void
*
max_val
,
void
*
stream
)
{
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::clip::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, y, {x, min_val, max_val}, stream)
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU_API
CALCULATE
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_CUDA_API
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
cuda
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CALCULATE
}
__C
infiniStatus_t
infiniopDestroyClipDescriptor
(
infiniopClipDescriptor_t
desc
)
{
#define DELETE(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<const op::clip::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU_API
DELETE
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_CUDA_API
DELETE
(
INFINI_DEVICE_NVIDIA
,
cuda
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef DELETE
}
src/infiniop/ops/gemm/ascend/gemm_ascend.cc
View file @
c2e87202
This diff is collapsed.
Click to expand it.
src/infiniop/ops/gemm/cpu/gemm_cpu.cc
View file @
c2e87202
...
@@ -40,31 +40,34 @@ void calculate(
...
@@ -40,31 +40,34 @@ void calculate(
std
::
swap
(
a
,
b
);
std
::
swap
(
a
,
b
);
}
}
for
(
size_t
i
=
0
;
i
<
info
.
batch
;
++
i
)
{
#pragma omp parallel for
for
(
size_t
m_
=
0
;
m_
<
info
.
m
;
++
m_
)
{
for
(
ptrdiff_t
index
=
0
;
index
<
ptrdiff_t
(
info
.
batch
*
info
.
m
*
info
.
n
);
++
index
)
{
for
(
size_t
n_
=
0
;
n_
<
info
.
n
;
++
n_
)
{
size_t
ind
=
index
;
auto
c_
=
reinterpret_cast
<
Tdata
*>
(
c
)
+
i
*
info
.
c_matrix
.
stride
+
m_
*
info
.
c_matrix
.
row_stride
+
n_
*
info
.
c_matrix
.
col_stride
;
size_t
n_
=
ind
%
info
.
n
;
float
sum
=
0
;
ind
/=
info
.
n
;
for
(
size_t
k_
=
0
;
k_
<
info
.
k
;
++
k_
)
{
size_t
m_
=
ind
%
info
.
m
;
auto
a_
=
reinterpret_cast
<
const
Tdata
*>
(
a
)
+
i
*
info
.
a_matrix
.
stride
+
m_
*
info
.
a_matrix
.
row_stride
+
k_
*
info
.
a_matrix
.
col_stride
;
ind
/=
info
.
m
;
auto
b_
=
reinterpret_cast
<
const
Tdata
*>
(
b
)
+
i
*
info
.
b_matrix
.
stride
+
n_
*
info
.
b_matrix
.
col_stride
+
k_
*
info
.
b_matrix
.
row_stride
;
size_t
i
=
ind
;
if
constexpr
(
std
::
is_same
<
Tdata
,
fp16_t
>::
value
)
{
auto
c_
=
reinterpret_cast
<
Tdata
*>
(
c
)
+
i
*
info
.
c_matrix
.
stride
+
m_
*
info
.
c_matrix
.
row_stride
+
n_
*
info
.
c_matrix
.
col_stride
;
sum
+=
utils
::
cast
<
float
>
(
*
a_
)
*
utils
::
cast
<
float
>
(
*
b_
);
float
sum
=
0
;
}
else
{
for
(
int
k_
=
0
;
k_
<
static_cast
<
int
>
(
info
.
k
);
++
k_
)
{
sum
+=
*
a_
*
(
*
b_
);
auto
a_
=
reinterpret_cast
<
const
Tdata
*>
(
a
)
+
i
*
info
.
a_matrix
.
stride
+
m_
*
info
.
a_matrix
.
row_stride
+
k_
*
info
.
a_matrix
.
col_stride
;
}
auto
b_
=
reinterpret_cast
<
const
Tdata
*>
(
b
)
+
i
*
info
.
b_matrix
.
stride
+
n_
*
info
.
b_matrix
.
col_stride
+
k_
*
info
.
b_matrix
.
row_stride
;
}
if
constexpr
(
std
::
is_same
<
Tdata
,
fp16_t
>::
value
)
{
if
constexpr
(
std
::
is_same
<
Tdata
,
fp16_t
>::
value
)
{
sum
+=
utils
::
cast
<
float
>
(
*
a_
)
*
utils
::
cast
<
float
>
(
*
b_
);
if
(
beta
==
0
)
{
}
else
{
*
c_
=
utils
::
cast
<
fp16_t
>
(
alpha
*
sum
);
sum
+=
*
a_
*
(
*
b_
);
}
else
{
*
c_
=
utils
::
cast
<
fp16_t
>
(
beta
*
utils
::
cast
<
float
>
(
*
c_
)
+
alpha
*
sum
);
}
}
else
{
*
c_
=
beta
*
(
*
c_
)
+
alpha
*
sum
;
}
}
}
}
}
if
constexpr
(
std
::
is_same
<
Tdata
,
fp16_t
>::
value
)
{
if
(
beta
==
0
)
{
*
c_
=
utils
::
cast
<
fp16_t
>
(
alpha
*
sum
);
}
else
{
*
c_
=
utils
::
cast
<
fp16_t
>
(
beta
*
utils
::
cast
<
float
>
(
*
c_
)
+
alpha
*
sum
);
}
}
else
{
*
c_
=
beta
*
(
*
c_
)
+
alpha
*
sum
;
}
}
}
}
}
...
...
src/infiniop/ops/gemm/kunlun/gemm_kunlun.cc
View file @
c2e87202
...
@@ -62,7 +62,7 @@ infiniStatus_t calculate(
...
@@ -62,7 +62,7 @@ infiniStatus_t calculate(
(
kunlunStream_t
)
stream
,
(
kunlunStream_t
)
stream
,
[
&
](
xdnnHandle_t
handle
)
{
[
&
](
xdnnHandle_t
handle
)
{
for
(
size_t
i
=
0
;
i
<
info
.
batch
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
info
.
batch
;
i
++
)
{
CHECK_
XDN
N
((
xdnn
::
fc_fusion
<
Tdata
,
Tdata
,
Tdata
,
int16_t
>
(
CHECK_
KUNLU
N
((
xdnn
::
fc_fusion
<
Tdata
,
Tdata
,
Tdata
,
int16_t
>
(
handle
,
handle
,
(
Tdata
*
)((
char
*
)
a
+
i
*
info
.
a_matrix
.
stride
*
unit
),
(
Tdata
*
)((
char
*
)
a
+
i
*
info
.
a_matrix
.
stride
*
unit
),
(
Tdata
*
)((
char
*
)
b
+
i
*
info
.
b_matrix
.
stride
*
unit
),
(
Tdata
*
)((
char
*
)
b
+
i
*
info
.
b_matrix
.
stride
*
unit
),
...
...
src/infiniop/ops/gemm/musa/gemm_musa.mu
View file @
c2e87202
...
@@ -23,14 +23,11 @@ infiniStatus_t Descriptor::create(
...
@@ -23,14 +23,11 @@ infiniStatus_t Descriptor::create(
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32);
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32);
infiniStatus_t status;
auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR);
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::COL_MAJOR);
CHECK_RESULT(result);
if (status != INFINI_STATUS_SUCCESS) {
return status;
}
*desc_ptr = new Descriptor(
*desc_ptr = new Descriptor(
dtype,
info
, 0,
dtype,
result.take()
, 0,
new Opaque{handle->internal()},
new Opaque{handle->internal()},
handle->device, handle->device_id);
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
return INFINI_STATUS_SUCCESS;
...
...
src/infiniop/ops/mul/cpu/mul_cpu.cc
0 → 100644
View file @
c2e87202
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment