Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
21c6af2d
Unverified
Commit
21c6af2d
authored
Mar 11, 2026
by
thatPepe
Committed by
GitHub
Mar 11, 2026
Browse files
Merge pull request #1069 from InfiniTensor/issue/1031_T1_1_15
【算子比赛2025秋】T1-1-15
parents
99a802dd
5f329d7a
Changes
112
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1277 additions
and
0 deletions
+1277
-0
src/infiniop/ops/cdist/cpu/cdist_cpu.h
src/infiniop/ops/cdist/cpu/cdist_cpu.h
+11
-0
src/infiniop/ops/cdist/info.h
src/infiniop/ops/cdist/info.h
+113
-0
src/infiniop/ops/cdist/metax/cdist_metax.h
src/infiniop/ops/cdist/metax/cdist_metax.h
+16
-0
src/infiniop/ops/cdist/metax/cdist_metax.maca
src/infiniop/ops/cdist/metax/cdist_metax.maca
+161
-0
src/infiniop/ops/cdist/moore/cdist_moore.h
src/infiniop/ops/cdist/moore/cdist_moore.h
+16
-0
src/infiniop/ops/cdist/moore/cdist_moore.mu
src/infiniop/ops/cdist/moore/cdist_moore.mu
+145
-0
src/infiniop/ops/cdist/nvidia/cdist_nvidia.cu
src/infiniop/ops/cdist/nvidia/cdist_nvidia.cu
+161
-0
src/infiniop/ops/cdist/nvidia/cdist_nvidia.cuh
src/infiniop/ops/cdist/nvidia/cdist_nvidia.cuh
+16
-0
src/infiniop/ops/cdist/operator.cc
src/infiniop/ops/cdist/operator.cc
+229
-0
src/infiniop/ops/reciprocal/cpu/reciprocal_cpu.cc
src/infiniop/ops/reciprocal/cpu/reciprocal_cpu.cc
+54
-0
src/infiniop/ops/reciprocal/cpu/reciprocal_cpu.h
src/infiniop/ops/reciprocal/cpu/reciprocal_cpu.h
+19
-0
src/infiniop/ops/reciprocal/cuda/kernel.cuh
src/infiniop/ops/reciprocal/cuda/kernel.cuh
+27
-0
src/infiniop/ops/reciprocal/metax/reciprocal_metax.h
src/infiniop/ops/reciprocal/metax/reciprocal_metax.h
+8
-0
src/infiniop/ops/reciprocal/metax/reciprocal_metax.maca
src/infiniop/ops/reciprocal/metax/reciprocal_metax.maca
+61
-0
src/infiniop/ops/reciprocal/metax/reciprocal_metax_kernel.h
src/infiniop/ops/reciprocal/metax/reciprocal_metax_kernel.h
+47
-0
src/infiniop/ops/reciprocal/moore/reciprocal_moore.h
src/infiniop/ops/reciprocal/moore/reciprocal_moore.h
+11
-0
src/infiniop/ops/reciprocal/moore/reciprocal_moore.mu
src/infiniop/ops/reciprocal/moore/reciprocal_moore.mu
+66
-0
src/infiniop/ops/reciprocal/moore/reciprocal_moore_kernel.h
src/infiniop/ops/reciprocal/moore/reciprocal_moore_kernel.h
+47
-0
src/infiniop/ops/reciprocal/nvidia/reciprocal_nvidia.cu
src/infiniop/ops/reciprocal/nvidia/reciprocal_nvidia.cu
+61
-0
src/infiniop/ops/reciprocal/nvidia/reciprocal_nvidia.cuh
src/infiniop/ops/reciprocal/nvidia/reciprocal_nvidia.cuh
+8
-0
No files found.
src/infiniop/ops/cdist/cpu/cdist_cpu.h
0 → 100644
View file @
21c6af2d
#ifndef __CDIST_CPU_H__
#define __CDIST_CPU_H__
#include "../cdist.h"
// 使用 cdist.h 中定义的 DESCRIPTOR 宏
// 这将在命名空间 op::cdist::cpu 中生成 Descriptor 类
// 该类包含对 CdistInfo 的引用以及 create/calculate 等接口
DESCRIPTOR
(
cpu
)
#endif // __CDIST_CPU_H__
src/infiniop/ops/cdist/info.h
0 → 100644
View file @
21c6af2d
#ifndef __CDIST_INFO_H__
#define __CDIST_INFO_H__
#include "../../../utils.h"
#include "../../operator.h"
#include "../../tensor.h"
#include <algorithm>
namespace
op
::
cdist
{
/**
* 借用 BlasMatrix 的概念来描述 cdist 的输入输出矩阵
* x1: (Batch, M, D)
* x2: (Batch, N, D)
* y: (Batch, M, N)
*/
struct
CdistMatrix
{
size_t
ndim
;
size_t
batch
;
ptrdiff_t
stride
;
// Batch 之间的步长
size_t
rows
;
// M 或 N
size_t
cols
;
// D (特征维度) 或结果中的 N
ptrdiff_t
row_stride
;
ptrdiff_t
col_stride
;
static
utils
::
Result
<
CdistMatrix
>
create
(
infiniopTensorDescriptor_t
layout
)
{
CdistMatrix
ans
;
auto
ndim
=
layout
->
ndim
();
if
(
ndim
==
2
)
{
ans
.
ndim
=
2
;
ans
.
batch
=
1
;
ans
.
stride
=
0
;
ans
.
rows
=
layout
->
dim
(
0
);
ans
.
cols
=
layout
->
dim
(
1
);
ans
.
row_stride
=
layout
->
stride
(
0
);
ans
.
col_stride
=
layout
->
stride
(
1
);
}
else
if
(
ndim
==
3
)
{
ans
.
ndim
=
3
;
ans
.
batch
=
layout
->
dim
(
0
);
ans
.
stride
=
ans
.
batch
==
1
?
0
:
layout
->
stride
(
0
);
ans
.
rows
=
layout
->
dim
(
1
);
ans
.
cols
=
layout
->
dim
(
2
);
ans
.
row_stride
=
layout
->
stride
(
1
);
ans
.
col_stride
=
layout
->
stride
(
2
);
}
else
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
return
utils
::
Result
<
CdistMatrix
>
(
ans
);
}
bool
match_batch
(
size_t
_batch
)
const
{
return
batch
==
_batch
||
batch
==
1
;
}
};
class
CdistInfo
{
CdistInfo
()
=
default
;
public:
CdistMatrix
x1_matrix
;
CdistMatrix
x2_matrix
;
CdistMatrix
y_matrix
;
size_t
m
,
n
,
d
,
batch
;
static
utils
::
Result
<
CdistInfo
>
create
(
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
x1_desc
,
infiniopTensorDescriptor_t
x2_desc
)
{
auto
x1_res
=
CdistMatrix
::
create
(
x1_desc
);
CHECK_RESULT
(
x1_res
);
auto
x2_res
=
CdistMatrix
::
create
(
x2_desc
);
CHECK_RESULT
(
x2_res
);
auto
y_res
=
CdistMatrix
::
create
(
y_desc
);
CHECK_RESULT
(
y_res
);
auto
x1
=
x1_res
.
take
();
auto
x2
=
x2_res
.
take
();
auto
y
=
y_res
.
take
();
// 1. 维度校验
// x1(M, D), x2(N, D) -> y(M, N)
if
(
x1
.
cols
!=
x2
.
cols
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
// 特征维度 D 必须一致
}
if
(
y
.
rows
!=
x1
.
rows
||
y
.
cols
!=
x2
.
rows
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
// 输出形状必须为 M x N
}
// 2. Batch 校验
size_t
batch_size
=
y
.
batch
;
if
(
!
x1
.
match_batch
(
batch_size
)
||
!
x2
.
match_batch
(
batch_size
))
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
size_t
m
=
x1
.
rows
;
size_t
n
=
x2
.
rows
;
size_t
d
=
x1
.
cols
;
return
utils
::
Result
<
CdistInfo
>
(
CdistInfo
{
x1
,
x2
,
y
,
m
,
n
,
d
,
batch_size
});
}
};
}
// namespace op::cdist
#endif // __CDIST_INFO_H__
src/infiniop/ops/cdist/metax/cdist_metax.h
0 → 100644
View file @
21c6af2d
#ifndef __CDIST_METAX_CUH__
#define __CDIST_METAX_CUH__
#include "../cdist.h"
/**
* 使用 cdist.h 中定义的 DESCRIPTOR 宏。
* 这将在命名空间 op::cdist::metax 中生成针对 METAX 设备的 Descriptor 类。
* * 在 METAX 端的具体实现中,Opaque 结构体通常会存储:
* - cublasHandle_t: 用于 p=2.0 时的矩阵乘法加速。
* - cudaStream_t: 当前执行的任务流。
* - 自定义 Kernel 的配置参数。
*/
DESCRIPTOR
(
metax
)
#endif // __CDIST_METAX_CUH__
src/infiniop/ops/cdist/metax/cdist_metax.maca
0 → 100644
View file @
21c6af2d
#include "../../../devices/metax/metax_handle.h"
#include "cdist_metax.h"
#include <iostream>
namespace op::cdist::metax {
struct Descriptor::Opaque {};
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x1_desc,
infiniopTensorDescriptor_t x2_desc,
double p) {
auto handle = reinterpret_cast<device::metax::Handle *>(handle_);
auto dtype = y_desc->dtype();
// 目前 METAX 后端仅支持 F32,测试也是 F32
CHECK_DTYPE(dtype, INFINI_DTYPE_F32);
auto result = CdistInfo::create(y_desc, x1_desc, x2_desc);
CHECK_RESULT(result);
// 当前实现不使用 workspace
*desc_ptr = new Descriptor(
dtype, result.take(), p, 0,
nullptr,
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
// --- Kernel 1: L2 Epilogue ---
// 保留占位,当前实现未使用 GEMM 加速路径
template <typename T>
__global__ void cdist_l2_epilogue_kernel(T *y, const T *x1_norm, const T *x2_norm,
int M, int N, int batch_stride_y) {
int j = blockIdx.x * blockDim.x + threadIdx.x;
int i = blockIdx.y * blockDim.y + threadIdx.y;
int b = blockIdx.z;
if (i < M && j < N) {
int idx = b * batch_stride_y + i * N + j;
// GEMM 已经计算了 -2*x1*x2^T 并存入 y
float val = (float)x1_norm[b * M + i] + (float)x2_norm[b * N + j] + (float)y[idx];
y[idx] = (T)sqrtf(fmaxf(val, 0.0f));
}
}
// --- Kernel 2: Generic P-Norm (F32, 支持通用步长) ---
__global__ void cdist_generic_kernel_f32(
float *y,
const float *x1,
const float *x2,
size_t m,
size_t n,
size_t d,
ptrdiff_t x1_stride,
ptrdiff_t x1_row_stride,
ptrdiff_t x1_col_stride,
ptrdiff_t x2_stride,
ptrdiff_t x2_row_stride,
ptrdiff_t x2_col_stride,
ptrdiff_t y_stride,
ptrdiff_t y_row_stride,
ptrdiff_t y_col_stride,
double p) {
int j = blockIdx.x * blockDim.x + threadIdx.x;
int i = blockIdx.y * blockDim.y + threadIdx.y;
int b = blockIdx.z;
if (i >= (int)m || j >= (int)n) {
return;
}
// 定位输出位置 y[b, i, j]
float *y_ptr = y + b * y_stride + i * y_row_stride + j * y_col_stride;
// 定位向量位置 x1[b, i, :] 和 x2[b, j, :]
const float *x1_vec = x1 + b * x1_stride + i * x1_row_stride;
const float *x2_vec = x2 + b * x2_stride + j * x2_row_stride;
double dist = 0.0;
for (size_t k = 0; k < d; ++k) {
float v1 = *(x1_vec + k * x1_col_stride);
float v2 = *(x2_vec + k * x2_col_stride);
float diff = fabsf(v1 - v2);
if (p == 1.0) {
dist += diff;
} else if (p == 2.0) {
dist += diff * diff;
} else if (isinf(p)) {
dist = fmaxf((float)dist, diff);
} else {
dist += powf((float)diff, (float)p);
}
}
if (p == 2.0) {
dist = sqrtf((float)dist);
} else if (!isinf(p) && p != 1.0) {
dist = powf((float)dist, 1.0f / (float)p);
}
*y_ptr = (float)dist;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *y,
const void *x1,
const void *x2,
void *stream) const {
(void)workspace;
(void)workspace_size;
if (_dtype != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
mcStream_t custream = (mcStream_t)stream;
dim3 block(16, 16);
dim3 grid(
static_cast<unsigned int>((_info.n + block.x - 1) / block.x),
static_cast<unsigned int>((_info.m + block.y - 1) / block.y),
static_cast<unsigned int>(_info.batch));
cdist_generic_kernel_f32<<<grid, block, 0, custream>>>(
static_cast<float *>(y),
static_cast<const float *>(x1),
static_cast<const float *>(x2),
_info.m,
_info.n,
_info.d,
_info.x1_matrix.stride,
_info.x1_matrix.row_stride,
_info.x1_matrix.col_stride,
_info.x2_matrix.stride,
_info.x2_matrix.row_stride,
_info.x2_matrix.col_stride,
_info.y_matrix.stride,
_info.y_matrix.row_stride,
_info.y_matrix.col_stride,
_p);
auto err = mcGetLastError();
if (err != mcSuccess) {
return INFINI_STATUS_INTERNAL_ERROR;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::cdist::metax
src/infiniop/ops/cdist/moore/cdist_moore.h
0 → 100644
View file @
21c6af2d
#ifndef __CDIST_MOORE_H__
#define __CDIST_MOORE_H__
#include "../cdist.h"
/**
* 使用 cdist.h 中定义的 DESCRIPTOR 宏。
* 这将在命名空间 op::cdist::moore 中生成针对 Moore 设备的 Descriptor 类。
* * 在 Moore 端的具体实现中,Opaque 结构体通常会存储:
* - mublasHandle_t: 用于 p=2.0 时的矩阵乘法加速(对应 NVIDIA 的 cuBLAS)。
* - musaStream_t: 当前执行的任务流。
* - 自定义 Kernel 的配置参数。
*/
DESCRIPTOR
(
moore
)
#endif // __CDIST_MOORE_H__
src/infiniop/ops/cdist/moore/cdist_moore.mu
0 → 100644
View file @
21c6af2d
#include "../../../devices/moore/moore_handle.h"
#include "cdist_moore.h"
#include <iostream>
#include <musa_runtime.h>
namespace op::cdist::moore {
struct Descriptor::Opaque {};
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x1_desc,
infiniopTensorDescriptor_t x2_desc,
double p) {
// 1. 转换至 Moore 句柄
auto handle = reinterpret_cast<device::moore::Handle *>(handle_);
auto dtype = y_desc->dtype();
// 保持与原版一致,目前仅支持 F32
CHECK_DTYPE(dtype, INFINI_DTYPE_F32);
auto result = CdistInfo::create(y_desc, x1_desc, x2_desc);
CHECK_RESULT(result);
*desc_ptr = new Descriptor(
dtype, result.take(), p, 0,
nullptr,
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
// --- Kernel: Generic P-Norm (MUSA F32 实现) ---
__global__ void cdist_generic_kernel_f32(
float *y,
const float *x1,
const float *x2,
size_t m,
size_t n,
size_t d,
ptrdiff_t x1_stride,
ptrdiff_t x1_row_stride,
ptrdiff_t x1_col_stride,
ptrdiff_t x2_stride,
ptrdiff_t x2_row_stride,
ptrdiff_t x2_col_stride,
ptrdiff_t y_stride,
ptrdiff_t y_row_stride,
ptrdiff_t y_col_stride,
double p) {
// 2. MUSA 同样支持 3D 线程索引
int j = blockIdx.x * blockDim.x + threadIdx.x;
int i = blockIdx.y * blockDim.y + threadIdx.y;
int b = blockIdx.z;
if (i >= (int)m || j >= (int)n) {
return;
}
// 定位输出 y[b, i, j]
float *y_ptr = y + b * y_stride + i * y_row_stride + j * y_col_stride;
// 定位输入向量
const float *x1_vec = x1 + b * x1_stride + i * x1_row_stride;
const float *x2_vec = x2 + b * x2_stride + j * x2_row_stride;
double dist = 0.0;
for (size_t k = 0; k < d; ++k) {
float v1 = *(x1_vec + k * x1_col_stride);
float v2 = *(x2_vec + k * x2_col_stride);
float diff = fabsf(v1 - v2);
if (p == 1.0) {
dist += (double)diff;
} else if (p == 2.0) {
dist += (double)diff * diff;
} else if (isinf(p)) {
dist = fmaxf((float)dist, diff);
} else {
dist += powf(diff, (float)p);
}
}
if (p == 2.0) {
dist = sqrtf((float)dist);
} else if (!isinf(p) && p != 1.0) {
dist = powf((float)dist, 1.0f / (float)p);
}
*y_ptr = (float)dist;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *y,
const void *x1,
const void *x2,
void *stream) const {
(void)workspace;
(void)workspace_size;
if (_dtype != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
// 3. 切换至 musaStream_t
musaStream_t mustream = reinterpret_cast<musaStream_t>(stream);
// 保持 16x16 的 Block 大小,这在 MUSA 架构上也是通用的
dim3 block(16, 16);
dim3 grid(
static_cast<unsigned int>((_info.n + block.x - 1) / block.x),
static_cast<unsigned int>((_info.m + block.y - 1) / block.y),
static_cast<unsigned int>(_info.batch));
cdist_generic_kernel_f32<<<grid, block, 0, mustream>>>(
static_cast<float *>(y),
static_cast<const float *>(x1),
static_cast<const float *>(x2),
_info.m,
_info.n,
_info.d,
_info.x1_matrix.stride,
_info.x1_matrix.row_stride,
_info.x1_matrix.col_stride,
_info.x2_matrix.stride,
_info.x2_matrix.row_stride,
_info.x2_matrix.col_stride,
_info.y_matrix.stride,
_info.y_matrix.row_stride,
_info.y_matrix.col_stride,
_p);
return INFINI_STATUS_SUCCESS;
}
} // namespace op::cdist::moore
src/infiniop/ops/cdist/nvidia/cdist_nvidia.cu
0 → 100644
View file @
21c6af2d
#include "../../../devices/nvidia/nvidia_handle.cuh"
#include "cdist_nvidia.cuh"
#include <iostream>
namespace
op
::
cdist
::
nvidia
{
struct
Descriptor
::
Opaque
{};
Descriptor
::~
Descriptor
()
=
default
;
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle_
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
x1_desc
,
infiniopTensorDescriptor_t
x2_desc
,
double
p
)
{
auto
handle
=
reinterpret_cast
<
device
::
nvidia
::
Handle
*>
(
handle_
);
auto
dtype
=
y_desc
->
dtype
();
// 目前 NVIDIA 后端仅支持 F32,测试也是 F32
CHECK_DTYPE
(
dtype
,
INFINI_DTYPE_F32
);
auto
result
=
CdistInfo
::
create
(
y_desc
,
x1_desc
,
x2_desc
);
CHECK_RESULT
(
result
);
// 当前实现不使用 workspace
*
desc_ptr
=
new
Descriptor
(
dtype
,
result
.
take
(),
p
,
0
,
nullptr
,
handle
->
device
,
handle
->
device_id
);
return
INFINI_STATUS_SUCCESS
;
}
// --- Kernel 1: L2 Epilogue ---
// 保留占位,当前实现未使用 GEMM 加速路径
template
<
typename
T
>
__global__
void
cdist_l2_epilogue_kernel
(
T
*
y
,
const
T
*
x1_norm
,
const
T
*
x2_norm
,
int
M
,
int
N
,
int
batch_stride_y
)
{
int
j
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
i
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
int
b
=
blockIdx
.
z
;
if
(
i
<
M
&&
j
<
N
)
{
int
idx
=
b
*
batch_stride_y
+
i
*
N
+
j
;
// GEMM 已经计算了 -2*x1*x2^T 并存入 y
float
val
=
(
float
)
x1_norm
[
b
*
M
+
i
]
+
(
float
)
x2_norm
[
b
*
N
+
j
]
+
(
float
)
y
[
idx
];
y
[
idx
]
=
(
T
)
sqrtf
(
fmaxf
(
val
,
0.0
f
));
}
}
// --- Kernel 2: Generic P-Norm (F32, 支持通用步长) ---
__global__
void
cdist_generic_kernel_f32
(
float
*
y
,
const
float
*
x1
,
const
float
*
x2
,
size_t
m
,
size_t
n
,
size_t
d
,
ptrdiff_t
x1_stride
,
ptrdiff_t
x1_row_stride
,
ptrdiff_t
x1_col_stride
,
ptrdiff_t
x2_stride
,
ptrdiff_t
x2_row_stride
,
ptrdiff_t
x2_col_stride
,
ptrdiff_t
y_stride
,
ptrdiff_t
y_row_stride
,
ptrdiff_t
y_col_stride
,
double
p
)
{
int
j
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
i
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
int
b
=
blockIdx
.
z
;
if
(
i
>=
(
int
)
m
||
j
>=
(
int
)
n
)
{
return
;
}
// 定位输出位置 y[b, i, j]
float
*
y_ptr
=
y
+
b
*
y_stride
+
i
*
y_row_stride
+
j
*
y_col_stride
;
// 定位向量位置 x1[b, i, :] 和 x2[b, j, :]
const
float
*
x1_vec
=
x1
+
b
*
x1_stride
+
i
*
x1_row_stride
;
const
float
*
x2_vec
=
x2
+
b
*
x2_stride
+
j
*
x2_row_stride
;
double
dist
=
0.0
;
for
(
size_t
k
=
0
;
k
<
d
;
++
k
)
{
float
v1
=
*
(
x1_vec
+
k
*
x1_col_stride
);
float
v2
=
*
(
x2_vec
+
k
*
x2_col_stride
);
float
diff
=
fabsf
(
v1
-
v2
);
if
(
p
==
1.0
)
{
dist
+=
diff
;
}
else
if
(
p
==
2.0
)
{
dist
+=
diff
*
diff
;
}
else
if
(
isinf
(
p
))
{
dist
=
fmaxf
((
float
)
dist
,
diff
);
}
else
{
dist
+=
powf
((
float
)
diff
,
(
float
)
p
);
}
}
if
(
p
==
2.0
)
{
dist
=
sqrtf
((
float
)
dist
);
}
else
if
(
!
isinf
(
p
)
&&
p
!=
1.0
)
{
dist
=
powf
((
float
)
dist
,
1.0
f
/
(
float
)
p
);
}
*
y_ptr
=
(
float
)
dist
;
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
y
,
const
void
*
x1
,
const
void
*
x2
,
void
*
stream
)
const
{
(
void
)
workspace
;
(
void
)
workspace_size
;
if
(
_dtype
!=
INFINI_DTYPE_F32
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
cudaStream_t
custream
=
(
cudaStream_t
)
stream
;
dim3
block
(
16
,
16
);
dim3
grid
(
static_cast
<
unsigned
int
>
((
_info
.
n
+
block
.
x
-
1
)
/
block
.
x
),
static_cast
<
unsigned
int
>
((
_info
.
m
+
block
.
y
-
1
)
/
block
.
y
),
static_cast
<
unsigned
int
>
(
_info
.
batch
));
cdist_generic_kernel_f32
<<<
grid
,
block
,
0
,
custream
>>>
(
static_cast
<
float
*>
(
y
),
static_cast
<
const
float
*>
(
x1
),
static_cast
<
const
float
*>
(
x2
),
_info
.
m
,
_info
.
n
,
_info
.
d
,
_info
.
x1_matrix
.
stride
,
_info
.
x1_matrix
.
row_stride
,
_info
.
x1_matrix
.
col_stride
,
_info
.
x2_matrix
.
stride
,
_info
.
x2_matrix
.
row_stride
,
_info
.
x2_matrix
.
col_stride
,
_info
.
y_matrix
.
stride
,
_info
.
y_matrix
.
row_stride
,
_info
.
y_matrix
.
col_stride
,
_p
);
auto
err
=
cudaGetLastError
();
if
(
err
!=
cudaSuccess
)
{
return
INFINI_STATUS_INTERNAL_ERROR
;
}
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::cdist::nvidia
src/infiniop/ops/cdist/nvidia/cdist_nvidia.cuh
0 → 100644
View file @
21c6af2d
#ifndef __CDIST_NVIDIA_CUH__
#define __CDIST_NVIDIA_CUH__
#include "../cdist.h"
/**
* 使用 cdist.h 中定义的 DESCRIPTOR 宏。
* 这将在命名空间 op::cdist::nvidia 中生成针对 NVIDIA 设备的 Descriptor 类。
* * 在 NVIDIA 端的具体实现中,Opaque 结构体通常会存储:
* - cublasHandle_t: 用于 p=2.0 时的矩阵乘法加速。
* - cudaStream_t: 当前执行的任务流。
* - 自定义 Kernel 的配置参数。
*/
DESCRIPTOR
(
nvidia
)
#endif // __CDIST_NVIDIA_CUH__
src/infiniop/ops/cdist/operator.cc
0 → 100644
View file @
21c6af2d
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/cdist.h"
// 引入各硬件后端的 Descriptor 定义
#ifdef ENABLE_CPU_API
#include "cpu/cdist_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
#include "nvidia/cdist_nvidia.cuh"
#endif
#ifdef ENABLE_CAMBRICON_API
#include "bang/cdist_bang.h"
#endif
#ifdef ENABLE_ASCEND_API
#include "ascend/cdist_ascend.h"
#endif
#ifdef ENABLE_METAX_API
#include "metax/cdist_metax.h"
#endif
#ifdef ENABLE_MOORE_API
#include "moore/cdist_moore.h"
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/cdist_kunlun.h"
#endif
// -----------------------------------------------------------------------------
// 1. 创建描述符
// -----------------------------------------------------------------------------
__INFINI_C
infiniStatus_t
infiniopCreateCdistDescriptor
(
infiniopHandle_t
handle
,
infiniopCdistDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
x1_desc
,
infiniopTensorDescriptor_t
x2_desc
,
double
p
)
{
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::cdist::NAMESPACE::Descriptor::create(handle, \
reinterpret_cast<op::cdist::NAMESPACE::Descriptor **>(desc_ptr), \
y_desc, x1_desc, x2_desc, p)
switch
(
handle
->
device
)
{
#ifdef ENABLE_CPU_API
CREATE
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_NVIDIA_API
CREATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
CREATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
CREATE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_HYGON_API
CREATE
(
INFINI_DEVICE_HYGON
,
nvidia
);
#endif
#ifdef ENABLE_CAMBRICON_API
CREATE
(
INFINI_DEVICE_CAMBRICON
,
bang
);
#endif
#ifdef ENABLE_ASCEND_API
CREATE
(
INFINI_DEVICE_ASCEND
,
ascend
);
#endif
#ifdef ENABLE_METAX_API
CREATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_MOORE_API
CREATE
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
#ifdef ENABLE_KUNLUN_API
CREATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CREATE
}
// -----------------------------------------------------------------------------
// 2. 获取 Workspace 大小
// -----------------------------------------------------------------------------
__INFINI_C
infiniStatus_t
infiniopGetCdistWorkspaceSize
(
infiniopCdistDescriptor_t
desc
,
size_t
*
size
)
{
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<const op::cdist::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU_API
GET
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_NVIDIA_API
GET
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
GET
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
GET
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_HYGON_API
GET
(
INFINI_DEVICE_HYGON
,
nvidia
);
#endif
#ifdef ENABLE_CAMBRICON_API
GET
(
INFINI_DEVICE_CAMBRICON
,
bang
);
#endif
#ifdef ENABLE_ASCEND_API
GET
(
INFINI_DEVICE_ASCEND
,
ascend
);
#endif
#ifdef ENABLE_METAX_API
GET
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_MOORE_API
GET
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
#ifdef ENABLE_KUNLUN_API
GET
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef GET
}
// -----------------------------------------------------------------------------
// 3. 执行计算 (计算成对距离)
// -----------------------------------------------------------------------------
__INFINI_C
infiniStatus_t
infiniopCdist
(
infiniopCdistDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
y
,
const
void
*
x1
,
const
void
*
x2
,
void
*
stream
)
{
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::cdist::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, y, x1, x2, stream)
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU_API
CALCULATE
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_NVIDIA_API
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
CALCULATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
CALCULATE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_HYGON_API
CALCULATE
(
INFINI_DEVICE_HYGON
,
nvidia
);
#endif
#ifdef ENABLE_CAMBRICON_API
CALCULATE
(
INFINI_DEVICE_CAMBRICON
,
bang
);
#endif
#ifdef ENABLE_ASCEND_API
CALCULATE
(
INFINI_DEVICE_ASCEND
,
ascend
);
#endif
#ifdef ENABLE_METAX_API
CALCULATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_MOORE_API
CALCULATE
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
#ifdef ENABLE_KUNLUN_API
CALCULATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CALCULATE
}
// -----------------------------------------------------------------------------
// 4. 销毁描述符
// -----------------------------------------------------------------------------
__INFINI_C
infiniStatus_t
infiniopDestroyCdistDescriptor
(
infiniopCdistDescriptor_t
desc
)
{
#define DELETE(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<const op::cdist::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU_API
DELETE
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_NVIDIA_API
DELETE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
DELETE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
DELETE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_HYGON_API
DELETE
(
INFINI_DEVICE_HYGON
,
nvidia
);
#endif
#ifdef ENABLE_CAMBRICON_API
DELETE
(
INFINI_DEVICE_CAMBRICON
,
bang
);
#endif
#ifdef ENABLE_ASCEND_API
DELETE
(
INFINI_DEVICE_ASCEND
,
ascend
);
#endif
#ifdef ENABLE_METAX_API
DELETE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_MOORE_API
DELETE
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
#ifdef ENABLE_KUNLUN_API
DELETE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef DELETE
}
src/infiniop/ops/reciprocal/cpu/reciprocal_cpu.cc
0 → 100644
View file @
21c6af2d
#include "reciprocal_cpu.h"
namespace
op
::
reciprocal
::
cpu
{
Descriptor
::~
Descriptor
()
=
default
;
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle_
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
out_desc
,
std
::
vector
<
infiniopTensorDescriptor_t
>
input_desc_vec
)
{
auto
handle
=
reinterpret_cast
<
device
::
cpu
::
Handle
*>
(
handle_
);
auto
dtype
=
out_desc
->
dtype
();
const
auto
&
y_desc
=
out_desc
;
const
auto
&
x_desc
=
input_desc_vec
.
at
(
0
);
const
auto
&
y_shape
=
y_desc
->
shape
();
const
auto
&
x_shape
=
x_desc
->
shape
();
// Reciprocal typically only supports floating point types
CHECK_DTYPE
(
dtype
,
INFINI_DTYPE_F16
,
INFINI_DTYPE_F32
,
INFINI_DTYPE_F64
,
INFINI_DTYPE_BF16
);
CHECK_SAME_SHAPE
(
y_shape
,
x_shape
);
// create CPU elementwise descriptor
CREATE_ELEMENTWISE_CPU_DESCRIPTOR
(
handle
,
dtype
,
out_desc
,
input_desc_vec
);
return
INFINI_STATUS_SUCCESS
;
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
output
,
std
::
vector
<
const
void
*>
inputs
,
void
*
stream
)
const
{
switch
(
_dtype
)
{
case
INFINI_DTYPE_F16
:
return
_device_info
->
calculate
<
ReciprocalOp
,
fp16_t
>
(
_info
,
output
,
inputs
,
stream
);
case
INFINI_DTYPE_F32
:
return
_device_info
->
calculate
<
ReciprocalOp
,
float
>
(
_info
,
output
,
inputs
,
stream
);
case
INFINI_DTYPE_F64
:
return
_device_info
->
calculate
<
ReciprocalOp
,
double
>
(
_info
,
output
,
inputs
,
stream
);
case
INFINI_DTYPE_BF16
:
return
_device_info
->
calculate
<
ReciprocalOp
,
bf16_t
>
(
_info
,
output
,
inputs
,
stream
);
default:
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::reciprocal::cpu
src/infiniop/ops/reciprocal/cpu/reciprocal_cpu.h
0 → 100644
View file @
21c6af2d
#ifndef __RECIPROCAL_CPU_H__
#define __RECIPROCAL_CPU_H__
#include "../../../elementwise/cpu/elementwise_cpu.h"
ELEMENTWISE_DESCRIPTOR
(
reciprocal
,
cpu
)
namespace
op
::
reciprocal
::
cpu
{
typedef
struct
ReciprocalOp
{
public:
static
constexpr
size_t
num_inputs
=
1
;
template
<
typename
T
>
T
operator
()(
const
T
&
x
)
const
{
return
static_cast
<
T
>
(
1
)
/
x
;
}
}
ReciprocalOp
;
}
// namespace op::reciprocal::cpu
#endif // __RECIPROCAL_CPU_H__
src/infiniop/ops/reciprocal/cuda/kernel.cuh
0 → 100644
View file @
21c6af2d
#ifndef __RECIPROCAL_CUDA_H__
#define __RECIPROCAL_CUDA_H__
namespace
op
::
reciprocal
::
cuda
{
typedef
struct
ReciprocalOp
{
public:
static
constexpr
size_t
num_inputs
=
1
;
template
<
typename
T
>
__device__
__forceinline__
T
operator
()(
const
T
&
x
)
const
{
if
constexpr
(
std
::
is_same_v
<
T
,
half2
>
)
{
return
h2rcp
(
x
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
half
>
)
{
return
hrcp
(
x
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
cuda_bfloat16
>
)
{
// bfloat16 does not have a direct hrcp intrinsic in some versions,
// often handled by converting to float or using specific bf16 intrinsics
return
__float2bfloat16
(
1.0
f
/
__bfloat162float
(
x
));
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
float
>
)
{
return
__frcp_rd
(
x
);
}
else
{
return
static_cast
<
T
>
(
1
)
/
x
;
}
}
}
ReciprocalOp
;
}
// namespace op::reciprocal::cuda
#endif // __RECIPROCAL_CUDA_H__
src/infiniop/ops/reciprocal/metax/reciprocal_metax.h
0 → 100644
View file @
21c6af2d
#ifndef __RECIPROCAL_METAX_API_H__
#define __RECIPROCAL_METAX_API_H__
#include "../../../elementwise/metax/elementwise_metax_api.h"
ELEMENTWISE_DESCRIPTOR
(
reciprocal
,
metax
)
#endif // __RECIPROCAL_METAX_API_H__
src/infiniop/ops/reciprocal/metax/reciprocal_metax.maca
0 → 100644
View file @
21c6af2d
#include "../../../elementwise/metax/elementwise_metax.h"
#include "reciprocal_metax.h"
#include "reciprocal_metax_kernel.h"
namespace op::reciprocal::metax {
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
auto handle = reinterpret_cast<device::metax::Handle *>(handle_);
auto dtype = out_desc->dtype();
const auto &y_desc = out_desc;
const auto &x_desc = input_desc_vec.at(0);
const auto &y_shape = y_desc->shape();
const auto &x_shape = x_desc->shape();
// Reciprocal typically only supports floating point types
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16, INFINI_DTYPE_F64);
CHECK_SAME_SHAPE(y_shape, x_shape);
// create METAX elementwise descriptor
CREATE_ELEMENTWISE_METAX_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<256, metax::ReciprocalOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_BF16:
return _device_info->calculate<256, metax::ReciprocalOp, cuda_bfloat16>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<256, metax::ReciprocalOp, float>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F64:
return _device_info->calculate<256, metax::ReciprocalOp, double>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::reciprocal::metax
src/infiniop/ops/reciprocal/metax/reciprocal_metax_kernel.h
0 → 100644
View file @
21c6af2d
#ifndef __RECIPROCAL_METAX_KERNEL_H__
#define __RECIPROCAL_METAX_KERNEL_H__
/*
* This file contains the Reciprocal operation implementation for the MUSA backend.
*
* It follows the consistent code structure to ensure alignment across different
* hardware platforms within the Moore Threads (MUSA) ecosystem.
*/
namespace
op
::
reciprocal
::
metax
{
typedef
struct
ReciprocalOp
{
public:
// 一元算子,输入数量为 1
static
constexpr
size_t
num_inputs
=
1
;
template
<
typename
T
>
__device__
__forceinline__
T
operator
()(
const
T
&
a
)
const
{
if
constexpr
(
std
::
is_same_v
<
T
,
half2
>
)
{
// 使用 MUSA 的 half2 倒数指令(如果硬件支持)
// 或者转为 float2 进行计算
float2
f2
=
__half22float2
(
a
);
f2
.
x
=
1.0
f
/
f2
.
x
;
f2
.
y
=
1.0
f
/
f2
.
y
;
return
__float22half2_rn
(
f2
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
half
>
)
{
// 提升到 float 计算以保证数值稳定性
return
__float2half
(
1.0
f
/
__half2float
(
a
));
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
cuda_bfloat16
>
)
{
// BF16 在 MUSA 上推荐转为 float 处理
float
a_f
=
__bfloat162float
(
a
);
return
__float2bfloat16_rn
(
1.0
f
/
a_f
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
float
>
)
{
// 编译器通常会将 1.0f/a 优化为硬件 rcp 指令 (Round to Nearest)
return
1.0
f
/
a
;
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
double
>
)
{
return
1.0
/
a
;
}
else
{
// 整数类型倒数通常返回 0 (除 1 以外),保持标准 C++ 行为
return
static_cast
<
T
>
(
1
)
/
a
;
}
}
}
ReciprocalOp
;
}
// namespace op::reciprocal::metax
#endif // __RECIPROCAL_METAX_KERNEL_H__
src/infiniop/ops/reciprocal/moore/reciprocal_moore.h
0 → 100644
View file @
21c6af2d
#ifndef __RECIPROCAL_MOORE_API_H__
#define __RECIPROCAL_MOORE_API_H__
// 1. 切换到 Moore 平台的 elementwise API 定义文件
#include "../../../elementwise/moore/elementwise_moore_api.h"
// 2. 调用宏生成 op::reciprocal::moore::Descriptor
// 宏展开后会包含 create 和 calculate 的标准声明
ELEMENTWISE_DESCRIPTOR
(
reciprocal
,
moore
)
#endif // __RECIPROCAL_MOORE_API_H__
src/infiniop/ops/reciprocal/moore/reciprocal_moore.mu
0 → 100644
View file @
21c6af2d
#include "../../../elementwise/moore/elementwise_moore.h"
#include "reciprocal_moore.h"
#include "reciprocal_moore_kernel.h"
namespace op::reciprocal::moore {
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
// 1. 解析 Moore (MUSA) 句柄
auto handle = reinterpret_cast<device::moore::Handle *>(handle_);
auto dtype = out_desc->dtype();
const auto &y_desc = out_desc;
const auto &x_desc = input_desc_vec.at(0);
const auto &y_shape = y_desc->shape();
const auto &x_shape = x_desc->shape();
// 2. 校验数据类型:Moore 平台同样在浮点类型上执行倒数运算
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16, INFINI_DTYPE_F64);
CHECK_SAME_SHAPE(y_shape, x_shape);
// 3. 使用 Moore 平台的 Elementwise 描述符创建宏
// 该宏会自动处理 MUSA 后端的算子元数据初始化
CREATE_ELEMENTWISE_MOORE_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
// 4. 分发至 Moore 特化的计算逻辑
// 注意:cuda::ReciprocalOp 替换为 moore::ReciprocalOp
switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<256, moore::ReciprocalOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_BF16:
// 确保使用 Moore 环境下的 bfloat16 类型定义
return _device_info->calculate<256, moore::ReciprocalOp, cuda_bfloat16>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<256, moore::ReciprocalOp, float>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F64:
return _device_info->calculate<256, moore::ReciprocalOp, double>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::reciprocal::moore
src/infiniop/ops/reciprocal/moore/reciprocal_moore_kernel.h
0 → 100644
View file @
21c6af2d
#ifndef __RECIPROCAL_MOORE_KERNEL_H__
#define __RECIPROCAL_MOORE_KERNEL_H__
/*
* This file contains the Reciprocal operation implementation for the MUSA backend.
*
* It follows the consistent code structure to ensure alignment across different
* hardware platforms within the Moore Threads (MUSA) ecosystem.
*/
namespace
op
::
reciprocal
::
moore
{
typedef
struct
ReciprocalOp
{
public:
// 一元算子,输入数量为 1
static
constexpr
size_t
num_inputs
=
1
;
template
<
typename
T
>
__device__
__forceinline__
T
operator
()(
const
T
&
a
)
const
{
if
constexpr
(
std
::
is_same_v
<
T
,
half2
>
)
{
// 使用 MUSA 的 half2 倒数指令(如果硬件支持)
// 或者转为 float2 进行计算
float2
f2
=
__half22float2
(
a
);
f2
.
x
=
1.0
f
/
f2
.
x
;
f2
.
y
=
1.0
f
/
f2
.
y
;
return
__float22half2_rn
(
f2
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
half
>
)
{
// 提升到 float 计算以保证数值稳定性
return
__float2half
(
1.0
f
/
__half2float
(
a
));
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
cuda_bfloat16
>
)
{
// BF16 在 MUSA 上推荐转为 float 处理
float
a_f
=
__bfloat162float
(
a
);
return
__float2bfloat16_rn
(
1.0
f
/
a_f
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
float
>
)
{
// 编译器通常会将 1.0f/a 优化为硬件 rcp 指令 (Round to Nearest)
return
1.0
f
/
a
;
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
double
>
)
{
return
1.0
/
a
;
}
else
{
// 整数类型倒数通常返回 0 (除 1 以外),保持标准 C++ 行为
return
static_cast
<
T
>
(
1
)
/
a
;
}
}
}
ReciprocalOp
;
}
// namespace op::reciprocal::moore
#endif // __RECIPROCAL_MOORE_KERNEL_H__
src/infiniop/ops/reciprocal/nvidia/reciprocal_nvidia.cu
0 → 100644
View file @
21c6af2d
#include "../../../elementwise/nvidia/elementwise_nvidia.cuh"
#include "../cuda/kernel.cuh"
#include "reciprocal_nvidia.cuh"
namespace
op
::
reciprocal
::
nvidia
{
Descriptor
::~
Descriptor
()
=
default
;
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle_
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
out_desc
,
std
::
vector
<
infiniopTensorDescriptor_t
>
input_desc_vec
)
{
auto
handle
=
reinterpret_cast
<
device
::
nvidia
::
Handle
*>
(
handle_
);
auto
dtype
=
out_desc
->
dtype
();
const
auto
&
y_desc
=
out_desc
;
const
auto
&
x_desc
=
input_desc_vec
.
at
(
0
);
const
auto
&
y_shape
=
y_desc
->
shape
();
const
auto
&
x_shape
=
x_desc
->
shape
();
// Reciprocal typically only supports floating point types
CHECK_DTYPE
(
dtype
,
INFINI_DTYPE_F16
,
INFINI_DTYPE_F32
,
INFINI_DTYPE_BF16
,
INFINI_DTYPE_F64
);
CHECK_SAME_SHAPE
(
y_shape
,
x_shape
);
// create CUDA elementwise descriptor
CREATE_ELEMENTWISE_CUDA_DESCRIPTOR
(
handle
,
dtype
,
out_desc
,
input_desc_vec
)
return
INFINI_STATUS_SUCCESS
;
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
output
,
std
::
vector
<
const
void
*>
inputs
,
void
*
stream
)
const
{
if
(
workspace_size
<
_workspace_size
)
{
return
INFINI_STATUS_INSUFFICIENT_WORKSPACE
;
}
switch
(
_dtype
)
{
case
INFINI_DTYPE_F16
:
return
_device_info
->
calculate
<
256
,
cuda
::
ReciprocalOp
,
half
>
(
_info
,
workspace
,
output
,
inputs
,
stream
);
case
INFINI_DTYPE_BF16
:
return
_device_info
->
calculate
<
256
,
cuda
::
ReciprocalOp
,
cuda_bfloat16
>
(
_info
,
workspace
,
output
,
inputs
,
stream
);
case
INFINI_DTYPE_F32
:
return
_device_info
->
calculate
<
256
,
cuda
::
ReciprocalOp
,
float
>
(
_info
,
workspace
,
output
,
inputs
,
stream
);
case
INFINI_DTYPE_F64
:
return
_device_info
->
calculate
<
256
,
cuda
::
ReciprocalOp
,
double
>
(
_info
,
workspace
,
output
,
inputs
,
stream
);
default:
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::reciprocal::nvidia
src/infiniop/ops/reciprocal/nvidia/reciprocal_nvidia.cuh
0 → 100644
View file @
21c6af2d
#ifndef __RECIPROCAL_CUDA_API_H__
#define __RECIPROCAL_CUDA_API_H__
#include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh"
ELEMENTWISE_DESCRIPTOR
(
reciprocal
,
nvidia
)
#endif // __RECIPROCAL_CUDA_API_H__
Prev
1
2
3
4
5
6
Next
yanzy
@yanzy
mentioned in commit
18773b69
·
Apr 21, 2026
mentioned in commit
18773b69
mentioned in commit 18773b69ae7bd79b4e9cf9ac0a4e4c6ed1bf9bf8
Toggle commit list
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment