Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
ef577d9d
Commit
ef577d9d
authored
Aug 12, 2025
by
wooway777
Browse files
issue/240 - added bf16 support to cambricon causal softmax and adjusted tolerance
parent
adbda4c4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
202 additions
and
2 deletions
+202
-2
src/infiniop-test/src/main.cpp
src/infiniop-test/src/main.cpp
+1
-1
src/infiniop/ops/causal_softmax/bang/causal_softmax_bang.mlu
src/infiniop/ops/causal_softmax/bang/causal_softmax_bang.mlu
+200
-0
test/infiniop/causal_softmax.py
test/infiniop/causal_softmax.py
+1
-1
No files found.
src/infiniop-test/src/main.cpp
View file @
ef577d9d
...
@@ -9,7 +9,7 @@ struct ParsedArgs {
...
@@ -9,7 +9,7 @@ struct ParsedArgs {
int
device_id
=
0
;
// CUDA device ID (if specified)
int
device_id
=
0
;
// CUDA device ID (if specified)
int
warmups
=
0
;
// Default to 0 if not given
int
warmups
=
0
;
// Default to 0 if not given
int
iterations
=
0
;
// Default to 0 if not given
int
iterations
=
0
;
// Default to 0 if not given
double
atol
=
0.001
;
// Default absolute tolerance
double
atol
=
0.001
5
;
// Default absolute tolerance
double
rtol
=
0.001
;
// Default relative tolerance
double
rtol
=
0.001
;
// Default relative tolerance
};
};
...
...
src/infiniop/ops/causal_softmax/bang/causal_softmax_bang.mlu
0 → 100644
View file @
ef577d9d
#include "../../../devices/bang/common_bang.h"
#include "../../../reduce/bang/reduce_bang.h"
#include "causal_softmax_bang.h"
__nram__ char nram_buffer[NRAM_MAX_SIZE];
const int SRC_MAX_SIZE = NRAM_MAX_SIZE / 4;
template <typename T>
__mlu_func__ void processSoftmaxStep(T *output, const T *input, float scalar, int num_elements, int stride, bool is_exp_phase) {
// Calculate buffer sizes (split between float and T buffers)
constexpr bool is_half = std::is_same_v<T, half>;
constexpr bool is_bfloat16 = std::is_same_v<T, bfloat16_t>;
constexpr bool is_float = !is_half && !is_bfloat16;
const int chunk_size = SRC_MAX_SIZE / ((is_half || is_bfloat16) ? (2 * sizeof(float)) : sizeof(float));
float *float_buffer = (float *)nram_buffer;
T *temp_buffer = is_float ? nullptr : (T *)(nram_buffer + chunk_size * sizeof(float));
// Common stride configurations
const int src_stride = stride * sizeof(T);
const int dst_stride = stride * sizeof(T);
int processed = 0;
while (processed < num_elements) {
int curr_batch = std::min(chunk_size, num_elements - processed);
// Gather input elements using 2D memcpy
if constexpr (is_float) {
__memcpy(float_buffer, (is_exp_phase ? input : output) + processed * stride, sizeof(float),
GDRAM2NRAM, sizeof(float), src_stride, curr_batch - 1);
} else {
__memcpy(temp_buffer, (is_exp_phase ? input : output) + processed * stride, sizeof(T),
GDRAM2NRAM, sizeof(T), src_stride, curr_batch - 1);
// Convert to float
if constexpr (is_half) {
__bang_half2float(float_buffer, temp_buffer, curr_batch);
} else if constexpr (is_bfloat16) {
__bang_bfloat162float(float_buffer, temp_buffer, curr_batch);
}
}
// Common processing for all types
if (is_exp_phase) {
__bang_sub_scalar(float_buffer, float_buffer, scalar, curr_batch); // scalar is max_val
__bang_active_exphp(float_buffer, float_buffer, curr_batch);
} else {
__bang_mul_scalar(float_buffer, float_buffer, scalar, curr_batch); // scalar is 1.0f/sum_val
}
// Convert back and scatter output using 2D memcpy
if constexpr (is_float) {
__memcpy(output + processed * stride, float_buffer, sizeof(float),
NRAM2GDRAM, dst_stride, sizeof(float), curr_batch - 1);
} else {
// Convert back to original type
if constexpr (is_half) {
__bang_float2half(temp_buffer, float_buffer, curr_batch);
} else if constexpr (is_bfloat16) {
__bang_float2bfloat16(temp_buffer, float_buffer, curr_batch);
}
// Scatter output
__memcpy(output + processed * stride, temp_buffer, sizeof(T),
NRAM2GDRAM, dst_stride, sizeof(T), curr_batch - 1);
}
processed += curr_batch;
}
}
template <typename T>
__mlu_global__ void causalSoftmax(T *y, const T *x,
size_t batch_size, size_t seq_len, size_t total_seq_len,
ptrdiff_t y_stride_b, ptrdiff_t y_stride_i, ptrdiff_t y_stride_j,
ptrdiff_t x_stride_b, ptrdiff_t x_stride_i, ptrdiff_t x_stride_j) {
using namespace op::common_bang::reduce_op;
// Get task information
size_t task_id = taskId;
size_t task_num = taskDimX * taskDimY;
// Calculate elements per task with better load balancing
size_t total_tasks = batch_size * seq_len;
size_t tasks_per_core = (total_tasks + task_num - 1) / task_num;
size_t start = task_id * tasks_per_core;
size_t end = std::min(start + tasks_per_core, total_tasks);
// Allocate NRAM buffers
const int max_batch = SRC_MAX_SIZE / sizeof(T);
T *src = (T *)nram_buffer;
float *dst = (float *)(nram_buffer + max_batch * sizeof(T));
for (size_t index = start; index < end; index++) {
size_t batch = index / seq_len;
size_t i = (index % seq_len);
ptrdiff_t y_offset = batch * y_stride_b + i * y_stride_i;
ptrdiff_t x_offset = batch * x_stride_b + i * x_stride_i;
T *y_ = y + y_offset;
const T *x_ = x + x_offset;
// Calculate the valid sequence length for this position
size_t valid_len = total_seq_len - seq_len + i + 1;
// Zero out future positions
for (size_t j = valid_len; j < total_seq_len; j++) {
y_[j * y_stride_j] = (T)0.0f;
}
// Calculate max value using optimized reduction
float max_val = maxBatched(x_, src, dst, valid_len, max_batch);
// Compute exp(x - max)
processSoftmaxStep(y_, x_, max_val, valid_len, x_stride_j, true);
// Calculate sum of exponentials
float sum_val = sumBatched(y_, src, dst, valid_len, max_batch);
// Normalize by sum
processSoftmaxStep(y_, y_, 1.0f / sum_val, valid_len, y_stride_j, false);
}
}
template <typename T>
void causalSoftmaxUnion(void *workspace, int core_per_cluster, int cluster_count,
cnrtQueue_t queue, void *y, const void *x, const op::causal_softmax::CausalSoftmaxInfo *info) {
cnrtDim3_t kernel_dim;
cnrtFunctionType_t kernel_type;
// Configure kernel dimensions
kernel_dim.x = core_per_cluster;
kernel_dim.y = cluster_count;
kernel_dim.z = 1;
kernel_type = CNRT_FUNC_TYPE_UNION1;
// Launch kernel
causalSoftmax<T><<<kernel_dim, kernel_type, queue>>>(
(T *)y, (const T *)x,
info->batch_size, info->seq_len, info->total_seq_len,
info->y_stride_b, info->y_stride_i, info->y_stride_j,
info->x_stride_b, info->x_stride_i, info->x_stride_j);
cnrtQueueSync(queue);
}
namespace op::causal_softmax::bang {
struct Descriptor::Opaque {
std::shared_ptr<device::bang::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc) {
auto handle = reinterpret_cast<device::bang::cambricon::Handle *>(handle_);
auto result = CausalSoftmaxInfo::create(y_desc, x_desc);
CHECK_RESULT(result);
auto info = result.take();
*desc_ptr = new Descriptor(
new Descriptor::Opaque{static_cast<device::bang::Handle *>(handle)->internal()},
info,
0,
handle->device,
handle->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *y,
const void *x,
void *stream) const {
auto queue = reinterpret_cast<cnrtQueue_t>(stream);
int core_per_cluster = _opaque->internal->getCorePerCluster();
int cluster_count = _opaque->internal->getClusterCount();
// Dispatch based on data type
if (_info.dtype == INFINI_DTYPE_F16) {
causalSoftmaxUnion<half>(workspace, core_per_cluster, cluster_count, queue, y, x, &_info);
} else if (_info.dtype == INFINI_DTYPE_BF16) {
causalSoftmaxUnion<bfloat16_t>(workspace, core_per_cluster, cluster_count, queue, y, x, &_info);
} else if (_info.dtype == INFINI_DTYPE_F32) {
causalSoftmaxUnion<float>(workspace, core_per_cluster, cluster_count, queue, y, x, &_info);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::causal_softmax::bang
test/infiniop/causal_softmax.py
View file @
ef577d9d
...
@@ -41,7 +41,7 @@ _TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32]
...
@@ -41,7 +41,7 @@ _TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32]
_TOLERANCE_MAP
=
{
_TOLERANCE_MAP
=
{
InfiniDtype
.
F16
:
{
"atol"
:
1e-3
,
"rtol"
:
1e-2
},
InfiniDtype
.
F16
:
{
"atol"
:
1e-3
,
"rtol"
:
1e-2
},
InfiniDtype
.
BF16
:
{
"atol"
:
5e-3
,
"rtol"
:
5e-2
},
InfiniDtype
.
BF16
:
{
"atol"
:
5e-3
,
"rtol"
:
5e-2
},
InfiniDtype
.
F32
:
{
"atol"
:
1
e-5
,
"rtol"
:
1e-5
},
InfiniDtype
.
F32
:
{
"atol"
:
3
e-5
,
"rtol"
:
1e-5
},
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment