Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
82e43d35
Unverified
Commit
82e43d35
authored
Aug 20, 2025
by
PanZezhong1725
Committed by
GitHub
Aug 20, 2025
Browse files
Merge pull request #364 from InfiniTensor/issue/240
Issue/240 - Cambricon Reduce
parents
d4b03cf7
00c30a4a
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
735 additions
and
62 deletions
+735
-62
src/infiniop-test/src/main.cpp
src/infiniop-test/src/main.cpp
+1
-1
src/infiniop/ops/causal_softmax/bang/causal_softmax_bang.h
src/infiniop/ops/causal_softmax/bang/causal_softmax_bang.h
+7
-0
src/infiniop/ops/causal_softmax/bang/causal_softmax_bang.mlu
src/infiniop/ops/causal_softmax/bang/causal_softmax_bang.mlu
+200
-0
src/infiniop/ops/causal_softmax/operator.cc
src/infiniop/ops/causal_softmax/operator.cc
+14
-43
src/infiniop/ops/rms_norm/bang/rms_norm_bang.h
src/infiniop/ops/rms_norm/bang/rms_norm_bang.h
+8
-0
src/infiniop/ops/rms_norm/bang/rms_norm_bang.mlu
src/infiniop/ops/rms_norm/bang/rms_norm_bang.mlu
+211
-0
src/infiniop/ops/rms_norm/operator.cc
src/infiniop/ops/rms_norm/operator.cc
+11
-16
src/infiniop/reduce/bang/reduce_bang.h
src/infiniop/reduce/bang/reduce_bang.h
+281
-0
test/infiniop/causal_softmax.py
test/infiniop/causal_softmax.py
+1
-1
test/infiniop/rms_norm.py
test/infiniop/rms_norm.py
+1
-1
No files found.
src/infiniop-test/src/main.cpp
View file @
82e43d35
...
@@ -9,7 +9,7 @@ struct ParsedArgs {
...
@@ -9,7 +9,7 @@ struct ParsedArgs {
int
device_id
=
0
;
// CUDA device ID (if specified)
int
device_id
=
0
;
// CUDA device ID (if specified)
int
warmups
=
0
;
// Default to 0 if not given
int
warmups
=
0
;
// Default to 0 if not given
int
iterations
=
0
;
// Default to 0 if not given
int
iterations
=
0
;
// Default to 0 if not given
double
atol
=
0.001
;
// Default absolute tolerance
double
atol
=
0.001
5
;
// Default absolute tolerance
double
rtol
=
0.001
;
// Default relative tolerance
double
rtol
=
0.001
;
// Default relative tolerance
};
};
...
...
src/infiniop/ops/causal_softmax/bang/causal_softmax_bang.h
0 → 100644
View file @
82e43d35
#ifndef __CAUSAL_SOFTMAX_BANG_H__
#define __CAUSAL_SOFTMAX_BANG_H__
#include "../causal_softmax.h"
DESCRIPTOR
(
bang
)
#endif // __CAUSAL_SOFTMAX_BANG_H__
src/infiniop/ops/causal_softmax/bang/causal_softmax_bang.mlu
0 → 100644
View file @
82e43d35
#include "../../../devices/bang/common_bang.h"
#include "../../../reduce/bang/reduce_bang.h"
#include "causal_softmax_bang.h"
__nram__ char nram_buffer[NRAM_MAX_SIZE];
const int SRC_MAX_SIZE = NRAM_MAX_SIZE / 4;
template <typename T>
__mlu_func__ void processSoftmaxStep(T *output, const T *input, float scalar, int num_elements, int stride, bool is_exp_phase) {
// Calculate buffer sizes (split between float and T buffers)
constexpr bool is_half = std::is_same_v<T, half>;
constexpr bool is_bfloat16 = std::is_same_v<T, bfloat16_t>;
constexpr bool is_float = !is_half && !is_bfloat16;
const int chunk_size = SRC_MAX_SIZE / ((is_half || is_bfloat16) ? (2 * sizeof(float)) : sizeof(float));
float *float_buffer = (float *)nram_buffer;
T *temp_buffer = is_float ? nullptr : (T *)(nram_buffer + chunk_size * sizeof(float));
// Common stride configurations
const int src_stride = stride * sizeof(T);
const int dst_stride = stride * sizeof(T);
int processed = 0;
while (processed < num_elements) {
int curr_batch = std::min(chunk_size, num_elements - processed);
// Gather input elements using 2D memcpy
if constexpr (is_float) {
__memcpy(float_buffer, (is_exp_phase ? input : output) + processed * stride, sizeof(float),
GDRAM2NRAM, sizeof(float), src_stride, curr_batch - 1);
} else {
__memcpy(temp_buffer, (is_exp_phase ? input : output) + processed * stride, sizeof(T),
GDRAM2NRAM, sizeof(T), src_stride, curr_batch - 1);
// Convert to float
if constexpr (is_half) {
__bang_half2float(float_buffer, temp_buffer, curr_batch);
} else if constexpr (is_bfloat16) {
__bang_bfloat162float(float_buffer, temp_buffer, curr_batch);
}
}
// Common processing for all types
if (is_exp_phase) {
__bang_sub_scalar(float_buffer, float_buffer, scalar, curr_batch); // scalar is max_val
__bang_active_exphp(float_buffer, float_buffer, curr_batch);
} else {
__bang_mul_scalar(float_buffer, float_buffer, scalar, curr_batch); // scalar is 1.0f/sum_val
}
// Convert back and scatter output using 2D memcpy
if constexpr (is_float) {
__memcpy(output + processed * stride, float_buffer, sizeof(float),
NRAM2GDRAM, dst_stride, sizeof(float), curr_batch - 1);
} else {
// Convert back to original type
if constexpr (is_half) {
__bang_float2half(temp_buffer, float_buffer, curr_batch);
} else if constexpr (is_bfloat16) {
__bang_float2bfloat16(temp_buffer, float_buffer, curr_batch);
}
// Scatter output
__memcpy(output + processed * stride, temp_buffer, sizeof(T),
NRAM2GDRAM, dst_stride, sizeof(T), curr_batch - 1);
}
processed += curr_batch;
}
}
template <typename T>
__mlu_global__ void causalSoftmax(T *y, const T *x,
size_t batch_size, size_t seq_len, size_t total_seq_len,
ptrdiff_t y_stride_b, ptrdiff_t y_stride_i, ptrdiff_t y_stride_j,
ptrdiff_t x_stride_b, ptrdiff_t x_stride_i, ptrdiff_t x_stride_j) {
using namespace op::common_bang::reduce_op;
// Get task information
size_t task_id = taskId;
size_t task_num = taskDimX * taskDimY;
// Calculate elements per task with better load balancing
size_t total_tasks = batch_size * seq_len;
size_t tasks_per_core = (total_tasks + task_num - 1) / task_num;
size_t start = task_id * tasks_per_core;
size_t end = std::min(start + tasks_per_core, total_tasks);
// Allocate NRAM buffers
const int max_batch = SRC_MAX_SIZE / sizeof(T);
T *src = (T *)nram_buffer;
float *dst = (float *)(nram_buffer + max_batch * sizeof(T));
for (size_t index = start; index < end; index++) {
size_t batch = index / seq_len;
size_t i = (index % seq_len);
ptrdiff_t y_offset = batch * y_stride_b + i * y_stride_i;
ptrdiff_t x_offset = batch * x_stride_b + i * x_stride_i;
T *y_ = y + y_offset;
const T *x_ = x + x_offset;
// Calculate the valid sequence length for this position
size_t valid_len = total_seq_len - seq_len + i + 1;
// Zero out future positions
for (size_t j = valid_len; j < total_seq_len; j++) {
y_[j * y_stride_j] = (T)0.0f;
}
// Calculate max value using optimized reduction
float max_val = maxBatched(x_, src, dst, valid_len, max_batch);
// Compute exp(x - max)
processSoftmaxStep(y_, x_, max_val, valid_len, x_stride_j, true);
// Calculate sum of exponentials
float sum_val = sumBatched(y_, src, dst, valid_len, max_batch);
// Normalize by sum
processSoftmaxStep(y_, y_, 1.0f / sum_val, valid_len, y_stride_j, false);
}
}
template <typename T>
void causalSoftmaxUnion(void *workspace, int core_per_cluster, int cluster_count,
cnrtQueue_t queue, void *y, const void *x, const op::causal_softmax::CausalSoftmaxInfo *info) {
cnrtDim3_t kernel_dim;
cnrtFunctionType_t kernel_type;
// Configure kernel dimensions
kernel_dim.x = core_per_cluster;
kernel_dim.y = cluster_count;
kernel_dim.z = 1;
kernel_type = CNRT_FUNC_TYPE_UNION1;
// Launch kernel
causalSoftmax<T><<<kernel_dim, kernel_type, queue>>>(
(T *)y, (const T *)x,
info->batch_size, info->seq_len, info->total_seq_len,
info->y_stride_b, info->y_stride_i, info->y_stride_j,
info->x_stride_b, info->x_stride_i, info->x_stride_j);
cnrtQueueSync(queue);
}
namespace op::causal_softmax::bang {
struct Descriptor::Opaque {
std::shared_ptr<device::bang::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc) {
auto handle = reinterpret_cast<device::bang::cambricon::Handle *>(handle_);
auto result = CausalSoftmaxInfo::create(y_desc, x_desc);
CHECK_RESULT(result);
auto info = result.take();
*desc_ptr = new Descriptor(
new Descriptor::Opaque{static_cast<device::bang::Handle *>(handle)->internal()},
info,
0,
handle->device,
handle->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *y,
const void *x,
void *stream) const {
auto queue = reinterpret_cast<cnrtQueue_t>(stream);
int core_per_cluster = _opaque->internal->getCorePerCluster();
int cluster_count = _opaque->internal->getClusterCount();
// Dispatch based on data type
if (_info.dtype == INFINI_DTYPE_F16) {
causalSoftmaxUnion<half>(workspace, core_per_cluster, cluster_count, queue, y, x, &_info);
} else if (_info.dtype == INFINI_DTYPE_BF16) {
causalSoftmaxUnion<bfloat16_t>(workspace, core_per_cluster, cluster_count, queue, y, x, &_info);
} else if (_info.dtype == INFINI_DTYPE_F32) {
causalSoftmaxUnion<float>(workspace, core_per_cluster, cluster_count, queue, y, x, &_info);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::causal_softmax::bang
src/infiniop/ops/causal_softmax/operator.cc
View file @
82e43d35
...
@@ -14,6 +14,9 @@
...
@@ -14,6 +14,9 @@
#ifdef ENABLE_ASCEND_API
#ifdef ENABLE_ASCEND_API
#include "ascend/causal_softmax_ascend.h"
#include "ascend/causal_softmax_ascend.h"
#endif
#endif
#ifdef ENABLE_CAMBRICON_API
#include "bang/causal_softmax_bang.h"
#endif
__C
infiniStatus_t
infiniopCreateCausalSoftmaxDescriptor
(
__C
infiniStatus_t
infiniopCreateCausalSoftmaxDescriptor
(
infiniopHandle_t
handle
,
infiniopHandle_t
handle
,
...
@@ -39,22 +42,14 @@ __C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor(
...
@@ -39,22 +42,14 @@ __C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor(
#ifdef ENABLE_ILUVATAR_API
#ifdef ENABLE_ILUVATAR_API
CREATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
CREATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#endif
#ifdef ENABLE_CAMBRICON_API
CREATE
(
INFINI_DEVICE_CAMBRICON
,
bang
)
#endif
#ifdef ENABLE_METAX_API
#ifdef ENABLE_METAX_API
CREATE
(
INFINI_DEVICE_METAX
,
metax
)
CREATE
(
INFINI_DEVICE_METAX
,
metax
)
#endif
#endif
#ifdef ENABLE_ASCEND_API
#ifdef ENABLE_ASCEND_API
CREATE
(
INFINI_DEVICE_ASCEND
,
ascend
)
CREATE
(
INFINI_DEVICE_ASCEND
,
ascend
)
#endif
#ifdef ENABLE_CAMBRICON_MLU
case
DevCambriconMlu
:
{
return
bangCreateCausalSoftmaxDescriptor
((
BangHandle_t
)
handle
,
(
CausalSoftmaxBangDescriptor_t
*
)
desc_ptr
,
y_desc
);
// return cnnlCreateCausalSoftmaxDescriptor((BangHandle_t) handle, (CausalSoftmaxCnnlDescriptor_t *) desc_ptr, y_desc);
}
#endif
#ifdef ENABLE_MTHREADS_GPU
case
DevMthreadsGpu
:
{
return
musaCreateCausalSoftmaxDescriptor
((
MusaHandle_t
)
handle
,
(
CausalSoftmaxMusaDescriptor_t
*
)
desc_ptr
,
y_desc
);
}
#endif
#endif
}
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
@@ -83,17 +78,8 @@ __C infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDe
...
@@ -83,17 +78,8 @@ __C infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDe
#ifdef ENABLE_ASCEND_API
#ifdef ENABLE_ASCEND_API
GET
(
INFINI_DEVICE_ASCEND
,
ascend
)
GET
(
INFINI_DEVICE_ASCEND
,
ascend
)
#endif
#endif
#ifdef ENABLE_CAMBRICON_MLU
#ifdef ENABLE_CAMBRICON_API
case
DevCambriconMlu
:
{
GET
(
INFINI_DEVICE_CAMBRICON
,
bang
)
return
bangGetCausalSoftmaxWorkspaceSize
((
CausalSoftmaxBangDescriptor_t
)
desc
,
size
);
// return cnnlGetCausalSoftmaxWorkspaceSize((CausalSoftmaxCnnlDescriptor_t) desc, size);
}
#endif
#ifdef ENABLE_MTHREADS_GPU
case
DevMthreadsGpu
:
{
return
musaGetCausalSoftmaxWorkspaceSize
((
CausalSoftmaxMusaDescriptor_t
)
desc
,
size
);
}
#endif
#endif
}
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
@@ -121,22 +107,14 @@ __C infiniStatus_t infiniopCausalSoftmax(
...
@@ -121,22 +107,14 @@ __C infiniStatus_t infiniopCausalSoftmax(
#ifdef ENABLE_ILUVATAR_API
#ifdef ENABLE_ILUVATAR_API
CALCULATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
CALCULATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#endif
#ifdef ENABLE_CAMBRICON_API
CALCULATE
(
INFINI_DEVICE_CAMBRICON
,
bang
)
#endif
#ifdef ENABLE_METAX_API
#ifdef ENABLE_METAX_API
CALCULATE
(
INFINI_DEVICE_METAX
,
metax
)
CALCULATE
(
INFINI_DEVICE_METAX
,
metax
)
#endif
#endif
#ifdef ENABLE_ASCEND_API
#ifdef ENABLE_ASCEND_API
CALCULATE
(
INFINI_DEVICE_ASCEND
,
ascend
)
CALCULATE
(
INFINI_DEVICE_ASCEND
,
ascend
)
#endif
#ifdef ENABLE_CAMBRICON_MLU
case
DevCambriconMlu
:
{
return
bangCausalSoftmax
((
CausalSoftmaxBangDescriptor_t
)
desc
,
workspace
,
workspace_size
,
data
,
stream
);
// return cnnlCausalSoftmax((CausalSoftmaxCnnlDescriptor_t) desc, workspace, workspace_size, data, stream);
}
#endif
#ifdef ENABLE_MTHREADS_GPU
case
DevMthreadsGpu
:
{
return
musaCausalSoftmax
((
CausalSoftmaxMusaDescriptor_t
)
desc
,
workspace
,
workspace_size
,
data
,
stream
);
}
#endif
#endif
}
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
@@ -159,21 +137,14 @@ __C infiniStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftmaxD
...
@@ -159,21 +137,14 @@ __C infiniStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftmaxD
#ifdef ENABLE_ILUVATAR_API
#ifdef ENABLE_ILUVATAR_API
DESTROY
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
DESTROY
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#endif
#ifdef ENABLE_CAMBRICON_API
DESTROY
(
INFINI_DEVICE_CAMBRICON
,
bang
)
#endif
#ifdef ENABLE_METAX_API
#ifdef ENABLE_METAX_API
DESTROY
(
INFINI_DEVICE_METAX
,
metax
)
DESTROY
(
INFINI_DEVICE_METAX
,
metax
)
#endif
#endif
#ifdef ENABLE_ASCEND_API
#ifdef ENABLE_ASCEND_API
DESTROY
(
INFINI_DEVICE_ASCEND
,
ascend
)
DESTROY
(
INFINI_DEVICE_ASCEND
,
ascend
)
#endif
#ifdef ENABLE_CAMBRICON_MLU
case
DevCambriconMlu
:
{
return
bangDestroyCausalSoftmaxDescriptor
((
CausalSoftmaxBangDescriptor_t
)
desc
);
// return cnnlDestroyCausalSoftmaxDescriptor((CausalSoftmaxCnnlDescriptor_t) desc);
}
#endif
#ifdef ENABLE_MTHREADS_GPU
case
DevMthreadsGpu
:
return
musaDestroyCausalSoftmaxDescriptor
((
CausalSoftmaxMusaDescriptor_t
)
desc
);
#endif
#endif
}
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
src/infiniop/ops/rms_norm/bang/rms_norm_bang.h
0 → 100644
View file @
82e43d35
#ifndef __RMS_NORM_BANG_H__
#define __RMS_NORM_BANG_H__
#include "../rms_norm.h"
DESCRIPTOR
(
bang
)
#endif // __RMS_NORM_BANG_H__
src/infiniop/ops/rms_norm/bang/rms_norm_bang.mlu
0 → 100644
View file @
82e43d35
#include "../../../devices/bang/common_bang.h"
#include "../../../reduce/bang/reduce_bang.h"
#include "rms_norm_bang.h"
__nram__ char nram_buffer[NRAM_MAX_SIZE];
const int SRC_MAX_SIZE = NRAM_MAX_SIZE / 4;
template <typename T, typename Tw>
__mlu_global__ void rmsnorm(T *output, const T *input, const Tw *weight,
size_t *shape, ptrdiff_t *output_strides, ptrdiff_t *input_strides,
float epsilon, int num_dims, int norm_dim_size) {
// Calculate problem dimensions
int batch_volume = 1;
for (int dim_idx = 0; dim_idx < num_dims - 1; ++dim_idx) {
batch_volume *= shape[dim_idx];
}
int vector_size = shape[num_dims - 1];
// Determine maximum batch size for NRAM operations
int max_batch_size = (vector_size >= SRC_MAX_SIZE / sizeof(Tw) ? SRC_MAX_SIZE / sizeof(Tw) : norm_dim_size);
constexpr int reduce_buffer_size = 128 / sizeof(float);
// Task distribution across cores
int remaining_tasks = batch_volume % taskDim;
int base_tasks_per_core = batch_volume / taskDim;
int actual_tasks = base_tasks_per_core + (taskId < remaining_tasks ? 1 : 0);
int task_start_idx = (taskId < remaining_tasks ? taskId * base_tasks_per_core + taskId : taskId * base_tasks_per_core + remaining_tasks);
// NRAM buffer allocation
int half_type_offset = (sizeof(T) == 2 ? max_batch_size : 0);
char *input_buffer = nram_buffer + reduce_buffer_size * sizeof(float);
char *weight_buffer = input_buffer + (max_batch_size + half_type_offset) * sizeof(T);
float *reduction_result = (float *)nram_buffer;
T *input_cache = (T *)input_buffer;
Tw *weight_cache = (Tw *)weight_buffer;
// Process vectors assigned to current core
int processed_tasks = 0;
while (processed_tasks < actual_tasks) {
int input_offset = 0;
int output_offset = 0;
int current_index = task_start_idx + processed_tasks;
// Calculate memory offsets for current task
for (int dim = num_dims - 2; dim >= 0; --dim) {
input_offset += (current_index % shape[dim]) * input_strides[dim];
output_offset += (current_index % shape[dim]) * output_strides[dim];
current_index = current_index / shape[dim];
}
// Compute sum of squares
__bang_write_zero(reduction_result, reduce_buffer_size);
float sum_squared = op::common_bang::reduce_op::sumSquaredBatched<T>(
input + input_offset, input_cache, reduction_result, vector_size, max_batch_size);
// Compute normalization factor
float rms_value = sum_squared / vector_size;
rms_value += epsilon;
rms_value = sqrtf(rms_value);
float inv_rms = 1.0f / rms_value;
// Process vector in chunks
size_t processed_elements = 0;
while (processed_elements < vector_size) {
size_t current_batch = std::min((size_t)max_batch_size, vector_size - processed_elements);
// Load data
__memcpy(input_cache, input + input_offset + processed_elements, current_batch * sizeof(T), GDRAM2NRAM);
__memcpy(weight_cache, weight + processed_elements, current_batch * sizeof(Tw), GDRAM2NRAM);
// Normalization and scaling
if constexpr (std::is_same<T, bfloat16_t>::value && std::is_same<Tw, float>::value) {
// Special handling for BF16 input with F32 weights
__bang_bfloat162float((float *)input_cache, input_cache, current_batch);
__bang_mul((float *)input_cache, (float *)input_cache, weight_cache, current_batch);
__bang_mul_scalar((float *)input_cache, (float *)input_cache, inv_rms, current_batch);
__bang_float2bfloat16(input_cache, (float *)input_cache, current_batch);
} else {
if constexpr (std::is_same<T, half>::value && std::is_same<Tw, float>::value) {
__bang_float2half_dn((T *)weight_cache, weight_cache, current_batch);
}
__bang_mul(input_cache, input_cache, (T *)weight_cache, current_batch);
__bang_mul_scalar(input_cache, input_cache, inv_rms, current_batch);
}
// Store results
__memcpy(output + output_offset + processed_elements, input_cache, current_batch * sizeof(T), NRAM2GDRAM);
processed_elements += current_batch;
}
processed_tasks++;
}
}
template <typename T, typename Tw>
void rmsnormUnion(void *workspace, int core_per_cluster, int cluster_count, cnrtQueue_t queue, void *y, const void *x, const void *w, const size_t *shape, const ptrdiff_t *y_strides, const ptrdiff_t *x_strides, float eps, int ndim) {
cnrtDim3_t kernel_dim;
cnrtFunctionType_t kernel_type;
// Configure kernel dimensions
kernel_dim.x = core_per_cluster;
kernel_dim.y = cluster_count;
kernel_dim.z = 1;
kernel_type = CNRT_FUNC_TYPE_UNION1; // Can choose others, but must adapt kernel_type accordingly
int dimsize = shape[ndim - 1]; // Length of operation dimension
int dim_s; // dim_s is the next power of 2 greater than dimsize
float mi = log2(dimsize);
if (floor(mi) == mi) {
dim_s = dimsize;
} else {
dim_s = pow(2, floor(mi) + 1);
}
constexpr int reduce_num = 128 / sizeof(float); // Cambricon __bang_reduce_sum can only reduce 128 bytes at a time
if (dim_s < reduce_num) {
dim_s = reduce_num; // Force dim_s >= reduce_num
}
// Prepare device pointers
auto y_ = reinterpret_cast<T *>(y);
auto x_ = reinterpret_cast<const T *>(x);
auto w_ = reinterpret_cast<const Tw *>(w);
char *tmp_device = reinterpret_cast<char *>(workspace);
char *tmp_stride = tmp_device + ndim * sizeof(size_t);
size_t *mlu_shape = (size_t *)tmp_device;
ptrdiff_t *mlu_x_strides = (ptrdiff_t *)tmp_stride;
ptrdiff_t *mlu_y_strides = mlu_x_strides + ndim;
// Copy shape and stride information to device
CNRT_CHECK(cnrtMemcpyAsync(mlu_shape, const_cast<size_t *>(shape), ndim * sizeof(size_t), queue, cnrtMemcpyHostToDev)); // const not supported
CNRT_CHECK(cnrtMemcpyAsync(mlu_x_strides, const_cast<ptrdiff_t *>(x_strides), ndim * sizeof(ptrdiff_t), queue, cnrtMemcpyHostToDev));
CNRT_CHECK(cnrtMemcpyAsync(mlu_y_strides, const_cast<ptrdiff_t *>(y_strides), ndim * sizeof(ptrdiff_t), queue, cnrtMemcpyHostToDev));
// Launch kernel
rmsnorm<T, Tw><<<kernel_dim, kernel_type, queue>>>(y_, x_, w_, mlu_shape, mlu_y_strides, mlu_x_strides, eps, ndim, dim_s);
cnrtQueueSync(queue);
}
namespace op::rms_norm::bang {
struct Descriptor::Opaque {
std::shared_ptr<device::bang::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t w_desc,
float epsilon) {
auto handle = reinterpret_cast<device::bang::cambricon::Handle *>(handle_);
auto result = RMSNormInfo::create(y_desc, x_desc, w_desc, epsilon);
CHECK_RESULT(result);
auto info = result.take();
size_t workspace_size = info.ndim() * (sizeof(size_t) + 2 * sizeof(ptrdiff_t));
*desc_ptr = new Descriptor(
new Descriptor::Opaque{static_cast<device::bang::Handle *>(handle)->internal()},
info,
workspace_size,
handle->device,
handle->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
void *y, const void *x, const void *w,
void *stream) const {
auto queue = reinterpret_cast<cnrtQueue_t>(stream);
int core_per_cluster = _opaque->internal->getCorePerCluster();
int cluster_count = _opaque->internal->getClusterCount();
// Dispatch based on data types
if (_info.atype == INFINI_DTYPE_F16) {
if (_info.wtype == INFINI_DTYPE_F16) {
rmsnormUnion<half, half>(workspace, core_per_cluster, cluster_count, queue, y, x, w, _info.shape.data(), _info.y_strides.data(), _info.x_strides.data(), _info.epsilon, _info.ndim());
} else if (_info.wtype == INFINI_DTYPE_F32) {
rmsnormUnion<half, float>(workspace, core_per_cluster, cluster_count, queue, y, x, w, _info.shape.data(), _info.y_strides.data(), _info.x_strides.data(), _info.epsilon, _info.ndim());
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (_info.atype == INFINI_DTYPE_F32) {
if (_info.wtype == INFINI_DTYPE_F32) {
rmsnormUnion<float, float>(workspace, core_per_cluster, cluster_count, queue, y, x, w, _info.shape.data(), _info.y_strides.data(), _info.x_strides.data(), _info.epsilon, _info.ndim());
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (_info.atype == INFINI_DTYPE_BF16) {
if (_info.wtype == INFINI_DTYPE_BF16) {
rmsnormUnion<bfloat16_t, bfloat16_t>(workspace, core_per_cluster, cluster_count, queue, y, x, w, _info.shape.data(), _info.y_strides.data(), _info.x_strides.data(), _info.epsilon, _info.ndim());
} else if (_info.wtype == INFINI_DTYPE_F32) {
rmsnormUnion<bfloat16_t, float>(workspace, core_per_cluster, cluster_count, queue, y, x, w, _info.shape.data(), _info.y_strides.data(), _info.x_strides.data(), _info.epsilon, _info.ndim());
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::rms_norm::bang
src/infiniop/ops/rms_norm/operator.cc
View file @
82e43d35
...
@@ -11,6 +11,9 @@
...
@@ -11,6 +11,9 @@
#ifdef ENABLE_ASCEND_API
#ifdef ENABLE_ASCEND_API
#include "ascend/rms_norm_aclnn.h"
#include "ascend/rms_norm_aclnn.h"
#endif
#endif
#ifdef ENABLE_CAMBRICON_API
#include "bang/rms_norm_bang.h"
#endif
#ifdef ENABLE_METAX_API
#ifdef ENABLE_METAX_API
#include "metax/rms_norm_metax.cuh"
#include "metax/rms_norm_metax.cuh"
#endif
#endif
...
@@ -52,10 +55,8 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor(
...
@@ -52,10 +55,8 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor(
#ifdef ENABLE_KUNLUN_API
#ifdef ENABLE_KUNLUN_API
CREATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
CREATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
#endif
#ifdef ENABLE_CAMBRICON_MLU
#ifdef ENABLE_CAMBRICON_API
case
DevCambriconMlu
:
{
CREATE
(
INFINI_DEVICE_CAMBRICON
,
bang
);
return
bangCreateRMSNormDescriptor
((
BangHandle_t
)
handle
,
(
RMSNormBangDescriptor_t
*
)
desc_ptr
,
y_desc
,
x_desc
,
w_desc
,
epsilon
);
}
#endif
#endif
#ifdef ENABLE_ASCEND_API
#ifdef ENABLE_ASCEND_API
CREATE
(
INFINI_DEVICE_ASCEND
,
ascend
);
CREATE
(
INFINI_DEVICE_ASCEND
,
ascend
);
...
@@ -93,10 +94,8 @@ __C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t d
...
@@ -93,10 +94,8 @@ __C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t d
#ifdef ENABLE_KUNLUN_API
#ifdef ENABLE_KUNLUN_API
GET
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
GET
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
#endif
#ifdef ENABLE_CAMBRICON_MLU
#ifdef ENABLE_CAMBRICON_API
case
DevCambriconMlu
:
{
GET
(
INFINI_DEVICE_CAMBRICON
,
bang
);
return
bangGetRMSNormWorkspaceSize
((
RMSNormBangDescriptor_t
)
desc
,
size
);
}
#endif
#endif
#ifdef ENABLE_ASCEND_API
#ifdef ENABLE_ASCEND_API
GET
(
INFINI_DEVICE_ASCEND
,
ascend
);
GET
(
INFINI_DEVICE_ASCEND
,
ascend
);
...
@@ -135,10 +134,8 @@ __C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *works
...
@@ -135,10 +134,8 @@ __C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *works
#ifdef ENABLE_KUNLUN_API
#ifdef ENABLE_KUNLUN_API
CALCULATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
CALCULATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
#endif
#ifdef ENABLE_CAMBRICON_MLU
#ifdef ENABLE_CAMBRICON_API
case
DevCambriconMlu
:
{
CALCULATE
(
INFINI_DEVICE_CAMBRICON
,
bang
);
return
bangRMSNorm
((
RMSNormBangDescriptor_t
)
desc
,
workspace
,
workspace_size
,
y
,
x
,
w
,
stream
);
}
#endif
#endif
#ifdef ENABLE_ASCEND_API
#ifdef ENABLE_ASCEND_API
CALCULATE
(
INFINI_DEVICE_ASCEND
,
ascend
);
CALCULATE
(
INFINI_DEVICE_ASCEND
,
ascend
);
...
@@ -176,10 +173,8 @@ __C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t
...
@@ -176,10 +173,8 @@ __C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t
#ifdef ENABLE_KUNLUN_API
#ifdef ENABLE_KUNLUN_API
DESTROY
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
DESTROY
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
#endif
#ifdef ENABLE_CAMBRICON_MLU
#ifdef ENABLE_CAMBRICON_API
case
DevCambriconMlu
:
{
DESTROY
(
INFINI_DEVICE_CAMBRICON
,
bang
);
return
bangDestroyRMSNormDescriptor
((
RMSNormBangDescriptor_t
)
desc
);
}
#endif
#endif
#ifdef ENABLE_ASCEND_API
#ifdef ENABLE_ASCEND_API
DESTROY
(
INFINI_DEVICE_ASCEND
,
ascend
);
DESTROY
(
INFINI_DEVICE_ASCEND
,
ascend
);
...
...
src/infiniop/reduce/bang/reduce_bang.h
0 → 100644
View file @
82e43d35
#ifndef __INFINIOP_REDUCE_BANG_H__
#define __INFINIOP_REDUCE_BANG_H__
#include "../../devices/bang/common_bang.h"
namespace
op
::
common_bang
::
reduce_op
{
constexpr
int
batch_size
=
128
/
sizeof
(
float
);
__mlu_func__
void
sumInternal
(
float
*
dst
,
float
*
src
,
int
max_batch
)
{
const
int
width
=
max_batch
/
batch_size
;
// Use vectorized reduction
if
(
width
>=
4
)
{
__bang_sumpool
(
dst
,
src
,
batch_size
,
1
,
width
,
1
,
width
,
1
,
1
);
__bang_reduce_sum
(
dst
,
dst
,
batch_size
);
}
else
{
// Fallback for small batches
float
sum
=
0.0
f
;
for
(
int
i
=
0
;
i
<
max_batch
;
++
i
)
{
sum
+=
src
[
i
];
}
dst
[
0
]
=
sum
;
}
}
template
<
typename
T
>
__mlu_func__
void
sumTyped
(
float
*
result
,
T
*
data
,
size_t
len
)
{
if
constexpr
(
std
::
is_same_v
<
T
,
half
>
)
{
__bang_half2float
((
float
*
)
data
,
data
+
len
,
len
);
sumInternal
(
result
,
(
float
*
)
data
,
len
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
bfloat16_t
>
)
{
__bang_bfloat162float
((
float
*
)
data
,
data
+
len
,
len
);
sumInternal
(
result
,
(
float
*
)
data
,
len
);
}
else
{
sumInternal
(
result
,
data
,
len
);
}
}
template
<
typename
T
>
__mlu_func__
float
sum
(
const
T
*
source
,
T
*
src
,
float
*
dst
,
int
num_elements
,
int
max_batch
)
{
float
res
=
0.0
f
;
int
offset
=
(
sizeof
(
T
)
==
2
?
max_batch
:
0
);
size_t
processed
=
0
;
while
(
processed
<
num_elements
)
{
size_t
curr_batch
=
std
::
min
<
size_t
>
(
max_batch
,
num_elements
-
processed
);
if
(
curr_batch
<
max_batch
)
{
__bang_write_zero
(
src
,
max_batch
+
offset
);
}
__memcpy
(
src
+
offset
,
source
+
processed
,
curr_batch
*
sizeof
(
T
),
GDRAM2NRAM
);
sumTyped
(
dst
,
src
,
max_batch
);
res
+=
dst
[
0
];
processed
+=
curr_batch
;
}
return
res
;
}
template
<
typename
T
>
__mlu_func__
float
sumBatched
(
const
T
*
source
,
T
*
src
,
float
*
dst
,
int
num_elements
,
int
max_batch
)
{
constexpr
int
min_vector_size
=
32
;
// For small vectors, use safer element-wise computation
if
(
num_elements
<
min_vector_size
)
{
return
sum
(
source
,
src
,
dst
,
num_elements
,
max_batch
);
}
float
res
=
0.0
f
;
int
offset
=
(
sizeof
(
T
)
==
2
?
max_batch
:
0
);
size_t
processed
=
0
;
while
(
processed
<
num_elements
)
{
size_t
curr_batch
=
std
::
min
<
size_t
>
(
max_batch
,
num_elements
-
processed
);
size_t
aligned_batch
=
(
curr_batch
/
batch_size
)
*
batch_size
;
size_t
remainder
=
curr_batch
%
batch_size
;
// Ensure NRAM buffer is zeroed
__bang_write_zero
(
src
,
max_batch
+
offset
);
// Copy data to NRAM
__memcpy
(
src
+
offset
,
source
+
processed
,
curr_batch
*
sizeof
(
T
),
GDRAM2NRAM
);
if
constexpr
(
std
::
is_same_v
<
T
,
half
>
)
{
__bang_half2float
((
float
*
)(
src
+
offset
),
src
+
offset
,
curr_batch
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
bfloat16_t
>
)
{
__bang_bfloat162float
((
float
*
)(
src
+
offset
),
src
+
offset
,
curr_batch
);
}
// Process aligned portion
if
(
aligned_batch
>
0
)
{
sumInternal
(
dst
,
(
float
*
)(
src
+
offset
),
aligned_batch
);
res
+=
dst
[
0
];
}
// Process unaligned tail
if
(
remainder
>
0
)
{
for
(
size_t
i
=
aligned_batch
;
i
<
curr_batch
;
++
i
)
{
res
+=
((
float
*
)(
src
+
offset
))[
i
];
}
}
processed
+=
curr_batch
;
}
return
res
;
}
template
<
typename
T
>
__mlu_func__
float
sumSquared
(
const
T
*
source
,
T
*
src
,
float
*
dst
,
int
num_elements
,
int
max_batch
)
{
float
res
=
0.0
f
;
int
offset
=
(
sizeof
(
T
)
==
2
?
max_batch
:
0
);
size_t
processed
=
0
;
while
(
processed
<
num_elements
)
{
size_t
curr_batch
=
std
::
min
<
size_t
>
(
max_batch
,
num_elements
-
processed
);
if
(
curr_batch
<
max_batch
)
{
__bang_write_zero
(
src
,
max_batch
+
offset
);
}
__memcpy
(
src
+
offset
,
source
+
processed
,
curr_batch
*
sizeof
(
T
),
GDRAM2NRAM
);
float
sum
=
0.0
f
;
for
(
size_t
i
=
0
;
i
<
curr_batch
;
++
i
)
{
float
val
=
0.0
f
;
if
constexpr
(
std
::
is_same_v
<
T
,
half
>
)
{
val
=
__half2float
(
src
[
offset
+
i
]);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
bfloat16_t
>
)
{
val
=
__bfloat162float
(
src
[
offset
+
i
]);
}
else
{
val
=
src
[
offset
+
i
];
}
sum
+=
val
*
val
;
}
res
+=
sum
;
processed
+=
curr_batch
;
}
return
res
;
}
template
<
typename
T
>
__mlu_func__
float
sumSquaredBatched
(
const
T
*
source
,
T
*
src
,
float
*
dst
,
int
num_elements
,
int
max_batch
)
{
constexpr
int
min_vector_size
=
32
;
// For small vectors, use safer element-wise computation
if
(
num_elements
<
min_vector_size
)
{
return
sumSquared
(
source
,
src
,
dst
,
num_elements
,
max_batch
);
}
float
res
=
0.0
f
;
int
offset
=
(
sizeof
(
T
)
==
2
?
max_batch
:
0
);
size_t
processed
=
0
;
while
(
processed
<
num_elements
)
{
size_t
curr_batch
=
std
::
min
<
size_t
>
(
max_batch
,
num_elements
-
processed
);
size_t
aligned_batch
=
(
curr_batch
/
batch_size
)
*
batch_size
;
size_t
remainder
=
curr_batch
%
batch_size
;
// Ensure NRAM buffer is zeroed
__bang_write_zero
(
src
,
max_batch
+
offset
);
// Copy data to NRAM
__memcpy
(
src
+
offset
,
source
+
processed
,
curr_batch
*
sizeof
(
T
),
GDRAM2NRAM
);
if
constexpr
(
std
::
is_same_v
<
T
,
half
>
)
{
__bang_half2float
((
float
*
)(
src
+
offset
),
src
+
offset
,
curr_batch
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
bfloat16_t
>
)
{
__bang_bfloat162float
((
float
*
)(
src
+
offset
),
src
+
offset
,
curr_batch
);
}
// Process aligned portion
if
(
aligned_batch
>
0
)
{
__bang_mul
((
float
*
)(
src
+
offset
),
(
float
*
)(
src
+
offset
),
(
float
*
)(
src
+
offset
),
aligned_batch
);
sumInternal
(
dst
,
(
float
*
)(
src
+
offset
),
aligned_batch
);
res
+=
dst
[
0
];
}
// Process unaligned tail
if
(
remainder
>
0
)
{
for
(
size_t
i
=
aligned_batch
;
i
<
curr_batch
;
++
i
)
{
float
val
=
((
float
*
)(
src
+
offset
))[
i
];
res
+=
val
*
val
;
}
}
processed
+=
curr_batch
;
}
return
res
;
}
__mlu_func__
void
maxInternal
(
float
*
dst
,
float
*
src
,
int
max_batch
)
{
__bang_maxpool
(
dst
,
src
,
batch_size
,
// channel size
1
,
// height
max_batch
/
batch_size
,
// width
1
,
// kernel_height
max_batch
/
batch_size
,
// kernel_width
1
,
// stride_height
1
// stride_width
);
__bang_argmax
(
dst
,
dst
,
batch_size
);
}
template
<
typename
T
>
__mlu_func__
void
maxTyped
(
float
*
result
,
T
*
data
,
size_t
len
)
{
if
constexpr
(
std
::
is_same_v
<
T
,
half
>
)
{
__bang_half2float
((
float
*
)
data
,
data
+
len
,
len
);
maxInternal
(
result
,
(
float
*
)
data
,
len
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
bfloat16_t
>
)
{
__bang_bfloat162float
((
float
*
)
data
,
data
+
len
,
len
);
maxInternal
(
result
,
(
float
*
)
data
,
len
);
}
else
{
maxInternal
(
result
,
data
,
len
);
}
}
template
<
typename
T
>
__mlu_func__
float
max
(
const
T
*
source
,
T
*
src
,
float
*
dst
,
int
num_elements
,
int
max_batch
)
{
float
max_val
=
-
INFINITY
;
int
offset
=
(
sizeof
(
T
)
==
2
?
max_batch
:
0
);
size_t
processed
=
0
;
while
(
processed
<
num_elements
)
{
size_t
curr_batch
=
std
::
min
<
size_t
>
(
max_batch
,
num_elements
-
processed
);
if
(
curr_batch
<
max_batch
)
{
__bang_write_zero
(
src
,
max_batch
+
offset
);
}
__memcpy
(
src
+
offset
,
source
+
processed
,
curr_batch
*
sizeof
(
T
),
GDRAM2NRAM
);
maxTyped
(
dst
,
src
,
max_batch
);
max_val
=
std
::
max
(
max_val
,
dst
[
0
]);
processed
+=
curr_batch
;
}
return
max_val
;
}
template
<
typename
T
>
__mlu_func__
float
maxBatched
(
const
T
*
source
,
T
*
src
,
float
*
dst
,
int
num_elements
,
int
max_batch
)
{
constexpr
int
min_vector_size
=
32
;
// For small vectors, use safer element-wise computation
if
(
num_elements
<
min_vector_size
)
{
return
max
(
source
,
src
,
dst
,
num_elements
,
max_batch
);
}
float
max_val
=
-
INFINITY
;
int
offset
=
(
sizeof
(
T
)
==
2
?
max_batch
:
0
);
size_t
processed
=
0
;
while
(
processed
<
num_elements
)
{
size_t
curr_batch
=
std
::
min
<
size_t
>
(
max_batch
,
num_elements
-
processed
);
if
(
curr_batch
<
max_batch
)
{
__bang_write_zero
(
src
,
max_batch
+
offset
);
}
__memcpy
(
src
+
offset
,
source
+
processed
,
curr_batch
*
sizeof
(
T
),
GDRAM2NRAM
);
maxTyped
(
dst
,
src
,
max_batch
);
max_val
=
std
::
max
(
max_val
,
dst
[
0
]);
processed
+=
curr_batch
;
}
return
max_val
;
}
}
// namespace op::common_bang::reduce_op
#endif // __INFINIOP_REDUCE_BANG_H__
test/infiniop/causal_softmax.py
View file @
82e43d35
...
@@ -41,7 +41,7 @@ _TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32]
...
@@ -41,7 +41,7 @@ _TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32]
_TOLERANCE_MAP
=
{
_TOLERANCE_MAP
=
{
InfiniDtype
.
F16
:
{
"atol"
:
1e-3
,
"rtol"
:
1e-2
},
InfiniDtype
.
F16
:
{
"atol"
:
1e-3
,
"rtol"
:
1e-2
},
InfiniDtype
.
BF16
:
{
"atol"
:
5e-3
,
"rtol"
:
5e-2
},
InfiniDtype
.
BF16
:
{
"atol"
:
5e-3
,
"rtol"
:
5e-2
},
InfiniDtype
.
F32
:
{
"atol"
:
1
e-5
,
"rtol"
:
1e-5
},
InfiniDtype
.
F32
:
{
"atol"
:
3
e-5
,
"rtol"
:
1e-5
},
}
}
...
...
test/infiniop/rms_norm.py
View file @
82e43d35
...
@@ -46,7 +46,7 @@ _TEST_CASES = [
...
@@ -46,7 +46,7 @@ _TEST_CASES = [
# Tolerance map for different data types
# Tolerance map for different data types
_TOLERANCE_MAP
=
{
_TOLERANCE_MAP
=
{
InfiniDtype
.
F16
:
{
"atol"
:
2e-3
,
"rtol"
:
2e-3
},
InfiniDtype
.
F16
:
{
"atol"
:
2e-3
,
"rtol"
:
2e-3
},
InfiniDtype
.
BF16
:
{
"atol"
:
8
e-
3
,
"rtol"
:
8
e-
3
},
InfiniDtype
.
BF16
:
{
"atol"
:
1
e-
2
,
"rtol"
:
1
e-
2
},
}
}
DEBUG
=
False
DEBUG
=
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment