Unverified Commit a8ea2306 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #1064 from InfiniTensor/issue/1031_T1-1-4

【算子比赛2025秋】 T1-1-4
parents 7f295448 210e31d3
#ifndef __TOPK_INFO_H__
#define __TOPK_INFO_H__
#include "../../../utils.h"
#include "../../tensor.h"
#include <algorithm>
#include <cstddef>
#include <vector>
namespace op::topk {
class TopKInfo {
TopKInfo() = default;
public:
infiniDtype_t dtype;
std::vector<size_t> input_shape;
std::vector<size_t> output_shape;
std::vector<ptrdiff_t> input_strides;
std::vector<ptrdiff_t> output_strides;
size_t k;
size_t dim;
bool largest;
bool sorted;
size_t ndim;
size_t dim_elements; // processed dim elements
size_t n_iteration; // total number of topk iteration
static utils::Result<TopKInfo> create(
infiniopTensorDescriptor_t values_output_desc,
infiniopTensorDescriptor_t indices_output_desc,
infiniopTensorDescriptor_t input_desc,
size_t k,
size_t dim,
bool largest,
bool sorted) {
auto input_shape = input_desc->shape();
auto input_strides = input_desc->strides();
size_t input_ndim = input_desc->ndim();
size_t dim_elements = input_shape[dim];
size_t n_iteration = 1;
for (size_t i = 0; i < input_ndim; i++) {
if (i != dim) {
n_iteration *= input_shape[i];
}
}
return utils::Result<TopKInfo>(TopKInfo{input_desc->dtype(),
input_desc->shape(),
values_output_desc->shape(),
input_desc->strides(),
values_output_desc->strides(),
k,
dim,
largest,
sorted,
input_ndim,
dim_elements,
n_iteration});
}
};
} // namespace op::topk
#endif
#ifndef __TOPK_METAX_H__
#define __TOPK_METAX_H__
#include "../topk_desc.h"
DESCRIPTOR(metax);
#endif
#include "../../../devices/metax/metax_common.h"
#include "../../../devices/metax/metax_kernel_common.h"
#include "../cuda/kernel.cuh"
#include "topk_metax.h"
#include <cub/block/block_radix_sort.cuh>
#include <cub/cub.cuh>
namespace op::topk::metax {
struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t values_output_desc,
infiniopTensorDescriptor_t indices_output_desc,
infiniopTensorDescriptor_t input_desc,
size_t k,
size_t dim,
bool largest,
bool sorted) {
auto result = TopKInfo::create(values_output_desc, indices_output_desc, input_desc, k, dim, largest, sorted);
CHECK_RESULT(result);
auto info = result.take();
size_t workspace_size = 0;
workspace_size += (input_desc->ndim() + values_output_desc->ndim()) * (sizeof(size_t) + sizeof(ptrdiff_t));
size_t dim_elements = input_desc->shape()[dim];
size_t n_iteration = 1;
for (size_t i = 0; i < input_desc->ndim(); i++) {
if (i != dim) {
n_iteration *= input_desc->shape()[i];
}
}
size_t total = n_iteration * dim_elements;
workspace_size += 3 * total * sizeof(uint32_t);
workspace_size += 3 * total * sizeof(int32_t);
workspace_size += n_iteration * k * (sizeof(uint32_t) + sizeof(int32_t));
if (sorted) {
workspace_size += n_iteration * k * (sizeof(uint32_t) + sizeof(int32_t));
}
workspace_size += 5 * n_iteration * sizeof(int32_t);
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
info, workspace_size, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
namespace {
template <size_t BLOCK_SIZE, int32_t SORT_ITEMS_PER_THREAD, typename Tdata>
infiniStatus_t launchKernel(
const TopKInfo &info,
Tdata *values_output, int32_t *indices_output, const Tdata *input,
size_t k, size_t dim, bool largest, bool sorted,
hcStream_t stream, void *workspace, size_t workspace_size) {
if (dim >= info.ndim) {
return INFINI_STATUS_BAD_PARAM;
}
if (k == 0) {
return INFINI_STATUS_SUCCESS;
}
if (k > info.dim_elements) {
return INFINI_STATUS_BAD_PARAM;
}
size_t input_ndim = info.ndim;
size_t output_ndim = input_ndim;
size_t n_iteration = info.n_iteration;
size_t dim_elements = info.dim_elements;
unsigned char *workspace_ptr = reinterpret_cast<unsigned char *>(workspace);
size_t workspace_offset = 0;
size_t *input_shape_hc = reinterpret_cast<size_t *>(workspace_ptr + workspace_offset);
size_t *output_shape_hc = input_shape_hc + input_ndim;
workspace_offset += (input_ndim + output_ndim) * sizeof(size_t);
ptrdiff_t *input_strides_hc = reinterpret_cast<ptrdiff_t *>(workspace_ptr + workspace_offset);
ptrdiff_t *output_strides_hc = input_strides_hc + input_ndim;
workspace_offset += (input_ndim + output_ndim) * sizeof(ptrdiff_t);
CHECK_METAX(hcMemcpyAsync(input_shape_hc, info.input_shape.data(), input_ndim * sizeof(size_t), hcMemcpyHostToDevice, stream));
CHECK_METAX(hcMemcpyAsync(output_shape_hc, info.output_shape.data(), output_ndim * sizeof(size_t), hcMemcpyHostToDevice, stream));
CHECK_METAX(hcMemcpyAsync(input_strides_hc, info.input_strides.data(), input_ndim * sizeof(ptrdiff_t), hcMemcpyHostToDevice, stream));
CHECK_METAX(hcMemcpyAsync(output_strides_hc, info.output_strides.data(), output_ndim * sizeof(ptrdiff_t), hcMemcpyHostToDevice, stream));
const int32_t total = n_iteration * dim_elements;
uint32_t *cur_vals = reinterpret_cast<uint32_t *>(workspace_ptr + workspace_offset);
workspace_offset += total * sizeof(uint32_t);
uint32_t *ones_vals = reinterpret_cast<uint32_t *>(workspace_ptr + workspace_offset);
workspace_offset += total * sizeof(uint32_t);
uint32_t *zeros_vals = reinterpret_cast<uint32_t *>(workspace_ptr + workspace_offset);
workspace_offset += total * sizeof(uint32_t);
int32_t *cur_idx = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += total * sizeof(int32_t);
int32_t *ones_idx = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += total * sizeof(int32_t);
int32_t *zeros_idx = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += total * sizeof(int32_t);
uint32_t *sel_vals = reinterpret_cast<uint32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * k * sizeof(uint32_t);
int32_t *sel_idx = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * k * sizeof(int32_t);
uint32_t *sel_sorted_vals = nullptr;
int32_t *sel_sorted_idx = nullptr;
if (sorted) {
sel_sorted_vals = reinterpret_cast<uint32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * k * sizeof(uint32_t);
sel_sorted_idx = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * k * sizeof(int32_t);
}
int32_t *cur_n = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * sizeof(int32_t);
int32_t *rem_k = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * sizeof(int32_t);
int32_t *out_pos = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * sizeof(int32_t);
int32_t *ones_count = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * sizeof(int32_t);
int32_t *zeros_count = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * sizeof(int32_t);
// init
{
size_t threads = 256;
size_t blocks = (n_iteration + threads - 1) / threads;
op::topk::cuda::init_row_state<<<blocks, threads, 0, stream>>>(cur_n, rem_k, out_pos, n_iteration, dim_elements, k);
}
// gather input -> cur
{
dim3 block(BLOCK_SIZE);
dim3 grid((dim_elements + BLOCK_SIZE - 1) / BLOCK_SIZE, n_iteration);
op::topk::cuda::gather_rowwise<Tdata><<<grid, block, 0, stream>>>(
input, cur_vals, cur_idx,
n_iteration, dim_elements,
input_ndim, dim,
input_shape_hc, input_strides_hc);
}
// radix select/filter
for (int bit = 31; bit >= 0; --bit) {
{
size_t threads = 256;
size_t blocks = (n_iteration + threads - 1) / threads;
op::topk::cuda::zero_row_counters<<<blocks, threads, 0, stream>>>(ones_count, zeros_count, n_iteration);
}
{
dim3 block(BLOCK_SIZE);
dim3 grid((dim_elements + BLOCK_SIZE - 1) / BLOCK_SIZE, n_iteration);
op::topk::cuda::partition_rowwise<BLOCK_SIZE><<<grid, block, 0, stream>>>(
cur_vals, cur_idx,
ones_vals, ones_idx,
zeros_vals, zeros_idx,
cur_n, n_iteration, dim_elements,
bit, largest,
ones_count, zeros_count);
}
{
op::topk::cuda::decide_and_compact<BLOCK_SIZE><<<n_iteration, BLOCK_SIZE, 0, stream>>>(
cur_vals, cur_idx,
ones_vals, ones_idx,
zeros_vals, zeros_idx,
ones_count, zeros_count,
cur_n, rem_k, out_pos,
sel_vals, sel_idx,
n_iteration, dim_elements, k);
}
}
// append remaining
op::topk::cuda::take_remaining<BLOCK_SIZE><<<n_iteration, BLOCK_SIZE, 0, stream>>>(
cur_vals, cur_idx,
cur_n, rem_k, out_pos,
sel_vals, sel_idx,
n_iteration, dim_elements, k);
// sort (CUB block radix sort)
const int32_t *final_idx = sel_idx;
if (sorted) {
std::vector<int> h_offsets(n_iteration + 1);
for (size_t i = 0; i <= n_iteration; i++) {
h_offsets[i] = i * k;
}
int *d_offsets;
CHECK_METAX(hcMalloc(&d_offsets, (n_iteration + 1) * sizeof(int)));
CHECK_METAX(hcMemcpy(d_offsets, h_offsets.data(), (n_iteration + 1) * sizeof(int), hcMemcpyHostToDevice));
void *d_temp_storage = nullptr;
size_t temp_storage_bytes = 0;
if (!largest) {
cub::DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, sel_vals, sel_sorted_vals, sel_idx, sel_sorted_idx,
n_iteration * k, n_iteration, d_offsets, d_offsets + 1, 0, sizeof(uint32_t) * 8, stream);
hcMalloc(&d_temp_storage, temp_storage_bytes);
cub::DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, sel_vals, sel_sorted_vals, sel_idx, sel_sorted_idx,
n_iteration * k, n_iteration, d_offsets, d_offsets + 1, 0, sizeof(uint32_t) * 8, stream);
} else {
cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, sel_vals, sel_sorted_vals, sel_idx, sel_sorted_idx,
n_iteration * k, n_iteration, d_offsets, d_offsets + 1, 0, sizeof(uint32_t) * 8, stream);
hcMalloc(&d_temp_storage, temp_storage_bytes);
cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, sel_vals, sel_sorted_vals, sel_idx, sel_sorted_idx,
n_iteration * k, n_iteration, d_offsets, d_offsets + 1, 0, sizeof(uint32_t) * 8, stream);
}
CHECK_METAX(hcFree(d_offsets));
CHECK_METAX(hcFree(d_temp_storage));
final_idx = sel_sorted_idx;
}
// scatter to output (strided write)
{
dim3 block(BLOCK_SIZE);
dim3 grid((k + BLOCK_SIZE - 1) / BLOCK_SIZE, n_iteration);
op::topk::cuda::scatter_to_output<Tdata><<<grid, block, 0, stream>>>(
input, final_idx,
values_output, indices_output,
n_iteration, k,
input_ndim, dim,
input_shape_hc, input_strides_hc,
output_shape_hc, output_strides_hc);
}
return INFINI_STATUS_SUCCESS;
}
} // namespace
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *values_output,
void *indices_output,
const void *input,
size_t k,
size_t dim,
bool largest,
bool sorted,
void *stream_) const {
hcStream_t stream = (hcStream_t)stream_;
constexpr int ITEMS = 4;
#define CALCULATE_TOPK(BLOCK_SIZE, Tdata) \
launchKernel<BLOCK_SIZE, ITEMS, Tdata>( \
_info, \
(Tdata *)values_output, (int32_t *)indices_output, (const Tdata *)input, \
k, dim, largest, sorted, \
stream, workspace, workspace_size)
#define CALCULATE_TOPK_WITH_BLOCK_SIZE(BLOCK_SIZE) \
{ \
if (_info.dtype == INFINI_DTYPE_BF16) \
return CALCULATE_TOPK(BLOCK_SIZE, __hpcc_bfloat16); \
else if (_info.dtype == INFINI_DTYPE_F16) \
return CALCULATE_TOPK(BLOCK_SIZE, half); \
else if (_info.dtype == INFINI_DTYPE_F32) \
return CALCULATE_TOPK(BLOCK_SIZE, float); \
else \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
if (_opaque->internal->maxThreadsPerBlock() >= 256) {
CALCULATE_TOPK_WITH_BLOCK_SIZE(256)
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::topk::metax
#ifndef __TOPK_MOORE_H__
#define __TOPK_MOORE_H__
#include "../topk_desc.h"
DESCRIPTOR(moore);
#endif
#include "../../../devices/moore/moore_common.h"
#include "../../../devices/moore/moore_kernel_common.h"
#include "../cuda/kernel.cuh"
#include "topk_moore.h"
#include <cub/block/block_radix_sort.cuh>
#include <cub/cub.cuh>
namespace op::topk::moore {
struct Descriptor::Opaque {
std::shared_ptr<device::moore::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t values_output_desc,
infiniopTensorDescriptor_t indices_output_desc,
infiniopTensorDescriptor_t input_desc,
size_t k,
size_t dim,
bool largest,
bool sorted) {
auto result = TopKInfo::create(values_output_desc, indices_output_desc, input_desc, k, dim, largest, sorted);
CHECK_RESULT(result);
auto info = result.take();
size_t workspace_size = 0;
workspace_size += (input_desc->ndim() + values_output_desc->ndim()) * (sizeof(size_t) + sizeof(ptrdiff_t));
size_t dim_elements = input_desc->shape()[dim];
size_t n_iteration = 1;
for (size_t i = 0; i < input_desc->ndim(); i++) {
if (i != dim) {
n_iteration *= input_desc->shape()[i];
}
}
size_t total = n_iteration * dim_elements;
workspace_size += 3 * total * sizeof(uint32_t);
workspace_size += 3 * total * sizeof(int32_t);
workspace_size += n_iteration * k * (sizeof(uint32_t) + sizeof(int32_t));
if (sorted) {
workspace_size += n_iteration * k * (sizeof(uint32_t) + sizeof(int32_t));
}
workspace_size += 5 * n_iteration * sizeof(int32_t);
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::moore::Handle *>(handle)->internal()},
info, workspace_size, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
namespace {
template <size_t BLOCK_SIZE, int32_t SORT_ITEMS_PER_THREAD, typename Tdata>
infiniStatus_t launchKernel(
const TopKInfo &info,
Tdata *values_output, int32_t *indices_output, const Tdata *input,
size_t k, size_t dim, bool largest, bool sorted,
musaStream_t stream, void *workspace, size_t workspace_size) {
if (dim >= info.ndim) {
return INFINI_STATUS_BAD_PARAM;
}
if (k == 0) {
return INFINI_STATUS_SUCCESS;
}
if (k > info.dim_elements) {
return INFINI_STATUS_BAD_PARAM;
}
size_t input_ndim = info.ndim;
size_t output_ndim = input_ndim;
size_t n_iteration = info.n_iteration;
size_t dim_elements = info.dim_elements;
unsigned char *workspace_ptr = reinterpret_cast<unsigned char *>(workspace);
size_t workspace_offset = 0;
size_t *input_shape_musa = reinterpret_cast<size_t *>(workspace_ptr + workspace_offset);
size_t *output_shape_musa = input_shape_musa + input_ndim;
workspace_offset += (input_ndim + output_ndim) * sizeof(size_t);
ptrdiff_t *input_strides_musa = reinterpret_cast<ptrdiff_t *>(workspace_ptr + workspace_offset);
ptrdiff_t *output_strides_musa = input_strides_musa + input_ndim;
workspace_offset += (input_ndim + output_ndim) * sizeof(ptrdiff_t);
CHECK_MOORE(musaMemcpyAsync(input_shape_musa, info.input_shape.data(), input_ndim * sizeof(size_t), musaMemcpyHostToDevice, stream));
CHECK_MOORE(musaMemcpyAsync(output_shape_musa, info.output_shape.data(), output_ndim * sizeof(size_t), musaMemcpyHostToDevice, stream));
CHECK_MOORE(musaMemcpyAsync(input_strides_musa, info.input_strides.data(), input_ndim * sizeof(ptrdiff_t), musaMemcpyHostToDevice, stream));
CHECK_MOORE(musaMemcpyAsync(output_strides_musa, info.output_strides.data(), output_ndim * sizeof(ptrdiff_t), musaMemcpyHostToDevice, stream));
const int32_t total = n_iteration * dim_elements;
uint32_t *cur_vals = reinterpret_cast<uint32_t *>(workspace_ptr + workspace_offset);
workspace_offset += total * sizeof(uint32_t);
uint32_t *ones_vals = reinterpret_cast<uint32_t *>(workspace_ptr + workspace_offset);
workspace_offset += total * sizeof(uint32_t);
uint32_t *zeros_vals = reinterpret_cast<uint32_t *>(workspace_ptr + workspace_offset);
workspace_offset += total * sizeof(uint32_t);
int32_t *cur_idx = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += total * sizeof(int32_t);
int32_t *ones_idx = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += total * sizeof(int32_t);
int32_t *zeros_idx = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += total * sizeof(int32_t);
uint32_t *sel_vals = reinterpret_cast<uint32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * k * sizeof(uint32_t);
int32_t *sel_idx = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * k * sizeof(int32_t);
uint32_t *sel_sorted_vals = nullptr;
int32_t *sel_sorted_idx = nullptr;
if (sorted) {
sel_sorted_vals = reinterpret_cast<uint32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * k * sizeof(uint32_t);
sel_sorted_idx = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * k * sizeof(int32_t);
}
int32_t *cur_n = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * sizeof(int32_t);
int32_t *rem_k = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * sizeof(int32_t);
int32_t *out_pos = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * sizeof(int32_t);
int32_t *ones_count = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * sizeof(int32_t);
int32_t *zeros_count = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * sizeof(int32_t);
// init
{
size_t threads = 256;
size_t blocks = (n_iteration + threads - 1) / threads;
op::topk::cuda::init_row_state<<<blocks, threads, 0, stream>>>(cur_n, rem_k, out_pos, n_iteration, dim_elements, k);
}
// gather input -> cur
{
dim3 block(BLOCK_SIZE);
dim3 grid((dim_elements + BLOCK_SIZE - 1) / BLOCK_SIZE, n_iteration);
op::topk::cuda::gather_rowwise<Tdata><<<grid, block, 0, stream>>>(
input, cur_vals, cur_idx,
n_iteration, dim_elements,
input_ndim, dim,
input_shape_musa, input_strides_musa);
}
// radix select/filter
for (int bit = 31; bit >= 0; --bit) {
{
size_t threads = 256;
size_t blocks = (n_iteration + threads - 1) / threads;
op::topk::cuda::zero_row_counters<<<blocks, threads, 0, stream>>>(ones_count, zeros_count, n_iteration);
}
{
dim3 block(BLOCK_SIZE);
dim3 grid((dim_elements + BLOCK_SIZE - 1) / BLOCK_SIZE, n_iteration);
op::topk::cuda::partition_rowwise<BLOCK_SIZE><<<grid, block, 0, stream>>>(
cur_vals, cur_idx,
ones_vals, ones_idx,
zeros_vals, zeros_idx,
cur_n, n_iteration, dim_elements,
bit, largest,
ones_count, zeros_count);
}
{
op::topk::cuda::decide_and_compact<BLOCK_SIZE><<<n_iteration, BLOCK_SIZE, 0, stream>>>(
cur_vals, cur_idx,
ones_vals, ones_idx,
zeros_vals, zeros_idx,
ones_count, zeros_count,
cur_n, rem_k, out_pos,
sel_vals, sel_idx,
n_iteration, dim_elements, k);
}
}
// append remaining
op::topk::cuda::take_remaining<BLOCK_SIZE><<<n_iteration, BLOCK_SIZE, 0, stream>>>(
cur_vals, cur_idx,
cur_n, rem_k, out_pos,
sel_vals, sel_idx,
n_iteration, dim_elements, k);
// sort (CUB block radix sort)
const int32_t *final_idx = sel_idx;
if (sorted) {
std::vector<int> h_offsets(n_iteration + 1);
for (size_t i = 0; i <= n_iteration; i++) {
h_offsets[i] = i * k;
}
int *d_offsets;
CHECK_MOORE(musaMalloc(&d_offsets, (n_iteration + 1) * sizeof(int)));
CHECK_MOORE(musaMemcpy(d_offsets, h_offsets.data(), (n_iteration + 1) * sizeof(int), musaMemcpyHostToDevice));
void *d_temp_storage = nullptr;
size_t temp_storage_bytes = 0;
if (!largest) {
cub::DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, sel_vals, sel_sorted_vals, sel_idx, sel_sorted_idx,
n_iteration * k, n_iteration, d_offsets, d_offsets + 1, 0, sizeof(uint32_t) * 8, stream);
musaMalloc(&d_temp_storage, temp_storage_bytes);
cub::DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, sel_vals, sel_sorted_vals, sel_idx, sel_sorted_idx,
n_iteration * k, n_iteration, d_offsets, d_offsets + 1, 0, sizeof(uint32_t) * 8, stream);
} else {
cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, sel_vals, sel_sorted_vals, sel_idx, sel_sorted_idx,
n_iteration * k, n_iteration, d_offsets, d_offsets + 1, 0, sizeof(uint32_t) * 8, stream);
musaMalloc(&d_temp_storage, temp_storage_bytes);
cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, sel_vals, sel_sorted_vals, sel_idx, sel_sorted_idx,
n_iteration * k, n_iteration, d_offsets, d_offsets + 1, 0, sizeof(uint32_t) * 8, stream);
}
CHECK_MOORE(musaFree(d_offsets));
CHECK_MOORE(musaFree(d_temp_storage));
final_idx = sel_sorted_idx;
}
// scatter to output (strided write)
{
dim3 block(BLOCK_SIZE);
dim3 grid((k + BLOCK_SIZE - 1) / BLOCK_SIZE, n_iteration);
op::topk::cuda::scatter_to_output<Tdata><<<grid, block, 0, stream>>>(
input, final_idx,
values_output, indices_output,
n_iteration, k,
input_ndim, dim,
input_shape_musa, input_strides_musa,
output_shape_musa, output_strides_musa);
}
return INFINI_STATUS_SUCCESS;
}
} // namespace
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *values_output,
void *indices_output,
const void *input,
size_t k,
size_t dim,
bool largest,
bool sorted,
void *stream_) const {
musaStream_t stream = (musaStream_t)stream_;
constexpr int ITEMS = 4;
#define CALCULATE_TOPK(BLOCK_SIZE, Tdata) \
launchKernel<BLOCK_SIZE, ITEMS, Tdata>( \
_info, \
(Tdata *)values_output, (int32_t *)indices_output, (const Tdata *)input, \
k, dim, largest, sorted, \
stream, workspace, workspace_size)
#define CALCULATE_TOPK_WITH_BLOCK_SIZE(BLOCK_SIZE) \
{ \
if (_info.dtype == INFINI_DTYPE_BF16) \
return CALCULATE_TOPK(BLOCK_SIZE, __mt_bfloat16); \
else if (_info.dtype == INFINI_DTYPE_F16) \
return CALCULATE_TOPK(BLOCK_SIZE, half); \
else if (_info.dtype == INFINI_DTYPE_F32) \
return CALCULATE_TOPK(BLOCK_SIZE, float); \
else \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
if (_opaque->internal->maxThreadsPerBlock() >= 256) {
CALCULATE_TOPK_WITH_BLOCK_SIZE(256)
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::topk::moore
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include "../cuda/kernel.cuh"
#include "topk_nvidia.cuh"
#include <cub/block/block_radix_sort.cuh>
#include <cub/cub.cuh>
namespace op::topk::nvidia {
struct Descriptor::Opaque {
std::shared_ptr<device::nvidia::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t values_output_desc,
infiniopTensorDescriptor_t indices_output_desc,
infiniopTensorDescriptor_t input_desc,
size_t k,
size_t dim,
bool largest,
bool sorted) {
auto result = TopKInfo::create(values_output_desc, indices_output_desc, input_desc, k, dim, largest, sorted);
CHECK_RESULT(result);
auto info = result.take();
size_t workspace_size = 0;
workspace_size += (input_desc->ndim() + values_output_desc->ndim()) * (sizeof(size_t) + sizeof(ptrdiff_t));
// 计算临时变量空间
size_t dim_elements = input_desc->shape()[dim];
size_t n_iteration = 1;
for (size_t i = 0; i < input_desc->ndim(); i++) {
if (i != dim) {
n_iteration *= input_desc->shape()[i];
}
}
size_t total = n_iteration * dim_elements;
workspace_size += 3 * total * sizeof(uint32_t);
workspace_size += 3 * total * sizeof(int32_t);
workspace_size += n_iteration * k * (sizeof(uint32_t) + sizeof(int32_t));
if (sorted) {
workspace_size += n_iteration * k * (sizeof(uint32_t) + sizeof(int32_t));
}
workspace_size += 5 * n_iteration * sizeof(int32_t);
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::nvidia::Handle *>(handle)->internal()},
info, workspace_size, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
namespace {
template <size_t BLOCK_SIZE, int32_t SORT_ITEMS_PER_THREAD, typename Tdata>
infiniStatus_t launchKernel(
const TopKInfo &info,
Tdata *values_output, int32_t *indices_output, const Tdata *input,
size_t k, size_t dim, bool largest, bool sorted,
cudaStream_t stream, void *workspace, size_t workspace_size) {
if (dim >= info.ndim) {
return INFINI_STATUS_BAD_PARAM;
}
if (k == 0) {
return INFINI_STATUS_SUCCESS;
}
if (k > info.dim_elements) {
return INFINI_STATUS_BAD_PARAM;
}
size_t input_ndim = info.ndim;
size_t output_ndim = input_ndim;
size_t n_iteration = info.n_iteration;
size_t dim_elements = info.dim_elements;
unsigned char *workspace_ptr = reinterpret_cast<unsigned char *>(workspace);
size_t workspace_offset = 0;
size_t *input_shape_cuda = reinterpret_cast<size_t *>(workspace_ptr + workspace_offset);
size_t *output_shape_cuda = input_shape_cuda + input_ndim;
workspace_offset += (input_ndim + output_ndim) * sizeof(size_t);
ptrdiff_t *input_strides_cuda = reinterpret_cast<ptrdiff_t *>(workspace_ptr + workspace_offset);
ptrdiff_t *output_strides_cuda = input_strides_cuda + input_ndim;
workspace_offset += (input_ndim + output_ndim) * sizeof(ptrdiff_t);
CHECK_CUDA(cudaMemcpyAsync(input_shape_cuda, info.input_shape.data(), input_ndim * sizeof(size_t), cudaMemcpyHostToDevice, stream));
CHECK_CUDA(cudaMemcpyAsync(output_shape_cuda, info.output_shape.data(), output_ndim * sizeof(size_t), cudaMemcpyHostToDevice, stream));
CHECK_CUDA(cudaMemcpyAsync(input_strides_cuda, info.input_strides.data(), input_ndim * sizeof(ptrdiff_t), cudaMemcpyHostToDevice, stream));
CHECK_CUDA(cudaMemcpyAsync(output_strides_cuda, info.output_strides.data(), output_ndim * sizeof(ptrdiff_t), cudaMemcpyHostToDevice, stream));
const int32_t total = n_iteration * dim_elements;
uint32_t *cur_vals = reinterpret_cast<uint32_t *>(workspace_ptr + workspace_offset);
workspace_offset += total * sizeof(uint32_t);
uint32_t *ones_vals = reinterpret_cast<uint32_t *>(workspace_ptr + workspace_offset);
workspace_offset += total * sizeof(uint32_t);
uint32_t *zeros_vals = reinterpret_cast<uint32_t *>(workspace_ptr + workspace_offset);
workspace_offset += total * sizeof(uint32_t);
int32_t *cur_idx = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += total * sizeof(int32_t);
int32_t *ones_idx = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += total * sizeof(int32_t);
int32_t *zeros_idx = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += total * sizeof(int32_t);
uint32_t *sel_vals = reinterpret_cast<uint32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * k * sizeof(uint32_t);
int32_t *sel_idx = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * k * sizeof(int32_t);
uint32_t *sel_sorted_vals = nullptr;
int32_t *sel_sorted_idx = nullptr;
if (sorted) {
sel_sorted_vals = reinterpret_cast<uint32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * k * sizeof(uint32_t);
sel_sorted_idx = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * k * sizeof(int32_t);
}
int32_t *cur_n = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * sizeof(int32_t);
int32_t *rem_k = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * sizeof(int32_t);
int32_t *out_pos = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * sizeof(int32_t);
int32_t *ones_count = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * sizeof(int32_t);
int32_t *zeros_count = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * sizeof(int32_t);
// init
{
size_t threads = 256;
size_t blocks = (n_iteration + threads - 1) / threads;
op::topk::cuda::init_row_state<<<blocks, threads, 0, stream>>>(cur_n, rem_k, out_pos, n_iteration, dim_elements, k);
}
// gather input -> cur
{
dim3 block(BLOCK_SIZE);
dim3 grid((dim_elements + BLOCK_SIZE - 1) / BLOCK_SIZE, n_iteration);
op::topk::cuda::gather_rowwise<Tdata><<<grid, block, 0, stream>>>(
input, cur_vals, cur_idx,
n_iteration, dim_elements,
input_ndim, dim,
input_shape_cuda, input_strides_cuda);
}
// radix select/filter
for (int bit = 31; bit >= 0; --bit) {
{
size_t threads = 256;
size_t blocks = (n_iteration + threads - 1) / threads;
op::topk::cuda::zero_row_counters<<<blocks, threads, 0, stream>>>(ones_count, zeros_count, n_iteration);
}
{
dim3 block(BLOCK_SIZE);
dim3 grid((dim_elements + BLOCK_SIZE - 1) / BLOCK_SIZE, n_iteration);
op::topk::cuda::partition_rowwise<BLOCK_SIZE><<<grid, block, 0, stream>>>(
cur_vals, cur_idx,
ones_vals, ones_idx,
zeros_vals, zeros_idx,
cur_n, n_iteration, dim_elements,
bit, largest,
ones_count, zeros_count);
}
{
op::topk::cuda::decide_and_compact<BLOCK_SIZE><<<n_iteration, BLOCK_SIZE, 0, stream>>>(
cur_vals, cur_idx,
ones_vals, ones_idx,
zeros_vals, zeros_idx,
ones_count, zeros_count,
cur_n, rem_k, out_pos,
sel_vals, sel_idx,
n_iteration, dim_elements, k);
}
}
// append remaining
op::topk::cuda::take_remaining<BLOCK_SIZE><<<n_iteration, BLOCK_SIZE, 0, stream>>>(
cur_vals, cur_idx,
cur_n, rem_k, out_pos,
sel_vals, sel_idx,
n_iteration, dim_elements, k);
// sort (CUB block radix sort)
const int32_t *final_idx = sel_idx;
if (sorted) {
std::vector<int> h_offsets(n_iteration + 1);
for (size_t i = 0; i <= n_iteration; i++) {
h_offsets[i] = i * k;
}
int *d_offsets;
CHECK_CUDA(cudaMalloc(&d_offsets, (n_iteration + 1) * sizeof(int)));
CHECK_CUDA(cudaMemcpy(d_offsets, h_offsets.data(), (n_iteration + 1) * sizeof(int), cudaMemcpyHostToDevice));
void *d_temp_storage = nullptr;
size_t temp_storage_bytes = 0;
if (!largest) {
cub::DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, sel_vals, sel_sorted_vals, sel_idx, sel_sorted_idx,
n_iteration * k, n_iteration, d_offsets, d_offsets + 1, 0, sizeof(uint32_t) * 8, stream);
cudaMalloc(&d_temp_storage, temp_storage_bytes);
cub::DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, sel_vals, sel_sorted_vals, sel_idx, sel_sorted_idx,
n_iteration * k, n_iteration, d_offsets, d_offsets + 1, 0, sizeof(uint32_t) * 8, stream);
} else {
cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, sel_vals, sel_sorted_vals, sel_idx, sel_sorted_idx,
n_iteration * k, n_iteration, d_offsets, d_offsets + 1, 0, sizeof(uint32_t) * 8, stream);
cudaMalloc(&d_temp_storage, temp_storage_bytes);
cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, sel_vals, sel_sorted_vals, sel_idx, sel_sorted_idx,
n_iteration * k, n_iteration, d_offsets, d_offsets + 1, 0, sizeof(uint32_t) * 8, stream);
}
CHECK_CUDA(cudaFree(d_offsets));
CHECK_CUDA(cudaFree(d_temp_storage));
final_idx = sel_sorted_idx;
}
// scatter to output (strided write)
{
dim3 block(BLOCK_SIZE);
dim3 grid((k + BLOCK_SIZE - 1) / BLOCK_SIZE, n_iteration);
op::topk::cuda::scatter_to_output<Tdata><<<grid, block, 0, stream>>>(
input, final_idx,
values_output, indices_output,
n_iteration, k,
input_ndim, dim,
input_shape_cuda, input_strides_cuda,
output_shape_cuda, output_strides_cuda);
}
CHECK_CUDA(cudaGetLastError());
return INFINI_STATUS_SUCCESS;
}
} // namespace
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *values_output,
void *indices_output,
const void *input,
size_t k,
size_t dim,
bool largest,
bool sorted,
void *stream_) const {
cudaStream_t stream = (cudaStream_t)stream_;
constexpr int ITEMS = 4;
#define CALCULATE_TOPK(BLOCK_SIZE, Tdata) \
launchKernel<BLOCK_SIZE, ITEMS, Tdata>( \
_info, \
(Tdata *)values_output, (int32_t *)indices_output, (const Tdata *)input, \
k, dim, largest, sorted, \
stream, workspace, workspace_size)
#define CALCULATE_TOPK_WITH_BLOCK_SIZE(BLOCK_SIZE) \
{ \
if (_info.dtype == INFINI_DTYPE_BF16) \
return CALCULATE_TOPK(BLOCK_SIZE, __nv_bfloat16); \
else if (_info.dtype == INFINI_DTYPE_F16) \
return CALCULATE_TOPK(BLOCK_SIZE, half); \
else if (_info.dtype == INFINI_DTYPE_F32) \
return CALCULATE_TOPK(BLOCK_SIZE, float); \
else \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
if (_opaque->internal->maxThreadsPerBlock() >= 256) {
CALCULATE_TOPK_WITH_BLOCK_SIZE(256)
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::topk::nvidia
#ifndef __TOPK_NVIDIA_H__
#define __TOPK_NVIDIA_H__
#include "../topk_desc.h"
DESCRIPTOR(nvidia);
#endif // __TOPK_NVIDIA_H__
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/topk.h"
#include <vector>
#ifdef ENABLE_CPU_API
#include "cpu/topk_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#include "nvidia/topk_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
#include "metax/topk_metax.h"
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/topk_kunlun.h"
#endif
#ifdef ENABLE_MOORE_API
#include "moore/topk_moore.h"
#endif
__INFINI_C infiniStatus_t infiniopCreateTopKDescriptor(
infiniopHandle_t handle,
infiniopTopKDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t values_output_desc,
infiniopTensorDescriptor_t indices_output_desc,
infiniopTensorDescriptor_t input_desc,
size_t k,
size_t dim,
bool largest,
bool sorted) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::topk::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::topk::NAMESPACE::Descriptor **>(desc_ptr), \
values_output_desc, \
indices_output_desc, \
input_desc, \
k, \
dim, \
largest, \
sorted)
switch (handle->device) {
#ifdef ENABLE_CPU_API
CREATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, moore);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CREATE
}
__INFINI_C infiniStatus_t infiniopGetTopKWorkspaceSize(infiniopTopKDescriptor_t desc, size_t *size) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::topk::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
GET(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, moore);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef GET
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__INFINI_C infiniStatus_t infiniopTopK(
infiniopTopKDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *values_output,
void *indices_output,
const void *input,
size_t k,
size_t dim,
bool largest,
bool sorted,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::topk::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, values_output, indices_output, input, k, dim, largest, sorted, stream)
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
CALCULATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, moore);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CALCULATE
}
__INFINI_C infiniStatus_t
infiniopDestroyTopKDescriptor(infiniopTopKDescriptor_t desc) {
#define DELETE(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<const op::topk::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
DELETE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
DELETE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
DELETE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
DELETE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
DELETE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_KUNLUN_API
DELETE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_MOORE_API
DELETE(INFINI_DEVICE_MOORE, moore);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef DELETE
}
#ifndef INFINIOP_TOPK_DESCRIPTOR_H_
#define INFINIOP_TOPK_DESCRIPTOR_H_
#include "../../../utils.h"
#include "../../operator.h"
#include "../../tensor.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::topk::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
TopKInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
TopKInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t values_output_desc, \
infiniopTensorDescriptor_t indices_output_desc, \
infiniopTensorDescriptor_t input_desc, \
size_t k, \
size_t dim, \
bool largest, \
bool sorted); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *values_output, \
void *indices_output, \
const void *input, \
size_t k, \
size_t dim, \
bool largest, \
bool sorted, \
void *stream) const; \
}; \
}
#endif
#include "var_cpu.h"
#include "../../../../utils.h"
#include "../../../devices/cpu/common_cpu.h"
namespace op::var::cpu {
Descriptor::~Descriptor() {}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t var_output_desc,
infiniopTensorDescriptor_t input_desc,
size_t *dim,
size_t dim_size,
bool unbiased,
bool keepdim) {
auto result = VarInfo::create(var_output_desc, input_desc, dim, dim_size, unbiased, keepdim);
CHECK_RESULT(result);
*desc_ptr = new Descriptor(nullptr, result.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
// welford
namespace {
bool IsNanOut(const VarInfo &info) {
return (info.reduce_num == 0) || (info.reduce_num == 1 && info.unbiased_var == true);
}
// 直接用float计算
template <typename Tdata>
void computeVarUsingWelfordCpu(const Tdata *input_ptr, float &var_output, size_t start, size_t end, const VarInfo &info) {
if (start >= end) {
return;
}
float old_mean = 0.0f; // previous mean
float mean = 0.0f; // new mean
float M2 = 0.0f; // variance sum
size_t count = 0; // element count of new sum
for (size_t idx = start; idx < end; ++idx) {
size_t input_offset = op::common_cpu::indexToOffset(idx, info.permuted_input_shape.size(), info.permuted_input_shape.data(), info.permuted_input_strides.data());
;
float value = utils::cast<float>(input_ptr[input_offset]);
count++;
old_mean = mean;
mean += (value - mean) / count;
M2 += (value - old_mean) * (value - mean);
}
var_output = M2 / (info.unbiased_var ? (count - 1) : count);
}
template <typename Tdata>
infiniStatus_t calculateVar(
const VarInfo &info,
Tdata *var_output,
const Tdata *input) {
Tdata nan_value = utils::cast<Tdata>(NAN);
bool is_scalar = (info.reduce_dim_size == info.permuted_input_shape.size());
for (size_t i = 0; i < info.output_size; ++i) {
size_t output_offset = op::common_cpu::indexToOffset(i, info.output_shape.size(), info.output_shape.data(), info.output_strides.data());
if (IsNanOut(info)) {
var_output[output_offset] = nan_value;
} else {
size_t start = is_scalar ? 0 : i * info.reduce_num;
size_t end = is_scalar ? info.input_size : (i + 1) * info.reduce_num;
float var = 0.0f;
computeVarUsingWelfordCpu(input, var, start, end, info);
var_output[output_offset] = utils::cast<Tdata>(var);
}
}
return INFINI_STATUS_SUCCESS;
}
} // namespace
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *var_output,
const void *input,
bool unbiased,
bool keepdim,
void *stream) const {
switch (_info.dtype) {
case INFINI_DTYPE_F16:
return calculateVar<fp16_t>(_info, (fp16_t *)var_output, reinterpret_cast<const fp16_t *>(input));
case INFINI_DTYPE_F32:
return calculateVar<float>(_info, (float *)var_output, reinterpret_cast<const float *>(input));
case INFINI_DTYPE_BF16:
return calculateVar<bf16_t>(_info, (bf16_t *)var_output, reinterpret_cast<const bf16_t *>(input));
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::var::cpu
#ifndef __INFINIOP_VAR_CPU_H__
#define __INFINIOP_VAR_CPU_H__
#include "../var_desc.h"
DESCRIPTOR(cpu);
#endif // __INFINIOP_VAR_CPU_H__
This diff is collapsed.
#ifndef __VAR_INFO_H__
#define __VAR_INFO_H__
#include "../../../utils.h"
#include "../../tensor.h"
#include <algorithm>
#include <cstddef>
#include <vector>
namespace op::var {
class VarInfo {
VarInfo() = default;
public:
infiniDtype_t dtype;
std::vector<size_t> permuted_input_shape; // need to permute
std::vector<size_t> output_shape;
std::vector<ptrdiff_t> permuted_input_strides; // need to permute
std::vector<ptrdiff_t> output_strides;
size_t reduce_dim_size; // reduce dim size
size_t reduce_num; // number of elements to reduce for each output element
size_t input_size; // total number of input elements
size_t output_size; // total number of output elements
bool unbiased_var;
static utils::Result<VarInfo> create(
infiniopTensorDescriptor_t var_output_desc,
infiniopTensorDescriptor_t input_desc,
size_t *dim,
size_t dim_size,
bool unbiased,
bool keepdim) {
auto input_shape = input_desc->shape();
auto input_strides = input_desc->strides();
size_t input_ndim = input_desc->ndim();
size_t reduce_num = 1;
for (size_t i = 0; i < dim_size; i++) {
reduce_num *= input_shape[dim[i]];
}
std::vector<size_t> permute_order;
for (size_t i = 0; i < input_ndim; i++) {
if (std::find(dim, dim + dim_size, i) == dim + dim_size) {
permute_order.push_back(i);
}
}
for (size_t i = 0; i < dim_size; i++) {
permute_order.push_back(dim[i]);
}
std::vector<size_t> permuted_input_shape;
std::vector<ptrdiff_t> permuted_input_strides;
for (size_t i = 0; i < permute_order.size(); i++) {
permuted_input_shape.push_back(input_shape[permute_order[i]]);
permuted_input_strides.push_back(input_strides[permute_order[i]]);
}
return utils::Result<VarInfo>(VarInfo{input_desc->dtype(),
permuted_input_shape,
var_output_desc->shape(),
permuted_input_strides,
var_output_desc->strides(),
dim_size,
reduce_num,
input_desc->numel(),
var_output_desc->numel(),
unbiased});
}
};
} // namespace op::var
#endif
#ifndef __VAR_METAX_H__
#define __VAR_METAX_H__
#include "../var_desc.h"
DESCRIPTOR(metax);
#endif // __VAR_METAX_H__
#include "../../../devices/metax/metax_common.h"
#include "../../../devices/metax/metax_kernel_common.h"
#include "../cuda/kernel.cuh"
#include "var_metax.h"
namespace op::var::metax {
struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t var_output_desc,
infiniopTensorDescriptor_t input_desc,
size_t *dim,
size_t dim_size,
bool unbiased,
bool keepdim) {
auto result = VarInfo::create(var_output_desc, input_desc, dim, dim_size, unbiased, keepdim);
CHECK_RESULT(result);
auto info = result.take();
size_t workspace_size = 0;
workspace_size += input_desc->ndim() * (sizeof(size_t) + sizeof(ptrdiff_t)); // permuted_input_shape + permuted_input_strides
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
info, workspace_size, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
namespace {
bool IsNanOut(const VarInfo &info) {
return (info.reduce_num == 0) || (info.reduce_num == 1 && info.unbiased_var == true);
}
template <size_t BLOCK_SIZE, typename Tdata, typename ComputeType>
infiniStatus_t launchKernel(
const VarInfo &info,
Tdata *var_output, const Tdata *input,
bool unbiased, bool keepdim,
hcStream_t stream, void *workspace, size_t workspace_size) {
size_t input_ndim = info.permuted_input_shape.size();
size_t output_ndim = info.output_shape.size();
size_t input_size = info.input_size;
size_t output_size = info.output_size;
size_t reduce_num = info.reduce_num;
unsigned char *workspace_ptr = reinterpret_cast<unsigned char *>(workspace);
size_t workspace_offset = 0;
size_t *permuted_input_shape_hc = reinterpret_cast<size_t *>(workspace_ptr + workspace_offset);
workspace_offset += input_ndim * sizeof(size_t);
ptrdiff_t *permuted_input_strides_hc = reinterpret_cast<ptrdiff_t *>(workspace_ptr + workspace_offset);
workspace_offset += input_ndim * sizeof(ptrdiff_t);
CHECK_METAX(hcMemcpyAsync(permuted_input_shape_hc, info.permuted_input_shape.data(), input_ndim * sizeof(size_t), hcMemcpyHostToDevice, stream));
CHECK_METAX(hcMemcpyAsync(permuted_input_strides_hc, info.permuted_input_strides.data(), input_ndim * sizeof(ptrdiff_t), hcMemcpyHostToDevice, stream));
bool is_nan = IsNanOut(info);
if (info.reduce_num == input_size) { // scalar output
ComputeType *tmp_buffer;
constexpr size_t MAX_GRID_SIZE = 128;
size_t grid_size = std::min(MAX_GRID_SIZE,
(input_size + BLOCK_SIZE - 1) / BLOCK_SIZE);
grid_size = std::max(1UL, grid_size);
CHECK_METAX(hcMalloc(&tmp_buffer, grid_size * 3 * sizeof(ComputeType)));
ComputeVarScalarOut<Tdata, ComputeType><<<grid_size, BLOCK_SIZE, 0, stream>>>(
input, var_output, tmp_buffer, input_size, input_ndim,
permuted_input_shape_hc, permuted_input_strides_hc, unbiased, is_nan);
CHECK_METAX(hcFree(tmp_buffer));
} else {
size_t grid_size = std::min(256UL, (info.output_size + BLOCK_SIZE - 1) / BLOCK_SIZE);
grid_size = std::max(1UL, grid_size);
ComputeVarUsingWelfordWrapper<Tdata, ComputeType><<<grid_size, BLOCK_SIZE, 0, stream>>>(
input, var_output, input_ndim, output_size, reduce_num,
permuted_input_shape_hc, permuted_input_strides_hc, unbiased, is_nan);
}
return INFINI_STATUS_SUCCESS;
}
} // namespace
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *var_output,
const void *input,
bool unbiased,
bool keepdim,
void *stream_) const {
hcStream_t stream = (hcStream_t)stream_;
#define CALCULATE_VAR(BLOCK_SIZE, Tdata, ComputeType) \
launchKernel<BLOCK_SIZE, Tdata, ComputeType>( \
_info, \
(Tdata *)var_output, (const Tdata *)input, \
unbiased, keepdim, \
stream, workspace, workspace_size)
#define CALCULATE_VAR_WITH_BLOCK_SIZE(BLOCK_SIZE) \
{ \
if (_info.dtype == INFINI_DTYPE_BF16) \
return CALCULATE_VAR(BLOCK_SIZE, __hpcc_bfloat16, double); \
else if (_info.dtype == INFINI_DTYPE_F16) \
return CALCULATE_VAR(BLOCK_SIZE, half, double); \
else if (_info.dtype == INFINI_DTYPE_F32) \
return CALCULATE_VAR(BLOCK_SIZE, float, double); \
else \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
if (_opaque->internal->maxThreadsPerBlock() >= 256) {
CALCULATE_VAR_WITH_BLOCK_SIZE(256)
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::var::metax
#ifndef __VAR_MOORE_H__
#define __VAR_MOORE_H__
#include "../var_desc.h"
DESCRIPTOR(moore);
#endif // __VAR_MOORE_H__
This diff is collapsed.
This diff is collapsed.
#ifndef __VAR_NVIDIA_H__
#define __VAR_NVIDIA_H__
#include "../var_desc.h"
DESCRIPTOR(nvidia);
#endif // __VAR_NVIDIA_H__
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment