Commit 18773b69 authored by wooway777's avatar wooway777
Browse files

Revert "Merge pull request #1069 from InfiniTensor/issue/1031_T1_1_15"

This reverts commit 21c6af2d, reversing
changes made to 99a802dd.
parent bfead271
#ifndef __ATANH_METAX_KERNEL_H__
#define __ATANH_METAX_KERNEL_H__
/*
* This file contains the Atanh operation implementation for the MUSA backend.
*
* It follows the consistent code structure to ensure alignment across different
* hardware platforms within the Moore Threads (MUSA) ecosystem.
*/
namespace op::atanh::metax {
typedef struct AtanhOp {
public:
// 一元算子,输入数量为 1
static constexpr size_t num_inputs = 1;
template <typename T>
__device__ __forceinline__ T operator()(const T &a) const {
if constexpr (std::is_same_v<T, half2>) {
// 针对 half2 进行并行计算
float2 f2 = __half22float2(a);
f2.x = atanhf(f2.x);
f2.y = atanhf(f2.y);
return __float22half2_rn(f2);
} else if constexpr (std::is_same_v<T, half>) {
// 转为 float 计算以保证精度并匹配 MUSA 数学库
return __float2half(atanhf(__half2float(a)));
} else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
// BF16 同样提升到 float 计算,避免转换歧义
float a_f = __bfloat162float(a);
return __float2bfloat16_rn(atanhf(a_f));
} else if constexpr (std::is_same_v<T, float>) {
// 调用 MUSA 内置的单精度反双曲正切函数
return atanhf(a);
} else if constexpr (std::is_same_v<T, double>) {
// 调用双精度版本
return ::atanh(a);
} else {
// 兜底实现(如果是整数类型,通常会隐式转为 float)
return static_cast<T>(atanhf(static_cast<float>(a)));
}
}
} AtanhOp;
} // namespace op::atanh::metax
#endif // __ATANH_METAX_KERNEL_H__
#ifndef __ATANH_MOORE_API_H__
#define __ATANH_MOORE_API_H__
// 1. 修改包含路径,指向 moore 平台的 elementwise API 定义
#include "../../../elementwise/moore/elementwise_moore_api.h"
// 2. 使用 ELEMENTWISE_DESCRIPTOR 宏,平台参数改为 moore
// 这将自动生成 op::atanh::moore::Descriptor 类的声明
ELEMENTWISE_DESCRIPTOR(atanh, moore)
#endif // __ATANH_MOORE_API_H__
#include "../../../elementwise/moore/elementwise_moore.h"
#include "atanh_moore.h"
#include "atanh_moore_kernel.h"
namespace op::atanh::moore {
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
// 1. 转换 Handle 为 Moore 平台类型
auto handle = reinterpret_cast<device::moore::Handle *>(handle_);
auto dtype = out_desc->dtype();
const auto &a_desc = input_desc_vec.at(0);
const auto &y_shape = out_desc->shape();
const auto &a_shape = a_desc->shape();
// 2. 检查数据类型支持情况
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16, INFINI_DTYPE_F64);
// 3. 校验 Shape 一致性
CHECK_SAME_SHAPE(y_shape, a_shape);
// 4. 创建 Moore 平台的 Elementwise 描述符
CREATE_ELEMENTWISE_MOORE_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
// 5. 根据数据类型分发到具体的 MUSA Kernel 逻辑
switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<256, moore::AtanhOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_BF16:
// 注意:这里将 nv_bfloat16 替换为 Moore 环境下的 bfloat16 类型名
return _device_info->calculate<256, moore::AtanhOp, cuda_bfloat16>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<256, moore::AtanhOp, float>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F64:
return _device_info->calculate<256, moore::AtanhOp, double>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::atanh::moore
#ifndef __ATANH_MOORE_KERNEL_H__
#define __ATANH_MOORE_KERNEL_H__
/*
* This file contains the Atanh operation implementation for the MUSA backend.
*
* It follows the consistent code structure to ensure alignment across different
* hardware platforms within the Moore Threads (MUSA) ecosystem.
*/
namespace op::atanh::moore {
typedef struct AtanhOp {
public:
// 一元算子,输入数量为 1
static constexpr size_t num_inputs = 1;
template <typename T>
__device__ __forceinline__ T operator()(const T &a) const {
if constexpr (std::is_same_v<T, half2>) {
// 针对 half2 进行并行计算
float2 f2 = __half22float2(a);
f2.x = atanhf(f2.x);
f2.y = atanhf(f2.y);
return __float22half2_rn(f2);
} else if constexpr (std::is_same_v<T, half>) {
// 转为 float 计算以保证精度并匹配 MUSA 数学库
return __float2half(atanhf(__half2float(a)));
} else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
// BF16 同样提升到 float 计算,避免转换歧义
float a_f = __bfloat162float(a);
return __float2bfloat16_rn(atanhf(a_f));
} else if constexpr (std::is_same_v<T, float>) {
// 调用 MUSA 内置的单精度反双曲正切函数
return atanhf(a);
} else if constexpr (std::is_same_v<T, double>) {
// 调用双精度版本
return ::atanh(a);
} else {
// 兜底实现(如果是整数类型,通常会隐式转为 float)
return static_cast<T>(atanhf(static_cast<float>(a)));
}
}
} AtanhOp;
} // namespace op::atanh::moore
#endif // __ATANH_MOORE_KERNEL_H__
#include "../../../elementwise/nvidia/elementwise_nvidia.cuh"
#include "../cuda/kernel.cuh"
#include "atanh_nvidia.cuh"
namespace op::atanh::nvidia {
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
auto handle = reinterpret_cast<device::nvidia::Handle *>(handle_);
auto dtype = out_desc->dtype();
const auto &a_desc = input_desc_vec.at(0);
const auto &y_shape = out_desc->shape();
const auto &a_shape = a_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16, INFINI_DTYPE_F64);
CHECK_SAME_SHAPE(y_shape, a_shape);
CREATE_ELEMENTWISE_CUDA_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<256, cuda::AtanhOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_BF16:
return _device_info->calculate<256, cuda::AtanhOp, nv_bfloat16>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<256, cuda::AtanhOp, float>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F64:
return _device_info->calculate<256, cuda::AtanhOp, double>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::atanh::nvidia
#ifndef __ATANH_CUDA_API_H__
#define __ATANH_CUDA_API_H__
#include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh"
ELEMENTWISE_DESCRIPTOR(atanh, nvidia)
#endif // __ATANH_CUDA_API_H__
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/atanh.h"
#ifdef ENABLE_CPU_API
#include "cpu/atanh_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#include "nvidia/atanh_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
#include "metax/atanh_metax.h"
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/atanh_kunlun.h"
#endif
#ifdef ENABLE_CAMBRICON_API
#include "bang/atanh_bang.h"
#endif
#ifdef ENABLE_MOORE_API
#include "moore/atanh_moore.h"
#endif
__INFINI_C infiniStatus_t infiniopCreateAtanhDescriptor(
infiniopHandle_t handle,
infiniopAtanhDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t a_desc) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::atanh::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::atanh::NAMESPACE::Descriptor **>(desc_ptr), \
y_desc, \
{a_desc}) // 一元算子只传入 a_desc
switch (handle->device) {
#ifdef ENABLE_CPU_API
CREATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_CAMBRICON_API
CREATE(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, moore);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CREATE
}
__INFINI_C infiniStatus_t infiniopGetAtanhWorkspaceSize(infiniopAtanhDescriptor_t desc, size_t *size) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::atanh::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
GET(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_CAMBRICON_API
GET(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, moore);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef GET
}
__INFINI_C infiniStatus_t infiniopAtanh(
infiniopAtanhDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *y,
const void *a,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::atanh::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, y, {a}, stream) // 一元算子只传入 a
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
CALCULATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_CAMBRICON_API
CALCULATE(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, moore);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CALCULATE
}
__INFINI_C infiniStatus_t
infiniopDestroyAtanhDescriptor(infiniopAtanhDescriptor_t desc) {
#define DELETE(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<const op::atanh::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
DELETE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
DELETE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
DELETE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
DELETE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
DELETE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_KUNLUN_API
DELETE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_CAMBRICON_API
DELETE(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_MOORE_API
DELETE(INFINI_DEVICE_MOORE, moore);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef DELETE
}
#ifndef __BINARY_CROSS_ENTROPY_WITH_LOGITS_H__
#define __BINARY_CROSS_ENTROPY_WITH_LOGITS_H__
#include "../../operator.h"
#include "info.h"
/**
* # 关于 `BCEWithLogits` 算子描述符的说明
* * 采用 PImpl 设计模式,将不同硬件后端(如 CUDA 原生算子、CPU 循环、或是芯片厂商的专用库调用)
* 封装在 `Opaque` 结构中。
* * 描述符在创建时会完成形状校验、步长分析,并确定最优的计算 Workspace 大小。
*/
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::bce_with_logits::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
infiniDtype_t _dtype; \
BCEWithLogitsInfo _info; /* 包含各输入输出张量的维度与步长 */ \
size_t _workspace_size; \
infiniopReduction_t _reduction; \
\
Descriptor( \
infiniDtype_t dtype, \
BCEWithLogitsInfo info, \
infiniopReduction_t reduction, \
size_t workspace_size_, \
Opaque *opaque, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_dtype(dtype), \
_info(info), \
_workspace_size(workspace_size_), \
_reduction(reduction) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t out_desc, \
infiniopTensorDescriptor_t logits_desc, \
infiniopTensorDescriptor_t target_desc, \
infiniopTensorDescriptor_t weight_desc, \
infiniopTensorDescriptor_t pos_weight_desc, \
infiniopReduction_t reduction); \
\
infiniStatus_t calculate( \
void *workspace, \
size_t workspace_size, \
void *out, \
const void *logits, \
const void *target, \
const void *weight, /* 可选,可为 nullptr */ \
const void *pos_weight, /* 可选,可为 nullptr */ \
void *stream) const; \
}; \
}
#endif // __BINARY_CROSS_ENTROPY_WITH_LOGITS_H__
#include "binary_cross_entropy_with_logits_cpu.h"
#include "../../../devices/cpu/common_cpu.h"
#include <algorithm>
#include <cmath>
namespace op::bce_with_logits::cpu {
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t logits_desc,
infiniopTensorDescriptor_t target_desc,
infiniopTensorDescriptor_t weight_desc,
infiniopTensorDescriptor_t pos_weight_desc,
infiniopReduction_t reduction) {
auto handle = reinterpret_cast<device::cpu::Handle *>(handle_);
auto dtype = logits_desc->dtype();
// 1. 类型检查
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16);
// 2. 解析维度信息 (利用之前定义的 BCEWithLogitsInfo)
auto result = BCEWithLogitsInfo::create(out_desc, logits_desc, target_desc,
weight_desc, pos_weight_desc, reduction);
CHECK_RESULT(result);
// 3. 实例化描述符
*desc_ptr = new Descriptor(
dtype, result.take(), reduction, 0,
nullptr,
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
/**
* 核心数值稳定逻辑:L = -w * [pw * y * log(sigmoid(x)) + (1-y) * log(1-sigmoid(x))]
* 变形为:L = w * [max(x, 0) - x * y * pw + (1 + (pw-1) * y) * log(1 + exp(-|x|))]
* 当 pw=1 时简化为:L = w * [max(x, 0) - x * y + log(1 + exp(-|x|))]
*/
template <typename Tdata>
void calculate_bce(
const BCEWithLogitsInfo &info,
void *out,
const void *logits,
const void *target,
const void *weight,
const void *pos_weight) {
size_t n = info.num_elements;
float total_loss = 0.0f;
// 获取各张量指针
const Tdata *l_ptr = reinterpret_cast<const Tdata *>(logits);
const Tdata *t_ptr = reinterpret_cast<const Tdata *>(target);
const Tdata *w_ptr = reinterpret_cast<const Tdata *>(weight);
const Tdata *pw_ptr = reinterpret_cast<const Tdata *>(pos_weight);
Tdata *o_ptr = reinterpret_cast<Tdata *>(out);
auto &logits_info = info.logits;
auto &target_info = info.target;
auto &weight_info = info.weight;
auto &out_info = info.out;
#pragma omp parallel for reduction(+ : total_loss)
for (ptrdiff_t i = 0; i < (ptrdiff_t)n; ++i) {
size_t idx = static_cast<size_t>(i);
// 使用 stride 计算实际内存偏移,支持任意内存布局
size_t logits_offset = op::common_cpu::indexToOffset(
idx,
logits_info.ndim,
logits_info.dims.data(),
logits_info.stride.data());
size_t target_offset = op::common_cpu::indexToOffset(
idx,
target_info.ndim,
target_info.dims.data(),
target_info.stride.data());
float x = utils::cast<float>(l_ptr[logits_offset]);
float y = utils::cast<float>(t_ptr[target_offset]);
// 处理 pos_weight 广播 (假设 logits 形状 [..., C], pos_weight 为 [C] 且连续)
float pw = 1.0f;
if (pw_ptr && info.pos_weight.total_elements > 0) {
size_t c = idx % info.pos_weight.total_elements;
pw = utils::cast<float>(pw_ptr[c]);
}
// 处理 weight:
// - 如果与 logits 完全同形状,则按 stride 精确索引;
// - 如果为向量 [C],则通过 indexToOffset 实现按最后一维广播。
float w = 1.0f;
if (w_ptr && weight_info.ndim > 0) {
size_t weight_offset = op::common_cpu::indexToOffset(
idx,
weight_info.ndim,
weight_info.dims.data(),
weight_info.stride.data());
w = utils::cast<float>(w_ptr[weight_offset]);
}
// 数值稳定的 BCEWithLogits 计算(对齐 PyTorch 实现):
// max_val = max(-x, 0)
// log_weight = 1 + (pos_weight - 1) * y
// loss = (1 - y) * x + log_weight * (log(1 + exp(-|x|)) + max_val)
float max_val = std::max(-x, 0.0f);
float log_weight = 1.0f + (pw - 1.0f) * y;
float loss = (1.0f - y) * x + log_weight * (std::log1p(std::exp(-std::abs(x))) + max_val);
loss *= w;
if (info.reduction == INFINIOP_REDUCTION_NONE) {
// 逐元素写回时同样遵循 out 的 stride
size_t out_offset = op::common_cpu::indexToOffset(
idx,
out_info.ndim,
out_info.dims.data(),
out_info.stride.data());
o_ptr[out_offset] = utils::cast<Tdata>(loss);
} else {
total_loss += loss;
}
}
// 处理归约输出
if (info.reduction == INFINIOP_REDUCTION_MEAN) {
o_ptr[0] = utils::cast<Tdata>(total_loss / n);
} else if (info.reduction == INFINIOP_REDUCTION_SUM) {
o_ptr[0] = utils::cast<Tdata>(total_loss);
}
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *out,
const void *logits,
const void *target,
const void *weight,
const void *pos_weight,
void *stream) const {
switch (_dtype) {
case INFINI_DTYPE_F16:
cpu::calculate_bce<fp16_t>(_info, out, logits, target, weight, pos_weight);
return INFINI_STATUS_SUCCESS;
case INFINI_DTYPE_BF16:
cpu::calculate_bce<bf16_t>(_info, out, logits, target, weight, pos_weight);
return INFINI_STATUS_SUCCESS;
case INFINI_DTYPE_F32:
cpu::calculate_bce<float>(_info, out, logits, target, weight, pos_weight);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}
} // namespace op::bce_with_logits::cpu
#ifndef __BINARY_CROSS_ENTROPY_WITH_LOGITS_CPU_H__
#define __BINARY_CROSS_ENTROPY_WITH_LOGITS_CPU_H__
#include "../binary_cross_entropy_with_logits.h"
/**
* 使用 bce_with_logits.h 中定义的 DESCRIPTOR 宏
* * 这将自动在命名空间 op::bce_with_logits::cpu 中生成 Descriptor 类。
* 该类将继承自 InfiniopDescriptor,并包含:
* - BCEWithLogitsInfo _info (存储校验后的维度和步长)
* - create() 静态方法 (负责 CPU 版描述符的实例化)
* - calculate() 方法 (负责 CPU 版数值稳定逻辑的执行)
*/
DESCRIPTOR(cpu)
#endif // __BINARY_CROSS_ENTROPY_WITH_LOGITS_CPU_H__
#ifndef __BINARY_CROSS_ENTROPY_WITH_LOGITS_INFO_H__
#define __BINARY_CROSS_ENTROPY_WITH_LOGITS_INFO_H__
#include "../../../utils.h"
#include "../../operator.h"
#include "../../tensor.h"
#include "infiniop/ops/binary_cross_entropy_with_logits.h"
#include <numeric>
#include <vector>
namespace op::bce_with_logits {
/**
* 描述 BCE 算子中各张量的内存布局
* 动态申请 dims 和 stride,支持任意维度的张量
*/
struct BCETensorInfo {
size_t total_elements = 0;
size_t ndim = 0;
std::vector<size_t> dims; // 动态存储维度
std::vector<ptrdiff_t> stride; // 动态存储步长
BCETensorInfo() = default;
static utils::Result<BCETensorInfo> create(infiniopTensorDescriptor_t desc) {
if (desc == nullptr) {
return INFINI_STATUS_SUCCESS;
}
BCETensorInfo info;
info.ndim = desc->ndim();
info.total_elements = 1;
// 动态调整 vector 大小
info.dims.reserve(info.ndim);
info.stride.reserve(info.ndim);
for (size_t i = 0; i < info.ndim; ++i) {
size_t d = desc->dim(i);
info.dims.push_back(d);
info.stride.push_back(desc->stride(i));
info.total_elements *= d;
}
return utils::Result<BCETensorInfo>(std::move(info));
}
// 辅助方法:获取最后一维大小(用于 pos_weight 校验)
size_t last_dim() const {
return ndim > 0 ? dims.back() : 0;
}
};
class BCEWithLogitsInfo {
public:
BCETensorInfo logits;
BCETensorInfo target;
BCETensorInfo weight;
BCETensorInfo pos_weight;
BCETensorInfo out;
size_t num_elements;
infiniopReduction_t reduction;
// 由于 BCETensorInfo 内部使用了 vector,BCEWithLogitsInfo 现在是可移动且安全的
BCEWithLogitsInfo() = default;
static utils::Result<BCEWithLogitsInfo> create(
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t logits_desc,
infiniopTensorDescriptor_t target_desc,
infiniopTensorDescriptor_t weight_desc,
infiniopTensorDescriptor_t pos_weight_desc,
infiniopReduction_t reduction) {
auto logits_res = BCETensorInfo::create(logits_desc);
CHECK_RESULT(logits_res);
auto target_res = BCETensorInfo::create(target_desc);
CHECK_RESULT(target_res);
auto out_res = BCETensorInfo::create(out_desc);
CHECK_RESULT(out_res);
BCEWithLogitsInfo info;
info.logits = logits_res.take();
info.target = target_res.take();
info.out = out_res.take();
info.reduction = reduction;
info.num_elements = info.logits.total_elements;
// 1. 基本形状一致性校验
if (info.logits.ndim != info.target.ndim) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
for (size_t i = 0; i < info.logits.ndim; ++i) {
if (info.logits.dims[i] != info.target.dims[i]) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
// 2. 校验 weight (需完全一致)
if (weight_desc) {
auto w_res = BCETensorInfo::create(weight_desc);
CHECK_RESULT(w_res);
info.weight = w_res.take();
// 允许两种情况:
// 1. 完全一致
// 2. weight 是一个向量,且长度等于 logits 的最后一维 (常见广播场景)
bool is_full_match = (info.weight.total_elements == info.logits.total_elements);
bool is_last_dim_match = (info.weight.total_elements == info.logits.last_dim());
if (!is_full_match && !is_last_dim_match) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
// 3. 记录 pos_weight 信息
// 广播行为由计算 Kernel 通过长度进行处理,这里不过度限制形状,
// 只要能够提供有效的长度即可,避免误报 Bad Tensor Shape。
if (pos_weight_desc) {
auto pw_res = BCETensorInfo::create(pos_weight_desc);
CHECK_RESULT(pw_res);
info.pos_weight = pw_res.take();
}
// 4. 输出形状
// 这里不再强制校验 out 与 logits/标量的元素数量完全一致,
// 由高层 API 负责创建合理的输出张量;底层实现只依赖
// `_info.out` 的 stride 在 reduction==NONE 且逐元素写回时使用。
return utils::Result<BCEWithLogitsInfo>(std::move(info));
}
};
} // namespace op::bce_with_logits
#endif // __BINARY_CROSS_ENTROPY_WITH_LOGITS_INFO_H__
#ifndef __BINARY_CROSS_ENTROPY_WITH_LOGITS_METAX_CUH__
#define __BINARY_CROSS_ENTROPY_WITH_LOGITS_METAX_CUH__
#include "../binary_cross_entropy_with_logits.h"
/**
* 使用 bce_with_logits.h 中定义的 DESCRIPTOR 宏。
* 这将在命名空间 op::bce_with_logits::metax 中生成针对 METAX 设备的 Descriptor 类。
* * * 在 METAX 端的实现(.cu 文件)中,Opaque 结构体通常包含:
* - cudnnHandle_t: 如果使用 cuDNN 的算子实现。
* - cudnnTensorDescriptor_t: 用于描述各输入输出张量的 cuDNN 格式。
* - KernelConfig: 用于自定义 METAX Kernel 的网格(Grid)和线程块(Block)配置。
* - dataType: 对应的 METAX 数据类型(如 METAX_R_32F)。
*/
DESCRIPTOR(metax)
#endif // __BINARY_CROSS_ENTROPY_WITH_LOGITS_METAX_CUH__
#include "../../../devices/metax/metax_common.h"
#include "../../../devices/metax/metax_handle.h"
#include "../../../devices/metax/metax_kernel_common.h"
#include "binary_cross_entropy_with_logits_metax.h"
#include <mc_runtime.h>
#include <type_traits>
namespace op::bce_with_logits::metax {
using device::metax::indexToOffset;
struct Descriptor::Opaque {};
Descriptor::~Descriptor() = default;
// 在 GPU 侧使用的简化张量信息(固定上限维度,支持 stride)
constexpr int BCE_MAX_DIMS = 8;
struct BCETensorInfoDevice {
size_t ndim;
size_t shape[BCE_MAX_DIMS];
ptrdiff_t strides[BCE_MAX_DIMS];
};
static inline BCETensorInfoDevice make_device_info(const BCETensorInfo &info) {
BCETensorInfoDevice dev{};
dev.ndim = info.ndim;
for (size_t i = 0; i < info.ndim && i < static_cast<size_t>(BCE_MAX_DIMS); ++i) {
dev.shape[i] = info.dims[i];
dev.strides[i] = info.stride[i];
}
return dev;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t logits_desc,
infiniopTensorDescriptor_t target_desc,
infiniopTensorDescriptor_t weight_desc,
infiniopTensorDescriptor_t pos_weight_desc,
infiniopReduction_t reduction) {
auto handle = reinterpret_cast<device::metax::Handle *>(handle_);
auto dtype = logits_desc->dtype();
// METAX 实现支持 F16 / F32 / BF16
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16);
auto result = BCEWithLogitsInfo::create(out_desc, logits_desc, target_desc,
weight_desc, pos_weight_desc, reduction);
CHECK_RESULT(result);
auto info = result.take();
// F16/BF16 在做归约时需要一个 float 标量 workspace 来累加
size_t workspace_size = 0;
if (reduction != INFINIOP_REDUCTION_NONE && (dtype == INFINI_DTYPE_F16 || dtype == INFINI_DTYPE_BF16)) {
workspace_size = sizeof(float);
}
*desc_ptr = new Descriptor(
dtype, std::move(info), reduction, workspace_size,
nullptr,
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
// 将任意标量类型提升为 float
template <typename T>
__device__ __forceinline__ float to_float(T x) {
if constexpr (std::is_same_v<T, float>) {
return x;
} else if constexpr (std::is_same_v<T, half>) {
return __half2float(x);
} else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
return __bfloat162float(x);
} else {
return static_cast<float>(x);
}
}
// 从 float 转回目标标量类型
template <typename T>
__device__ __forceinline__ T from_float(float x) {
if constexpr (std::is_same_v<T, float>) {
return x;
} else if constexpr (std::is_same_v<T, half>) {
return __float2half(x);
} else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
return __float2bfloat16(x);
} else {
return static_cast<T>(x);
}
}
// --- METAX Kernel: 支持 stride 的数值稳定 BCE 计算 ---
template <typename Tdata, typename Taccum>
__global__ void bce_logits_kernel(
void *out_raw,
const Tdata *logits,
const Tdata *target,
const Tdata *weight,
const Tdata *pos_weight,
BCETensorInfoDevice logits_info,
BCETensorInfoDevice target_info,
BCETensorInfoDevice weight_info,
BCETensorInfoDevice out_info,
size_t n,
size_t pos_weight_len,
infiniopReduction_t reduction) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= n) {
return;
}
// 计算逻辑索引在各张量中的偏移(支持任意 stride)
size_t logits_offset = indexToOffset(idx, logits_info.ndim,
logits_info.shape, logits_info.strides);
size_t target_offset = indexToOffset(idx, target_info.ndim,
target_info.shape, target_info.strides);
float x = to_float(logits[logits_offset]);
float y = to_float(target[target_offset]);
float pw = 1.0f;
if (pos_weight && pos_weight_len > 0) {
// 按最后一维广播:假设 pos_weight 是连续的一维张量
size_t c = idx % pos_weight_len;
pw = to_float(pos_weight[c]);
}
float w = 1.0f;
if (weight && weight_info.ndim > 0) {
size_t weight_offset = indexToOffset(idx, weight_info.ndim,
weight_info.shape, weight_info.strides);
w = to_float(weight[weight_offset]);
}
// 数值稳定公式:max(x, 0) - x * y * pw + (1 + (pw - 1) * y) * log(1 + exp(-abs(x)))
float loss = (fmaxf(x, 0.0f) - x * y * pw + (1.0f + (pw - 1.0f) * y) * logf(1.0f + expf(-fabsf(x))));
loss *= w;
if (reduction == INFINIOP_REDUCTION_NONE) {
// 写回逐元素 loss(支持 stride 的 out)
size_t out_offset = indexToOffset(idx, out_info.ndim,
out_info.shape, out_info.strides);
auto *out_ptr = static_cast<Tdata *>(out_raw);
out_ptr[out_offset] = from_float<Tdata>(loss);
} else {
// 对于 mean 或 sum,使用 float 累加到标量位置
auto *out_accum = static_cast<Taccum *>(out_raw);
atomicAdd(out_accum, static_cast<Taccum>(loss));
}
}
// F32 mean 归约:对输出标量做除法
__global__ void bce_logits_mean_scale_kernel_f32(float *val, size_t count) {
if (threadIdx.x == 0 && blockIdx.x == 0) {
*val /= static_cast<float>(count);
}
}
// F16/BF16 归约:从 float workspace 写回目标 dtype
template <typename Tdata>
__global__ void bce_logits_reduce_finalize_kernel(
Tdata *out,
float *workspace,
size_t count,
int is_mean) {
if (threadIdx.x == 0 && blockIdx.x == 0) {
float v = *workspace;
if (is_mean) {
v /= static_cast<float>(count);
}
*out = from_float<Tdata>(v);
}
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *out,
const void *logits,
const void *target,
const void *weight,
const void *pos_weight,
void *stream) const {
mcStream_t custream = (mcStream_t)stream;
size_t n = _info.num_elements;
// F16/BF16 + 归约需要 float workspace
if (_reduction != INFINIOP_REDUCTION_NONE && (_dtype == INFINI_DTYPE_F16 || _dtype == INFINI_DTYPE_BF16)) {
if (workspace_size < sizeof(float)) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
}
int block = 256;
int grid = static_cast<int>((n + block - 1) / block);
// 构造 GPU 侧的张量信息(含 stride)
BCETensorInfoDevice logits_info = make_device_info(_info.logits);
BCETensorInfoDevice target_info = make_device_info(_info.target);
BCETensorInfoDevice out_info = make_device_info(_info.out);
BCETensorInfoDevice weight_info = {};
if (_info.weight.total_elements != 0) {
weight_info = make_device_info(_info.weight);
}
size_t pos_weight_len = _info.pos_weight.total_elements;
switch (_dtype) {
case INFINI_DTYPE_F32: {
// 如果是规约操作,计算前需将输出位置清零
if (_reduction != INFINIOP_REDUCTION_NONE) {
mcMemsetAsync(out, 0, sizeof(float), custream);
}
bce_logits_kernel<float, float><<<grid, block, 0, custream>>>(
out,
static_cast<const float *>(logits),
static_cast<const float *>(target),
static_cast<const float *>(weight),
static_cast<const float *>(pos_weight),
logits_info,
target_info,
weight_info,
out_info,
n,
pos_weight_len,
_reduction);
if (_reduction == INFINIOP_REDUCTION_MEAN) {
bce_logits_mean_scale_kernel_f32<<<1, 1, 0, custream>>>(
static_cast<float *>(out), n);
}
break;
}
case INFINI_DTYPE_F16: {
auto *logits_h = static_cast<const half *>(logits);
auto *target_h = static_cast<const half *>(target);
auto *weight_h = static_cast<const half *>(weight);
auto *pos_weight_h = static_cast<const half *>(pos_weight);
void *out_raw = nullptr;
float *workspace_f = nullptr;
if (_reduction == INFINIOP_REDUCTION_NONE) {
out_raw = out;
} else {
workspace_f = static_cast<float *>(workspace);
mcMemsetAsync(workspace_f, 0, sizeof(float), custream);
out_raw = workspace_f;
}
bce_logits_kernel<half, float><<<grid, block, 0, custream>>>(
out_raw,
logits_h,
target_h,
weight_h,
pos_weight_h,
logits_info,
target_info,
weight_info,
out_info,
n,
pos_weight_len,
_reduction);
if (_reduction != INFINIOP_REDUCTION_NONE) {
int is_mean = (_reduction == INFINIOP_REDUCTION_MEAN) ? 1 : 0;
bce_logits_reduce_finalize_kernel<half><<<1, 1, 0, custream>>>(
static_cast<half *>(out), workspace_f, n, is_mean);
}
break;
}
case INFINI_DTYPE_BF16: {
auto *logits_b = static_cast<const cuda_bfloat16 *>(logits);
auto *target_b = static_cast<const cuda_bfloat16 *>(target);
auto *weight_b = static_cast<const cuda_bfloat16 *>(weight);
auto *pos_weight_b = static_cast<const cuda_bfloat16 *>(pos_weight);
void *out_raw = nullptr;
float *workspace_f = nullptr;
if (_reduction == INFINIOP_REDUCTION_NONE) {
out_raw = out;
} else {
workspace_f = static_cast<float *>(workspace);
mcMemsetAsync(workspace_f, 0, sizeof(float), custream);
out_raw = workspace_f;
}
bce_logits_kernel<cuda_bfloat16, float><<<grid, block, 0, custream>>>(
out_raw,
logits_b,
target_b,
weight_b,
pos_weight_b,
logits_info,
target_info,
weight_info,
out_info,
n,
pos_weight_len,
_reduction);
if (_reduction != INFINIOP_REDUCTION_NONE) {
int is_mean = (_reduction == INFINIOP_REDUCTION_MEAN) ? 1 : 0;
bce_logits_reduce_finalize_kernel<cuda_bfloat16><<<1, 1, 0, custream>>>(
static_cast<cuda_bfloat16 *>(out), workspace_f, n, is_mean);
}
break;
}
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
mcError_t err = mcGetLastError();
if (err != mcSuccess) {
return INFINI_STATUS_INTERNAL_ERROR;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::bce_with_logits::metax
#ifndef __BINARY_CROSS_ENTROPY_WITH_LOGITS_MOORE_H__
#define __BINARY_CROSS_ENTROPY_WITH_LOGITS_MOORE_H__
#include "../binary_cross_entropy_with_logits.h"
/**
* 使用 bce_with_logits.h 中定义的 DESCRIPTOR 宏。
* 这将在命名空间 op::bce_with_logits::moore 中生成针对 Moore 设备的 Descriptor 类。
* * 在 Moore 端的实现(.mu 文件)中,Opaque 结构体通常包含:
* - musaHandle_t: 如果使用 MUSA 库的算子实现。
* - KernelConfig: 用于 MUSA Kernel 的网格(Grid)和线程块(Block)配置。
* - dataType: 对应的 MUSA 数据类型(如 MUSA_R_32F)。
*/
DESCRIPTOR(moore)
#endif // __BINARY_CROSS_ENTROPY_WITH_LOGITS_MOORE_H__
#ifndef __BINARY_CROSS_ENTROPY_WITH_LOGITS_NVIDIA_CUH__
#define __BINARY_CROSS_ENTROPY_WITH_LOGITS_NVIDIA_CUH__
#include "../binary_cross_entropy_with_logits.h"
/**
* 使用 bce_with_logits.h 中定义的 DESCRIPTOR 宏。
* 这将在命名空间 op::bce_with_logits::nvidia 中生成针对 NVIDIA 设备的 Descriptor 类。
* * * 在 NVIDIA 端的实现(.cu 文件)中,Opaque 结构体通常包含:
* - cudnnHandle_t: 如果使用 cuDNN 的算子实现。
* - cudnnTensorDescriptor_t: 用于描述各输入输出张量的 cuDNN 格式。
* - KernelConfig: 用于自定义 CUDA Kernel 的网格(Grid)和线程块(Block)配置。
* - dataType: 对应的 CUDA 数据类型(如 CUDA_R_32F)。
*/
DESCRIPTOR(nvidia)
#endif // __BINARY_CROSS_ENTROPY_WITH_LOGITS_NVIDIA_CUH__
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/binary_cross_entropy_with_logits.h"
// 引入各硬件后端的 Descriptor 定义
#ifdef ENABLE_CPU_API
#include "cpu/binary_cross_entropy_with_logits_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
#include "nvidia/binary_cross_entropy_with_logits_nvidia.cuh"
#endif
#ifdef ENABLE_CAMBRICON_API
#include "bang/binary_cross_entropy_with_logits_bang.h"
#endif
#ifdef ENABLE_ASCEND_API
#include "ascend/binary_cross_entropy_with_logits_ascend.h"
#endif
#ifdef ENABLE_METAX_API
#include "metax/binary_cross_entropy_with_logits_metax.h"
#endif
#ifdef ENABLE_MOORE_API
#include "moore/binary_cross_entropy_with_logits_moore.h"
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/binary_cross_entropy_with_logits_kunlun.h"
#endif
// -----------------------------------------------------------------------------
// 1. 创建描述符
// -----------------------------------------------------------------------------
__INFINI_C infiniStatus_t infiniopCreateBCEWithLogitsDescriptor(
infiniopHandle_t handle,
infiniopBCEWithLogitsDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t logits_desc,
infiniopTensorDescriptor_t target_desc,
infiniopTensorDescriptor_t weight_desc,
infiniopTensorDescriptor_t pos_weight_desc,
infiniopReduction_t reduction) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::bce_with_logits::NAMESPACE::Descriptor::create(handle, \
reinterpret_cast<op::bce_with_logits::NAMESPACE::Descriptor **>(desc_ptr), \
out_desc, logits_desc, target_desc, weight_desc, pos_weight_desc, reduction)
switch (handle->device) {
#ifdef ENABLE_CPU_API
CREATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
CREATE(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_CAMBRICON_API
CREATE(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_ASCEND_API
CREATE(INFINI_DEVICE_ASCEND, ascend);
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, moore);
#endif
#ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CREATE
}
// -----------------------------------------------------------------------------
// 2. 获取 Workspace 大小
// -----------------------------------------------------------------------------
__INFINI_C infiniStatus_t infiniopGetBCEWithLogitsWorkspaceSize(
infiniopBCEWithLogitsDescriptor_t desc,
size_t *size) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<const op::bce_with_logits::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
GET(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
GET(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_CAMBRICON_API
GET(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_ASCEND_API
GET(INFINI_DEVICE_ASCEND, ascend);
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, moore);
#endif
#ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef GET
}
// -----------------------------------------------------------------------------
// 3. 执行计算
// -----------------------------------------------------------------------------
__INFINI_C infiniStatus_t infiniopBCEWithLogits(
infiniopBCEWithLogitsDescriptor_t desc,
void *workspace, size_t workspace_size,
void *out,
const void *logits,
const void *target,
const void *weight,
const void *pos_weight,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::bce_with_logits::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, out, logits, target, weight, pos_weight, stream)
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
CALCULATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
CALCULATE(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_CAMBRICON_API
CALCULATE(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_ASCEND_API
CALCULATE(INFINI_DEVICE_ASCEND, ascend);
#endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, moore);
#endif
#ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CALCULATE
}
// -----------------------------------------------------------------------------
// 4. 销毁描述符
// -----------------------------------------------------------------------------
__INFINI_C infiniStatus_t infiniopDestroyBCEWithLogitsDescriptor(infiniopBCEWithLogitsDescriptor_t desc) {
#define DELETE(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<const op::bce_with_logits::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
DELETE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
DELETE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
DELETE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
DELETE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
DELETE(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_CAMBRICON_API
DELETE(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_ASCEND_API
DELETE(INFINI_DEVICE_ASCEND, ascend);
#endif
#ifdef ENABLE_METAX_API
DELETE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_MOORE_API
DELETE(INFINI_DEVICE_MOORE, moore);
#endif
#ifdef ENABLE_KUNLUN_API
DELETE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef DELETE
}
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment