Unverified Commit c24a52ea authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

issue/278 - Cambricon Random Sample

parent f19238db
......@@ -7,6 +7,21 @@
namespace device::bang::kernel {
template <typename T>
__mlu_device__ float to_float(const T &v) {
return static_cast<float>(v);
}
template <typename T>
__mlu_device__ bfloat16_t to_bfloat16(const T &v) {
return static_cast<bfloat16_t>(v);
}
template <typename T>
__mlu_device__ half to_half(const T &v) {
return static_cast<half>(v);
}
/**
* @brief Converts a flattened index to a reduced offset considering broadcasting.
*
......
#ifndef __RANDOM_SAMPLE_BANG_H__
#define __RANDOM_SAMPLE_BANG_H__
#include "../random_sample.h"
DESCRIPTOR(bang)
#endif // __RANDOM_SAMPLE_BANG_H__
#include "../../../devices/bang/bang_handle.h"
#include "../info.h"
#include "random_sample_bang.h"
#include "random_sample_kernel.mlu"
namespace op::random_sample::bang {
struct Descriptor::Opaque {
std::shared_ptr<device::bang::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t result_desc,
infiniopTensorDescriptor_t probs_desc) {
auto handle = reinterpret_cast<device::bang::Handle *>(handle_);
auto result = RandomSampleInfo::create(result_desc, probs_desc);
CHECK_RESULT(result);
auto info = result.take();
size_t workspace_size;
#define CASE_P(CASE, Tidx, Tval) \
case CASE: { \
auto workspace_result = calculateWorkspace<Tidx, Tval>(info.n); \
CHECK_RESULT(workspace_result); \
workspace_size = workspace_result.take(); \
} break
#define CASE_I(CASE, Tidx) \
case CASE: \
switch (info.dt_p) { \
CASE_P(INFINI_DTYPE_F16, Tidx, half); \
CASE_P(INFINI_DTYPE_BF16, Tidx, bfloat16_t); \
CASE_P(INFINI_DTYPE_F32, Tidx, float); \
default: \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
} \
break
switch (info.dt_i) {
CASE_I(INFINI_DTYPE_I8, int8_t);
CASE_I(INFINI_DTYPE_I16, int16_t);
CASE_I(INFINI_DTYPE_I32, int32_t);
CASE_I(INFINI_DTYPE_I64, int64_t);
CASE_I(INFINI_DTYPE_U8, uint8_t);
CASE_I(INFINI_DTYPE_U16, uint16_t);
CASE_I(INFINI_DTYPE_U32, uint32_t);
CASE_I(INFINI_DTYPE_U64, uint64_t);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
#undef CASE_I
#undef CASE_P
*desc_ptr = new Descriptor(
info,
workspace_size,
new Opaque{handle->internal()},
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
size_t Descriptor::minWorkspaceSize() const {
return _min_workspace_size;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *result,
const void *probs,
float random_val,
float topp,
int topk,
float temperature,
void *stream) const {
if (workspace_size < _min_workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
Calculate::calculate<Algo>(
Algo{}, _info, workspace, workspace_size,
result, probs,
random_val, topp, topk, temperature,
stream);
return INFINI_STATUS_SUCCESS;
}
} // namespace op::random_sample::bang
#include "../../../../utils/custom_types.h"
#include "../../../devices/bang/bang_kernel_common.h"
#include "../../../devices/bang/common_bang.h"
#include "infinicore.h"
#include <algorithm>
const int SRC_MAX_SIZE = 1024 * 32;
namespace op::random_sample::bang {
using namespace std;
using namespace device::bang::kernel;
__nram__ char nram_buffer[NRAM_MAX_SIZE];
template <class Tidx, class Tval>
utils::Result<size_t> calculateWorkspace(size_t n) {
size_t size = n * (sizeof(Tidx) + sizeof(Tval)) + sizeof(Tval);
return utils::Result<size_t>(size);
}
template <typename Tval>
__mlu_device__ void swap(Tval &a, Tval &b) {
Tval tmp = a;
a = b;
b = tmp;
}
template <typename Tval, typename Tidx>
__mlu_device__ void findTopk(
Tval *values,
Tidx *result,
int size,
int topk) {
for (int i = 0; i < topk; ++i) {
for (int j = i + 1; j < size; ++j) {
if (values[i] < values[j]) {
swap(values[i], values[j]);
swap(result[i], result[j]);
}
}
}
}
template <typename Tval, typename Tidx>
__mlu_device__ void initTopkBuffer(
Tval *values,
Tidx *result,
int actual_size,
int topk) {
if (actual_size < topk) {
int fill_size = topk - actual_size;
__bang_write_value(values + actual_size, fill_size, -INFINITY);
__bang_write_value(result + actual_size, fill_size, -1);
}
}
template <typename Tval, typename Tidx>
__mlu_global__ void argMax(
const Tval *probs,
Tidx *result,
Tidx *gdram_indices,
int vocab_size) {
const size_t max_num = SRC_MAX_SIZE / sizeof(Tval);
size_t task_per_load = taskDim * max_num;
size_t repeat = vocab_size / task_per_load;
size_t remain = vocab_size % task_per_load;
Tval *nram_src = (Tval *)nram_buffer;
Tval *nram_max = nram_src + max_num;
Tval current_max = -INFINITY;
size_t current_index = 0;
// Process full chunks
for (size_t r = 0; r < repeat; ++r) {
__memcpy(nram_src, probs + r * task_per_load + taskId * max_num,
max_num * sizeof(Tval), GDRAM2NRAM);
__bang_argmax(nram_max, nram_src, max_num);
if (nram_max[0] > current_max) {
current_max = nram_max[0];
current_index = r * task_per_load + taskId * max_num + *((int64_t *)&nram_max[1]);
}
}
// Process remainder
size_t remain_start = repeat * task_per_load;
size_t step = remain / taskDim + (taskId < remain % taskDim ? 1 : 0);
size_t start = remain_start + taskId * (remain / taskDim) + (taskId < remain % taskDim ? taskId : remain % taskDim);
if (step > 0) {
__bang_write_value(nram_src, max_num, -INFINITY);
__memcpy(nram_src, probs + start, step * sizeof(Tval), GDRAM2NRAM);
__bang_argmax(nram_max, nram_src, step);
if (nram_max[0] > current_max) {
current_max = nram_max[0];
current_index = start + *((int64_t *)&nram_max[1]);
}
}
// Reduce across tasks
gdram_indices[taskId] = current_index;
if (taskId == 0) {
Tval global_max = probs[gdram_indices[0]];
Tidx global_idx = gdram_indices[0];
for (size_t id = 1; id < taskDim; ++id) {
if constexpr (std::is_same<Tval, bfloat16_t>::value) {
if (to_float(probs[gdram_indices[id]]) > to_float(global_max)) {
global_max = probs[gdram_indices[id]];
global_idx = gdram_indices[id];
}
} else {
if (probs[gdram_indices[id]] > global_max) {
global_max = probs[gdram_indices[id]];
global_idx = gdram_indices[id];
}
}
}
result[0] = global_idx;
}
}
template <typename Tval, typename Tidx>
__mlu_global__ void randomSampleKernel(
const Tval *probs,
Tidx *result,
Tidx *gdram_indices,
Tval *global_topk,
Tval *global_sum,
size_t vocab_size,
float random_val,
float topp,
int topk,
float temperature) {
constexpr int max_num = SRC_MAX_SIZE / sizeof(Tval);
constexpr int w_size = 128 / sizeof(Tval);
constexpr int seg_num = max_num / w_size;
// Handle temperature division carefully
Tval temp_inv;
if constexpr (std::is_same<Tval, bfloat16_t>::value) {
temp_inv = to_bfloat16(1.0f / temperature);
} else {
temp_inv = static_cast<Tval>(1.0 / temperature);
}
// NRAM buffer allocation
Tval *nram_src = (Tval *)nram_buffer;
Tval *nram_partial_sum = nram_src + max_num;
Tval *nram_sum_final = nram_partial_sum + max_num;
Tval *nram_topk = nram_sum_final + w_size;
Tidx *nram_indices = (Tidx *)(nram_topk + topk);
Tidx *nram_global_indices = nram_indices + max_num;
// Initialize buffers
__bang_write_value(nram_src, max_num, -INFINITY);
__bang_write_value(nram_partial_sum, max_num, 0);
__bang_write_value(nram_sum_final, w_size, 0);
// Calculate workload distribution
int step = vocab_size / taskDim + (taskId < vocab_size % taskDim ? 1 : 0);
int start_idx = taskId * (vocab_size / taskDim) + min<size_t>(taskId, vocab_size % taskDim);
// Load and process data
__memcpy(nram_src, probs + start_idx, step * sizeof(Tval), GDRAM2NRAM);
for (int i = 0; i < step; ++i) {
nram_indices[i] = start_idx + i;
}
// Find top-k elements
initTopkBuffer(nram_src, nram_indices, step, topk);
findTopk(nram_src, nram_indices, max(step, topk), topk);
// Store results
__memcpy(global_topk + taskId * topk, nram_src, topk * sizeof(Tval), NRAM2GDRAM);
__memcpy(gdram_indices + taskId * topk, nram_indices, topk * sizeof(Tidx), NRAM2GDRAM);
__sync_all();
// Global reduction on task 0
if (taskId == 0) {
__memcpy(nram_topk, global_topk, taskDim * topk * sizeof(Tval), GDRAM2NRAM);
__memcpy(nram_global_indices, gdram_indices, taskDim * topk * sizeof(Tidx), GDRAM2NRAM);
findTopk(nram_topk, nram_global_indices, taskDim * topk, topk);
__memcpy(global_topk, nram_topk, topk * sizeof(Tval), NRAM2GDRAM);
__memcpy(gdram_indices, nram_global_indices, topk * sizeof(Tidx), NRAM2GDRAM);
}
__sync_io();
// Softmax computation
Tval global_max = global_topk[0];
__bang_write_value(nram_partial_sum, max_num, 0);
__bang_write_value(nram_sum_final, w_size, 0);
// Reset global_sum to 0 before accumulation
if (taskId == 0) {
global_sum[0] = 0;
}
__memcpy(nram_src, probs + start_idx, step * sizeof(Tval), GDRAM2NRAM);
// Stable softmax calculation
__bang_sub_scalar(nram_src, nram_src, global_max, step);
__bang_mul_scalar(nram_src, nram_src, temp_inv, step);
Tval max_exp = static_cast<Tval>(20.0f); // Clamp threshold
for (int i = 0; i < step; i++) {
if (nram_src[i] > max_exp) {
nram_src[i] = max_exp;
}
}
__bang_active_exp_less_0(nram_src, nram_src, step);
__bang_add(nram_partial_sum, nram_partial_sum, nram_src, step);
// Reduce sum
for (int strip = seg_num / 2; strip > 0; strip /= 2) {
#pragma unroll
for (int i = 0; i < strip; ++i) {
__bang_add(nram_partial_sum + i * w_size, nram_partial_sum + i * w_size,
nram_partial_sum + (i + strip) * w_size, w_size);
}
}
__bang_reduce_sum(nram_sum_final, nram_partial_sum, w_size);
// Atomic add to global sum
__bang_atomic_add(nram_sum_final, global_sum, nram_sum_final, 1);
__sync_compute();
// Top-p sampling
if (taskId == 0) {
// Ensure a valid sum
Tval sum = global_sum[0];
if (sum <= static_cast<Tval>(0)) {
// Fallback: just select the first index if sum is invalid
result[0] = static_cast<Tidx>(gdram_indices[0]);
return;
}
Tval global_sum_inv;
if constexpr (std::is_same<Tval, bfloat16_t>::value) {
global_sum_inv = to_bfloat16(1.0f / to_float(sum));
} else {
global_sum_inv = static_cast<Tval>(1.0 / sum);
}
__memcpy(nram_topk, global_topk, topk * sizeof(Tval), GDRAM2NRAM);
// Recompute softmax for topk elements
__bang_sub_scalar(nram_topk, nram_topk, global_max, topk);
__bang_mul_scalar(nram_topk, nram_topk, temp_inv, topk);
// Manual clamp again
for (int i = 0; i < topk; i++) {
if (nram_topk[i] > max_exp) {
nram_topk[i] = max_exp;
}
}
__bang_active_exp_less_0(nram_topk, nram_topk, topk);
__bang_mul_scalar(nram_topk, nram_topk, global_sum_inv, topk);
// Cumulative sum
Tval cumsum = 0;
int end = topk;
for (int i = 0; i < topk; ++i) {
cumsum += nram_topk[i];
if (cumsum >= static_cast<Tval>(topp)) {
end = i + 1;
break;
}
}
// Scale random value to the cumulative sum range
random_val *= to_float(cumsum);
// Perform sampling
cumsum = 0;
for (int i = 0; i < end; ++i) {
cumsum += nram_topk[i];
if (random_val < to_float(cumsum)) {
result[0] = static_cast<Tidx>(gdram_indices[i]);
return;
}
}
// Fallback if no index was selected
result[0] = static_cast<Tidx>(gdram_indices[end - 1]);
}
}
template <typename Tval, typename Tidx>
__mlu_global__ void randomSampleKernelLarge(
const Tval *probs,
Tidx *result,
Tidx *gdram_indices,
Tval *global_topk,
Tval *global_sum,
size_t vocab_size,
float random_val,
float topp,
int topk,
float temperature) {
const int max_num = SRC_MAX_SIZE / sizeof(Tval);
const int w_size = 128 / sizeof(Tval);
const int seg_num = max_num / w_size;
// Handle temperature division carefully
Tval temp_inv;
if constexpr (std::is_same<Tval, bfloat16_t>::value) {
temp_inv = to_bfloat16(1.0f / temperature);
} else {
temp_inv = static_cast<Tval>(1.0 / temperature);
}
const int task_size = taskDim * max_num;
const int remain = vocab_size % task_size;
const int repeat = (vocab_size - remain) / task_size;
const int remain_t = remain % taskDim;
const int step_easy = (remain - remain_t) / taskDim;
const int step_hard = step_easy + 1;
const int step = (taskId < remain_t ? step_hard : step_easy);
const int start_idx = (taskId < remain_t ? taskId * step_hard : remain_t * step_hard + (taskId - remain_t) * step_easy);
// NRAM buffer allocation
char *nram_buffer_ind = nram_buffer + (2 * max_num + w_size + 2 * topk + taskDim * topk) * sizeof(Tval);
Tval *nram_src = (Tval *)nram_buffer;
Tval *nram_topk_buffer = nram_src + max_num;
Tval *nram_partial_sum = nram_topk_buffer + 2 * topk;
Tval *nram_sum_final = nram_partial_sum + max_num;
Tval *nram_global_topk = nram_sum_final + w_size;
Tidx *nram_indices = (Tidx *)nram_buffer_ind;
Tidx *nram_topk_indices = nram_indices + max_num;
Tidx *nram_global_indices = nram_topk_indices + 2 * topk;
// Initialize buffers
for (int i = 0; i < 2 * topk; ++i) {
nram_topk_buffer[i] = -INFINITY;
}
for (int j = 0; j < max_num; ++j) {
nram_indices[j] = taskId * max_num + j;
}
// Process full chunks
for (int r = 0; r < repeat; ++r) {
if (r > 0) {
__bang_add_scalar((short *)nram_indices, (short *)nram_indices, task_size, max_num / sizeof(short));
}
__memcpy(nram_src, probs + r * task_size + taskId * max_num, max_num * sizeof(Tval), GDRAM2NRAM);
findTopk(nram_src, nram_indices, max_num, topk);
__memcpy(nram_topk_buffer + topk, nram_src, topk * sizeof(Tval), NRAM2NRAM);
__memcpy(nram_topk_indices + topk, nram_indices, topk * sizeof(Tidx), NRAM2NRAM);
findTopk(nram_topk_buffer, nram_topk_indices, 2 * topk, topk);
}
// Handle remaining elements
if (step) {
for (int j = 0; j < step; ++j) {
nram_indices[j] = repeat * task_size + start_idx + j;
}
__memcpy(nram_src, probs + repeat * task_size + start_idx, step * sizeof(Tval), GDRAM2NRAM);
if (step >= topk) {
findTopk(nram_src, nram_indices, step, topk);
__memcpy(nram_topk_buffer + topk, nram_src, topk * sizeof(Tval), NRAM2NRAM);
__memcpy(nram_topk_indices + topk, nram_indices, topk * sizeof(Tidx), NRAM2NRAM);
} else {
initTopkBuffer(nram_src, nram_indices, step, topk);
__memcpy(nram_topk_buffer + topk, nram_src, step * sizeof(Tval), NRAM2NRAM),
__memcpy(nram_topk_indices + topk, nram_indices, step * sizeof(Tidx), NRAM2NRAM);
}
findTopk(nram_topk_buffer, nram_topk_indices, 2 * topk, topk);
}
// Store results to global memory
__memcpy(global_topk + taskId * topk, nram_topk_buffer, topk * sizeof(Tval), NRAM2GDRAM);
__memcpy(gdram_indices + taskId * topk, nram_topk_indices, topk * sizeof(Tidx), NRAM2GDRAM);
__sync_all();
// Task 0 merges all partial results
if (taskId == 0) {
__memcpy(nram_global_topk, global_topk, taskDim * topk * sizeof(Tval), GDRAM2NRAM);
__memcpy(nram_global_indices, gdram_indices, taskDim * topk * sizeof(Tidx), GDRAM2NRAM);
findTopk(nram_global_topk, nram_global_indices, taskDim * topk, topk);
__memcpy(global_topk, nram_global_topk, topk * sizeof(Tval), NRAM2GDRAM);
__memcpy(gdram_indices, nram_global_indices, topk * sizeof(Tidx), NRAM2GDRAM);
}
__sync_all();
// Softmax transformation
Tval global_max = global_topk[0];
__bang_write_value(nram_partial_sum, max_num, 0);
__bang_write_value(nram_sum_final, w_size, 0);
// Process full chunks
for (int r = 0; r < repeat; ++r) {
__memcpy(nram_src, probs + r * task_size + taskId * max_num, max_num * sizeof(Tval), GDRAM2NRAM);
__bang_sub_scalar(nram_src, nram_src, global_max, max_num);
__bang_mul_scalar(nram_src, nram_src, temp_inv, max_num);
__bang_active_exp_less_0(nram_src, nram_src, max_num);
__bang_add(nram_partial_sum, nram_partial_sum, nram_src, max_num);
}
// Process remaining elements
if (step) {
__bang_write_value(nram_src, max_num, 0);
__memcpy(nram_src, probs + repeat * task_size + start_idx, step * sizeof(Tval), GDRAM2NRAM);
__bang_sub_scalar(nram_src, nram_src, global_max, step);
__bang_mul_scalar(nram_src, nram_src, temp_inv, step);
__bang_active_exp_less_0(nram_src, nram_src, step);
__bang_add(nram_partial_sum, nram_partial_sum, nram_src, max_num);
}
// Reduce sum
if (max_num >= w_size) {
for (int strip = seg_num / 2; strip > 0; strip = strip / 2) {
for (int i = 0; i < strip; ++i) {
__bang_add(nram_partial_sum + i * w_size, nram_partial_sum + i * w_size,
nram_partial_sum + (i + strip) * w_size, w_size);
}
}
for (int i = 0; i < w_size; ++i) {
if constexpr (std::is_same<Tval, bfloat16_t>::value) {
nram_sum_final[0] = to_bfloat16(
to_float(nram_sum_final[0]) + to_float(nram_partial_sum[i]));
} else {
nram_sum_final[0] += nram_partial_sum[i];
}
}
} else {
for (int i = 0; i < max_num; ++i) {
if constexpr (std::is_same<Tval, bfloat16_t>::value) {
nram_sum_final[0] = to_bfloat16(
to_float(nram_sum_final[0]) + to_float(nram_partial_sum[i]));
} else {
nram_sum_final[0] += nram_partial_sum[i];
}
}
}
if constexpr (std::is_same<Tval, bfloat16_t>::value) {
global_sum[0] = to_bfloat16(0.0f);
} else {
global_sum[0] = 0.0;
}
__sync_all();
__bang_atomic_add(nram_sum_final, global_sum, nram_sum_final, 1);
// Task 0 performs the final sampling
if (taskId == 0) {
Tval global_sum_inv;
if constexpr (std::is_same<Tval, bfloat16_t>::value) {
global_sum_inv = to_bfloat16(1.0f / to_float(global_sum[0]));
} else {
global_sum_inv = static_cast<Tval>(1.0 / global_sum[0]);
}
__memcpy(nram_global_topk, global_topk, topk * sizeof(Tval), GDRAM2NRAM);
// Softmax for topk elements
__bang_sub_scalar(nram_global_topk, nram_global_topk, global_max, topk);
__bang_mul_scalar(nram_global_topk, nram_global_topk, temp_inv, topk);
__bang_active_exp_less_0(nram_global_topk, nram_global_topk, topk);
__bang_mul_scalar(nram_global_topk, nram_global_topk, global_sum_inv, topk);
// Compute cumulative sum for sampling
__bang_write_value(nram_topk_buffer, 2 * topk, 0);
nram_topk_buffer[0] = nram_global_topk[0];
for (int i = 1; i < topk; ++i) {
if constexpr (std::is_same<Tval, bfloat16_t>::value) {
nram_topk_buffer[i] = to_bfloat16(
to_float(nram_topk_buffer[i - 1]) + to_float(nram_global_topk[i]));
} else {
nram_topk_buffer[i] = nram_topk_buffer[i - 1] + nram_global_topk[i];
}
}
// Find the cutoff point for top-p sampling
int end = 0;
for (end = 0; end < topk; ++end) {
if constexpr (std::is_same<Tval, bfloat16_t>::value) {
if (to_float(nram_topk_buffer[end]) >= topp) {
break;
}
} else {
if (nram_topk_buffer[end] >= static_cast<Tval>(topp)) {
break;
}
}
}
end = (end < topk - 1) ? end + 1 : topk;
// Perform the sampling
if constexpr (std::is_same<Tval, bfloat16_t>::value) {
random_val *= to_float(nram_topk_buffer[end - 1]);
} else {
random_val *= nram_topk_buffer[end - 1];
}
for (int i = 0; i < end; ++i) {
if (random_val < to_float(nram_topk_buffer[i])) {
result[0] = gdram_indices[i];
break;
}
}
__memcpy(global_topk, nram_global_topk, topk * sizeof(Tval), NRAM2GDRAM);
}
}
struct Algo {
template <class Tidx, class Tval_>
infiniStatus_t argmax(
void *workspace, size_t workspace_size,
void *result_, const void *probs, size_t voc,
void *stream_) const {
cnrtDim3_t dim = {4, 1, 1};
auto queue = reinterpret_cast<cnrtQueue_t>(stream_);
auto result = reinterpret_cast<Tidx *>(result_);
auto gdram_indices = reinterpret_cast<Tidx *>((char *)workspace);
if constexpr (std::is_same<Tval_, float>::value) {
auto logits = reinterpret_cast<const float *>(probs);
argMax<<<dim, CNRT_FUNC_TYPE_BLOCK, queue>>>(logits, result, gdram_indices, voc);
} else if constexpr (std::is_same<Tval_, CustomFloat16>::value) {
auto logits = reinterpret_cast<const half *>(probs);
argMax<<<dim, CNRT_FUNC_TYPE_BLOCK, queue>>>(logits, result, gdram_indices, voc);
} else if constexpr (std::is_same<Tval_, CustomBFloat16>::value) {
auto logits = reinterpret_cast<const bfloat16_t *>(probs);
argMax<<<dim, CNRT_FUNC_TYPE_BLOCK, queue>>>(logits, result, gdram_indices, voc);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
cnrtQueueSync(queue);
return INFINI_STATUS_SUCCESS;
}
template <class Tidx, class Tval_>
infiniStatus_t random(
void *workspace, size_t workspace_size,
void *result_, const void *probs, size_t voc,
float random_val, float topp, int topk, float temperature,
void *stream_) const {
cnrtDim3_t dim = {4, 1, 1};
int task_num = dim.x * dim.y * dim.z;
auto queue = reinterpret_cast<cnrtQueue_t>(stream_);
auto result = reinterpret_cast<Tidx *>(result_);
auto gdram_indices = reinterpret_cast<Tidx *>((char *)workspace);
size_t offset = sizeof(Tidx) * voc;
if constexpr (std::is_same<Tval_, float>::value) {
auto logits = reinterpret_cast<const float *>(probs);
offset += sizeof(float) * voc;
if (offset > workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
float *global_top_k = (float *)gdram_indices + task_num * topk * sizeof(Tidx);
float *global_sum = global_top_k + task_num * topk;
const int max_num = SRC_MAX_SIZE / sizeof(float);
if (voc >= task_num * max_num) {
randomSampleKernelLarge<<<dim, CNRT_FUNC_TYPE_UNION1, queue>>>(
logits, result, gdram_indices, global_top_k, global_sum, voc, random_val, topp, topk, temperature);
} else {
randomSampleKernel<<<dim, CNRT_FUNC_TYPE_UNION1, queue>>>(
logits, result, gdram_indices, global_top_k, global_sum, voc, random_val, topp, topk, temperature);
}
} else if constexpr (std::is_same<Tval_, CustomFloat16>::value) {
auto logits = reinterpret_cast<const half *>(probs);
offset += sizeof(half) * voc;
if (offset > workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
half *global_top_k = (half *)gdram_indices + task_num * topk * sizeof(Tidx);
half *global_sum = global_top_k + task_num * topk;
const int max_num = SRC_MAX_SIZE / sizeof(half);
if (voc >= task_num * max_num) {
randomSampleKernelLarge<<<dim, CNRT_FUNC_TYPE_UNION1, queue>>>(
logits, result, gdram_indices, global_top_k, global_sum, voc, random_val, topp, topk, temperature);
} else {
randomSampleKernel<<<dim, CNRT_FUNC_TYPE_UNION1, queue>>>(
logits, result, gdram_indices, global_top_k, global_sum, voc, random_val, topp, topk, temperature);
}
} else if constexpr (std::is_same<Tval_, CustomBFloat16>::value) {
auto logits = reinterpret_cast<const bfloat16_t *>(probs);
offset += sizeof(bfloat16_t) * voc;
if (offset > workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
bfloat16_t *global_top_k = (bfloat16_t *)gdram_indices + task_num * topk * sizeof(Tidx);
bfloat16_t *global_sum = global_top_k + task_num * topk;
const int max_num = SRC_MAX_SIZE / sizeof(bfloat16_t);
if (voc >= task_num * max_num) {
randomSampleKernelLarge<<<dim, CNRT_FUNC_TYPE_UNION1, queue>>>(
logits, result, gdram_indices, global_top_k, global_sum, voc, random_val, topp, topk, temperature);
} else {
randomSampleKernel<<<dim, CNRT_FUNC_TYPE_UNION1, queue>>>(
logits, result, gdram_indices, global_top_k, global_sum, voc, random_val, topp, topk, temperature);
}
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
cnrtQueueSync(queue);
return INFINI_STATUS_SUCCESS;
}
};
} // namespace op::random_sample::bang
......@@ -8,6 +8,9 @@
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API)
#include "nvidia/random_sample_nvidia.cuh"
#endif
#ifdef ENABLE_CAMBRICON_API
#include "bang/random_sample_bang.h"
#endif
#ifdef ENABLE_METAX_API
#include "metax/random_sample_metax.h"
#endif
......@@ -41,6 +44,9 @@ infiniopCreateRandomSampleDescriptor(
#ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_CAMBRICON_API
CREATE(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax);
#endif
......@@ -77,6 +83,9 @@ __C infiniStatus_t infiniopGetRandomSampleWorkspaceSize(
#ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_CAMBRICON_API
GET(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax);
#endif
......@@ -123,6 +132,9 @@ __C infiniStatus_t infiniopRandomSample(
#ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_CAMBRICON_API
CALCULATE(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif
......@@ -156,6 +168,9 @@ __C infiniStatus_t infiniopDestroyRandomSampleDescriptor(
#ifdef ENABLE_ILUVATAR_API
DELETE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_CAMBRICON_API
DELETE(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_METAX_API
DELETE(INFINI_DEVICE_METAX, metax);
#endif
......
......@@ -56,7 +56,14 @@ def random_sample(data, random_val, topp, topk, voc, temperature):
sorted_vals, sorted_indices = torch.sort(data, descending=True)
scaled_vals = (sorted_vals - sorted_vals[0]) / temperature
probs = torch.softmax(scaled_vals, dim=0)
try:
probs = torch.softmax(scaled_vals, dim=0)
except RuntimeError as e:
if "not implemented for 'Half'" in str(e):
scaled_vals = scaled_vals.to(torch.float32)
probs = torch.softmax(scaled_vals, dim=0)
else:
raise
cum_probs = torch.cumsum(probs, dim=0)
k_index = min(topk, voc) - 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment