Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
8b760951
Commit
8b760951
authored
Aug 20, 2025
by
zhangyue
Browse files
Merge branch 'main' of
https://github.com/InfiniTensor/InfiniCore
into issue-385
parents
eb3972eb
d4b03cf7
Changes
54
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
547 additions
and
36 deletions
+547
-36
src/infiniop/ops/clip/kunlun/kernel.h
src/infiniop/ops/clip/kunlun/kernel.h
+30
-0
src/infiniop/ops/clip/operator.cc
src/infiniop/ops/clip/operator.cc
+15
-0
src/infiniop/ops/gemm/moore/gemm_moore.h
src/infiniop/ops/gemm/moore/gemm_moore.h
+8
-0
src/infiniop/ops/gemm/moore/gemm_moore.mu
src/infiniop/ops/gemm/moore/gemm_moore.mu
+125
-0
src/infiniop/ops/gemm/musa/gemm_musa.h
src/infiniop/ops/gemm/musa/gemm_musa.h
+0
-8
src/infiniop/ops/gemm/operator.cc
src/infiniop/ops/gemm/operator.cc
+5
-5
src/infiniop/ops/mul/kunlun/kernel.h
src/infiniop/ops/mul/kunlun/kernel.h
+25
-0
src/infiniop/ops/mul/kunlun/mul_kunlun.h
src/infiniop/ops/mul/kunlun/mul_kunlun.h
+8
-0
src/infiniop/ops/mul/kunlun/mul_kunlun.xpu
src/infiniop/ops/mul/kunlun/mul_kunlun.xpu
+67
-0
src/infiniop/ops/mul/operator.cc
src/infiniop/ops/mul/operator.cc
+15
-0
src/infiniop/ops/rms_norm/cuda/kernel.cuh
src/infiniop/ops/rms_norm/cuda/kernel.cuh
+1
-1
src/infiniop/ops/rms_norm/moore/rms_norm_moore.h
src/infiniop/ops/rms_norm/moore/rms_norm_moore.h
+8
-0
src/infiniop/ops/rms_norm/moore/rms_norm_moore.mu
src/infiniop/ops/rms_norm/moore/rms_norm_moore.mu
+43
-17
src/infiniop/ops/rms_norm/operator.cc
src/infiniop/ops/rms_norm/operator.cc
+5
-5
src/infiniop/ops/sub/kunlun/kernel.h
src/infiniop/ops/sub/kunlun/kernel.h
+25
-0
src/infiniop/ops/sub/kunlun/sub_kunlun.h
src/infiniop/ops/sub/kunlun/sub_kunlun.h
+8
-0
src/infiniop/ops/sub/kunlun/sub_kunlun.xpu
src/infiniop/ops/sub/kunlun/sub_kunlun.xpu
+67
-0
src/infiniop/ops/sub/operator.cc
src/infiniop/ops/sub/operator.cc
+16
-0
src/infiniop/ops/swiglu/bang/swiglu_bang.h
src/infiniop/ops/swiglu/bang/swiglu_bang.h
+8
-0
src/infiniop/ops/swiglu/bang/swiglu_bang.mlu
src/infiniop/ops/swiglu/bang/swiglu_bang.mlu
+68
-0
No files found.
src/infiniop/ops/clip/kunlun/kernel.h
0 → 100644
View file @
8b760951
#ifndef __CLIP_KUNLUN_KERNEL_H__
#define __CLIP_KUNLUN_KERNEL_H__
#include <xpu/kernel/xtdk_io.h>
namespace
op
::
clip
::
kunlun
{
typedef
struct
ClipOp
{
public:
static
constexpr
int
num_inputs
=
3
;
template
<
typename
T
>
inline
__device__
T
operator
()(
const
T
*
inputs
)
const
{
T
x
=
inputs
[
0
];
T
min_val
=
inputs
[
1
];
T
max_val
=
inputs
[
2
];
return
fmax
(
fmin
(
x
,
max_val
),
min_val
);
}
// bfloat16 特化版本(使用 float 计算精度)
inline
__device__
bfloat16_t
operator
()(
const
bfloat16_t
*
inputs
)
const
{
float
x_f
=
__bfloat162float
(
inputs
[
0
]);
float
min_val_f
=
__bfloat162float
(
inputs
[
1
]);
float
max_val_f
=
__bfloat162float
(
inputs
[
2
]);
float
result_f
=
fmax
(
fmin
(
x_f
,
max_val_f
),
min_val_f
);
return
__float2bfloat16
(
result_f
);
}
}
ClipOp
;
}
// namespace op::clip::kunlun
#endif // __CLIP_KUNLUN_KERNEL_H__
src/infiniop/ops/clip/operator.cc
View file @
8b760951
...
...
@@ -11,6 +11,9 @@
#ifdef ENABLE_METAX_API
#include "metax/clip_metax.h"
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/clip_kunlun.h"
#endif
__C
infiniStatus_t
infiniopCreateClipDescriptor
(
infiniopHandle_t
handle
,
...
...
@@ -42,6 +45,9 @@ __C infiniStatus_t infiniopCreateClipDescriptor(
#ifdef ENABLE_METAX_API
CREATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_KUNLUN_API
CREATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -69,6 +75,9 @@ __C infiniStatus_t infiniopGetClipWorkspaceSize(infiniopClipDescriptor_t desc, s
#endif
#ifdef ENABLE_METAX_API
GET
(
INFINI_DEVICE_METAX
,
metax
)
#endif
#ifdef ENABLE_KUNLUN_API
GET
(
INFINI_DEVICE_KUNLUN
,
kunlun
)
#endif
}
...
...
@@ -106,6 +115,9 @@ __C infiniStatus_t infiniopClip(
#ifdef ENABLE_METAX_API
CALCULATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_KUNLUN_API
CALCULATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -136,6 +148,9 @@ infiniopDestroyClipDescriptor(infiniopClipDescriptor_t desc) {
#ifdef ENABLE_METAX_API
DELETE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_KUNLUN_API
DELETE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
src/infiniop/ops/gemm/moore/gemm_moore.h
0 → 100644
View file @
8b760951
#ifndef __GEMM_MOORE_H__
#define __GEMM_MOORE_H__
#include "../gemm.h"
DESCRIPTOR
(
moore
)
#endif // __GEMM_MOORE_H__
src/infiniop/ops/gemm/m
usa
/gemm_m
usa
.mu
→
src/infiniop/ops/gemm/m
oore
/gemm_m
oore
.mu
View file @
8b760951
#include "../../../devices/m
usa/
common
_musa
.h"
#include "../../../devices/m
usa/musa
_handle.h"
#include "gemm_m
usa
.h"
#include "../../../devices/m
oore/moore_
common.h"
#include "../../../devices/m
oore/moore
_handle.h"
#include "gemm_m
oore
.h"
namespace op::gemm::m
usa
{
namespace op::gemm::m
oore
{
struct Descriptor::Opaque {
std::shared_ptr<device::m
usa
::Handle::Internal> internal;
std::shared_ptr<device::m
oore
::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
...
...
@@ -18,10 +18,10 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) {
auto handle = reinterpret_cast<device::m
usa
::Handle *>(handle_);
auto handle = reinterpret_cast<device::m
oore
::Handle *>(handle_);
auto dtype = c_desc->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32);
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32
, INFINI_DTYPE_BF16
);
auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR);
CHECK_RESULT(result);
...
...
@@ -33,41 +33,63 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_SUCCESS;
}
template <typename Tdata>
infiniStatus_t calculate(
const MatmulInfo &info,
std::shared_ptr<device::musa::Handle::Internal> &_internal,
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *c,
float beta,
const void *a,
const void *b,
float alpha,
void *stream) {
void *stream)
const
{
musaDataType a_type, b_type, c_type;
mublasComputeType_t compute_type;
Tdata alpha_, beta_;
if constexpr (std::is_same<Tdata, half>::value) {
alpha_ = __float2half(alpha);
beta_ = __float2half(beta);
// MUSA's GEMM operations require that the scalar values alpha and beta have the same data type as the matrices.
// This ensures correct computation during the muBLAS GEMM operation.
// Declare half-precision variables to handle F16 types.
half alpha_h, beta_h;
// Initialize generic void pointers for alpha and beta.
// They point to the original float values
// It will be used directly when the GEMM operation is performed with F32 data.
const void *p_alpha = α
const void *p_beta = β
switch (_dtype) {
case INFINI_DTYPE_F16:
a_type = b_type = c_type = MUSA_R_16F;
compute_type = MUBLAS_COMPUTE_16F;
} else {
alpha_ = alpha;
beta_ = beta;
// Convert alpha/beta to half-precision and update the pointers.
alpha_h = __float2half(alpha);
beta_h = __float2half(beta);
p_alpha = &alpha_h;
p_beta = &beta_h;
break;
case INFINI_DTYPE_BF16:
a_type = b_type = c_type = MUSA_R_16BF;
compute_type = MUBLAS_COMPUTE_32F;
break;
case INFINI_DTYPE_F32:
a_type = b_type = c_type = MUSA_R_32F;
compute_type = MUBLAS_COMPUTE_32F_FAST_TF32;
break;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (info.is_transed) {
if (
_
info.is_transed) {
std::swap(a, b);
}
auto op_a = info.a_matrix.row_stride == 1 ? MUBLAS_OP_N : MUBLAS_OP_T;
auto op_b = info.b_matrix.row_stride == 1 ? MUBLAS_OP_N : MUBLAS_OP_T;
auto op_a =
_
info.a_matrix.row_stride == 1 ? MUBLAS_OP_N : MUBLAS_OP_T;
auto op_b =
_
info.b_matrix.row_stride == 1 ? MUBLAS_OP_N : MUBLAS_OP_T;
CHECK_STATUS(_internal->useMublas(
CHECK_STATUS(_
opaque->
internal->useMublas(
(musaStream_t)stream,
[&](mublasHandle_t handle) {
CHECK_MUBLAS(
...
...
@@ -75,24 +97,24 @@ infiniStatus_t calculate(
handle,
op_a,
op_b,
static_cast<int>(info.m),
static_cast<int>(info.n),
static_cast<int>(info.k),
&
alpha
_
,
static_cast<int>(
_
info.m),
static_cast<int>(
_
info.n),
static_cast<int>(
_
info.k),
p_
alpha,
a,
a_type,
static_cast<int>(info.a_matrix.ld()),
info.a_matrix.stride,
static_cast<int>(
_
info.a_matrix.ld()),
_
info.a_matrix.stride,
b,
b_type,
static_cast<int>(info.b_matrix.ld()),
info.b_matrix.stride,
&
beta
_
,
static_cast<int>(
_
info.b_matrix.ld()),
_
info.b_matrix.stride,
p_
beta,
c,
c_type,
static_cast<int>(info.c_matrix.ld()),
info.c_matrix.stride,
static_cast<int>(info.batch),
static_cast<int>(
_
info.c_matrix.ld()),
_
info.c_matrix.stride,
static_cast<int>(
_
info.batch),
compute_type,
MUBLAS_GEMM_DEFAULT));
return INFINI_STATUS_SUCCESS;
...
...
@@ -100,22 +122,4 @@ infiniStatus_t calculate(
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(void *workspace,
size_t workspace_size,
void *c,
float beta,
const void *a,
const void *b,
float alpha,
void *stream) const {
switch (_dtype) {
case INFINI_DTYPE_F16:
return musa::calculate<half>(_info, _opaque->internal, c, beta, a, b, alpha, stream);
case INFINI_DTYPE_F32:
return musa::calculate<float>(_info,_opaque->internal, c, beta, a, b, alpha, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}
} // namespace op::gemm::musa
} // namespace op::gemm::moore
src/infiniop/ops/gemm/musa/gemm_musa.h
deleted
100644 → 0
View file @
eb3972eb
#ifndef __GEMM_MUSA_H__
#define __GEMM_MUSA_H__
#include "../gemm.h"
DESCRIPTOR
(
musa
)
#endif // __GEMM_MUSA_H__
src/infiniop/ops/gemm/operator.cc
View file @
8b760951
...
...
@@ -18,7 +18,7 @@
#include "metax/gemm_metax.h"
#endif
#ifdef ENABLE_MOORE_API
#include "m
usa
/gemm_m
usa
.h"
#include "m
oore
/gemm_m
oore
.h"
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/gemm_kunlun.h"
...
...
@@ -61,7 +61,7 @@ __C infiniStatus_t infiniopCreateGemmDescriptor(
CREATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_MOORE_API
CREATE
(
INFINI_DEVICE_MOORE
,
m
usa
);
CREATE
(
INFINI_DEVICE_MOORE
,
m
oore
);
#endif
#ifdef ENABLE_KUNLUN_API
...
...
@@ -106,7 +106,7 @@ infiniopGetGemmWorkspaceSize(
GET
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_MOORE_API
GET
(
INFINI_DEVICE_MOORE
,
m
usa
);
GET
(
INFINI_DEVICE_MOORE
,
m
oore
);
#endif
#ifdef ENABLE_KUNLUN_API
GET
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
...
...
@@ -158,7 +158,7 @@ __C infiniStatus_t infiniopGemm(
CALCULATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_MOORE_API
CALCULATE
(
INFINI_DEVICE_MOORE
,
m
usa
);
CALCULATE
(
INFINI_DEVICE_MOORE
,
m
oore
);
#endif
#ifdef ENABLE_KUNLUN_API
CALCULATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
...
...
@@ -200,7 +200,7 @@ infiniopDestroyGemmDescriptor(infiniopGemmDescriptor_t desc) {
DELETE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_MOORE_API
DELETE
(
INFINI_DEVICE_MOORE
,
m
usa
);
DELETE
(
INFINI_DEVICE_MOORE
,
m
oore
);
#endif
#ifdef ENABLE_KUNLUN_API
DELETE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
...
...
src/infiniop/ops/mul/kunlun/kernel.h
0 → 100644
View file @
8b760951
#ifndef __MUL_KUNLUN_KERNEL_H__
#define __MUL_KUNLUN_KERNEL_H__
namespace
op
::
mul
::
kunlun
{
typedef
struct
MulOp
{
public:
static
constexpr
int
num_inputs
=
2
;
template
<
typename
T
>
inline
__device__
T
operator
()(
const
T
*
inputs
)
const
{
T
a
=
inputs
[
0
];
T
b
=
inputs
[
1
];
return
a
*
b
;
}
// bfloat16 特化版本(使用 float 计算精度)
inline
__device__
bfloat16_t
operator
()(
const
bfloat16_t
*
inputs
)
const
{
float
a_f
=
__bfloat162float
(
inputs
[
0
]);
float
b_f
=
__bfloat162float
(
inputs
[
1
]);
return
__float2bfloat16
(
a_f
*
b_f
);
}
}
MulOp
;
}
// namespace op::mul::kunlun
#endif // __MUL_KUNLUN_KERNEL_H__
src/infiniop/ops/mul/kunlun/mul_kunlun.h
0 → 100644
View file @
8b760951
#ifndef __MUL_KUNLUN_API_H__
#define __MUL_KUNLUN_API_H__
#include "../../../elementwise/kunlun/elementwise_kunlun_api.h"
ELEMENTWISE_DESCRIPTOR
(
mul
,
kunlun
)
#endif // __MUL_KUNLUN_API_H__
src/infiniop/ops/mul/kunlun/mul_kunlun.xpu
0 → 100644
View file @
8b760951
#include "../../../elementwise/kunlun/elementwise_kunlun.h"
#include "kernel.h"
#include "mul_kunlun.h"
namespace op::elementwise::kunlun {
using MulOp = op::mul::kunlun::MulOp;
INSTANTIATE_ELEMENTWISE_KERNEL(MulOp::num_inputs, MulOp, float);
INSTANTIATE_ELEMENTWISE_KERNEL(MulOp::num_inputs, MulOp, half);
INSTANTIATE_ELEMENTWISE_KERNEL(MulOp::num_inputs, MulOp, bfloat16_t);
} // namespace op::elementwise::kunlun
namespace op::mul::kunlun {
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
auto handle = reinterpret_cast<device::kunlun::Handle *>(handle_);
auto dtype = out_desc->dtype();
const auto &a_desc = input_desc_vec.at(0);
const auto &b_desc = input_desc_vec.at(1);
const auto &c_shape = out_desc->shape();
const auto &a_shape = a_desc->shape();
const auto &b_shape = b_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16);
CHECK_SAME_SHAPE(c_shape, a_shape, b_shape);
// create KUNLUN elementwise descriptor
CREATE_ELEMENTWISE_KUNLUN_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<8, MulOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_BF16:
return _device_info->calculate<8, MulOp, bfloat16_t>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<8, MulOp, float>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::mul::kunlun
src/infiniop/ops/mul/operator.cc
View file @
8b760951
...
...
@@ -11,6 +11,9 @@
#ifdef ENABLE_METAX_API
#include "metax/mul_metax.h"
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/mul_kunlun.h"
#endif
__C
infiniStatus_t
infiniopCreateMulDescriptor
(
infiniopHandle_t
handle
,
...
...
@@ -42,6 +45,9 @@ __C infiniStatus_t infiniopCreateMulDescriptor(
#ifdef ENABLE_METAX_API
CREATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_KUNLUN_API
CREATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -70,6 +76,9 @@ __C infiniStatus_t infiniopGetMulWorkspaceSize(infiniopMulDescriptor_t desc, siz
#ifdef ENABLE_METAX_API
GET
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_KUNLUN_API
GET
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -107,6 +116,9 @@ __C infiniStatus_t infiniopMul(
#ifdef ENABLE_METAX_API
CALCULATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_KUNLUN_API
CALCULATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -137,6 +149,9 @@ infiniopDestroyMulDescriptor(infiniopMulDescriptor_t desc) {
#ifdef ENABLE_METAX_API
DELETE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_KUNLUN_API
DELETE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
src/infiniop/ops/rms_norm/cuda/kernel.cuh
View file @
8b760951
...
...
@@ -22,7 +22,7 @@ __device__ void rmsnormBlock(
// Thread_0 computes RMS=1/sqrt(ss/dim+epsilon) and stores in shared memory
__shared__
Tcompute
rms
;
if
(
threadIdx
.
x
==
0
)
{
rms
=
T
data
(
rsqrtf
(
ss
/
Tcompute
(
dim
)
+
epsilon
));
rms
=
T
compute
(
rsqrtf
(
ss
/
Tcompute
(
dim
)
+
epsilon
));
}
__syncthreads
();
...
...
src/infiniop/ops/rms_norm/m
usa
/rms_norm_m
usa.cu
h
→
src/infiniop/ops/rms_norm/m
oore
/rms_norm_m
oore.
h
View file @
8b760951
#ifndef __RMS_NORM_M
USA_CU
H__
#define __RMS_NORM_M
USA_CU
H__
#ifndef __RMS_NORM_M
OORE_
H__
#define __RMS_NORM_M
OORE_
H__
#include "../rms_norm.h"
DESCRIPTOR
(
m
usa
)
DESCRIPTOR
(
m
oore
)
#endif
src/infiniop/ops/rms_norm/m
usa
/rms_norm_m
usa
.mu
→
src/infiniop/ops/rms_norm/m
oore
/rms_norm_m
oore
.mu
View file @
8b760951
#include "../../../devices/musa/common_musa.h"
#include "../cuda/rms_norm_kernel.cuh"
#include "rms_norm_musa.cuh"
#include "../../../devices/moore/moore_common.h"
#include "rms_norm_moore.h"
namespace op::rms_norm::musa {
#include "../../../devices/moore/moore_kernel_common.h"
#include <cub/block/block_reduce.cuh>
#include "../../../reduce/cuda/reduce.cuh"
#include "../cuda/kernel.cuh"
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
INFINIOP_MOORE_KERNEL rmsnormKernel(
Tdata *__restrict__ y,
ptrdiff_t stride_y,
const Tdata *__restrict__ x,
ptrdiff_t stride_x,
const Tweight *__restrict__ w,
size_t dim,
float epsilon) {
rmsnormBlock<BLOCK_SIZE, Tcompute>(y, stride_y, x, stride_x, w, dim, epsilon);
}
namespace op::rms_norm::moore {
struct Descriptor::Opaque {
std::shared_ptr<device::m
usa
::Handle::Internal> internal;
std::shared_ptr<device::m
oore
::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
...
...
@@ -29,7 +47,7 @@ infiniStatus_t Descriptor::create(
}
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::m
usa
::Handle *>(handle)->internal()},
new Opaque{reinterpret_cast<device::m
oore
::Handle *>(handle)->internal()},
std::move(info),
0,
handle->device, handle->device_id);
...
...
@@ -46,20 +64,24 @@ infiniStatus_t launchKernel(
float epsilon,
musaStream_t musa_stream) {
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
rmsnorm
Block
<BLOCK_SIZE, Tdata, Tweight
, Tcompute
><<<batch_size, BLOCK_SIZE, 0, musa_stream>>>( \
reinterpret_cast<Tdata *>(y), \
stride_y, \
reinterpret_cast<const Tdata *>(x), \
stride_x, \
reinterpret_cast<const Tweight *>(w), \
dim, \
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute)
\
rmsnorm
Kernel
<BLOCK_SIZE,
Tcompute,
Tdata, Tweight><<<batch_size, BLOCK_SIZE, 0, musa_stream>>>( \
reinterpret_cast<Tdata *>(y),
\
stride_y,
\
reinterpret_cast<const Tdata *>(x),
\
stride_x,
\
reinterpret_cast<const Tweight *>(w),
\
dim,
\
epsilon)
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(half, half, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(half, float, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(__mt_bfloat16, __mt_bfloat16, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(__mt_bfloat16, float, float);
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(float, float, float);
} else {
...
...
@@ -87,11 +109,15 @@ infiniStatus_t Descriptor::calculate(
auto musa_stream = reinterpret_cast<musaStream_t>(stream);
// launch kernel with different block sizes
if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_1024>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, musa_stream));
if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_1024>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, musa_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_512) {
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_512>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, musa_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_2048) {
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_2048>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, musa_stream));
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::rms_norm::m
usa
} // namespace op::rms_norm::m
oore
src/infiniop/ops/rms_norm/operator.cc
View file @
8b760951
...
...
@@ -15,7 +15,7 @@
#include "metax/rms_norm_metax.cuh"
#endif
#ifdef ENABLE_MOORE_API
#include "m
usa
/rms_norm_m
usa.cu
h"
#include "m
oore
/rms_norm_m
oore.
h"
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/rms_norm_kunlun.h"
...
...
@@ -64,7 +64,7 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor(
CREATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_MOORE_API
CREATE
(
INFINI_DEVICE_MOORE
,
m
usa
);
CREATE
(
INFINI_DEVICE_MOORE
,
m
oore
);
#endif
}
...
...
@@ -105,7 +105,7 @@ __C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t d
GET
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_MOORE_API
GET
(
INFINI_DEVICE_MOORE
,
m
usa
);
GET
(
INFINI_DEVICE_MOORE
,
m
oore
);
#endif
}
...
...
@@ -147,7 +147,7 @@ __C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *works
CALCULATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_MOORE_API
CALCULATE
(
INFINI_DEVICE_MOORE
,
m
usa
);
CALCULATE
(
INFINI_DEVICE_MOORE
,
m
oore
);
#endif
}
...
...
@@ -188,7 +188,7 @@ __C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t
DESTROY
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_MOORE_API
DESTROY
(
INFINI_DEVICE_MOORE
,
m
usa
);
DESTROY
(
INFINI_DEVICE_MOORE
,
m
oore
);
#endif
}
...
...
src/infiniop/ops/sub/kunlun/kernel.h
0 → 100644
View file @
8b760951
#ifndef __SUB_KUNLUN_KERNEL_H__
#define __SUB_KUNLUN_KERNEL_H__
namespace
op
::
sub
::
kunlun
{
typedef
struct
SubOp
{
public:
static
constexpr
int
num_inputs
=
2
;
template
<
typename
T
>
inline
__device__
T
operator
()(
const
T
*
inputs
)
const
{
T
a
=
inputs
[
0
];
T
b
=
inputs
[
1
];
return
a
-
b
;
}
// bfloat16 特化版本(使用 float 计算精度)
inline
__device__
bfloat16_t
operator
()(
const
bfloat16_t
*
inputs
)
const
{
float
a_f
=
__bfloat162float
(
inputs
[
0
]);
float
b_f
=
__bfloat162float
(
inputs
[
1
]);
return
__float2bfloat16
(
a_f
-
b_f
);
}
}
SubOp
;
}
// namespace op::sub::kunlun
#endif // __SUB_KUNLUN_KERNEL_H__
src/infiniop/ops/sub/kunlun/sub_kunlun.h
0 → 100644
View file @
8b760951
#ifndef __SUB_KUNLUN_API_H__
#define __SUB_KUNLUN_API_H__
#include "../../../elementwise/kunlun/elementwise_kunlun_api.h"
ELEMENTWISE_DESCRIPTOR
(
sub
,
kunlun
)
#endif // __SUB_KUNLUN_API_H__
src/infiniop/ops/sub/kunlun/sub_kunlun.xpu
0 → 100644
View file @
8b760951
#include "../../../elementwise/kunlun/elementwise_kunlun.h"
#include "kernel.h"
#include "sub_kunlun.h"
namespace op::elementwise::kunlun {
using SubOp = op::sub::kunlun::SubOp;
INSTANTIATE_ELEMENTWISE_KERNEL(SubOp::num_inputs, SubOp, float);
INSTANTIATE_ELEMENTWISE_KERNEL(SubOp::num_inputs, SubOp, half);
INSTANTIATE_ELEMENTWISE_KERNEL(SubOp::num_inputs, SubOp, bfloat16_t);
} // namespace op::elementwise::kunlun
namespace op::sub::kunlun {
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
auto handle = reinterpret_cast<device::kunlun::Handle *>(handle_);
auto dtype = out_desc->dtype();
const auto &a_desc = input_desc_vec.at(0);
const auto &b_desc = input_desc_vec.at(1);
const auto &c_shape = out_desc->shape();
const auto &a_shape = a_desc->shape();
const auto &b_shape = b_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16);
CHECK_SAME_SHAPE(c_shape, a_shape, b_shape);
// create KUNLUN elementwise descriptor
CREATE_ELEMENTWISE_KUNLUN_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<8, SubOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_BF16:
return _device_info->calculate<8, SubOp, bfloat16_t>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<8, SubOp, float>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::sub::kunlun
src/infiniop/ops/sub/operator.cc
View file @
8b760951
...
...
@@ -11,6 +11,9 @@
#ifdef ENABLE_METAX_API
#include "metax/sub_metax.h"
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/sub_kunlun.h"
#endif
__C
infiniStatus_t
infiniopCreateSubDescriptor
(
infiniopHandle_t
handle
,
...
...
@@ -42,6 +45,9 @@ __C infiniStatus_t infiniopCreateSubDescriptor(
#ifdef ENABLE_METAX_API
CREATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_KUNLUN_API
CREATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -70,6 +76,10 @@ __C infiniStatus_t infiniopGetSubWorkspaceSize(infiniopSubDescriptor_t desc, siz
#ifdef ENABLE_METAX_API
GET
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_KUNLUN_API
GET
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
...
...
@@ -106,6 +116,9 @@ __C infiniStatus_t infiniopSub(
#ifdef ENABLE_METAX_API
CALCULATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_KUNLUN_API
CALCULATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -136,6 +149,9 @@ infiniopDestroySubDescriptor(infiniopSubDescriptor_t desc) {
#ifdef ENABLE_METAX_API
DELETE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_KUNLUN_API
DELETE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
src/infiniop/ops/swiglu/bang/swiglu_bang.h
0 → 100644
View file @
8b760951
#ifndef __SWIGLU_BANG_API_H__
#define __SWIGLU_BANG_API_H__
#include "../../../elementwise/bang/elementwise_bang.h"
ELEMENTWISE_DESCRIPTOR
(
swiglu
,
bang
)
#endif // __SWIGLU_BANG_API_H__
src/infiniop/ops/swiglu/bang/swiglu_bang.mlu
0 → 100644
View file @
8b760951
#include "swiglu_bang.h"
// Operator Interface Declaration
LAUNCH_ELEMENTWISE_KERNEL(SwiGLU)
namespace op::swiglu::bang {
typedef struct SwiGLUOp {
static constexpr size_t num_inputs = 2;
template <typename Tdata, typename... Args>
static infiniStatus_t launch(Args... args) {
launchSwiGLUKernel<Tdata>(args...);
return INFINI_STATUS_SUCCESS;
}
} SwiGLUOp;
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
auto handle = reinterpret_cast<device::bang::Handle *>(handle_);
auto dtype = out_desc->dtype();
const auto &up_desc = input_desc_vec.at(0);
const auto &gate_desc = input_desc_vec.at(1);
const auto &out_shape = out_desc->shape();
const auto &up_shape = up_desc->shape();
const auto &gate_shape = gate_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16);
CHECK_SAME_SHAPE(out_shape, up_shape, gate_shape);
// create Bang elementwise descriptor
CREATE_ELEMENTWISE_BANG_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *queue) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<SwiGLUOp, half>(_info, workspace, output, inputs, queue);
case INFINI_DTYPE_BF16:
return _device_info->calculate<SwiGLUOp, bfloat16_t>(_info, workspace, output, inputs, queue);
case INFINI_DTYPE_F32:
return _device_info->calculate<SwiGLUOp, float>(_info, workspace, output, inputs, queue);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::swiglu::bang
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment