Unverified Commit b9dd0004 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #438 from InfiniTensor/issue/434-metax

issue/434 hccl support bf16
parents f9d16628 3bb0c930
......@@ -11,6 +11,7 @@
#define TEST_INFINI_THREAD(API__) CHECK_API_OR(API__, INFINI_STATUS_SUCCESS, return nullptr)
const size_t MAX_COUNT = 8ULL * 1024 * 1024;
// const size_t MAX_COUNT = 512 * 1024; // for metax
const size_t TEST_COUNTS[] = {
128,
......@@ -19,7 +20,7 @@ const size_t TEST_COUNTS[] = {
MAX_COUNT,
};
const infiniDtype_t TEST_DTYPES[] = {INFINI_DTYPE_F32, INFINI_DTYPE_F16};
const infiniDtype_t TEST_DTYPES[] = {INFINI_DTYPE_F32, INFINI_DTYPE_F16, INFINI_DTYPE_BF16};
const size_t WARM_UPS = 10;
......@@ -51,6 +52,11 @@ void setData(infiniDtype_t dtype, void *data, size_t count, float val) {
((fp16_t *)data)[i] = utils::cast<fp16_t>(val);
}
break;
case INFINI_DTYPE_BF16:
for (size_t i = 0; i < count; i++) {
((bf16_t *)data)[i] = utils::cast<bf16_t>(val);
}
break;
default:
std::abort();
break;
......@@ -67,6 +73,12 @@ int checkData(const T *actual_, const T *expected_, size_t count) {
if (std::abs(actual - expected) > 1e-4) {
failed += 1;
}
} else if constexpr (std::is_same<T, bf16_t>::value) {
float actual = utils::cast<float>(actual_[i]);
float expected = utils::cast<float>(expected_[i]);
if (std::abs(actual - expected) > 1e-4) {
failed += 1;
}
} else {
if (std::abs(actual_[i] - expected_[i]) > 1e-4) {
failed += 1;
......@@ -82,6 +94,8 @@ int checkData(const void *actual, const void *expected, infiniDtype_t dtype, siz
return checkData((const float *)actual, (const float *)expected, count);
case INFINI_DTYPE_F16:
return checkData((const fp16_t *)actual, (const fp16_t *)expected, count);
case INFINI_DTYPE_BF16:
return checkData((const bf16_t *)actual, (const bf16_t *)expected, count);
default:
std::abort();
return 1;
......
......@@ -23,6 +23,8 @@ inline hcclDataType_t getHcclDtype(infiniDtype_t datatype) {
return hcclFloat;
case INFINI_DTYPE_F16:
return hcclHalf;
case INFINI_DTYPE_BF16:
return hcclBfloat16;
default:
std::abort();
return hcclHalf;
......@@ -83,9 +85,7 @@ infiniStatus_t allReduce(
infinicclComm_t comm,
infinirtStream_t stream) {
if (datatype != INFINI_DTYPE_F32 && datatype != INFINI_DTYPE_F16) {
return INFINI_STATUS_BAD_PARAM;
}
CHECK_DTYPE(datatype, INFINI_DTYPE_F32, INFINI_DTYPE_F16, INFINI_DTYPE_BF16);
CHECK_HCCL(hcclAllReduce(sendbuf, recvbuf, count, getHcclDtype(datatype),
getHcclRedOp(op), getHcclComm(comm), getMacaStream(stream)));
......
#ifndef __SOFTPLUS_METAX_API_H__
#define __SOFTPLUS_METAX_API_H__
#include "../../../elementwise/metax/elementwise_metax_api.h"
ELEMENTWISE_DESCRIPTOR(softplus, metax)
#endif // __SOFTPLUS_METAX_API_H__
#include "softplus_metax.h"
#include "../../../elementwise/metax/elementwise_metax.h"
#include "../cuda/kernel.cuh"
namespace op::softplus::metax {
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
auto handle = reinterpret_cast<device::metax::Handle *>(handle_);
auto dtype = out_desc->dtype();
const auto &x_desc = input_desc_vec.at(0);
const auto &y_shape = out_desc->shape();
const auto &x_shape = x_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16);
CHECK_SAME_SHAPE(y_shape, x_shape);
// create METAX elementwise descriptor
CREATE_ELEMENTWISE_METAX_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<256, cuda::SoftplusOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_BF16:
return _device_info->calculate<256, cuda::SoftplusOp, cuda_bfloat16>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<256, cuda::SoftplusOp, float>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F64:
return _device_info->calculate<256, cuda::SoftplusOp, double>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::softplus::metax
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment