Commit d93e352b authored by zhuyue's avatar zhuyue
Browse files

Issue/645 - Fix metax add rms_norm operators.

parent 85bc98ac
......@@ -35,10 +35,15 @@
#define HCDNN_DATA_INT64 MCDNN_DATA_INT64
#define HCDNN_DATA_UINT8 MCDNN_DATA_UINT8
#define hcEventCreate mcEventCreate
#define hcEventCreateWithFlags mcEventCreateWithFlags
#define hcEventDefault mcEventDefault
#define hcEventDisableTiming mcEventDisableTiming
#define hcEventBlockingSync mcEventBlockingSync
#define hcEventRecord mcEventRecord
#define hcEventQuery mcEventQuery
#define hcEventSynchronize mcEventSynchronize
#define hcEventDestroy mcEventDestroy
#define hcEventElapsedTime mcEventElapsedTime
#define hcMalloc mcMalloc
#define hpccDataType macaDataType
#define hcblasComputeType_t mcblasComputeType_t
......
......@@ -23,7 +23,7 @@ infiniStatus_t Descriptor::create(
const auto &a_shape = a_desc->shape();
const auto &b_shape = b_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16);
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16, INFINI_DTYPE_I32, INFINI_DTYPE_I64);
CHECK_SAME_SHAPE(c_shape, a_shape, b_shape);
......@@ -53,6 +53,10 @@ infiniStatus_t Descriptor::calculate(
return _device_info->calculate<256, cuda::AddOp, float>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F64:
return _device_info->calculate<256, cuda::AddOp, double>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_I32:
return _device_info->calculate<256, cuda::AddOp, int32_t>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_I64:
return _device_info->calculate<256, cuda::AddOp, int64_t>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......
......@@ -83,6 +83,10 @@ infiniStatus_t launchKernel(
LAUNCH_KERNEL(__hpcc_bfloat16, float, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(half, float, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(half, __hpcc_bfloat16, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(__hpcc_bfloat16, half, float);
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(float, float, float);
} else {
......
......@@ -8,6 +8,9 @@
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#include "nvidia/softplus_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
#include "metax/softplus_metax.h"
#endif
__C infiniStatus_t infiniopCreateSoftplusDescriptor(
infiniopHandle_t handle,
......@@ -37,6 +40,9 @@ __C infiniStatus_t infiniopCreateSoftplusDescriptor(
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -65,6 +71,9 @@ __C infiniStatus_t infiniopGetSoftplusWorkspaceSize(infiniopSoftplusDescriptor_t
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -101,6 +110,9 @@ __C infiniStatus_t infiniopSoftplus(
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -131,6 +143,9 @@ infiniopDestroySoftplusDescriptor(infiniopSoftplusDescriptor_t desc) {
#ifdef ENABLE_QY_API
DELETE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
DELETE(INFINI_DEVICE_METAX, metax);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment