Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
b9dd0004
Unverified
Commit
b9dd0004
authored
Sep 16, 2025
by
PanZezhong1725
Committed by
GitHub
Sep 16, 2025
Browse files
Merge pull request #438 from InfiniTensor/issue/434-metax
issue/434 hccl support bf16
parents
f9d16628
3bb0c930
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
86 additions
and
4 deletions
+86
-4
src/infiniccl-test/infiniccl_test.cpp
src/infiniccl-test/infiniccl_test.cpp
+15
-1
src/infiniccl/metax/infiniccl_metax.cc
src/infiniccl/metax/infiniccl_metax.cc
+3
-3
src/infiniop/ops/softplus/metax/softplus_metax.h
src/infiniop/ops/softplus/metax/softplus_metax.h
+8
-0
src/infiniop/ops/softplus/metax/softplus_metax.maca
src/infiniop/ops/softplus/metax/softplus_metax.maca
+60
-0
No files found.
src/infiniccl-test/infiniccl_test.cpp
View file @
b9dd0004
...
...
@@ -11,6 +11,7 @@
#define TEST_INFINI_THREAD(API__) CHECK_API_OR(API__, INFINI_STATUS_SUCCESS, return nullptr)
const
size_t
MAX_COUNT
=
8ULL
*
1024
*
1024
;
// const size_t MAX_COUNT = 512 * 1024; // for metax
const
size_t
TEST_COUNTS
[]
=
{
128
,
...
...
@@ -19,7 +20,7 @@ const size_t TEST_COUNTS[] = {
MAX_COUNT
,
};
const
infiniDtype_t
TEST_DTYPES
[]
=
{
INFINI_DTYPE_F32
,
INFINI_DTYPE_F16
};
const
infiniDtype_t
TEST_DTYPES
[]
=
{
INFINI_DTYPE_F32
,
INFINI_DTYPE_F16
,
INFINI_DTYPE_BF16
};
const
size_t
WARM_UPS
=
10
;
...
...
@@ -51,6 +52,11 @@ void setData(infiniDtype_t dtype, void *data, size_t count, float val) {
((
fp16_t
*
)
data
)[
i
]
=
utils
::
cast
<
fp16_t
>
(
val
);
}
break
;
case
INFINI_DTYPE_BF16
:
for
(
size_t
i
=
0
;
i
<
count
;
i
++
)
{
((
bf16_t
*
)
data
)[
i
]
=
utils
::
cast
<
bf16_t
>
(
val
);
}
break
;
default:
std
::
abort
();
break
;
...
...
@@ -67,6 +73,12 @@ int checkData(const T *actual_, const T *expected_, size_t count) {
if
(
std
::
abs
(
actual
-
expected
)
>
1e-4
)
{
failed
+=
1
;
}
}
else
if
constexpr
(
std
::
is_same
<
T
,
bf16_t
>::
value
)
{
float
actual
=
utils
::
cast
<
float
>
(
actual_
[
i
]);
float
expected
=
utils
::
cast
<
float
>
(
expected_
[
i
]);
if
(
std
::
abs
(
actual
-
expected
)
>
1e-4
)
{
failed
+=
1
;
}
}
else
{
if
(
std
::
abs
(
actual_
[
i
]
-
expected_
[
i
])
>
1e-4
)
{
failed
+=
1
;
...
...
@@ -82,6 +94,8 @@ int checkData(const void *actual, const void *expected, infiniDtype_t dtype, siz
return
checkData
((
const
float
*
)
actual
,
(
const
float
*
)
expected
,
count
);
case
INFINI_DTYPE_F16
:
return
checkData
((
const
fp16_t
*
)
actual
,
(
const
fp16_t
*
)
expected
,
count
);
case
INFINI_DTYPE_BF16
:
return
checkData
((
const
bf16_t
*
)
actual
,
(
const
bf16_t
*
)
expected
,
count
);
default:
std
::
abort
();
return
1
;
...
...
src/infiniccl/metax/infiniccl_metax.cc
View file @
b9dd0004
...
...
@@ -23,6 +23,8 @@ inline hcclDataType_t getHcclDtype(infiniDtype_t datatype) {
return
hcclFloat
;
case
INFINI_DTYPE_F16
:
return
hcclHalf
;
case
INFINI_DTYPE_BF16
:
return
hcclBfloat16
;
default:
std
::
abort
();
return
hcclHalf
;
...
...
@@ -83,9 +85,7 @@ infiniStatus_t allReduce(
infinicclComm_t
comm
,
infinirtStream_t
stream
)
{
if
(
datatype
!=
INFINI_DTYPE_F32
&&
datatype
!=
INFINI_DTYPE_F16
)
{
return
INFINI_STATUS_BAD_PARAM
;
}
CHECK_DTYPE
(
datatype
,
INFINI_DTYPE_F32
,
INFINI_DTYPE_F16
,
INFINI_DTYPE_BF16
);
CHECK_HCCL
(
hcclAllReduce
(
sendbuf
,
recvbuf
,
count
,
getHcclDtype
(
datatype
),
getHcclRedOp
(
op
),
getHcclComm
(
comm
),
getMacaStream
(
stream
)));
...
...
src/infiniop/ops/softplus/metax/softplus_metax.h
0 → 100644
View file @
b9dd0004
#ifndef __SOFTPLUS_METAX_API_H__
#define __SOFTPLUS_METAX_API_H__
#include "../../../elementwise/metax/elementwise_metax_api.h"
ELEMENTWISE_DESCRIPTOR
(
softplus
,
metax
)
#endif // __SOFTPLUS_METAX_API_H__
src/infiniop/ops/softplus/metax/softplus_metax.maca
0 → 100644
View file @
b9dd0004
#include "softplus_metax.h"
#include "../../../elementwise/metax/elementwise_metax.h"
#include "../cuda/kernel.cuh"
namespace op::softplus::metax {
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
auto handle = reinterpret_cast<device::metax::Handle *>(handle_);
auto dtype = out_desc->dtype();
const auto &x_desc = input_desc_vec.at(0);
const auto &y_shape = out_desc->shape();
const auto &x_shape = x_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16);
CHECK_SAME_SHAPE(y_shape, x_shape);
// create METAX elementwise descriptor
CREATE_ELEMENTWISE_METAX_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<256, cuda::SoftplusOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_BF16:
return _device_info->calculate<256, cuda::SoftplusOp, cuda_bfloat16>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<256, cuda::SoftplusOp, float>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F64:
return _device_info->calculate<256, cuda::SoftplusOp, double>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::softplus::metax
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment