Unverified Commit edc11eb5 authored by gongchensu's avatar gongchensu Committed by GitHub
Browse files

Merge pull request #647 from gongchensu/feature/metax_op_fixed

Fix metax add rms_norm operators.
parents 85bc98ac d93e352b
...@@ -35,10 +35,15 @@ ...@@ -35,10 +35,15 @@
#define HCDNN_DATA_INT64 MCDNN_DATA_INT64 #define HCDNN_DATA_INT64 MCDNN_DATA_INT64
#define HCDNN_DATA_UINT8 MCDNN_DATA_UINT8 #define HCDNN_DATA_UINT8 MCDNN_DATA_UINT8
#define hcEventCreate mcEventCreate #define hcEventCreate mcEventCreate
#define hcEventCreateWithFlags mcEventCreateWithFlags
#define hcEventDefault mcEventDefault
#define hcEventDisableTiming mcEventDisableTiming
#define hcEventBlockingSync mcEventBlockingSync
#define hcEventRecord mcEventRecord #define hcEventRecord mcEventRecord
#define hcEventQuery mcEventQuery #define hcEventQuery mcEventQuery
#define hcEventSynchronize mcEventSynchronize #define hcEventSynchronize mcEventSynchronize
#define hcEventDestroy mcEventDestroy #define hcEventDestroy mcEventDestroy
#define hcEventElapsedTime mcEventElapsedTime
#define hcMalloc mcMalloc #define hcMalloc mcMalloc
#define hpccDataType macaDataType #define hpccDataType macaDataType
#define hcblasComputeType_t mcblasComputeType_t #define hcblasComputeType_t mcblasComputeType_t
......
...@@ -23,7 +23,7 @@ infiniStatus_t Descriptor::create( ...@@ -23,7 +23,7 @@ infiniStatus_t Descriptor::create(
const auto &a_shape = a_desc->shape(); const auto &a_shape = a_desc->shape();
const auto &b_shape = b_desc->shape(); const auto &b_shape = b_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16); CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16, INFINI_DTYPE_I32, INFINI_DTYPE_I64);
CHECK_SAME_SHAPE(c_shape, a_shape, b_shape); CHECK_SAME_SHAPE(c_shape, a_shape, b_shape);
...@@ -53,6 +53,10 @@ infiniStatus_t Descriptor::calculate( ...@@ -53,6 +53,10 @@ infiniStatus_t Descriptor::calculate(
return _device_info->calculate<256, cuda::AddOp, float>(_info, workspace, output, inputs, stream); return _device_info->calculate<256, cuda::AddOp, float>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F64: case INFINI_DTYPE_F64:
return _device_info->calculate<256, cuda::AddOp, double>(_info, workspace, output, inputs, stream); return _device_info->calculate<256, cuda::AddOp, double>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_I32:
return _device_info->calculate<256, cuda::AddOp, int32_t>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_I64:
return _device_info->calculate<256, cuda::AddOp, int64_t>(_info, workspace, output, inputs, stream);
default: default:
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
......
...@@ -83,6 +83,10 @@ infiniStatus_t launchKernel( ...@@ -83,6 +83,10 @@ infiniStatus_t launchKernel(
LAUNCH_KERNEL(__hpcc_bfloat16, float, float); LAUNCH_KERNEL(__hpcc_bfloat16, float, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) { } else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(half, float, float); LAUNCH_KERNEL(half, float, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(half, __hpcc_bfloat16, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(__hpcc_bfloat16, half, float);
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) { } else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(float, float, float); LAUNCH_KERNEL(float, float, float);
} else { } else {
......
...@@ -8,6 +8,9 @@ ...@@ -8,6 +8,9 @@
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#include "nvidia/softplus_nvidia.cuh" #include "nvidia/softplus_nvidia.cuh"
#endif #endif
#ifdef ENABLE_METAX_API
#include "metax/softplus_metax.h"
#endif
__C infiniStatus_t infiniopCreateSoftplusDescriptor( __C infiniStatus_t infiniopCreateSoftplusDescriptor(
infiniopHandle_t handle, infiniopHandle_t handle,
...@@ -37,6 +40,9 @@ __C infiniStatus_t infiniopCreateSoftplusDescriptor( ...@@ -37,6 +40,9 @@ __C infiniStatus_t infiniopCreateSoftplusDescriptor(
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia); CREATE(INFINI_DEVICE_QY, nvidia);
#endif #endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax);
#endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -65,6 +71,9 @@ __C infiniStatus_t infiniopGetSoftplusWorkspaceSize(infiniopSoftplusDescriptor_t ...@@ -65,6 +71,9 @@ __C infiniStatus_t infiniopGetSoftplusWorkspaceSize(infiniopSoftplusDescriptor_t
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia); GET(INFINI_DEVICE_QY, nvidia);
#endif #endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax);
#endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -101,6 +110,9 @@ __C infiniStatus_t infiniopSoftplus( ...@@ -101,6 +110,9 @@ __C infiniStatus_t infiniopSoftplus(
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia); CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif #endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -131,6 +143,9 @@ infiniopDestroySoftplusDescriptor(infiniopSoftplusDescriptor_t desc) { ...@@ -131,6 +143,9 @@ infiniopDestroySoftplusDescriptor(infiniopSoftplusDescriptor_t desc) {
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
DELETE(INFINI_DEVICE_QY, nvidia); DELETE(INFINI_DEVICE_QY, nvidia);
#endif #endif
#ifdef ENABLE_METAX_API
DELETE(INFINI_DEVICE_METAX, metax);
#endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment