Commit 33d0f769 authored by zhuyue's avatar zhuyue
Browse files

Issue/658 - Add Moore platform support for add, mul, and silu operations

- Implement Moore backend for add, mul, and silu elementwise operations
- Filter unsupported dtypes (BF16, F64) for Moore platform in tests
parent dfd1341e
#ifndef __ADD_MOORE_API_H__
#define __ADD_MOORE_API_H__
#include "../../../elementwise/moore/elementwise_moore_api.h"
ELEMENTWISE_DESCRIPTOR(add, moore)
#endif // __ADD_MOORE_API_H__
#include "add_moore.h"
#include "../../../elementwise/moore/elementwise_moore.h"
#include "add_moore_kernel.h"
namespace op::add::moore {
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
auto handle = reinterpret_cast<device::moore::Handle *>(handle_);
auto dtype = out_desc->dtype();
const auto &a_desc = input_desc_vec.at(0);
const auto &b_desc = input_desc_vec.at(1);
const auto &c_shape = out_desc->shape();
const auto &a_shape = a_desc->shape();
const auto &b_shape = b_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16, INFINI_DTYPE_I32, INFINI_DTYPE_I64);
CHECK_SAME_SHAPE(c_shape, a_shape, b_shape);
// create MOORE elementwise descriptor
CREATE_ELEMENTWISE_MOORE_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<256, moore::AddOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_BF16:
return _device_info->calculate<256, moore::AddOp, cuda_bfloat16>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<256, moore::AddOp, float>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F64:
return _device_info->calculate<256, moore::AddOp, double>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_I32:
return _device_info->calculate<256, moore::AddOp, int32_t>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_I64:
return _device_info->calculate<256, moore::AddOp, int64_t>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::add::moore
#ifndef __ADD_MOORE_KERNEL_H__
#define __ADD_MOORE_KERNEL_H__
/*
* This file contains the Add operation implementation for the MUSA backend.
*
* It uses the 'op::add::cuda' namespace to maintain a consistent code structure
* and interface with the CUDA implementation, ensuring code alignment across different
* hardware platforms.
*/
namespace op::add::moore {
typedef struct AddOp {
public:
static constexpr size_t num_inputs = 2;
template <typename T>
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
if constexpr (std::is_same_v<T, half2>) {
return __hadd2(a, b);
} else if constexpr (std::is_same_v<T, half>) {
return __hadd(a, b);
} else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
// On MUSA platform, convert to float, add, then convert back to avoid ambiguous conversion
// from int (returned by __hadd) to __mt_bfloat16
float a_f = __bfloat162float(a);
float b_f = __bfloat162float(b);
return __float2bfloat16_rn(a_f + b_f);
} else if constexpr (std::is_same_v<T, float>) {
// Use __fadd_rn instead of __fadd_rd for moore platform compatibility
return __fadd_rn(a, b);
} else {
return a + b;
}
}
} AddOp;
} // namespace op::add::moore
#endif // __ADD_MOORE_KERNEL_H__
......@@ -17,6 +17,9 @@
#ifdef ENABLE_CAMBRICON_API
#include "bang/add_bang.h"
#endif
#ifdef ENABLE_MOORE_API
#include "moore/add_moore.h"
#endif
__C infiniStatus_t infiniopCreateAddDescriptor(
infiniopHandle_t handle,
......@@ -57,6 +60,9 @@ __C infiniStatus_t infiniopCreateAddDescriptor(
#ifdef ENABLE_CAMBRICON_API
CREATE(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, moore);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -93,6 +99,9 @@ __C infiniStatus_t infiniopGetAddWorkspaceSize(infiniopAddDescriptor_t desc, siz
#endif
#ifdef ENABLE_CAMBRICON_API
GET(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, moore);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -139,6 +148,9 @@ __C infiniStatus_t infiniopAdd(
#ifdef ENABLE_CAMBRICON_API
CALCULATE(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, moore);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -178,6 +190,9 @@ infiniopDestroyAddDescriptor(infiniopAddDescriptor_t desc) {
#ifdef ENABLE_CAMBRICON_API
DELETE(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_MOORE_API
DELETE(INFINI_DEVICE_MOORE, moore);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
#ifndef __MUL_MOORE_API_H__
#define __MUL_MOORE_API_H__
#include "../../../elementwise/moore/elementwise_moore_api.h"
ELEMENTWISE_DESCRIPTOR(mul, moore)
#endif // __MUL_MOORE_API_H__
#include "mul_moore.h"
#include "../../../elementwise/moore/elementwise_moore.h"
#include "mul_moore_kernel.h"
namespace op::mul::moore {
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
auto handle = reinterpret_cast<device::moore::Handle *>(handle_);
auto dtype = out_desc->dtype();
const auto &a_desc = input_desc_vec.at(0);
const auto &b_desc = input_desc_vec.at(1);
const auto &c_shape = out_desc->shape();
const auto &a_shape = a_desc->shape();
const auto &b_shape = b_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16);
CHECK_SAME_SHAPE(c_shape, a_shape, b_shape);
// create MOORE elementwise descriptor
CREATE_ELEMENTWISE_MOORE_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<256, moore::MulOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_BF16:
return _device_info->calculate<256, moore::MulOp, cuda_bfloat16>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<256, moore::MulOp, float>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F64:
return _device_info->calculate<256, moore::MulOp, double>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::mul::moore
#ifndef __MUL_MOORE_KERNEL_H__
#define __MUL_MOORE_KERNEL_H__
/*
* This file contains the Mul operation implementation for the MUSA backend.
*
* It uses the 'op::mul::cuda' namespace to maintain a consistent code structure
* and interface with the CUDA implementation, ensuring code alignment across different
* hardware platforms.
*/
namespace op::mul::moore {
typedef struct MulOp {
public:
static constexpr size_t num_inputs = 2;
template <typename T>
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
if constexpr (std::is_same_v<T, half2>) {
return __hmul2(a, b);
} else if constexpr (std::is_same_v<T, half>) {
return __hmul(a, b);
} else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
// On MUSA platform, convert to float, multiply, then convert back
float a_f = __bfloat162float(a);
float b_f = __bfloat162float(b);
return __float2bfloat16_rn(a_f * b_f);
} else if constexpr (std::is_same_v<T, float>) {
// Use __fmul_rn for moore platform compatibility
return __fmul_rn(a, b);
} else {
return a * b;
}
}
} MulOp;
} // namespace op::mul::moore
#endif // __MUL_MOORE_KERNEL_H__
......@@ -14,6 +14,9 @@
#ifdef ENABLE_KUNLUN_API
#include "kunlun/mul_kunlun.h"
#endif
#ifdef ENABLE_MOORE_API
#include "moore/mul_moore.h"
#endif
__C infiniStatus_t infiniopCreateMulDescriptor(
infiniopHandle_t handle,
......@@ -51,6 +54,9 @@ __C infiniStatus_t infiniopCreateMulDescriptor(
#ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, moore);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -85,6 +91,9 @@ __C infiniStatus_t infiniopGetMulWorkspaceSize(infiniopMulDescriptor_t desc, siz
#ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, moore);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -128,6 +137,9 @@ __C infiniStatus_t infiniopMul(
#ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, moore);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -164,6 +176,9 @@ infiniopDestroyMulDescriptor(infiniopMulDescriptor_t desc) {
#ifdef ENABLE_KUNLUN_API
DELETE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_MOORE_API
DELETE(INFINI_DEVICE_MOORE, moore);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
#ifndef __SILU_MOORE_API_H__
#define __SILU_MOORE_API_H__
#include "../../../elementwise/moore/elementwise_moore_api.h"
ELEMENTWISE_DESCRIPTOR(silu, moore)
#endif // __SILU_MOORE_API_H__
#include "silu_moore.h"
#include "../../../elementwise/moore/elementwise_moore.h"
#include "silu_moore_kernel.h"
namespace op::silu::moore {
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
auto handle = reinterpret_cast<device::moore::Handle *>(handle_);
auto dtype = out_desc->dtype();
const auto &input_desc = input_desc_vec.at(0);
const auto &output_shape = out_desc->shape();
const auto &input_shape = input_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_BF16, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_SAME_SHAPE(output_shape, input_shape);
// create MOORE elementwise descriptor
CREATE_ELEMENTWISE_MOORE_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
switch (_dtype) {
case INFINI_DTYPE_BF16:
return _device_info->calculate<256, moore::SiluOp, cuda_bfloat16>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F16:
return _device_info->calculate<256, moore::SiluOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<256, moore::SiluOp, float>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F64:
return _device_info->calculate<256, moore::SiluOp, double>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::silu::moore
#ifndef __SILU_MOORE_KERNEL_H__
#define __SILU_MOORE_KERNEL_H__
#include <cmath>
namespace op::silu::moore {
typedef struct SiluOp {
public:
static constexpr size_t num_inputs = 1;
template <typename T>
__device__ __forceinline__ T operator()(const T &x) const {
if constexpr (std::is_same_v<T, half2>) {
// half2 vectorized optimization
return __hmul2(x, __h2div(__float2half2_rn(1.0f),
__hadd2(__float2half2_rn(1.0f), h2exp(__hneg2(x)))));
} else if constexpr (std::is_same_v<T, half>) {
// FP16: convert to float, calculate, then convert back for MUSA platform compatibility
float x_f = __half2float(x);
float sigmoid_f = 1.0f / (1.0f + __expf(-x_f));
return __float2half(x_f * sigmoid_f);
} else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
// BF16: convert to float, calculate, then convert back
float x_f = __bfloat162float(x);
float sigmoid_f = 1.0f / (1.0f + __expf(-x_f));
return __float2bfloat16_rn(x_f * sigmoid_f);
} else if constexpr (std::is_same_v<T, float>) {
// FP32: use __frcp_rn and __expf for moore platform compatibility
return __fmul_rn(x, __frcp_rn(__fadd_rn(1.0f, __expf(-x))));
} else if constexpr (std::is_same_v<T, double>) {
// FP64
return x / (1.0 + exp(-x));
}
}
} SiluOp;
} // namespace op::silu::moore
#endif // __SILU_MOORE_KERNEL_H__
......@@ -11,6 +11,9 @@
#ifdef ENABLE_METAX_API
#include "metax/silu_metax.h"
#endif
#ifdef ENABLE_MOORE_API
#include "moore/silu_moore.h"
#endif
__C infiniStatus_t infiniopCreateSiluDescriptor(
infiniopHandle_t handle,
......@@ -40,6 +43,9 @@ __C infiniStatus_t infiniopCreateSiluDescriptor(
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, moore);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -67,6 +73,9 @@ __C infiniStatus_t infiniopGetSiluWorkspaceSize(infiniopSiluDescriptor_t desc, s
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, moore);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -103,6 +112,9 @@ __C infiniStatus_t infiniopSilu(
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, moore);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -133,6 +145,9 @@ infiniopDestroySiluDescriptor(infiniopSiluDescriptor_t desc) {
#ifdef ENABLE_METAX_API
DELETE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_MOORE_API
DELETE(INFINI_DEVICE_MOORE, moore);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
......@@ -465,6 +465,9 @@ def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
def filter_tensor_dtypes_by_device(device, tensor_dtypes):
if device in (InfiniDeviceEnum.CPU, InfiniDeviceEnum.NVIDIA):
return tensor_dtypes
elif device == InfiniDeviceEnum.MOORE:
# 过滤掉 BF16 和 F64(PyTorch 在摩尔平台上不支持这些类型的某些操作)
return [dt for dt in tensor_dtypes if dt != InfiniDtype.BF16 and dt != InfiniDtype.F64]
else:
# 过滤掉 torch.bfloat16
return [dt for dt in tensor_dtypes if dt != torch.bfloat16]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment