Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
e184c7e4
Commit
e184c7e4
authored
Jul 25, 2025
by
tianyuxbear
Committed by
zhuyue
Oct 28, 2025
Browse files
issue/456/feat: add silu operator
parent
d4f72bdd
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
725 additions
and
1 deletion
+725
-1
include/infiniop.h
include/infiniop.h
+1
-0
include/infiniop/ops/silu.h
include/infiniop/ops/silu.h
+24
-0
src/infiniop-test/include/ops.hpp
src/infiniop-test/include/ops.hpp
+2
-0
src/infiniop-test/src/ops/silu.cpp
src/infiniop-test/src/ops/silu.cpp
+101
-0
src/infiniop/ops/silu/cpu/silu_cpu.cc
src/infiniop/ops/silu/cpu/silu_cpu.cc
+52
-0
src/infiniop/ops/silu/cpu/silu_cpu.h
src/infiniop/ops/silu/cpu/silu_cpu.h
+23
-0
src/infiniop/ops/silu/cuda/kernel.cuh
src/infiniop/ops/silu/cuda/kernel.cuh
+37
-0
src/infiniop/ops/silu/metax/silu_metax.h
src/infiniop/ops/silu/metax/silu_metax.h
+8
-0
src/infiniop/ops/silu/metax/silu_metax.maca
src/infiniop/ops/silu/metax/silu_metax.maca
+60
-0
src/infiniop/ops/silu/nvidia/silu_nvidia.cu
src/infiniop/ops/silu/nvidia/silu_nvidia.cu
+59
-0
src/infiniop/ops/silu/nvidia/silu_nvidia.cuh
src/infiniop/ops/silu/nvidia/silu_nvidia.cuh
+8
-0
src/infiniop/ops/silu/operator.cc
src/infiniop/ops/silu/operator.cc
+142
-0
test/infiniop/libinfiniop/op_register.py
test/infiniop/libinfiniop/op_register.py
+32
-0
test/infiniop/libinfiniop/utils.py
test/infiniop/libinfiniop/utils.py
+4
-1
test/infiniop/silu.py
test/infiniop/silu.py
+172
-0
No files found.
include/infiniop.h
View file @
e184c7e4
...
...
@@ -16,6 +16,7 @@
#include "infiniop/ops/relu.h"
#include "infiniop/ops/rms_norm.h"
#include "infiniop/ops/rope.h"
#include "infiniop/ops/silu.h"
#include "infiniop/ops/softplus.h"
#include "infiniop/ops/sub.h"
#include "infiniop/ops/swiglu.h"
...
...
include/infiniop/ops/silu.h
0 → 100644
View file @
e184c7e4
#ifndef __INFINIOP_SILU_API_H__
#define __INFINIOP_SILU_API_H__
#include "../operator_descriptor.h"
typedef
struct
InfiniopDescriptor
*
infiniopSiluDescriptor_t
;
__C
__export
infiniStatus_t
infiniopCreateSiluDescriptor
(
infiniopHandle_t
handle
,
infiniopSiluDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
output
,
infiniopTensorDescriptor_t
intput
);
__C
__export
infiniStatus_t
infiniopGetSiluWorkspaceSize
(
infiniopSiluDescriptor_t
desc
,
size_t
*
size
);
__C
__export
infiniStatus_t
infiniopSilu
(
infiniopSiluDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
output
,
const
void
*
intput
,
void
*
stream
);
__C
__export
infiniStatus_t
infiniopDestroySiluDescriptor
(
infiniopSiluDescriptor_t
desc
);
#endif
src/infiniop-test/include/ops.hpp
View file @
e184c7e4
...
...
@@ -15,6 +15,7 @@ DECLARE_INFINIOP_TEST(swiglu)
DECLARE_INFINIOP_TEST
(
add
)
DECLARE_INFINIOP_TEST
(
causal_softmax
)
DECLARE_INFINIOP_TEST
(
rearrange
)
DECLARE_INFINIOP_TEST
(
silu
)
DECLARE_INFINIOP_TEST
(
sub
)
DECLARE_INFINIOP_TEST
(
zeros
)
DECLARE_INFINIOP_TEST
(
ones
)
...
...
@@ -53,6 +54,7 @@ DECLARE_INFINIOP_TEST(topksoftmax)
REGISTER_INFINIOP_TEST(sigmoid) \
REGISTER_INFINIOP_TEST(topkrouter) \
REGISTER_INFINIOP_TEST(topksoftmax) \
REGISTER_INFINIOP_TEST(silu) \
}
namespace
infiniop_test
{
...
...
src/infiniop-test/src/ops/silu.cpp
0 → 100644
View file @
e184c7e4
#include "ops.hpp"
#include "utils.hpp"
#include <infinirt.h>
#include <iomanip>
#include <iostream>
namespace
infiniop_test
::
silu
{
struct
Test
::
Attributes
{
std
::
shared_ptr
<
Tensor
>
input
;
std
::
shared_ptr
<
Tensor
>
output
;
std
::
shared_ptr
<
Tensor
>
ans
;
};
std
::
shared_ptr
<
Test
>
Test
::
build
(
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
uint8_t
>>
attributes
,
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Tensor
>>
tensors
,
double
rtol
,
double
atol
)
{
auto
test
=
std
::
shared_ptr
<
Test
>
(
new
Test
(
rtol
,
atol
));
test
->
_attributes
=
new
Attributes
();
if
(
tensors
.
find
(
"input"
)
==
tensors
.
end
()
||
tensors
.
find
(
"output"
)
==
tensors
.
end
()
||
tensors
.
find
(
"ans"
)
==
tensors
.
end
())
{
throw
std
::
runtime_error
(
"Invalid Test"
);
}
test
->
_attributes
->
input
=
tensors
[
"input"
];
test
->
_attributes
->
output
=
tensors
[
"output"
];
test
->
_attributes
->
ans
=
tensors
[
"ans"
];
return
test
;
}
std
::
shared_ptr
<
infiniop_test
::
Result
>
Test
::
run
(
infiniopHandle_t
handle
,
infiniDevice_t
device
,
int
device_id
,
size_t
warm_ups
,
size_t
iterations
)
{
infiniopSiluDescriptor_t
op_desc
;
auto
input
=
_attributes
->
input
->
to
(
device
,
device_id
);
auto
output
=
_attributes
->
output
->
to
(
device
,
device_id
);
CHECK_OR
(
infiniopCreateSiluDescriptor
(
handle
,
&
op_desc
,
output
->
desc
(),
input
->
desc
()),
return
TEST_FAILED
(
OP_CREATION_FAILED
,
"Failed to create op descriptor."
));
size_t
workspace_size
;
CHECK_OR
(
infiniopGetSiluWorkspaceSize
(
op_desc
,
&
workspace_size
),
return
TEST_FAILED
(
OP_CREATION_FAILED
,
"Failed to get workspace size."
));
void
*
workspace
;
CHECK_OR
(
infinirtMalloc
(
&
workspace
,
workspace_size
),
return
TEST_FAILED
(
OP_CREATION_FAILED
,
"Failed to allocate workspace."
));
CHECK_OR
(
infiniopSilu
(
op_desc
,
workspace
,
workspace_size
,
output
->
data
(),
input
->
data
(),
nullptr
),
return
TEST_FAILED
(
OP_EXECUTION_FAILED
,
"Failed during execution."
));
try
{
allClose
(
output
,
_attributes
->
ans
,
_rtol
,
_atol
);
}
catch
(
const
std
::
exception
&
e
)
{
return
TEST_FAILED
(
RESULT_INCORRECT
,
e
.
what
());
}
double
elapsed_time
=
0.
;
elapsed_time
=
benchmark
(
[
=
]()
{
infiniopSilu
(
op_desc
,
workspace
,
workspace_size
,
output
->
data
(),
input
->
data
(),
nullptr
);
},
warm_ups
,
iterations
);
return
TEST_PASSED
(
elapsed_time
);
}
std
::
vector
<
std
::
string
>
Test
::
attribute_names
()
{
return
{};
}
std
::
vector
<
std
::
string
>
Test
::
tensor_names
()
{
return
{
"input"
,
"output"
,
"ans"
};
}
std
::
vector
<
std
::
string
>
Test
::
output_names
()
{
return
{
"output"
};
}
std
::
string
Test
::
toString
()
const
{
std
::
ostringstream
oss
;
oss
<<
op_name
()
<<
std
::
endl
;
oss
<<
"- input: "
<<
_attributes
->
input
->
info
()
<<
std
::
endl
;
oss
<<
"- output: "
<<
_attributes
->
output
->
info
()
<<
std
::
endl
;
oss
<<
std
::
scientific
<<
std
::
setprecision
(
2
);
oss
<<
"- rtol="
<<
_rtol
<<
", atol="
<<
_atol
<<
std
::
endl
;
return
oss
.
str
();
}
Test
::~
Test
()
{
delete
_attributes
;
}
}
// namespace infiniop_test::silu
src/infiniop/ops/silu/cpu/silu_cpu.cc
0 → 100644
View file @
e184c7e4
#include "silu_cpu.h"
namespace
op
::
silu
::
cpu
{
Descriptor
::~
Descriptor
()
=
default
;
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle_
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
out_desc
,
std
::
vector
<
infiniopTensorDescriptor_t
>
input_desc_vec
)
{
auto
handle
=
reinterpret_cast
<
device
::
cpu
::
Handle
*>
(
handle_
);
auto
dtype
=
out_desc
->
dtype
();
const
auto
&
input_desc
=
input_desc_vec
.
at
(
0
);
const
auto
&
output_shape
=
out_desc
->
shape
();
const
auto
&
input_shape
=
input_desc
->
shape
();
CHECK_DTYPE
(
dtype
,
INFINI_DTYPE_BF16
,
INFINI_DTYPE_F16
,
INFINI_DTYPE_F32
,
INFINI_DTYPE_F64
);
CHECK_SAME_SHAPE
(
output_shape
,
input_shape
);
// create CPU elementwise descriptor
CREATE_ELEMENTWISE_CPU_DESCRIPTOR
(
handle
,
dtype
,
out_desc
,
input_desc_vec
);
return
INFINI_STATUS_SUCCESS
;
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
output
,
std
::
vector
<
const
void
*>
inputs
,
void
*
stream
)
const
{
switch
(
_dtype
)
{
case
INFINI_DTYPE_BF16
:
return
_device_info
->
calculate
<
SiluOp
,
bf16_t
>
(
_info
,
output
,
inputs
,
stream
);
case
INFINI_DTYPE_F16
:
return
_device_info
->
calculate
<
SiluOp
,
fp16_t
>
(
_info
,
output
,
inputs
,
stream
);
case
INFINI_DTYPE_F32
:
return
_device_info
->
calculate
<
SiluOp
,
float
>
(
_info
,
output
,
inputs
,
stream
);
case
INFINI_DTYPE_F64
:
return
_device_info
->
calculate
<
SiluOp
,
double
>
(
_info
,
output
,
inputs
,
stream
);
default:
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::silu::cpu
src/infiniop/ops/silu/cpu/silu_cpu.h
0 → 100644
View file @
e184c7e4
#ifndef __SILU_CPU_H__
#define __SILU_CPU_H__
#include "../../../elementwise/cpu/elementwise_cpu.h"
ELEMENTWISE_DESCRIPTOR
(
silu
,
cpu
)
#include <cmath>
namespace
op
::
silu
::
cpu
{
typedef
struct
SiluOp
{
public:
static
constexpr
size_t
num_inputs
=
1
;
template
<
typename
T
>
T
operator
()(
const
T
&
x
)
const
{
return
x
/
(
static_cast
<
T
>
(
1
)
+
std
::
exp
(
-
x
));
}
}
SiluOp
;
}
// namespace op::silu::cpu
#endif // __SILU_CPU_H__
src/infiniop/ops/silu/cuda/kernel.cuh
0 → 100644
View file @
e184c7e4
#ifndef __SILU_CUDA_H__
#define __SILU_CUDA_H__
#include <cmath>
namespace
op
::
silu
::
cuda
{
typedef
struct
SiluOp
{
public:
static
constexpr
size_t
num_inputs
=
1
;
template
<
typename
T
>
__device__
__forceinline__
T
operator
()(
const
T
&
x
)
const
{
if
constexpr
(
std
::
is_same_v
<
T
,
half2
>
)
{
// half2向量化优化
return
__hmul2
(
x
,
__h2div
(
__float2half2_rn
(
1.0
f
),
__hadd2
(
__float2half2_rn
(
1.0
f
),
h2exp
(
__hneg2
(
x
)))));
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
cuda_bfloat16
>
)
{
// BF16
const
float
x_f
=
__bfloat162float
(
x
);
return
__float2bfloat16
(
x_f
/
(
1.0
f
+
__expf
(
-
x_f
)));
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
half
>
)
{
// FP16
const
float
x_f
=
__half2float
(
x
);
return
__float2half
(
x_f
/
(
1.0
f
+
__expf
(
-
x_f
)));
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
float
>
)
{
// FP32
return
x
*
(
1.0
f
/
(
1.0
f
+
__expf
(
-
x
)));
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
double
>
)
{
// FP64
return
x
/
(
1.0
+
exp
(
-
x
));
}
}
}
SiluOp
;
}
// namespace op::silu::cuda
#endif // __SILU_CUDA_H__
src/infiniop/ops/silu/metax/silu_metax.h
0 → 100644
View file @
e184c7e4
#ifndef __SILU_METAX_API_H__
#define __SILU_METAX_API_H__
#include "../../../elementwise/metax/elementwise_metax_api.h"
ELEMENTWISE_DESCRIPTOR
(
silu
,
metax
)
#endif // __SILU_METAX_API_H__
src/infiniop/ops/silu/metax/silu_metax.maca
0 → 100644
View file @
e184c7e4
#include "silu_metax.h"
#include "../../../elementwise/metax/elementwise_metax.h"
#include "../cuda/kernel.cuh"
namespace op::silu::metax {
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
auto handle = reinterpret_cast<device::metax::Handle *>(handle_);
auto dtype = out_desc->dtype();
const auto &input_desc = input_desc_vec.at(0);
const auto &output_shape = out_desc->shape();
const auto &input_shape = input_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_BF16, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_SAME_SHAPE(output_shape, input_shape);
// create METAX elementwise descriptor
CREATE_ELEMENTWISE_METAX_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
switch (_dtype) {
case INFINI_DTYPE_BF16:
return _device_info->calculate<256, cuda::SiluOp, cuda_bfloat16>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F16:
return _device_info->calculate<256, cuda::SiluOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<256, cuda::SiluOp, float>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F64:
return _device_info->calculate<256, cuda::SiluOp, double>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::silu::metax
src/infiniop/ops/silu/nvidia/silu_nvidia.cu
0 → 100644
View file @
e184c7e4
#include "../../../elementwise/nvidia/elementwise_nvidia.cuh"
#include "../cuda/kernel.cuh"
#include "silu_nvidia.cuh"
namespace
op
::
silu
::
nvidia
{
Descriptor
::~
Descriptor
()
=
default
;
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle_
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
out_desc
,
std
::
vector
<
infiniopTensorDescriptor_t
>
input_desc_vec
)
{
auto
handle
=
reinterpret_cast
<
device
::
nvidia
::
Handle
*>
(
handle_
);
auto
dtype
=
out_desc
->
dtype
();
const
auto
&
input_desc
=
input_desc_vec
.
at
(
0
);
const
auto
&
output_shape
=
out_desc
->
shape
();
const
auto
&
input_shape
=
input_desc
->
shape
();
CHECK_DTYPE
(
dtype
,
INFINI_DTYPE_BF16
,
INFINI_DTYPE_F16
,
INFINI_DTYPE_F32
,
INFINI_DTYPE_F64
);
CHECK_SAME_SHAPE
(
output_shape
,
input_shape
);
// create CUDA elementwise descriptor
CREATE_ELEMENTWISE_CUDA_DESCRIPTOR
(
handle
,
dtype
,
out_desc
,
input_desc_vec
)
return
INFINI_STATUS_SUCCESS
;
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
output
,
std
::
vector
<
const
void
*>
inputs
,
void
*
stream
)
const
{
if
(
workspace_size
<
_workspace_size
)
{
return
INFINI_STATUS_INSUFFICIENT_WORKSPACE
;
}
switch
(
_dtype
)
{
case
INFINI_DTYPE_BF16
:
return
_device_info
->
calculate
<
256
,
cuda
::
SiluOp
,
cuda_bfloat16
>
(
_info
,
workspace
,
output
,
inputs
,
stream
);
case
INFINI_DTYPE_F16
:
return
_device_info
->
calculate
<
256
,
cuda
::
SiluOp
,
half
>
(
_info
,
workspace
,
output
,
inputs
,
stream
);
case
INFINI_DTYPE_F32
:
return
_device_info
->
calculate
<
256
,
cuda
::
SiluOp
,
float
>
(
_info
,
workspace
,
output
,
inputs
,
stream
);
case
INFINI_DTYPE_F64
:
return
_device_info
->
calculate
<
256
,
cuda
::
SiluOp
,
double
>
(
_info
,
workspace
,
output
,
inputs
,
stream
);
default:
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::silu::nvidia
src/infiniop/ops/silu/nvidia/silu_nvidia.cuh
0 → 100644
View file @
e184c7e4
#ifndef __SILU_CUDA_API_H__
#define __SILU_CUDA_API_H__
#include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh"
ELEMENTWISE_DESCRIPTOR
(
silu
,
nvidia
)
#endif // __SILU_CUDA_API_H__
src/infiniop/ops/silu/operator.cc
0 → 100644
View file @
e184c7e4
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/silu.h"
#ifdef ENABLE_CPU_API
#include "cpu/silu_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API)
#include "nvidia/silu_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
#include "metax/silu_metax.h"
#endif
__C
infiniStatus_t
infiniopCreateSiluDescriptor
(
infiniopHandle_t
handle
,
infiniopSiluDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
output_desc
,
infiniopTensorDescriptor_t
input_desc
)
{
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::silu::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::silu::NAMESPACE::Descriptor **>(desc_ptr), \
output_desc, \
{input_desc})
switch
(
handle
->
device
)
{
#ifdef ENABLE_CPU_API
CREATE
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_NVIDIA_API
CREATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
CREATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_METAX_API
CREATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CREATE
}
__C
infiniStatus_t
infiniopGetSiluWorkspaceSize
(
infiniopSiluDescriptor_t
desc
,
size_t
*
size
)
{
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::silu::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU_API
GET
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_NVIDIA_API
GET
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
GET
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_METAX_API
GET
(
INFINI_DEVICE_METAX
,
metax
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef GET
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopSilu
(
infiniopSiluDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
output
,
const
void
*
input
,
void
*
stream
)
{
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::silu::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, output, {input}, stream)
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU_API
CALCULATE
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_NVIDIA_API
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
CALCULATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_METAX_API
CALCULATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CALCULATE
}
__C
infiniStatus_t
infiniopDestroySiluDescriptor
(
infiniopSiluDescriptor_t
desc
)
{
#define DELETE(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<const op::silu::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU_API
DELETE
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_NVIDIA_API
DELETE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
DELETE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_METAX_API
DELETE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef DELETE
}
test/infiniop/libinfiniop/op_register.py
View file @
e184c7e4
...
...
@@ -736,3 +736,35 @@ def ones_(lib):
lib
.
infiniopDestroyOnesDescriptor
.
argtypes
=
[
infiniopOperatorDescriptor_t
,
]
@
OpRegister
.
operator
def
silu_
(
lib
):
lib
.
infiniopCreateSiluDescriptor
.
restype
=
c_int32
lib
.
infiniopCreateSiluDescriptor
.
argtypes
=
[
infiniopHandle_t
,
POINTER
(
infiniopOperatorDescriptor_t
),
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
]
lib
.
infiniopGetSiluWorkspaceSize
.
restype
=
c_int32
lib
.
infiniopGetSiluWorkspaceSize
.
argtypes
=
[
infiniopOperatorDescriptor_t
,
POINTER
(
c_size_t
),
]
lib
.
infiniopSilu
.
restype
=
c_int32
lib
.
infiniopSilu
.
argtypes
=
[
infiniopOperatorDescriptor_t
,
c_void_p
,
c_size_t
,
c_void_p
,
c_void_p
,
c_void_p
,
]
lib
.
infiniopDestroySiluDescriptor
.
restype
=
c_int32
lib
.
infiniopDestroySiluDescriptor
.
argtypes
=
[
infiniopOperatorDescriptor_t
,
]
test/infiniop/libinfiniop/utils.py
View file @
e184c7e4
...
...
@@ -143,6 +143,9 @@ class TestTensor(CTensor):
shape_
,
strides_
,
dt
,
device
,
mode
=
"manual"
,
set_tensor
=
torch_tensor
)
def
update_torch_tensor
(
self
,
new_tensor
:
torch
.
Tensor
):
self
.
_torch_tensor
=
new_tensor
def
to_torch_dtype
(
dt
:
InfiniDtype
,
compatability_mode
=
False
):
if
dt
==
InfiniDtype
.
BOOL
:
...
...
@@ -607,7 +610,7 @@ def profile_operation(desc, func, torch_device, NUM_PRERUN, NUM_ITERATIONS):
# Timed execution
elapsed
=
timed_op
(
lambda
:
func
(),
NUM_ITERATIONS
,
torch_device
)
print
(
f
"
{
desc
}
time:
{
elapsed
*
1000
:
6
f
}
ms"
)
print
(
f
"
{
desc
}
time:
{
elapsed
*
1000
:
6
f
}
ms"
)
def
test_operator
(
device
,
test_func
,
test_cases
,
tensor_dtypes
):
...
...
test/infiniop/silu.py
0 → 100644
View file @
e184c7e4
import
torch
import
ctypes
from
ctypes
import
c_uint64
from
libinfiniop
import
(
LIBINFINIOP
,
TestTensor
,
get_test_devices
,
check_error
,
test_operator
,
get_args
,
debug
,
get_tolerance
,
profile_operation
,
TestWorkspace
,
InfiniDtype
,
InfiniDtypeNames
,
InfiniDeviceNames
,
infiniopOperatorDescriptor_t
,
)
from
enum
import
Enum
,
auto
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES_
=
[
# shape, input_stride, output_stride
((
13
,
4
),
None
,
None
),
((
13
,
4
),
(
10
,
1
),
(
10
,
1
)),
((
13
,
4
),
(
0
,
1
),
None
),
((
13
,
4
,
4
),
None
,
None
),
((
13
,
4
,
4
),
(
20
,
4
,
1
),
(
20
,
4
,
1
)),
((
13
,
4
,
4
),
(
4
,
0
,
1
),
None
),
((
16
,
5632
),
None
,
None
),
((
16
,
5632
),
(
13312
,
1
),
(
13312
,
1
)),
((
4
,
4
,
5632
),
None
,
None
),
((
4
,
4
,
5632
),
(
45056
,
5632
,
1
),
(
45056
,
5632
,
1
)),
]
class
Inplace
(
Enum
):
OUT_OF_PLACE
=
auto
()
INPLACE
=
auto
()
# Inplace options applied for each test case in _TEST_CASES_
_INPLACE
=
[
Inplace
.
OUT_OF_PLACE
,
Inplace
.
INPLACE
,
]
# Form the test cases by appending each element of _INPLACE to each tuple in _TEST_CASES_
_TEST_CASES
=
[
test_case
+
(
inplace_item
,)
for
test_case
in
_TEST_CASES_
for
inplace_item
in
_INPLACE
]
# Data types used for testing
_TENSOR_DTYPES
=
[
InfiniDtype
.
BF16
,
InfiniDtype
.
F16
,
InfiniDtype
.
F32
,
InfiniDtype
.
F64
]
# Tolerance map for different data types
_TOLERANCE_MAP
=
{
InfiniDtype
.
BF16
:
{
"atol"
:
1e-2
,
"rtol"
:
1e-2
},
InfiniDtype
.
F16
:
{
"atol"
:
1e-3
,
"rtol"
:
1e-3
},
InfiniDtype
.
F32
:
{
"atol"
:
1e-7
,
"rtol"
:
1e-7
},
InfiniDtype
.
F64
:
{
"atol"
:
2.22e-15
,
"rtol"
:
2.22e-15
},
}
DEBUG
=
False
PROFILE
=
False
NUM_PRERUN
=
10
NUM_ITERATIONS
=
1000
def
test
(
handle
,
device
,
shape
,
input_stride
=
None
,
output_stride
=
None
,
inplace
=
Inplace
.
OUT_OF_PLACE
,
dtype
=
torch
.
float16
,
sync
=
None
,
):
input
=
TestTensor
(
shape
,
input_stride
,
dtype
,
device
)
if
inplace
==
Inplace
.
INPLACE
:
if
input_stride
!=
output_stride
:
return
output
=
input
else
:
output
=
TestTensor
(
shape
,
output_stride
,
dtype
,
device
,
mode
=
"ones"
)
if
output
.
is_broadcast
():
return
print
(
f
"Testing Silu on
{
InfiniDeviceNames
[
device
]
}
with shape:
{
shape
}
input_stride:
{
input_stride
}
output_stride:
{
output_stride
}
"
f
"dtype:
{
InfiniDtypeNames
[
dtype
]
}
inplace:
{
inplace
}
"
)
new_output
=
torch
.
nn
.
functional
.
silu
(
input
.
torch_tensor
())
output
.
update_torch_tensor
(
new_output
)
if
sync
is
not
None
:
sync
()
descriptor
=
infiniopOperatorDescriptor_t
()
check_error
(
LIBINFINIOP
.
infiniopCreateSiluDescriptor
(
handle
,
ctypes
.
byref
(
descriptor
),
output
.
descriptor
,
input
.
descriptor
,
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
for
tensor
in
[
input
,
output
]:
tensor
.
destroy_desc
()
workspace_size
=
c_uint64
(
0
)
check_error
(
LIBINFINIOP
.
infiniopGetSiluWorkspaceSize
(
descriptor
,
ctypes
.
byref
(
workspace_size
)
)
)
workspace
=
TestWorkspace
(
workspace_size
.
value
,
output
.
device
)
def
lib_silu
():
check_error
(
LIBINFINIOP
.
infiniopSilu
(
descriptor
,
workspace
.
data
(),
workspace
.
size
(),
output
.
data
(),
input
.
data
(),
None
,
)
)
lib_silu
()
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
if
DEBUG
:
debug
(
output
.
actual_tensor
(),
output
.
torch_tensor
(),
atol
=
atol
,
rtol
=
rtol
)
assert
torch
.
allclose
(
output
.
actual_tensor
(),
output
.
torch_tensor
(),
atol
=
atol
,
rtol
=
rtol
)
# Profiling workflow
if
PROFILE
:
# fmt: off
profile_operation
(
"PyTorch"
,
lambda
:
torch
.
nn
.
functional
.
silu
(
input
.
torch_tensor
()),
device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
profile_operation
(
" lib"
,
lambda
:
lib_silu
(),
device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
# fmt: on
check_error
(
LIBINFINIOP
.
infiniopDestroySiluDescriptor
(
descriptor
))
if
__name__
==
"__main__"
:
args
=
get_args
()
# Configure testing options
DEBUG
=
args
.
debug
PROFILE
=
args
.
profile
NUM_PRERUN
=
args
.
num_prerun
NUM_ITERATIONS
=
args
.
num_iterations
for
device
in
get_test_devices
(
args
):
test_operator
(
device
,
test
,
_TEST_CASES
,
_TENSOR_DTYPES
)
print
(
"
\033
[92mTest passed!
\033
[0m"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment