Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
f53b8435
Unverified
Commit
f53b8435
authored
Dec 11, 2025
by
zhangyue
Committed by
GitHub
Dec 11, 2025
Browse files
Merge pull request #755 from InfiniTensor/issue/753
issue/753: kunlun gelu kernel
parents
681f4e1e
579bb1bf
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
176 additions
and
129 deletions
+176
-129
src/infiniop/elementwise/kunlun/elementwise_kunlun.h
src/infiniop/elementwise/kunlun/elementwise_kunlun.h
+0
-15
src/infiniop/ops/add/kunlun/add_kunlun.xpu
src/infiniop/ops/add/kunlun/add_kunlun.xpu
+3
-10
src/infiniop/ops/clip/kunlun/clip_kunlun.xpu
src/infiniop/ops/clip/kunlun/clip_kunlun.xpu
+0
-9
src/infiniop/ops/gelu/kunlun/gelu_kunlun.h
src/infiniop/ops/gelu/kunlun/gelu_kunlun.h
+8
-0
src/infiniop/ops/gelu/kunlun/gelu_kunlun.xpu
src/infiniop/ops/gelu/kunlun/gelu_kunlun.xpu
+56
-0
src/infiniop/ops/gelu/kunlun/kernel.h
src/infiniop/ops/gelu/kunlun/kernel.h
+32
-0
src/infiniop/ops/gelu/operator.cc
src/infiniop/ops/gelu/operator.cc
+15
-0
src/infiniop/ops/mul/kunlun/mul_kunlun.xpu
src/infiniop/ops/mul/kunlun/mul_kunlun.xpu
+0
-9
src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.xpu
...nfiniop/ops/random_sample/kunlun/random_sample_kunlun.xpu
+62
-67
src/infiniop/ops/sub/kunlun/sub_kunlun.xpu
src/infiniop/ops/sub/kunlun/sub_kunlun.xpu
+0
-9
src/infiniop/ops/swiglu/kunlun/swiglu_kunlun.xpu
src/infiniop/ops/swiglu/kunlun/swiglu_kunlun.xpu
+0
-10
No files found.
src/infiniop/elementwise/kunlun/elementwise_kunlun.h
View file @
f53b8435
...
...
@@ -312,21 +312,6 @@ infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &inf
std
::
forward
<
Args
>
(
args
)...);
}
#define INSTANTIATE_ELEMENTWISE_KERNEL(N, Op, Tdata, ...) \
template __global__ void elementwiseKernel<N, Op, Tdata, ##__VA_ARGS__>( \
int output_size, \
int ndim, \
bool output_contiguous, \
const bool *input_contiguous_gm, \
const bool *input_broadcasted_gm, \
const void *output_shape_gm, \
const void *input_shapes_gm, \
const void *output_strides_gm, \
const void *input_strides_gm, \
Tdata *output, \
const void *const *inputs, \
##__VA_ARGS__);
}
// namespace op::elementwise::kunlun
#endif
src/infiniop/ops/add/kunlun/add_kunlun.xpu
View file @
f53b8435
...
...
@@ -2,15 +2,6 @@
#include "add_kunlun.h"
#include "kernel.h"
namespace op::elementwise::kunlun {
using AddOp = op::add::kunlun::AddOp;
INSTANTIATE_ELEMENTWISE_KERNEL(AddOp::num_inputs, AddOp, float);
INSTANTIATE_ELEMENTWISE_KERNEL(AddOp::num_inputs, AddOp, half);
INSTANTIATE_ELEMENTWISE_KERNEL(AddOp::num_inputs, AddOp, bfloat16_t);
} // namespace op::elementwise::kunlun
namespace op::add::kunlun {
Descriptor::~Descriptor() = default;
...
...
@@ -30,7 +21,7 @@ infiniStatus_t Descriptor::create(
const auto &a_shape = a_desc->shape();
const auto &b_shape = b_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16);
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16
, INFINI_DTYPE_I32
);
CHECK_SAME_SHAPE(c_shape, a_shape, b_shape);
...
...
@@ -58,6 +49,8 @@ infiniStatus_t Descriptor::calculate(
return _device_info->calculate<8, AddOp, bfloat16_t>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<8, AddOp, float>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_I32:
return _device_info->calculate<8, AddOp, int32_t>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
...
...
src/infiniop/ops/clip/kunlun/clip_kunlun.xpu
View file @
f53b8435
...
...
@@ -2,15 +2,6 @@
#include "clip_kunlun.h"
#include "kernel.h"
namespace op::elementwise::kunlun {
using ClipOp = op::clip::kunlun::ClipOp;
INSTANTIATE_ELEMENTWISE_KERNEL(ClipOp::num_inputs, ClipOp, float);
INSTANTIATE_ELEMENTWISE_KERNEL(ClipOp::num_inputs, ClipOp, half);
INSTANTIATE_ELEMENTWISE_KERNEL(ClipOp::num_inputs, ClipOp, bfloat16_t);
} // namespace op::elementwise::kunlun
namespace op::clip::kunlun {
Descriptor::~Descriptor() = default;
...
...
src/infiniop/ops/gelu/kunlun/gelu_kunlun.h
0 → 100644
View file @
f53b8435
#ifndef __GELU_KUNLUN_H__
#define __GELU_KUNLUN_H__
#include "../../../elementwise/kunlun/elementwise_kunlun_api.h"
ELEMENTWISE_DESCRIPTOR
(
gelu
,
kunlun
)
#endif
src/infiniop/ops/gelu/kunlun/gelu_kunlun.xpu
0 → 100644
View file @
f53b8435
#include "../../../elementwise/kunlun/elementwise_kunlun.h"
#include "gelu_kunlun.h"
#include "kernel.h"
namespace op::gelu::kunlun {
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
auto handle = reinterpret_cast<device::kunlun::Handle *>(handle_);
auto dtype = out_desc->dtype();
const auto &input_desc = input_desc_vec.at(0);
const auto &output_shape = out_desc->shape();
const auto &input_shape = input_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_BF16, INFINI_DTYPE_F16, INFINI_DTYPE_F32);
CHECK_SAME_SHAPE(output_shape, input_shape);
// create Kunlun elementwise descriptor
CREATE_ELEMENTWISE_KUNLUN_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
switch (_dtype) {
case INFINI_DTYPE_BF16:
return _device_info->calculate<8, GeluOp, bfloat16_t>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F16:
return _device_info->calculate<8, GeluOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<8, GeluOp, float>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::gelu::kunlun
src/infiniop/ops/gelu/kunlun/kernel.h
0 → 100644
View file @
f53b8435
#ifndef __GELU_KUNLUN_KERNEL_H__
#define __GELU_KUNLUN_KERNEL_H__
namespace
op
::
gelu
::
kunlun
{
typedef
struct
GeluOp
{
public:
static
constexpr
size_t
num_inputs
=
1
;
template
<
typename
T
>
inline
__device__
T
operator
()(
const
T
*
x
)
const
{
if
constexpr
(
std
::
is_same_v
<
T
,
bfloat16_t
>
)
{
float
x_f
=
__bfloat162float
(
x
[
0
]);
float
result
=
0.5
*
x_f
*
(
1
+
fast_erf
(
x_f
/
sqrt
(
2.0
f
)));
return
__float2bfloat16
(
result
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
half
>
)
{
float
x_f
=
__half2float
(
x
[
0
]);
float
result
=
0.5
*
x_f
*
(
1
+
fast_erf
(
x_f
/
sqrt
(
2.0
f
)));
return
__float2half
(
result
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
float
>
)
{
return
0.5
*
x
[
0
]
*
(
1
+
fast_erf
(
x
[
0
]
/
sqrt
(
2.0
f
)));
}
else
{
return
0.5
*
x
[
0
]
*
(
1
+
fast_erf
(
x
[
0
]
/
sqrt
(
2.0
)));
}
}
}
GeluOp
;
}
// namespace op::gelu::kunlun
#endif // __GELU_KUNLUN_H__
src/infiniop/ops/gelu/operator.cc
View file @
f53b8435
...
...
@@ -11,6 +11,9 @@
#ifdef ENABLE_METAX_API
#include "metax/gelu_metax.h"
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/gelu_kunlun.h"
#endif
__C
infiniStatus_t
infiniopCreateGeluDescriptor
(
infiniopHandle_t
handle
,
...
...
@@ -43,6 +46,9 @@ __C infiniStatus_t infiniopCreateGeluDescriptor(
#ifdef ENABLE_METAX_API
CREATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_KUNLUN_API
CREATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -73,6 +79,9 @@ __C infiniStatus_t infiniopGetGeluWorkspaceSize(infiniopGeluDescriptor_t desc, s
#endif
#ifdef ENABLE_METAX_API
GET
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_KUNLUN_API
GET
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -112,6 +121,9 @@ __C infiniStatus_t infiniopGelu(
#ifdef ENABLE_METAX_API
CALCULATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_KUNLUN_API
CALCULATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -145,6 +157,9 @@ infiniopDestroyGeluDescriptor(infiniopGeluDescriptor_t desc) {
#ifdef ENABLE_METAX_API
DELETE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_KUNLUN_API
DELETE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
src/infiniop/ops/mul/kunlun/mul_kunlun.xpu
View file @
f53b8435
...
...
@@ -2,15 +2,6 @@
#include "kernel.h"
#include "mul_kunlun.h"
namespace op::elementwise::kunlun {
using MulOp = op::mul::kunlun::MulOp;
INSTANTIATE_ELEMENTWISE_KERNEL(MulOp::num_inputs, MulOp, float);
INSTANTIATE_ELEMENTWISE_KERNEL(MulOp::num_inputs, MulOp, half);
INSTANTIATE_ELEMENTWISE_KERNEL(MulOp::num_inputs, MulOp, bfloat16_t);
} // namespace op::elementwise::kunlun
namespace op::mul::kunlun {
Descriptor::~Descriptor() = default;
...
...
src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.xpu
View file @
f53b8435
#include "random_sample_kunlun.h"
#include "../../../devices/kunlun/kunlun_common.h"
#include "../../../devices/kunlun/kunlun_handle.h"
#include "random_sample_kunlun.h"
#include "../info.h"
#include "kernel.h"
#include "xpu/kernel/xtdk_io.h"
template <typename Tval, typename Tidx>
void launchKernel(void *workspace,
void *result,
const void *probs,
float random_val,
float topp,
int topk,
float temperature,
int64_t n,
XPUStream stream) {
void launchKernel(void *workspace,
void *result,
const void *probs,
float random_val,
float topp,
int topk,
float temperature,
int64_t n,
XPUStream stream) {
constexpr unsigned int cluster_num = 8;
constexpr unsigned int core_num = 64;
char *workspace_value = reinterpret_cast<char *>(workspace);
int topk_ = topk <= (int)n ? topk : (int)n;
bool dosample = topk_ > 1 && temperature != 0.0f && topp != 0.0f && random_val != 0.0f;
Tval *values = (Tval *)workspace_value;
xpu_memcpy(values, (Tval *)probs, n * sizeof(Tval), XPU_DEVICE_TO_DEVICE);
...
...
@@ -33,32 +29,33 @@ void launchKernel(void *workspace,
char *workspace_index = workspace_sum + cluster_num * sizeof(float);
Tidx *indices = (Tidx *)workspace_index;
Tidx *indices_global = indices + n;
if (dosample){
randomSampleKernel<cluster_num, core_num, Tval, float, Tidx><<<cluster_num, core_num, stream>>>((Tidx *)result,
(Tval *)probs,
random_val,
topp,
n,
topk_,
temperature,
indices,
values,
indices_global,
values_global,
sum_global);
if (dosample) {
randomSampleKernel<cluster_num, core_num, Tval, float, Tidx>
<<<cluster_num, core_num, stream>>>((Tidx *)result,
(Tval *)probs,
random_val,
topp,
n,
topk_,
temperature,
indices,
values,
indices_global,
values_global,
sum_global);
}
else {
argmaxKernel<Tval, Tidx>
<<<cluster_num, core_num, stream>>>((Tidx *)result, (Tval *)probs, n,
indices,
values,
indices_global,
values_global);
}
else{
argmaxKernel<Tval, Tidx><<<cluster_num, core_num, stream>>>((Tidx *)result, (Tval *)probs, n,
indices,
values,
indices_global,
values_global);
}
}
#define LAUNCH_KERNEL(Tval, Tidx)
\
#define LAUNCH_KERNEL(Tval, Tidx) \
launchKernel<Tval, Tidx>(workspace, result, probs, random_val, topp, topk, temperature, n, reinterpret_cast<kunlunStream_t>(stream));
namespace op::random_sample::kunlun {
...
...
@@ -82,11 +79,11 @@ infiniStatus_t Descriptor::create(
CHECK_RESULT(result);
auto info = result.take();
int cluster_num = 8;
int core_num = 64;
int n = probs_desc->numel();
size_t workspace_size = (n + cluster_num * core_num * n) * (infiniSizeOf(probs_desc->dtype()) + infiniSizeOf(result_desc->dtype())) + cluster_num * sizeof(float);
*desc_ptr = new Descriptor(
info,
...
...
@@ -116,40 +113,38 @@ Descriptor::calculate(
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
int n = (int)_info.n;
if (_info.dt_i == INFINI_DTYPE_I32){
if (_info.dt_i == INFINI_DTYPE_I32)
{
switch (_info.dt_p) {
case INFINI_DTYPE_F16:
LAUNCH_KERNEL(half, int32_t);
break;
case INFINI_DTYPE_BF16:
LAUNCH_KERNEL(bfloat16_t, int32_t);
break;
case INFINI_DTYPE_F32:
LAUNCH_KERNEL(float, int32_t);
break;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
case INFINI_DTYPE_F16:
LAUNCH_KERNEL(half, int32_t);
break;
case INFINI_DTYPE_BF16:
LAUNCH_KERNEL(bfloat16_t, int32_t);
break;
case INFINI_DTYPE_F32:
LAUNCH_KERNEL(float, int32_t);
break;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}
else if (_info.dt_i == INFINI_DTYPE_I64){
} else if (_info.dt_i == INFINI_DTYPE_I64) {
switch (_info.dt_p) {
case INFINI_DTYPE_F16:
LAUNCH_KERNEL(half, int64_t);
break;
case INFINI_DTYPE_BF16:
LAUNCH_KERNEL(bfloat16_t, int64_t);
break;
case INFINI_DTYPE_F32:
LAUNCH_KERNEL(float, int64_t);
break;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
case INFINI_DTYPE_F16:
LAUNCH_KERNEL(half, int64_t);
break;
case INFINI_DTYPE_BF16:
LAUNCH_KERNEL(bfloat16_t, int64_t);
break;
case INFINI_DTYPE_F32:
LAUNCH_KERNEL(float, int64_t);
break;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}
else {
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
...
...
src/infiniop/ops/sub/kunlun/sub_kunlun.xpu
View file @
f53b8435
...
...
@@ -2,15 +2,6 @@
#include "kernel.h"
#include "sub_kunlun.h"
namespace op::elementwise::kunlun {
using SubOp = op::sub::kunlun::SubOp;
INSTANTIATE_ELEMENTWISE_KERNEL(SubOp::num_inputs, SubOp, float);
INSTANTIATE_ELEMENTWISE_KERNEL(SubOp::num_inputs, SubOp, half);
INSTANTIATE_ELEMENTWISE_KERNEL(SubOp::num_inputs, SubOp, bfloat16_t);
} // namespace op::elementwise::kunlun
namespace op::sub::kunlun {
Descriptor::~Descriptor() = default;
...
...
src/infiniop/ops/swiglu/kunlun/swiglu_kunlun.xpu
View file @
f53b8435
...
...
@@ -2,16 +2,6 @@
#include "kernel.h"
#include "swiglu_kunlun.h"
namespace op::elementwise::kunlun {
using SwiGLUOp = op::swiglu::kunlun::SwiGLUOp;
// __global__ template function instantiation
INSTANTIATE_ELEMENTWISE_KERNEL(SwiGLUOp::num_inputs, SwiGLUOp, float);
INSTANTIATE_ELEMENTWISE_KERNEL(SwiGLUOp::num_inputs, SwiGLUOp, half);
INSTANTIATE_ELEMENTWISE_KERNEL(SwiGLUOp::num_inputs, SwiGLUOp, bfloat16_t);
} // namespace op::elementwise::kunlun
namespace op::swiglu::kunlun {
Descriptor::~Descriptor() = default;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment