Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
93191613
Unverified
Commit
93191613
authored
Mar 13, 2026
by
thatPepe
Committed by
GitHub
Mar 13, 2026
Browse files
Merge pull request #1075 from InfiniTensor/RevertT_1-1-4
Revert T1-1-4
parents
6ab911c3
def22a08
Changes
203
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
2217 deletions
+0
-2217
src/infiniop/ops/sum/info.h
src/infiniop/ops/sum/info.h
+0
-64
src/infiniop/ops/sum/metax/sum_metax.h
src/infiniop/ops/sum/metax/sum_metax.h
+0
-8
src/infiniop/ops/sum/metax/sum_metax.maca
src/infiniop/ops/sum/metax/sum_metax.maca
+0
-116
src/infiniop/ops/sum/moore/sum_moore.h
src/infiniop/ops/sum/moore/sum_moore.h
+0
-8
src/infiniop/ops/sum/moore/sum_moore.mu
src/infiniop/ops/sum/moore/sum_moore.mu
+0
-133
src/infiniop/ops/sum/nvidia/sum_nvidia.cu
src/infiniop/ops/sum/nvidia/sum_nvidia.cu
+0
-118
src/infiniop/ops/sum/nvidia/sum_nvidia.cuh
src/infiniop/ops/sum/nvidia/sum_nvidia.cuh
+0
-8
src/infiniop/ops/sum/operator.cc
src/infiniop/ops/sum/operator.cc
+0
-194
src/infiniop/ops/sum/sum_desc.h
src/infiniop/ops/sum/sum_desc.h
+0
-50
src/infiniop/ops/topk/cpu/topk_cpu.cc
src/infiniop/ops/topk/cpu/topk_cpu.cc
+0
-130
src/infiniop/ops/topk/cpu/topk_cpu.h
src/infiniop/ops/topk/cpu/topk_cpu.h
+0
-8
src/infiniop/ops/topk/cuda/kernel.cuh
src/infiniop/ops/topk/cuda/kernel.cuh
+0
-253
src/infiniop/ops/topk/info.h
src/infiniop/ops/topk/info.h
+0
-60
src/infiniop/ops/topk/metax/topk_metax.h
src/infiniop/ops/topk/metax/topk_metax.h
+0
-8
src/infiniop/ops/topk/metax/topk_metax.maca
src/infiniop/ops/topk/metax/topk_metax.maca
+0
-280
src/infiniop/ops/topk/moore/topk_moore.h
src/infiniop/ops/topk/moore/topk_moore.h
+0
-8
src/infiniop/ops/topk/moore/topk_moore.mu
src/infiniop/ops/topk/moore/topk_moore.mu
+0
-280
src/infiniop/ops/topk/nvidia/topk_nvidia.cu
src/infiniop/ops/topk/nvidia/topk_nvidia.cu
+0
-283
src/infiniop/ops/topk/nvidia/topk_nvidia.cuh
src/infiniop/ops/topk/nvidia/topk_nvidia.cuh
+0
-8
src/infiniop/ops/topk/operator.cc
src/infiniop/ops/topk/operator.cc
+0
-200
No files found.
src/infiniop/ops/sum/info.h
deleted
100644 → 0
View file @
6ab911c3
#ifndef __SUM_INFO_H__
#define __SUM_INFO_H__
#include "../../../utils.h"
#include "../../tensor.h"
#include <algorithm>
#include <cstddef>
#include <vector>
namespace
op
::
sum
{
class
SumInfo
{
SumInfo
()
=
default
;
public:
infiniDtype_t
dtype
;
std
::
vector
<
size_t
>
permuted_input_shape
;
// need to permute
std
::
vector
<
size_t
>
output_shape
;
std
::
vector
<
ptrdiff_t
>
permuted_input_strides
;
// need to permute
std
::
vector
<
ptrdiff_t
>
output_strides
;
size_t
reduce_dim_size
;
// reduce dim size
size_t
reduce_num
;
// number of elements to reduce for each output element
size_t
input_size
;
// total number of input elements
size_t
output_size
;
// total number of output elements
static
utils
::
Result
<
SumInfo
>
create
(
infiniopTensorDescriptor_t
output_desc
,
infiniopTensorDescriptor_t
input_desc
,
size_t
*
dim
,
size_t
dim_size
,
bool
keepdim
)
{
auto
input_shape
=
input_desc
->
shape
();
auto
input_strides
=
input_desc
->
strides
();
size_t
input_ndim
=
input_desc
->
ndim
();
size_t
reduce_num
=
1
;
for
(
size_t
i
=
0
;
i
<
dim_size
;
i
++
)
{
reduce_num
*=
input_shape
[
dim
[
i
]];
}
std
::
vector
<
size_t
>
permute_order
;
for
(
size_t
i
=
0
;
i
<
input_ndim
;
i
++
)
{
if
(
std
::
find
(
dim
,
dim
+
dim_size
,
i
)
==
dim
+
dim_size
)
{
permute_order
.
push_back
(
i
);
}
}
for
(
size_t
i
=
0
;
i
<
dim_size
;
i
++
)
{
permute_order
.
push_back
(
dim
[
i
]);
}
std
::
vector
<
size_t
>
permuted_input_shape
;
std
::
vector
<
ptrdiff_t
>
permuted_input_strides
;
for
(
size_t
i
=
0
;
i
<
permute_order
.
size
();
i
++
)
{
permuted_input_shape
.
push_back
(
input_shape
[
permute_order
[
i
]]);
permuted_input_strides
.
push_back
(
input_strides
[
permute_order
[
i
]]);
}
return
utils
::
Result
<
SumInfo
>
(
SumInfo
{
input_desc
->
dtype
(),
permuted_input_shape
,
output_desc
->
shape
(),
permuted_input_strides
,
output_desc
->
strides
(),
dim_size
,
reduce_num
,
input_desc
->
numel
(),
output_desc
->
numel
()});
}
};
}
// namespace op::sum
#endif
src/infiniop/ops/sum/metax/sum_metax.h
deleted
100644 → 0
View file @
6ab911c3
#ifndef __SUM_METAX_H__
#define __SUM_METAX_H__
#include "../sum_desc.h"
DESCRIPTOR
(
metax
);
#endif
src/infiniop/ops/sum/metax/sum_metax.maca
deleted
100644 → 0
View file @
6ab911c3
#include "../../../devices/metax/metax_common.h"
#include "../../../devices/metax/metax_kernel_common.h"
#include "../cuda/kernel.cuh"
#include "sum_metax.h"
namespace op::sum::metax {
struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t output_desc,
infiniopTensorDescriptor_t input_desc,
size_t *dim,
size_t dim_size,
bool keepdim) {
auto result = SumInfo::create(output_desc, input_desc, dim, dim_size, keepdim);
CHECK_RESULT(result);
auto info = result.take();
size_t workspace_size = 0;
workspace_size += (input_desc->ndim() + output_desc->ndim()) * (sizeof(size_t) + sizeof(ptrdiff_t));
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
info, workspace_size, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
namespace {
template <size_t BLOCK_SIZE, typename T>
infiniStatus_t launchKernel(
const SumInfo &info,
T *output, const T *input,
hcStream_t stream, void *workspace, size_t workspace_size) {
size_t input_ndim = info.permuted_input_shape.size();
size_t output_ndim = info.output_shape.size();
size_t input_size = info.input_size;
size_t output_size = info.output_size;
size_t reduce_num = info.reduce_num;
unsigned char *workspace_ptr = reinterpret_cast<unsigned char *>(workspace);
size_t workspace_offset = 0;
size_t *permuted_input_shape_hc = reinterpret_cast<size_t *>(workspace_ptr + workspace_offset);
size_t *output_shape_hc = permuted_input_shape_hc + input_ndim;
workspace_offset += (input_ndim + output_ndim) * sizeof(size_t);
ptrdiff_t *permuted_input_strides_hc = reinterpret_cast<ptrdiff_t *>(workspace_ptr + workspace_offset);
ptrdiff_t *output_strides_hc = permuted_input_strides_hc + input_ndim;
workspace_offset += (input_ndim + output_ndim) * sizeof(ptrdiff_t);
CHECK_METAX(hcMemcpyAsync(permuted_input_shape_hc, info.permuted_input_shape.data(), input_ndim * sizeof(size_t), hcMemcpyHostToDevice, stream));
CHECK_METAX(hcMemcpyAsync(output_shape_hc, info.output_shape.data(), output_ndim * sizeof(size_t), hcMemcpyHostToDevice, stream));
CHECK_METAX(hcMemcpyAsync(output_strides_hc, info.output_strides.data(), output_ndim * sizeof(ptrdiff_t), hcMemcpyHostToDevice, stream));
CHECK_METAX(hcMemcpyAsync(permuted_input_strides_hc, info.permuted_input_strides.data(), input_ndim * sizeof(ptrdiff_t), hcMemcpyHostToDevice, stream));
if (info.reduce_num == input_size) {
T zero = static_cast<T>(0.0f);
CHECK_METAX(hcMemcpyAsync(output, &zero, sizeof(T), hcMemcpyHostToDevice, stream));
size_t grid_size = (input_size + BLOCK_SIZE - 1) / BLOCK_SIZE;
sumAllKernel<BLOCK_SIZE, T, T><<<grid_size, BLOCK_SIZE, BLOCK_SIZE * sizeof(T), stream>>>(
output, input, input_size, input_ndim, permuted_input_shape_hc, permuted_input_strides_hc);
} else {
size_t grid_size = (info.output_size + BLOCK_SIZE - 1) / BLOCK_SIZE;
sumKernel<BLOCK_SIZE, T><<<grid_size, BLOCK_SIZE, 0, stream>>>(
output, input, input_ndim, output_ndim, output_size, reduce_num,
permuted_input_shape_hc, output_shape_hc, permuted_input_strides_hc, output_strides_hc);
}
return INFINI_STATUS_SUCCESS;
}
} // namespace
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
const void *input,
void *stream_) const {
hcStream_t stream = (hcStream_t)stream_;
#define CALCULATE_SUM(BLOCK_SIZE, T) \
launchKernel<BLOCK_SIZE, T>( \
_info, \
(T *)output, (const T *)input, \
stream, workspace, workspace_size)
#define CALCULATE_SUM_WITH_BLOCK_SIZE(BLOCK_SIZE) \
{ \
if (_info.dtype == INFINI_DTYPE_BF16) \
return CALCULATE_SUM(BLOCK_SIZE, __hpcc_bfloat16); \
else if (_info.dtype == INFINI_DTYPE_F16) \
return CALCULATE_SUM(BLOCK_SIZE, half); \
else if (_info.dtype == INFINI_DTYPE_F32) \
return CALCULATE_SUM(BLOCK_SIZE, float); \
else \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) {
CALCULATE_SUM_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_1024)
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_512) {
CALCULATE_SUM_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_512)
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::sum::metax
src/infiniop/ops/sum/moore/sum_moore.h
deleted
100644 → 0
View file @
6ab911c3
#ifndef __SUM_MOORE_H__
#define __SUM_MOORE_H__
#include "../sum_desc.h"
DESCRIPTOR
(
moore
);
#endif
src/infiniop/ops/sum/moore/sum_moore.mu
deleted
100644 → 0
View file @
6ab911c3
#include "../../../devices/moore/moore_common.h"
#include "../../../devices/moore/moore_kernel_common.h"
#include "../cuda/kernel.cuh"
#include "sum_moore.h"
namespace op::sum::moore {
struct Descriptor::Opaque {
std::shared_ptr<device::moore::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t output_desc,
infiniopTensorDescriptor_t input_desc,
size_t *dim,
size_t dim_size,
bool keepdim) {
auto result = SumInfo::create(output_desc, input_desc, dim, dim_size, keepdim);
CHECK_RESULT(result);
auto info = result.take();
size_t workspace_size = 0;
workspace_size += (input_desc->ndim() + output_desc->ndim()) * (sizeof(size_t) + sizeof(ptrdiff_t));
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::moore::Handle *>(handle)->internal()},
info, workspace_size, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
namespace {
template <size_t BLOCK_SIZE, typename T>
infiniStatus_t launchKernel(
const SumInfo &info,
T *output, const T *input,
musaStream_t stream, void *workspace, size_t workspace_size) {
size_t input_ndim = info.permuted_input_shape.size();
size_t output_ndim = info.output_shape.size();
size_t input_size = info.input_size;
size_t output_size = info.output_size;
size_t reduce_num = info.reduce_num;
unsigned char *workspace_ptr = reinterpret_cast<unsigned char *>(workspace);
size_t workspace_offset = 0;
size_t *permuted_input_shape_musa = reinterpret_cast<size_t *>(workspace_ptr + workspace_offset);
size_t *output_shape_musa = permuted_input_shape_musa + input_ndim;
workspace_offset += (input_ndim + output_ndim) * sizeof(size_t);
ptrdiff_t *permuted_input_strides_musa = reinterpret_cast<ptrdiff_t *>(workspace_ptr + workspace_offset);
ptrdiff_t *output_strides_musa = permuted_input_strides_musa + input_ndim;
workspace_offset += (input_ndim + output_ndim) * sizeof(ptrdiff_t);
CHECK_MOORE(musaMemcpyAsync(permuted_input_shape_musa, info.permuted_input_shape.data(), input_ndim * sizeof(size_t), musaMemcpyHostToDevice, stream));
CHECK_MOORE(musaMemcpyAsync(output_shape_musa, info.output_shape.data(), output_ndim * sizeof(size_t), musaMemcpyHostToDevice, stream));
CHECK_MOORE(musaMemcpyAsync(output_strides_musa, info.output_strides.data(), output_ndim * sizeof(ptrdiff_t), musaMemcpyHostToDevice, stream));
CHECK_MOORE(musaMemcpyAsync(permuted_input_strides_musa, info.permuted_input_strides.data(), input_ndim * sizeof(ptrdiff_t), musaMemcpyHostToDevice, stream));
if (info.reduce_num == input_size) {
if constexpr (std::is_same_v<T, __mt_bfloat16>) {
// 需要解决 moore不支持bf16的atomic add的问题
float zero = 0.0f;
float *tmp_output;
CHECK_MOORE(musaMalloc(&tmp_output, sizeof(float)));
CHECK_MOORE(musaMemcpyAsync(tmp_output, &zero, sizeof(float), musaMemcpyHostToDevice, stream));
size_t grid_size = (input_size + BLOCK_SIZE - 1) / BLOCK_SIZE;
sumAllKernel<BLOCK_SIZE, T, float><<<grid_size, BLOCK_SIZE, BLOCK_SIZE * sizeof(float), stream>>>(
tmp_output, input, input_size, input_ndim, permuted_input_shape_musa, permuted_input_strides_musa);
// 可以自定义 kernel,将 float -> T,这里直接memcpy了
float host_val;
CHECK_MOORE(musaMemcpy(&host_val, tmp_output, sizeof(float), musaMemcpyDeviceToHost));
T out_val = static_cast<T>(host_val);
CHECK_MOORE(musaMemcpyAsync(output, &out_val, sizeof(T), musaMemcpyHostToDevice, stream));
CHECK_MOORE(musaFree(tmp_output));
} else {
T zero = static_cast<T>(0.0f);
CHECK_MOORE(musaMemcpyAsync(output, &zero, sizeof(T), musaMemcpyHostToDevice, stream));
size_t grid_size = (input_size + BLOCK_SIZE - 1) / BLOCK_SIZE;
sumAllKernel<BLOCK_SIZE, T, T><<<grid_size, BLOCK_SIZE, BLOCK_SIZE * sizeof(T), stream>>>(
output, input, input_size, input_ndim, permuted_input_shape_musa, permuted_input_strides_musa);
}
} else {
size_t grid_size = (info.output_size + BLOCK_SIZE - 1) / BLOCK_SIZE;
sumKernel<BLOCK_SIZE, T><<<grid_size, BLOCK_SIZE, 0, stream>>>(
output, input, input_ndim, output_ndim, output_size, reduce_num,
permuted_input_shape_musa, output_shape_musa, permuted_input_strides_musa, output_strides_musa);
}
return INFINI_STATUS_SUCCESS;
}
} // namespace
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
const void *input,
void *stream_) const {
musaStream_t stream = (musaStream_t)stream_;
#define CALCULATE_SUM(BLOCK_SIZE, T) \
launchKernel<BLOCK_SIZE, T>( \
_info, \
(T *)output, (const T *)input, \
stream, workspace, workspace_size)
#define CALCULATE_SUM_WITH_BLOCK_SIZE(BLOCK_SIZE) \
{ \
if (_info.dtype == INFINI_DTYPE_BF16) \
return CALCULATE_SUM(BLOCK_SIZE, __mt_bfloat16); \
else if (_info.dtype == INFINI_DTYPE_F16) \
return CALCULATE_SUM(BLOCK_SIZE, half); \
else if (_info.dtype == INFINI_DTYPE_F32) \
return CALCULATE_SUM(BLOCK_SIZE, float); \
else \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_1024) {
CALCULATE_SUM_WITH_BLOCK_SIZE(MOORE_BLOCK_SIZE_1024)
} else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_512) {
CALCULATE_SUM_WITH_BLOCK_SIZE(MOORE_BLOCK_SIZE_512)
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::sum::moore
src/infiniop/ops/sum/nvidia/sum_nvidia.cu
deleted
100644 → 0
View file @
6ab911c3
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include "../cuda/kernel.cuh"
#include "sum_nvidia.cuh"
namespace
op
::
sum
::
nvidia
{
struct
Descriptor
::
Opaque
{
std
::
shared_ptr
<
device
::
nvidia
::
Handle
::
Internal
>
internal
;
};
Descriptor
::~
Descriptor
()
{
delete
_opaque
;
}
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
output_desc
,
infiniopTensorDescriptor_t
input_desc
,
size_t
*
dim
,
size_t
dim_size
,
bool
keepdim
)
{
auto
result
=
SumInfo
::
create
(
output_desc
,
input_desc
,
dim
,
dim_size
,
keepdim
);
CHECK_RESULT
(
result
);
auto
info
=
result
.
take
();
size_t
workspace_size
=
0
;
workspace_size
+=
(
input_desc
->
ndim
()
+
output_desc
->
ndim
())
*
(
sizeof
(
size_t
)
+
sizeof
(
ptrdiff_t
));
*
desc_ptr
=
new
Descriptor
(
new
Opaque
{
reinterpret_cast
<
device
::
nvidia
::
Handle
*>
(
handle
)
->
internal
()},
info
,
workspace_size
,
handle
->
device
,
handle
->
device_id
);
return
INFINI_STATUS_SUCCESS
;
}
namespace
{
template
<
size_t
BLOCK_SIZE
,
typename
T
>
infiniStatus_t
launchKernel
(
const
SumInfo
&
info
,
T
*
output
,
const
T
*
input
,
cudaStream_t
stream
,
void
*
workspace
,
size_t
workspace_size
)
{
size_t
input_ndim
=
info
.
permuted_input_shape
.
size
();
size_t
output_ndim
=
info
.
output_shape
.
size
();
size_t
input_size
=
info
.
input_size
;
size_t
output_size
=
info
.
output_size
;
size_t
reduce_num
=
info
.
reduce_num
;
unsigned
char
*
workspace_ptr
=
reinterpret_cast
<
unsigned
char
*>
(
workspace
);
size_t
workspace_offset
=
0
;
size_t
*
permuted_input_shape_cuda
=
reinterpret_cast
<
size_t
*>
(
workspace_ptr
+
workspace_offset
);
size_t
*
output_shape_cuda
=
permuted_input_shape_cuda
+
input_ndim
;
workspace_offset
+=
(
input_ndim
+
output_ndim
)
*
sizeof
(
size_t
);
ptrdiff_t
*
permuted_input_strides_cuda
=
reinterpret_cast
<
ptrdiff_t
*>
(
workspace_ptr
+
workspace_offset
);
ptrdiff_t
*
output_strides_cuda
=
permuted_input_strides_cuda
+
input_ndim
;
workspace_offset
+=
(
input_ndim
+
output_ndim
)
*
sizeof
(
ptrdiff_t
);
CHECK_CUDA
(
cudaMemcpyAsync
(
permuted_input_shape_cuda
,
info
.
permuted_input_shape
.
data
(),
input_ndim
*
sizeof
(
size_t
),
cudaMemcpyHostToDevice
,
stream
));
CHECK_CUDA
(
cudaMemcpyAsync
(
output_shape_cuda
,
info
.
output_shape
.
data
(),
output_ndim
*
sizeof
(
size_t
),
cudaMemcpyHostToDevice
,
stream
));
CHECK_CUDA
(
cudaMemcpyAsync
(
permuted_input_strides_cuda
,
info
.
permuted_input_strides
.
data
(),
input_ndim
*
sizeof
(
ptrdiff_t
),
cudaMemcpyHostToDevice
,
stream
));
CHECK_CUDA
(
cudaMemcpyAsync
(
output_strides_cuda
,
info
.
output_strides
.
data
(),
output_ndim
*
sizeof
(
ptrdiff_t
),
cudaMemcpyHostToDevice
,
stream
));
if
(
info
.
reduce_num
==
input_size
)
{
T
zero
=
static_cast
<
T
>
(
0.0
f
);
CHECK_CUDA
(
cudaMemcpyAsync
(
output
,
&
zero
,
sizeof
(
T
),
cudaMemcpyHostToDevice
,
stream
));
size_t
grid_size
=
(
input_size
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
;
sumAllKernel
<
BLOCK_SIZE
,
T
,
T
><<<
grid_size
,
BLOCK_SIZE
,
BLOCK_SIZE
*
sizeof
(
T
),
stream
>>>
(
output
,
input
,
input_size
,
input_ndim
,
permuted_input_shape_cuda
,
permuted_input_strides_cuda
);
}
else
{
size_t
grid_size
=
(
info
.
output_size
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
;
sumKernel
<
BLOCK_SIZE
,
T
><<<
grid_size
,
BLOCK_SIZE
,
0
,
stream
>>>
(
output
,
input
,
input_ndim
,
output_ndim
,
output_size
,
reduce_num
,
permuted_input_shape_cuda
,
output_shape_cuda
,
permuted_input_strides_cuda
,
output_strides_cuda
);
}
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
output
,
const
void
*
input
,
void
*
stream_
)
const
{
cudaStream_t
stream
=
(
cudaStream_t
)
stream_
;
#define CALCULATE_SUM(BLOCK_SIZE, T) \
launchKernel<BLOCK_SIZE, T>( \
_info, \
(T *)output, (const T *)input, \
stream, workspace, workspace_size)
#define CALCULATE_SUM_WITH_BLOCK_SIZE(BLOCK_SIZE) \
{ \
if (_info.dtype == INFINI_DTYPE_BF16) \
return CALCULATE_SUM(BLOCK_SIZE, __nv_bfloat16); \
else if (_info.dtype == INFINI_DTYPE_F16) \
return CALCULATE_SUM(BLOCK_SIZE, half); \
else if (_info.dtype == INFINI_DTYPE_F32) \
return CALCULATE_SUM(BLOCK_SIZE, float); \
else \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_1024
)
{
CALCULATE_SUM_WITH_BLOCK_SIZE
(
CUDA_BLOCK_SIZE_1024
)
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_512
)
{
CALCULATE_SUM_WITH_BLOCK_SIZE
(
CUDA_BLOCK_SIZE_512
)
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_4096
)
{
CALCULATE_SUM_WITH_BLOCK_SIZE
(
CUDA_BLOCK_SIZE_4096
)
}
else
{
return
INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED
;
}
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::sum::nvidia
src/infiniop/ops/sum/nvidia/sum_nvidia.cuh
deleted
100644 → 0
View file @
6ab911c3
#ifndef __SUM_NVIDIA_H__
#define __SUM_NVIDIA_H__
#include "../sum_desc.h"
DESCRIPTOR
(
nvidia
);
#endif // __SUM_CUDA_API_H__
src/infiniop/ops/sum/operator.cc
deleted
100644 → 0
View file @
6ab911c3
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/sum.h"
#include <vector>
#ifdef ENABLE_CPU_API
#include "cpu/sum_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#include "nvidia/sum_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
#include "metax/sum_metax.h"
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/sum_kunlun.h"
#endif
#ifdef ENABLE_MOORE_API
#include "moore/sum_moore.h"
#endif
__INFINI_C
infiniStatus_t
infiniopCreateSumDescriptor
(
infiniopHandle_t
handle
,
infiniopSumDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
output_desc
,
infiniopTensorDescriptor_t
input_desc
,
size_t
*
dim
,
size_t
dim_size
,
bool
keepdim
)
{
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::sum::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::sum::NAMESPACE::Descriptor **>(desc_ptr), \
output_desc, \
input_desc, \
dim, \
dim_size, \
keepdim)
switch
(
handle
->
device
)
{
#ifdef ENABLE_CPU_API
CREATE
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_NVIDIA_API
CREATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
CREATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
CREATE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_METAX_API
CREATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_KUNLUN_API
CREATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
#ifdef ENABLE_MOORE_API
CREATE
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CREATE
}
__INFINI_C
infiniStatus_t
infiniopGetSumWorkspaceSize
(
infiniopSumDescriptor_t
desc
,
size_t
*
size
)
{
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::sum::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU_API
GET
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_NVIDIA_API
GET
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
GET
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
GET
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_METAX_API
GET
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_KUNLUN_API
GET
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
#ifdef ENABLE_MOORE_API
GET
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef GET
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__INFINI_C
infiniStatus_t
infiniopSum
(
infiniopSumDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
output
,
const
void
*
input
,
size_t
*
dim
,
size_t
dim_size
,
bool
keepdim
,
void
*
stream
)
{
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::sum::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, output, input, stream)
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU_API
CALCULATE
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_NVIDIA_API
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
CALCULATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
CALCULATE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_METAX_API
CALCULATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_KUNLUN_API
CALCULATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
#ifdef ENABLE_MOORE_API
CALCULATE
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CALCULATE
}
__INFINI_C
infiniStatus_t
infiniopDestroySumDescriptor
(
infiniopSumDescriptor_t
desc
)
{
#define DELETE(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<const op::sum::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU_API
DELETE
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_NVIDIA_API
DELETE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
DELETE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
DELETE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_METAX_API
DELETE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_KUNLUN_API
DELETE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
#ifdef ENABLE_MOORE_API
DELETE
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef DELETE
}
src/infiniop/ops/sum/sum_desc.h
deleted
100644 → 0
View file @
6ab911c3
#ifndef INFINIOP_SUM_DESCRIPTOR_H_
#define INFINIOP_SUM_DESCRIPTOR_H_
#include "../../../utils.h"
#include "../../operator.h"
#include "../../tensor.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::sum::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
SumInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
SumInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t output_desc, \
infiniopTensorDescriptor_t input_desc, \
size_t *dim, \
size_t dim_size, \
bool keepdim); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *output, \
const void *input, \
void *stream) const; \
}; \
}
#endif
src/infiniop/ops/topk/cpu/topk_cpu.cc
deleted
100644 → 0
View file @
6ab911c3
#include "topk_cpu.h"
#include "../../../../utils.h"
#include "../../../devices/cpu/common_cpu.h"
#include <algorithm>
#include <vector>
namespace
op
::
topk
::
cpu
{
Descriptor
::~
Descriptor
()
{}
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
values_output_desc
,
infiniopTensorDescriptor_t
indices_output_desc
,
infiniopTensorDescriptor_t
input_desc
,
size_t
k
,
size_t
dim
,
bool
largest
,
bool
sorted
)
{
auto
result
=
TopKInfo
::
create
(
values_output_desc
,
indices_output_desc
,
input_desc
,
k
,
dim
,
largest
,
sorted
);
CHECK_RESULT
(
result
);
*
desc_ptr
=
new
Descriptor
(
nullptr
,
result
.
take
(),
0
,
handle
->
device
,
handle
->
device_id
);
return
INFINI_STATUS_SUCCESS
;
}
namespace
{
template
<
typename
Tdata
>
infiniStatus_t
calculateTopK
(
const
TopKInfo
&
info
,
Tdata
*
values_output
,
int32_t
*
indices_output
,
const
Tdata
*
input
,
size_t
k
,
size_t
dim
,
bool
largest
,
bool
sorted
)
{
if
(
k
==
0
)
{
return
INFINI_STATUS_SUCCESS
;
}
for
(
size_t
i
=
0
;
i
<
info
.
n_iteration
;
i
++
)
{
size_t
index
=
i
;
size_t
input_start
=
0
;
size_t
output_start
=
0
;
for
(
size_t
j
=
info
.
ndim
-
1
;
j
>=
0
;
j
--
)
{
if
(
j
==
dim
)
{
continue
;
}
input_start
+=
(
index
%
info
.
input_shape
[
j
])
*
info
.
input_strides
[
j
];
output_start
+=
(
index
%
info
.
output_shape
[
j
])
*
info
.
output_strides
[
j
];
index
/=
info
.
input_shape
[
j
];
}
using
elem_t
=
std
::
pair
<
Tdata
,
size_t
>
;
std
::
vector
<
elem_t
>
vi_queue
(
info
.
dim_elements
);
for
(
size_t
j
=
0
;
j
<
info
.
dim_elements
;
j
++
)
{
vi_queue
[
j
].
first
=
input
[
input_start
+
j
*
info
.
input_strides
[
dim
]];
vi_queue
[
j
].
second
=
j
;
}
bool
use_partial_sort
=
static_cast
<
size_t
>
(
k
)
*
64
<=
info
.
dim_elements
;
if
(
use_partial_sort
)
{
if
(
largest
)
{
std
::
partial_sort
(
vi_queue
.
begin
(),
vi_queue
.
begin
()
+
k
,
vi_queue
.
end
(),
[](
const
elem_t
&
a
,
const
elem_t
&
b
)
->
bool
{
return
utils
::
cast
<
float
>
(
a
.
first
)
>
utils
::
cast
<
float
>
(
b
.
first
);
});
}
else
{
std
::
partial_sort
(
vi_queue
.
begin
(),
vi_queue
.
begin
()
+
k
,
vi_queue
.
end
(),
[](
const
elem_t
&
a
,
const
elem_t
&
b
)
->
bool
{
return
utils
::
cast
<
float
>
(
a
.
first
)
<
utils
::
cast
<
float
>
(
b
.
first
);
});
}
}
else
{
if
(
largest
)
{
std
::
nth_element
(
vi_queue
.
begin
(),
vi_queue
.
begin
()
+
k
-
1
,
vi_queue
.
end
(),
[](
const
elem_t
&
a
,
const
elem_t
&
b
)
->
bool
{
return
utils
::
cast
<
float
>
(
a
.
first
)
>
utils
::
cast
<
float
>
(
b
.
first
);
});
if
(
sorted
)
{
std
::
sort
(
vi_queue
.
begin
(),
vi_queue
.
begin
()
+
k
,
// 注意:PyTorch 这里是 k,不是 k-1
[](
const
elem_t
&
a
,
const
elem_t
&
b
)
->
bool
{
return
utils
::
cast
<
float
>
(
a
.
first
)
>
utils
::
cast
<
float
>
(
b
.
first
);
});
}
}
else
{
std
::
nth_element
(
vi_queue
.
begin
(),
vi_queue
.
begin
()
+
k
-
1
,
vi_queue
.
end
(),
[](
const
elem_t
&
a
,
const
elem_t
&
b
)
->
bool
{
return
utils
::
cast
<
float
>
(
a
.
first
)
<
utils
::
cast
<
float
>
(
b
.
first
);
});
if
(
sorted
)
{
std
::
sort
(
vi_queue
.
begin
(),
vi_queue
.
begin
()
+
k
,
// 注意:PyTorch 这里是 k,不是 k-1
[](
const
elem_t
&
a
,
const
elem_t
&
b
)
->
bool
{
return
utils
::
cast
<
float
>
(
a
.
first
)
<
utils
::
cast
<
float
>
(
b
.
first
);
});
}
}
}
for
(
size_t
j
=
0
;
j
<
k
;
j
++
)
{
values_output
[
output_start
+
j
*
info
.
output_strides
[
dim
]]
=
vi_queue
[
j
].
first
;
indices_output
[
output_start
+
j
*
info
.
output_strides
[
dim
]]
=
(
int32_t
)
vi_queue
[
j
].
second
;
}
}
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
values_output
,
void
*
indices_output
,
const
void
*
input
,
size_t
k
,
size_t
dim
,
bool
largest
,
bool
sorted
,
void
*
stream
)
const
{
switch
(
_info
.
dtype
)
{
case
INFINI_DTYPE_F16
:
return
calculateTopK
<
fp16_t
>
(
_info
,
(
fp16_t
*
)
values_output
,
(
int32_t
*
)
indices_output
,
reinterpret_cast
<
const
fp16_t
*>
(
input
),
k
,
dim
,
largest
,
sorted
);
case
INFINI_DTYPE_F32
:
return
calculateTopK
<
float
>
(
_info
,
(
float
*
)
values_output
,
(
int32_t
*
)
indices_output
,
reinterpret_cast
<
const
float
*>
(
input
),
k
,
dim
,
largest
,
sorted
);
case
INFINI_DTYPE_BF16
:
return
calculateTopK
<
bf16_t
>
(
_info
,
(
bf16_t
*
)
values_output
,
(
int32_t
*
)
indices_output
,
reinterpret_cast
<
const
bf16_t
*>
(
input
),
k
,
dim
,
largest
,
sorted
);
default:
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::topk::cpu
src/infiniop/ops/topk/cpu/topk_cpu.h
deleted
100644 → 0
View file @
6ab911c3
#ifndef __INFINIOP_TOPK_CPU_H__
#define __INFINIOP_TOPK_CPU_H__
#include "../topk_desc.h"
DESCRIPTOR
(
cpu
);
#endif // __INFINIOP_TOPK_CPU_H__
src/infiniop/ops/topk/cuda/kernel.cuh
deleted
100644 → 0
View file @
6ab911c3
#ifndef __TOPK_CUDA_KERNEL_CUH__
#define __TOPK_CUDA_KERNEL_CUH__
#include <cmath> // NAN
#include <cub/block/block_radix_sort.cuh>
#include <stdint.h>
namespace
op
::
topk
::
cuda
{
__forceinline__
__device__
__host__
size_t
baseOffsetExcludingDim
(
size_t
flat_row
,
size_t
ndim
,
const
size_t
*
shape
,
const
ptrdiff_t
*
strides
,
size_t
dim
)
{
size_t
res
=
0
;
for
(
size_t
i
=
ndim
;
i
--
>
0
;)
{
if
(
i
==
dim
)
{
continue
;
}
res
+=
(
flat_row
%
shape
[
i
])
*
strides
[
i
];
flat_row
/=
shape
[
i
];
}
return
res
;
}
__forceinline__
__device__
__host__
size_t
indexToOffset
(
size_t
flat_index
,
size_t
ndim
,
const
size_t
*
shape
,
const
ptrdiff_t
*
strides
)
{
size_t
res
=
0
;
for
(
size_t
i
=
ndim
;
i
--
>
0
;)
{
res
+=
(
flat_index
%
shape
[
i
])
*
strides
[
i
];
flat_index
/=
shape
[
i
];
}
return
res
;
}
template
<
typename
Tdata
>
__device__
__forceinline__
float
to_float
(
Tdata
v
);
template
<
>
__device__
__forceinline__
float
to_float
<
float
>
(
float
v
)
{
return
v
;
}
template
<
>
__device__
__forceinline__
float
to_float
<
half
>
(
half
v
)
{
return
__half2float
(
v
);
}
#if defined(ENABLE_MOORE_API)
using
bf16_t
=
__mt_bfloat16
;
#elif defined(ENABLE_METAX_API)
using
bf16_t
=
__hpcc_bfloat16
;
#else
// CUDA / NVIDIA / ILUVATAR
using
bf16_t
=
__nv_bfloat16
;
#endif
template
<
>
__device__
__forceinline__
float
to_float
<
bf16_t
>
(
bf16_t
v
)
{
return
__bfloat162float
(
v
);
}
// float -> ordered uint32
__device__
__forceinline__
uint32_t
float_to_uint_ordered
(
float
value
)
{
uint32_t
bits
=
*
reinterpret_cast
<
uint32_t
*>
(
&
value
);
uint32_t
mask
=
(
uint32_t
)(
-
((
int32_t
)
bits
>>
31
))
|
0x80000000u
;
return
bits
^
mask
;
}
template
<
typename
Tdata
>
__global__
void
gather_rowwise
(
const
Tdata
*
input
,
uint32_t
*
cur_vals
,
int32_t
*
cur_idx
,
size_t
rows
,
size_t
n
,
size_t
ndim
,
size_t
dim
,
const
size_t
*
shape
,
const
ptrdiff_t
*
strides
)
{
size_t
row
=
blockIdx
.
y
;
size_t
i
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
row
>=
rows
||
i
>=
n
)
{
return
;
}
size_t
base
=
baseOffsetExcludingDim
(
row
,
ndim
,
shape
,
strides
,
dim
);
size_t
off
=
base
+
i
*
strides
[
dim
];
cur_vals
[
row
*
n
+
i
]
=
float_to_uint_ordered
(
to_float
<
Tdata
>
(
input
[
off
]));
cur_idx
[
row
*
n
+
i
]
=
i
;
}
__global__
void
init_row_state
(
int32_t
*
cur_n
,
int32_t
*
rem_k
,
int32_t
*
out_pos
,
size_t
rows
,
size_t
n
,
size_t
k
)
{
int32_t
r
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
r
<
rows
)
{
cur_n
[
r
]
=
n
;
rem_k
[
r
]
=
k
;
out_pos
[
r
]
=
0
;
}
}
__global__
void
zero_row_counters
(
int32_t
*
ones_count
,
int32_t
*
zeros_count
,
size_t
rows
)
{
int
r
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
r
<
rows
)
{
ones_count
[
r
]
=
0
;
zeros_count
[
r
]
=
0
;
}
}
template
<
size_t
BLOCK_SIZE
>
__global__
void
partition_rowwise
(
const
uint32_t
*
cur_vals
,
int32_t
*
cur_idx
,
uint32_t
*
ones_vals
,
int32_t
*
ones_idx
,
uint32_t
*
zeros_vals
,
int32_t
*
zeros_idx
,
const
int32_t
*
cur_n
,
size_t
rows
,
size_t
n
,
int32_t
bit_pos
,
bool
largest
,
int32_t
*
ones_count
,
int32_t
*
zeros_count
)
{
int32_t
row
=
blockIdx
.
y
;
if
(
row
>=
rows
)
{
return
;
}
__shared__
uint32_t
sh1_vals
[
BLOCK_SIZE
];
__shared__
int32_t
sh1_idx
[
BLOCK_SIZE
];
__shared__
uint32_t
sh0_vals
[
BLOCK_SIZE
];
__shared__
int32_t
sh0_idx
[
BLOCK_SIZE
];
__shared__
int
sh1_n
,
sh0_n
;
__shared__
int32_t
base1
,
base0
;
int32_t
tid
=
threadIdx
.
x
;
if
(
tid
==
0
)
{
sh1_n
=
0
;
sh0_n
=
0
;
}
__syncthreads
();
int32_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
tid
;
int32_t
cn
=
cur_n
[
row
];
if
(
i
<
cn
)
{
int32_t
off
=
row
*
n
+
i
;
int32_t
idx
=
cur_idx
[
off
];
uint32_t
key
=
cur_vals
[
off
];
uint32_t
cmp_key
=
largest
?
key
:
~
key
;
int32_t
b
=
(
cmp_key
>>
bit_pos
)
&
1
;
if
(
b
)
{
int32_t
p
=
atomicAdd
(
&
sh1_n
,
1
);
sh1_vals
[
p
]
=
key
;
sh1_idx
[
p
]
=
idx
;
}
else
{
int32_t
p
=
atomicAdd
(
&
sh0_n
,
1
);
sh0_vals
[
p
]
=
key
;
sh0_idx
[
p
]
=
idx
;
}
}
__syncthreads
();
if
(
tid
==
0
)
{
base1
=
atomicAdd
(
&
ones_count
[
row
],
sh1_n
);
base0
=
atomicAdd
(
&
zeros_count
[
row
],
sh0_n
);
}
__syncthreads
();
for
(
int32_t
j
=
tid
;
j
<
sh1_n
;
j
+=
blockDim
.
x
)
{
int32_t
o
=
row
*
n
+
base1
+
j
;
ones_vals
[
o
]
=
sh1_vals
[
j
];
ones_idx
[
o
]
=
sh1_idx
[
j
];
}
for
(
int32_t
j
=
tid
;
j
<
sh0_n
;
j
+=
blockDim
.
x
)
{
int32_t
o
=
row
*
n
+
base0
+
j
;
zeros_vals
[
o
]
=
sh0_vals
[
j
];
zeros_idx
[
o
]
=
sh0_idx
[
j
];
}
}
template
<
size_t
BLOCK_SIZE
>
__global__
void
decide_and_compact
(
uint32_t
*
cur_vals
,
int32_t
*
cur_idx
,
const
uint32_t
*
ones_vals
,
const
int32_t
*
ones_idx
,
const
uint32_t
*
zeros_vals
,
const
int32_t
*
zeros_idx
,
const
int32_t
*
ones_count
,
const
int32_t
*
zeros_count
,
int32_t
*
cur_n
,
int32_t
*
rem_k
,
int32_t
*
out_pos
,
uint32_t
*
sel_vals
,
int32_t
*
sel_idx
,
size_t
rows
,
size_t
n
,
size_t
k
)
{
int32_t
row
=
blockIdx
.
x
;
if
(
row
>=
rows
)
{
return
;
}
int32_t
tid
=
threadIdx
.
x
;
int32_t
rem
=
rem_k
[
row
];
if
(
rem
<=
0
)
{
return
;
}
int32_t
oc
=
ones_count
[
row
];
int32_t
zc
=
zeros_count
[
row
];
int32_t
pos
=
out_pos
[
row
];
bool
keep_ones
=
(
oc
>=
rem
);
if
(
!
keep_ones
)
{
for
(
int32_t
j
=
tid
;
j
<
oc
;
j
+=
blockDim
.
x
)
{
if
(
pos
+
j
<
k
)
{
int32_t
o
=
row
*
n
+
j
;
sel_vals
[
row
*
k
+
pos
+
j
]
=
ones_vals
[
o
];
sel_idx
[
row
*
k
+
pos
+
j
]
=
ones_idx
[
o
];
}
}
}
__syncthreads
();
if
(
tid
==
0
)
{
if
(
keep_ones
)
{
cur_n
[
row
]
=
oc
;
}
else
{
out_pos
[
row
]
=
pos
+
oc
;
rem_k
[
row
]
=
rem
-
oc
;
cur_n
[
row
]
=
zc
;
}
}
__syncthreads
();
int32_t
new_n
=
cur_n
[
row
];
for
(
int32_t
j
=
tid
;
j
<
new_n
;
j
+=
blockDim
.
x
)
{
int32_t
o
=
row
*
n
+
j
;
cur_vals
[
o
]
=
keep_ones
?
ones_vals
[
o
]
:
zeros_vals
[
o
];
cur_idx
[
o
]
=
keep_ones
?
ones_idx
[
o
]
:
zeros_idx
[
o
];
}
}
template
<
size_t
BLOCK_SIZE
>
__global__
void
take_remaining
(
const
uint32_t
*
cur_vals
,
const
int32_t
*
cur_idx
,
const
int32_t
*
cur_n
,
const
int32_t
*
rem_k
,
const
int32_t
*
out_pos
,
uint32_t
*
sel_vals
,
int32_t
*
sel_idx
,
size_t
rows
,
size_t
n
,
size_t
k
)
{
int32_t
row
=
blockIdx
.
x
;
int32_t
tid
=
threadIdx
.
x
;
if
(
row
>=
rows
)
{
return
;
}
int32_t
rem
=
rem_k
[
row
];
int32_t
pos
=
out_pos
[
row
];
int32_t
cn
=
cur_n
[
row
];
int32_t
take
=
rem
;
if
(
take
>
cn
)
{
take
=
cn
;
}
for
(
int32_t
j
=
tid
;
j
<
take
;
j
+=
blockDim
.
x
)
{
if
(
pos
+
j
<
k
)
{
int32_t
o
=
row
*
k
+
pos
+
j
;
sel_vals
[
o
]
=
cur_vals
[
row
*
n
+
j
];
sel_idx
[
o
]
=
cur_idx
[
row
*
n
+
j
];
}
}
}
template
<
typename
Tdata
>
__global__
void
scatter_to_output
(
const
Tdata
*
input
,
const
int32_t
*
sel_idx
,
Tdata
*
values_out
,
int32_t
*
indices_out
,
size_t
rows
,
size_t
k
,
size_t
ndim
,
size_t
dim
,
const
size_t
*
input_shape
,
const
ptrdiff_t
*
input_strides
,
const
size_t
*
output_shape
,
const
ptrdiff_t
*
output_strides
)
{
int32_t
row
=
blockIdx
.
y
;
int32_t
j
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
row
>=
rows
||
j
>=
k
)
{
return
;
}
int32_t
output_base
=
baseOffsetExcludingDim
(
row
,
ndim
,
output_shape
,
output_strides
,
dim
);
int32_t
output_off
=
output_base
+
j
*
output_strides
[
dim
];
int32_t
input_base
=
baseOffsetExcludingDim
(
row
,
ndim
,
input_shape
,
input_strides
,
dim
);
int32_t
input_off
=
input_base
+
sel_idx
[
row
*
k
+
j
]
*
input_strides
[
dim
];
values_out
[
output_off
]
=
input
[
input_off
];
indices_out
[
output_off
]
=
sel_idx
[
row
*
k
+
j
];
}
}
// namespace op::topk::cuda
#endif // __TOPK_CUDA_KERNEL_H__
src/infiniop/ops/topk/info.h
deleted
100644 → 0
View file @
6ab911c3
#ifndef __TOPK_INFO_H__
#define __TOPK_INFO_H__
#include "../../../utils.h"
#include "../../tensor.h"
#include <algorithm>
#include <cstddef>
#include <vector>
namespace
op
::
topk
{
class
TopKInfo
{
TopKInfo
()
=
default
;
public:
infiniDtype_t
dtype
;
std
::
vector
<
size_t
>
input_shape
;
std
::
vector
<
size_t
>
output_shape
;
std
::
vector
<
ptrdiff_t
>
input_strides
;
std
::
vector
<
ptrdiff_t
>
output_strides
;
size_t
k
;
size_t
dim
;
bool
largest
;
bool
sorted
;
size_t
ndim
;
size_t
dim_elements
;
// processed dim elements
size_t
n_iteration
;
// total number of topk iteration
static
utils
::
Result
<
TopKInfo
>
create
(
infiniopTensorDescriptor_t
values_output_desc
,
infiniopTensorDescriptor_t
indices_output_desc
,
infiniopTensorDescriptor_t
input_desc
,
size_t
k
,
size_t
dim
,
bool
largest
,
bool
sorted
)
{
auto
input_shape
=
input_desc
->
shape
();
auto
input_strides
=
input_desc
->
strides
();
size_t
input_ndim
=
input_desc
->
ndim
();
size_t
dim_elements
=
input_shape
[
dim
];
size_t
n_iteration
=
1
;
for
(
size_t
i
=
0
;
i
<
input_ndim
;
i
++
)
{
if
(
i
!=
dim
)
{
n_iteration
*=
input_shape
[
i
];
}
}
return
utils
::
Result
<
TopKInfo
>
(
TopKInfo
{
input_desc
->
dtype
(),
input_desc
->
shape
(),
values_output_desc
->
shape
(),
input_desc
->
strides
(),
values_output_desc
->
strides
(),
k
,
dim
,
largest
,
sorted
,
input_ndim
,
dim_elements
,
n_iteration
});
}
};
}
// namespace op::topk
#endif
src/infiniop/ops/topk/metax/topk_metax.h
deleted
100644 → 0
View file @
6ab911c3
#ifndef __TOPK_METAX_H__
#define __TOPK_METAX_H__
#include "../topk_desc.h"
DESCRIPTOR
(
metax
);
#endif
src/infiniop/ops/topk/metax/topk_metax.maca
deleted
100644 → 0
View file @
6ab911c3
#include "../../../devices/metax/metax_common.h"
#include "../../../devices/metax/metax_kernel_common.h"
#include "../cuda/kernel.cuh"
#include "topk_metax.h"
#include <cub/block/block_radix_sort.cuh>
#include <cub/cub.cuh>
namespace op::topk::metax {
struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t values_output_desc,
infiniopTensorDescriptor_t indices_output_desc,
infiniopTensorDescriptor_t input_desc,
size_t k,
size_t dim,
bool largest,
bool sorted) {
auto result = TopKInfo::create(values_output_desc, indices_output_desc, input_desc, k, dim, largest, sorted);
CHECK_RESULT(result);
auto info = result.take();
size_t workspace_size = 0;
workspace_size += (input_desc->ndim() + values_output_desc->ndim()) * (sizeof(size_t) + sizeof(ptrdiff_t));
size_t dim_elements = input_desc->shape()[dim];
size_t n_iteration = 1;
for (size_t i = 0; i < input_desc->ndim(); i++) {
if (i != dim) {
n_iteration *= input_desc->shape()[i];
}
}
size_t total = n_iteration * dim_elements;
workspace_size += 3 * total * sizeof(uint32_t);
workspace_size += 3 * total * sizeof(int32_t);
workspace_size += n_iteration * k * (sizeof(uint32_t) + sizeof(int32_t));
if (sorted) {
workspace_size += n_iteration * k * (sizeof(uint32_t) + sizeof(int32_t));
}
workspace_size += 5 * n_iteration * sizeof(int32_t);
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
info, workspace_size, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
namespace {
template <size_t BLOCK_SIZE, int32_t SORT_ITEMS_PER_THREAD, typename Tdata>
infiniStatus_t launchKernel(
const TopKInfo &info,
Tdata *values_output, int32_t *indices_output, const Tdata *input,
size_t k, size_t dim, bool largest, bool sorted,
hcStream_t stream, void *workspace, size_t workspace_size) {
if (dim >= info.ndim) {
return INFINI_STATUS_BAD_PARAM;
}
if (k == 0) {
return INFINI_STATUS_SUCCESS;
}
if (k > info.dim_elements) {
return INFINI_STATUS_BAD_PARAM;
}
size_t input_ndim = info.ndim;
size_t output_ndim = input_ndim;
size_t n_iteration = info.n_iteration;
size_t dim_elements = info.dim_elements;
unsigned char *workspace_ptr = reinterpret_cast<unsigned char *>(workspace);
size_t workspace_offset = 0;
size_t *input_shape_hc = reinterpret_cast<size_t *>(workspace_ptr + workspace_offset);
size_t *output_shape_hc = input_shape_hc + input_ndim;
workspace_offset += (input_ndim + output_ndim) * sizeof(size_t);
ptrdiff_t *input_strides_hc = reinterpret_cast<ptrdiff_t *>(workspace_ptr + workspace_offset);
ptrdiff_t *output_strides_hc = input_strides_hc + input_ndim;
workspace_offset += (input_ndim + output_ndim) * sizeof(ptrdiff_t);
CHECK_METAX(hcMemcpyAsync(input_shape_hc, info.input_shape.data(), input_ndim * sizeof(size_t), hcMemcpyHostToDevice, stream));
CHECK_METAX(hcMemcpyAsync(output_shape_hc, info.output_shape.data(), output_ndim * sizeof(size_t), hcMemcpyHostToDevice, stream));
CHECK_METAX(hcMemcpyAsync(input_strides_hc, info.input_strides.data(), input_ndim * sizeof(ptrdiff_t), hcMemcpyHostToDevice, stream));
CHECK_METAX(hcMemcpyAsync(output_strides_hc, info.output_strides.data(), output_ndim * sizeof(ptrdiff_t), hcMemcpyHostToDevice, stream));
const int32_t total = n_iteration * dim_elements;
uint32_t *cur_vals = reinterpret_cast<uint32_t *>(workspace_ptr + workspace_offset);
workspace_offset += total * sizeof(uint32_t);
uint32_t *ones_vals = reinterpret_cast<uint32_t *>(workspace_ptr + workspace_offset);
workspace_offset += total * sizeof(uint32_t);
uint32_t *zeros_vals = reinterpret_cast<uint32_t *>(workspace_ptr + workspace_offset);
workspace_offset += total * sizeof(uint32_t);
int32_t *cur_idx = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += total * sizeof(int32_t);
int32_t *ones_idx = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += total * sizeof(int32_t);
int32_t *zeros_idx = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += total * sizeof(int32_t);
uint32_t *sel_vals = reinterpret_cast<uint32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * k * sizeof(uint32_t);
int32_t *sel_idx = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * k * sizeof(int32_t);
uint32_t *sel_sorted_vals = nullptr;
int32_t *sel_sorted_idx = nullptr;
if (sorted) {
sel_sorted_vals = reinterpret_cast<uint32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * k * sizeof(uint32_t);
sel_sorted_idx = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * k * sizeof(int32_t);
}
int32_t *cur_n = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * sizeof(int32_t);
int32_t *rem_k = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * sizeof(int32_t);
int32_t *out_pos = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * sizeof(int32_t);
int32_t *ones_count = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * sizeof(int32_t);
int32_t *zeros_count = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * sizeof(int32_t);
// init
{
size_t threads = 256;
size_t blocks = (n_iteration + threads - 1) / threads;
op::topk::cuda::init_row_state<<<blocks, threads, 0, stream>>>(cur_n, rem_k, out_pos, n_iteration, dim_elements, k);
}
// gather input -> cur
{
dim3 block(BLOCK_SIZE);
dim3 grid((dim_elements + BLOCK_SIZE - 1) / BLOCK_SIZE, n_iteration);
op::topk::cuda::gather_rowwise<Tdata><<<grid, block, 0, stream>>>(
input, cur_vals, cur_idx,
n_iteration, dim_elements,
input_ndim, dim,
input_shape_hc, input_strides_hc);
}
// radix select/filter
for (int bit = 31; bit >= 0; --bit) {
{
size_t threads = 256;
size_t blocks = (n_iteration + threads - 1) / threads;
op::topk::cuda::zero_row_counters<<<blocks, threads, 0, stream>>>(ones_count, zeros_count, n_iteration);
}
{
dim3 block(BLOCK_SIZE);
dim3 grid((dim_elements + BLOCK_SIZE - 1) / BLOCK_SIZE, n_iteration);
op::topk::cuda::partition_rowwise<BLOCK_SIZE><<<grid, block, 0, stream>>>(
cur_vals, cur_idx,
ones_vals, ones_idx,
zeros_vals, zeros_idx,
cur_n, n_iteration, dim_elements,
bit, largest,
ones_count, zeros_count);
}
{
op::topk::cuda::decide_and_compact<BLOCK_SIZE><<<n_iteration, BLOCK_SIZE, 0, stream>>>(
cur_vals, cur_idx,
ones_vals, ones_idx,
zeros_vals, zeros_idx,
ones_count, zeros_count,
cur_n, rem_k, out_pos,
sel_vals, sel_idx,
n_iteration, dim_elements, k);
}
}
// append remaining
op::topk::cuda::take_remaining<BLOCK_SIZE><<<n_iteration, BLOCK_SIZE, 0, stream>>>(
cur_vals, cur_idx,
cur_n, rem_k, out_pos,
sel_vals, sel_idx,
n_iteration, dim_elements, k);
// sort (CUB block radix sort)
const int32_t *final_idx = sel_idx;
if (sorted) {
std::vector<int> h_offsets(n_iteration + 1);
for (size_t i = 0; i <= n_iteration; i++) {
h_offsets[i] = i * k;
}
int *d_offsets;
CHECK_METAX(hcMalloc(&d_offsets, (n_iteration + 1) * sizeof(int)));
CHECK_METAX(hcMemcpy(d_offsets, h_offsets.data(), (n_iteration + 1) * sizeof(int), hcMemcpyHostToDevice));
void *d_temp_storage = nullptr;
size_t temp_storage_bytes = 0;
if (!largest) {
cub::DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, sel_vals, sel_sorted_vals, sel_idx, sel_sorted_idx,
n_iteration * k, n_iteration, d_offsets, d_offsets + 1, 0, sizeof(uint32_t) * 8, stream);
hcMalloc(&d_temp_storage, temp_storage_bytes);
cub::DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, sel_vals, sel_sorted_vals, sel_idx, sel_sorted_idx,
n_iteration * k, n_iteration, d_offsets, d_offsets + 1, 0, sizeof(uint32_t) * 8, stream);
} else {
cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, sel_vals, sel_sorted_vals, sel_idx, sel_sorted_idx,
n_iteration * k, n_iteration, d_offsets, d_offsets + 1, 0, sizeof(uint32_t) * 8, stream);
hcMalloc(&d_temp_storage, temp_storage_bytes);
cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, sel_vals, sel_sorted_vals, sel_idx, sel_sorted_idx,
n_iteration * k, n_iteration, d_offsets, d_offsets + 1, 0, sizeof(uint32_t) * 8, stream);
}
CHECK_METAX(hcFree(d_offsets));
CHECK_METAX(hcFree(d_temp_storage));
final_idx = sel_sorted_idx;
}
// scatter to output (strided write)
{
dim3 block(BLOCK_SIZE);
dim3 grid((k + BLOCK_SIZE - 1) / BLOCK_SIZE, n_iteration);
op::topk::cuda::scatter_to_output<Tdata><<<grid, block, 0, stream>>>(
input, final_idx,
values_output, indices_output,
n_iteration, k,
input_ndim, dim,
input_shape_hc, input_strides_hc,
output_shape_hc, output_strides_hc);
}
return INFINI_STATUS_SUCCESS;
}
} // namespace
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *values_output,
void *indices_output,
const void *input,
size_t k,
size_t dim,
bool largest,
bool sorted,
void *stream_) const {
hcStream_t stream = (hcStream_t)stream_;
constexpr int ITEMS = 4;
#define CALCULATE_TOPK(BLOCK_SIZE, Tdata) \
launchKernel<BLOCK_SIZE, ITEMS, Tdata>( \
_info, \
(Tdata *)values_output, (int32_t *)indices_output, (const Tdata *)input, \
k, dim, largest, sorted, \
stream, workspace, workspace_size)
#define CALCULATE_TOPK_WITH_BLOCK_SIZE(BLOCK_SIZE) \
{ \
if (_info.dtype == INFINI_DTYPE_BF16) \
return CALCULATE_TOPK(BLOCK_SIZE, __hpcc_bfloat16); \
else if (_info.dtype == INFINI_DTYPE_F16) \
return CALCULATE_TOPK(BLOCK_SIZE, half); \
else if (_info.dtype == INFINI_DTYPE_F32) \
return CALCULATE_TOPK(BLOCK_SIZE, float); \
else \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
if (_opaque->internal->maxThreadsPerBlock() >= 256) {
CALCULATE_TOPK_WITH_BLOCK_SIZE(256)
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::topk::metax
src/infiniop/ops/topk/moore/topk_moore.h
deleted
100644 → 0
View file @
6ab911c3
#ifndef __TOPK_MOORE_H__
#define __TOPK_MOORE_H__
#include "../topk_desc.h"
DESCRIPTOR
(
moore
);
#endif
src/infiniop/ops/topk/moore/topk_moore.mu
deleted
100644 → 0
View file @
6ab911c3
#include "../../../devices/moore/moore_common.h"
#include "../../../devices/moore/moore_kernel_common.h"
#include "../cuda/kernel.cuh"
#include "topk_moore.h"
#include <cub/block/block_radix_sort.cuh>
#include <cub/cub.cuh>
namespace op::topk::moore {
struct Descriptor::Opaque {
std::shared_ptr<device::moore::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t values_output_desc,
infiniopTensorDescriptor_t indices_output_desc,
infiniopTensorDescriptor_t input_desc,
size_t k,
size_t dim,
bool largest,
bool sorted) {
auto result = TopKInfo::create(values_output_desc, indices_output_desc, input_desc, k, dim, largest, sorted);
CHECK_RESULT(result);
auto info = result.take();
size_t workspace_size = 0;
workspace_size += (input_desc->ndim() + values_output_desc->ndim()) * (sizeof(size_t) + sizeof(ptrdiff_t));
size_t dim_elements = input_desc->shape()[dim];
size_t n_iteration = 1;
for (size_t i = 0; i < input_desc->ndim(); i++) {
if (i != dim) {
n_iteration *= input_desc->shape()[i];
}
}
size_t total = n_iteration * dim_elements;
workspace_size += 3 * total * sizeof(uint32_t);
workspace_size += 3 * total * sizeof(int32_t);
workspace_size += n_iteration * k * (sizeof(uint32_t) + sizeof(int32_t));
if (sorted) {
workspace_size += n_iteration * k * (sizeof(uint32_t) + sizeof(int32_t));
}
workspace_size += 5 * n_iteration * sizeof(int32_t);
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::moore::Handle *>(handle)->internal()},
info, workspace_size, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
namespace {
template <size_t BLOCK_SIZE, int32_t SORT_ITEMS_PER_THREAD, typename Tdata>
infiniStatus_t launchKernel(
const TopKInfo &info,
Tdata *values_output, int32_t *indices_output, const Tdata *input,
size_t k, size_t dim, bool largest, bool sorted,
musaStream_t stream, void *workspace, size_t workspace_size) {
if (dim >= info.ndim) {
return INFINI_STATUS_BAD_PARAM;
}
if (k == 0) {
return INFINI_STATUS_SUCCESS;
}
if (k > info.dim_elements) {
return INFINI_STATUS_BAD_PARAM;
}
size_t input_ndim = info.ndim;
size_t output_ndim = input_ndim;
size_t n_iteration = info.n_iteration;
size_t dim_elements = info.dim_elements;
unsigned char *workspace_ptr = reinterpret_cast<unsigned char *>(workspace);
size_t workspace_offset = 0;
size_t *input_shape_musa = reinterpret_cast<size_t *>(workspace_ptr + workspace_offset);
size_t *output_shape_musa = input_shape_musa + input_ndim;
workspace_offset += (input_ndim + output_ndim) * sizeof(size_t);
ptrdiff_t *input_strides_musa = reinterpret_cast<ptrdiff_t *>(workspace_ptr + workspace_offset);
ptrdiff_t *output_strides_musa = input_strides_musa + input_ndim;
workspace_offset += (input_ndim + output_ndim) * sizeof(ptrdiff_t);
CHECK_MOORE(musaMemcpyAsync(input_shape_musa, info.input_shape.data(), input_ndim * sizeof(size_t), musaMemcpyHostToDevice, stream));
CHECK_MOORE(musaMemcpyAsync(output_shape_musa, info.output_shape.data(), output_ndim * sizeof(size_t), musaMemcpyHostToDevice, stream));
CHECK_MOORE(musaMemcpyAsync(input_strides_musa, info.input_strides.data(), input_ndim * sizeof(ptrdiff_t), musaMemcpyHostToDevice, stream));
CHECK_MOORE(musaMemcpyAsync(output_strides_musa, info.output_strides.data(), output_ndim * sizeof(ptrdiff_t), musaMemcpyHostToDevice, stream));
const int32_t total = n_iteration * dim_elements;
uint32_t *cur_vals = reinterpret_cast<uint32_t *>(workspace_ptr + workspace_offset);
workspace_offset += total * sizeof(uint32_t);
uint32_t *ones_vals = reinterpret_cast<uint32_t *>(workspace_ptr + workspace_offset);
workspace_offset += total * sizeof(uint32_t);
uint32_t *zeros_vals = reinterpret_cast<uint32_t *>(workspace_ptr + workspace_offset);
workspace_offset += total * sizeof(uint32_t);
int32_t *cur_idx = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += total * sizeof(int32_t);
int32_t *ones_idx = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += total * sizeof(int32_t);
int32_t *zeros_idx = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += total * sizeof(int32_t);
uint32_t *sel_vals = reinterpret_cast<uint32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * k * sizeof(uint32_t);
int32_t *sel_idx = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * k * sizeof(int32_t);
uint32_t *sel_sorted_vals = nullptr;
int32_t *sel_sorted_idx = nullptr;
if (sorted) {
sel_sorted_vals = reinterpret_cast<uint32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * k * sizeof(uint32_t);
sel_sorted_idx = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * k * sizeof(int32_t);
}
int32_t *cur_n = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * sizeof(int32_t);
int32_t *rem_k = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * sizeof(int32_t);
int32_t *out_pos = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * sizeof(int32_t);
int32_t *ones_count = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * sizeof(int32_t);
int32_t *zeros_count = reinterpret_cast<int32_t *>(workspace_ptr + workspace_offset);
workspace_offset += n_iteration * sizeof(int32_t);
// init
{
size_t threads = 256;
size_t blocks = (n_iteration + threads - 1) / threads;
op::topk::cuda::init_row_state<<<blocks, threads, 0, stream>>>(cur_n, rem_k, out_pos, n_iteration, dim_elements, k);
}
// gather input -> cur
{
dim3 block(BLOCK_SIZE);
dim3 grid((dim_elements + BLOCK_SIZE - 1) / BLOCK_SIZE, n_iteration);
op::topk::cuda::gather_rowwise<Tdata><<<grid, block, 0, stream>>>(
input, cur_vals, cur_idx,
n_iteration, dim_elements,
input_ndim, dim,
input_shape_musa, input_strides_musa);
}
// radix select/filter
for (int bit = 31; bit >= 0; --bit) {
{
size_t threads = 256;
size_t blocks = (n_iteration + threads - 1) / threads;
op::topk::cuda::zero_row_counters<<<blocks, threads, 0, stream>>>(ones_count, zeros_count, n_iteration);
}
{
dim3 block(BLOCK_SIZE);
dim3 grid((dim_elements + BLOCK_SIZE - 1) / BLOCK_SIZE, n_iteration);
op::topk::cuda::partition_rowwise<BLOCK_SIZE><<<grid, block, 0, stream>>>(
cur_vals, cur_idx,
ones_vals, ones_idx,
zeros_vals, zeros_idx,
cur_n, n_iteration, dim_elements,
bit, largest,
ones_count, zeros_count);
}
{
op::topk::cuda::decide_and_compact<BLOCK_SIZE><<<n_iteration, BLOCK_SIZE, 0, stream>>>(
cur_vals, cur_idx,
ones_vals, ones_idx,
zeros_vals, zeros_idx,
ones_count, zeros_count,
cur_n, rem_k, out_pos,
sel_vals, sel_idx,
n_iteration, dim_elements, k);
}
}
// append remaining
op::topk::cuda::take_remaining<BLOCK_SIZE><<<n_iteration, BLOCK_SIZE, 0, stream>>>(
cur_vals, cur_idx,
cur_n, rem_k, out_pos,
sel_vals, sel_idx,
n_iteration, dim_elements, k);
// sort (CUB block radix sort)
const int32_t *final_idx = sel_idx;
if (sorted) {
std::vector<int> h_offsets(n_iteration + 1);
for (size_t i = 0; i <= n_iteration; i++) {
h_offsets[i] = i * k;
}
int *d_offsets;
CHECK_MOORE(musaMalloc(&d_offsets, (n_iteration + 1) * sizeof(int)));
CHECK_MOORE(musaMemcpy(d_offsets, h_offsets.data(), (n_iteration + 1) * sizeof(int), musaMemcpyHostToDevice));
void *d_temp_storage = nullptr;
size_t temp_storage_bytes = 0;
if (!largest) {
cub::DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, sel_vals, sel_sorted_vals, sel_idx, sel_sorted_idx,
n_iteration * k, n_iteration, d_offsets, d_offsets + 1, 0, sizeof(uint32_t) * 8, stream);
musaMalloc(&d_temp_storage, temp_storage_bytes);
cub::DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, sel_vals, sel_sorted_vals, sel_idx, sel_sorted_idx,
n_iteration * k, n_iteration, d_offsets, d_offsets + 1, 0, sizeof(uint32_t) * 8, stream);
} else {
cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, sel_vals, sel_sorted_vals, sel_idx, sel_sorted_idx,
n_iteration * k, n_iteration, d_offsets, d_offsets + 1, 0, sizeof(uint32_t) * 8, stream);
musaMalloc(&d_temp_storage, temp_storage_bytes);
cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, sel_vals, sel_sorted_vals, sel_idx, sel_sorted_idx,
n_iteration * k, n_iteration, d_offsets, d_offsets + 1, 0, sizeof(uint32_t) * 8, stream);
}
CHECK_MOORE(musaFree(d_offsets));
CHECK_MOORE(musaFree(d_temp_storage));
final_idx = sel_sorted_idx;
}
// scatter to output (strided write)
{
dim3 block(BLOCK_SIZE);
dim3 grid((k + BLOCK_SIZE - 1) / BLOCK_SIZE, n_iteration);
op::topk::cuda::scatter_to_output<Tdata><<<grid, block, 0, stream>>>(
input, final_idx,
values_output, indices_output,
n_iteration, k,
input_ndim, dim,
input_shape_musa, input_strides_musa,
output_shape_musa, output_strides_musa);
}
return INFINI_STATUS_SUCCESS;
}
} // namespace
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *values_output,
void *indices_output,
const void *input,
size_t k,
size_t dim,
bool largest,
bool sorted,
void *stream_) const {
musaStream_t stream = (musaStream_t)stream_;
constexpr int ITEMS = 4;
#define CALCULATE_TOPK(BLOCK_SIZE, Tdata) \
launchKernel<BLOCK_SIZE, ITEMS, Tdata>( \
_info, \
(Tdata *)values_output, (int32_t *)indices_output, (const Tdata *)input, \
k, dim, largest, sorted, \
stream, workspace, workspace_size)
#define CALCULATE_TOPK_WITH_BLOCK_SIZE(BLOCK_SIZE) \
{ \
if (_info.dtype == INFINI_DTYPE_BF16) \
return CALCULATE_TOPK(BLOCK_SIZE, __mt_bfloat16); \
else if (_info.dtype == INFINI_DTYPE_F16) \
return CALCULATE_TOPK(BLOCK_SIZE, half); \
else if (_info.dtype == INFINI_DTYPE_F32) \
return CALCULATE_TOPK(BLOCK_SIZE, float); \
else \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
if (_opaque->internal->maxThreadsPerBlock() >= 256) {
CALCULATE_TOPK_WITH_BLOCK_SIZE(256)
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::topk::moore
src/infiniop/ops/topk/nvidia/topk_nvidia.cu
deleted
100644 → 0
View file @
6ab911c3
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include "../cuda/kernel.cuh"
#include "topk_nvidia.cuh"
#include <cub/block/block_radix_sort.cuh>
#include <cub/cub.cuh>
namespace
op
::
topk
::
nvidia
{
struct
Descriptor
::
Opaque
{
std
::
shared_ptr
<
device
::
nvidia
::
Handle
::
Internal
>
internal
;
};
Descriptor
::~
Descriptor
()
{
delete
_opaque
;
}
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
values_output_desc
,
infiniopTensorDescriptor_t
indices_output_desc
,
infiniopTensorDescriptor_t
input_desc
,
size_t
k
,
size_t
dim
,
bool
largest
,
bool
sorted
)
{
auto
result
=
TopKInfo
::
create
(
values_output_desc
,
indices_output_desc
,
input_desc
,
k
,
dim
,
largest
,
sorted
);
CHECK_RESULT
(
result
);
auto
info
=
result
.
take
();
size_t
workspace_size
=
0
;
workspace_size
+=
(
input_desc
->
ndim
()
+
values_output_desc
->
ndim
())
*
(
sizeof
(
size_t
)
+
sizeof
(
ptrdiff_t
));
// 计算临时变量空间
size_t
dim_elements
=
input_desc
->
shape
()[
dim
];
size_t
n_iteration
=
1
;
for
(
size_t
i
=
0
;
i
<
input_desc
->
ndim
();
i
++
)
{
if
(
i
!=
dim
)
{
n_iteration
*=
input_desc
->
shape
()[
i
];
}
}
size_t
total
=
n_iteration
*
dim_elements
;
workspace_size
+=
3
*
total
*
sizeof
(
uint32_t
);
workspace_size
+=
3
*
total
*
sizeof
(
int32_t
);
workspace_size
+=
n_iteration
*
k
*
(
sizeof
(
uint32_t
)
+
sizeof
(
int32_t
));
if
(
sorted
)
{
workspace_size
+=
n_iteration
*
k
*
(
sizeof
(
uint32_t
)
+
sizeof
(
int32_t
));
}
workspace_size
+=
5
*
n_iteration
*
sizeof
(
int32_t
);
*
desc_ptr
=
new
Descriptor
(
new
Opaque
{
reinterpret_cast
<
device
::
nvidia
::
Handle
*>
(
handle
)
->
internal
()},
info
,
workspace_size
,
handle
->
device
,
handle
->
device_id
);
return
INFINI_STATUS_SUCCESS
;
}
namespace
{
template
<
size_t
BLOCK_SIZE
,
int32_t
SORT_ITEMS_PER_THREAD
,
typename
Tdata
>
infiniStatus_t
launchKernel
(
const
TopKInfo
&
info
,
Tdata
*
values_output
,
int32_t
*
indices_output
,
const
Tdata
*
input
,
size_t
k
,
size_t
dim
,
bool
largest
,
bool
sorted
,
cudaStream_t
stream
,
void
*
workspace
,
size_t
workspace_size
)
{
if
(
dim
>=
info
.
ndim
)
{
return
INFINI_STATUS_BAD_PARAM
;
}
if
(
k
==
0
)
{
return
INFINI_STATUS_SUCCESS
;
}
if
(
k
>
info
.
dim_elements
)
{
return
INFINI_STATUS_BAD_PARAM
;
}
size_t
input_ndim
=
info
.
ndim
;
size_t
output_ndim
=
input_ndim
;
size_t
n_iteration
=
info
.
n_iteration
;
size_t
dim_elements
=
info
.
dim_elements
;
unsigned
char
*
workspace_ptr
=
reinterpret_cast
<
unsigned
char
*>
(
workspace
);
size_t
workspace_offset
=
0
;
size_t
*
input_shape_cuda
=
reinterpret_cast
<
size_t
*>
(
workspace_ptr
+
workspace_offset
);
size_t
*
output_shape_cuda
=
input_shape_cuda
+
input_ndim
;
workspace_offset
+=
(
input_ndim
+
output_ndim
)
*
sizeof
(
size_t
);
ptrdiff_t
*
input_strides_cuda
=
reinterpret_cast
<
ptrdiff_t
*>
(
workspace_ptr
+
workspace_offset
);
ptrdiff_t
*
output_strides_cuda
=
input_strides_cuda
+
input_ndim
;
workspace_offset
+=
(
input_ndim
+
output_ndim
)
*
sizeof
(
ptrdiff_t
);
CHECK_CUDA
(
cudaMemcpyAsync
(
input_shape_cuda
,
info
.
input_shape
.
data
(),
input_ndim
*
sizeof
(
size_t
),
cudaMemcpyHostToDevice
,
stream
));
CHECK_CUDA
(
cudaMemcpyAsync
(
output_shape_cuda
,
info
.
output_shape
.
data
(),
output_ndim
*
sizeof
(
size_t
),
cudaMemcpyHostToDevice
,
stream
));
CHECK_CUDA
(
cudaMemcpyAsync
(
input_strides_cuda
,
info
.
input_strides
.
data
(),
input_ndim
*
sizeof
(
ptrdiff_t
),
cudaMemcpyHostToDevice
,
stream
));
CHECK_CUDA
(
cudaMemcpyAsync
(
output_strides_cuda
,
info
.
output_strides
.
data
(),
output_ndim
*
sizeof
(
ptrdiff_t
),
cudaMemcpyHostToDevice
,
stream
));
const
int32_t
total
=
n_iteration
*
dim_elements
;
uint32_t
*
cur_vals
=
reinterpret_cast
<
uint32_t
*>
(
workspace_ptr
+
workspace_offset
);
workspace_offset
+=
total
*
sizeof
(
uint32_t
);
uint32_t
*
ones_vals
=
reinterpret_cast
<
uint32_t
*>
(
workspace_ptr
+
workspace_offset
);
workspace_offset
+=
total
*
sizeof
(
uint32_t
);
uint32_t
*
zeros_vals
=
reinterpret_cast
<
uint32_t
*>
(
workspace_ptr
+
workspace_offset
);
workspace_offset
+=
total
*
sizeof
(
uint32_t
);
int32_t
*
cur_idx
=
reinterpret_cast
<
int32_t
*>
(
workspace_ptr
+
workspace_offset
);
workspace_offset
+=
total
*
sizeof
(
int32_t
);
int32_t
*
ones_idx
=
reinterpret_cast
<
int32_t
*>
(
workspace_ptr
+
workspace_offset
);
workspace_offset
+=
total
*
sizeof
(
int32_t
);
int32_t
*
zeros_idx
=
reinterpret_cast
<
int32_t
*>
(
workspace_ptr
+
workspace_offset
);
workspace_offset
+=
total
*
sizeof
(
int32_t
);
uint32_t
*
sel_vals
=
reinterpret_cast
<
uint32_t
*>
(
workspace_ptr
+
workspace_offset
);
workspace_offset
+=
n_iteration
*
k
*
sizeof
(
uint32_t
);
int32_t
*
sel_idx
=
reinterpret_cast
<
int32_t
*>
(
workspace_ptr
+
workspace_offset
);
workspace_offset
+=
n_iteration
*
k
*
sizeof
(
int32_t
);
uint32_t
*
sel_sorted_vals
=
nullptr
;
int32_t
*
sel_sorted_idx
=
nullptr
;
if
(
sorted
)
{
sel_sorted_vals
=
reinterpret_cast
<
uint32_t
*>
(
workspace_ptr
+
workspace_offset
);
workspace_offset
+=
n_iteration
*
k
*
sizeof
(
uint32_t
);
sel_sorted_idx
=
reinterpret_cast
<
int32_t
*>
(
workspace_ptr
+
workspace_offset
);
workspace_offset
+=
n_iteration
*
k
*
sizeof
(
int32_t
);
}
int32_t
*
cur_n
=
reinterpret_cast
<
int32_t
*>
(
workspace_ptr
+
workspace_offset
);
workspace_offset
+=
n_iteration
*
sizeof
(
int32_t
);
int32_t
*
rem_k
=
reinterpret_cast
<
int32_t
*>
(
workspace_ptr
+
workspace_offset
);
workspace_offset
+=
n_iteration
*
sizeof
(
int32_t
);
int32_t
*
out_pos
=
reinterpret_cast
<
int32_t
*>
(
workspace_ptr
+
workspace_offset
);
workspace_offset
+=
n_iteration
*
sizeof
(
int32_t
);
int32_t
*
ones_count
=
reinterpret_cast
<
int32_t
*>
(
workspace_ptr
+
workspace_offset
);
workspace_offset
+=
n_iteration
*
sizeof
(
int32_t
);
int32_t
*
zeros_count
=
reinterpret_cast
<
int32_t
*>
(
workspace_ptr
+
workspace_offset
);
workspace_offset
+=
n_iteration
*
sizeof
(
int32_t
);
// init
{
size_t
threads
=
256
;
size_t
blocks
=
(
n_iteration
+
threads
-
1
)
/
threads
;
op
::
topk
::
cuda
::
init_row_state
<<<
blocks
,
threads
,
0
,
stream
>>>
(
cur_n
,
rem_k
,
out_pos
,
n_iteration
,
dim_elements
,
k
);
}
// gather input -> cur
{
dim3
block
(
BLOCK_SIZE
);
dim3
grid
((
dim_elements
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
,
n_iteration
);
op
::
topk
::
cuda
::
gather_rowwise
<
Tdata
><<<
grid
,
block
,
0
,
stream
>>>
(
input
,
cur_vals
,
cur_idx
,
n_iteration
,
dim_elements
,
input_ndim
,
dim
,
input_shape_cuda
,
input_strides_cuda
);
}
// radix select/filter
for
(
int
bit
=
31
;
bit
>=
0
;
--
bit
)
{
{
size_t
threads
=
256
;
size_t
blocks
=
(
n_iteration
+
threads
-
1
)
/
threads
;
op
::
topk
::
cuda
::
zero_row_counters
<<<
blocks
,
threads
,
0
,
stream
>>>
(
ones_count
,
zeros_count
,
n_iteration
);
}
{
dim3
block
(
BLOCK_SIZE
);
dim3
grid
((
dim_elements
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
,
n_iteration
);
op
::
topk
::
cuda
::
partition_rowwise
<
BLOCK_SIZE
><<<
grid
,
block
,
0
,
stream
>>>
(
cur_vals
,
cur_idx
,
ones_vals
,
ones_idx
,
zeros_vals
,
zeros_idx
,
cur_n
,
n_iteration
,
dim_elements
,
bit
,
largest
,
ones_count
,
zeros_count
);
}
{
op
::
topk
::
cuda
::
decide_and_compact
<
BLOCK_SIZE
><<<
n_iteration
,
BLOCK_SIZE
,
0
,
stream
>>>
(
cur_vals
,
cur_idx
,
ones_vals
,
ones_idx
,
zeros_vals
,
zeros_idx
,
ones_count
,
zeros_count
,
cur_n
,
rem_k
,
out_pos
,
sel_vals
,
sel_idx
,
n_iteration
,
dim_elements
,
k
);
}
}
// append remaining
op
::
topk
::
cuda
::
take_remaining
<
BLOCK_SIZE
><<<
n_iteration
,
BLOCK_SIZE
,
0
,
stream
>>>
(
cur_vals
,
cur_idx
,
cur_n
,
rem_k
,
out_pos
,
sel_vals
,
sel_idx
,
n_iteration
,
dim_elements
,
k
);
// sort (CUB block radix sort)
const
int32_t
*
final_idx
=
sel_idx
;
if
(
sorted
)
{
std
::
vector
<
int
>
h_offsets
(
n_iteration
+
1
);
for
(
size_t
i
=
0
;
i
<=
n_iteration
;
i
++
)
{
h_offsets
[
i
]
=
i
*
k
;
}
int
*
d_offsets
;
CHECK_CUDA
(
cudaMalloc
(
&
d_offsets
,
(
n_iteration
+
1
)
*
sizeof
(
int
)));
CHECK_CUDA
(
cudaMemcpy
(
d_offsets
,
h_offsets
.
data
(),
(
n_iteration
+
1
)
*
sizeof
(
int
),
cudaMemcpyHostToDevice
));
void
*
d_temp_storage
=
nullptr
;
size_t
temp_storage_bytes
=
0
;
if
(
!
largest
)
{
cub
::
DeviceSegmentedRadixSort
::
SortPairs
(
d_temp_storage
,
temp_storage_bytes
,
sel_vals
,
sel_sorted_vals
,
sel_idx
,
sel_sorted_idx
,
n_iteration
*
k
,
n_iteration
,
d_offsets
,
d_offsets
+
1
,
0
,
sizeof
(
uint32_t
)
*
8
,
stream
);
cudaMalloc
(
&
d_temp_storage
,
temp_storage_bytes
);
cub
::
DeviceSegmentedRadixSort
::
SortPairs
(
d_temp_storage
,
temp_storage_bytes
,
sel_vals
,
sel_sorted_vals
,
sel_idx
,
sel_sorted_idx
,
n_iteration
*
k
,
n_iteration
,
d_offsets
,
d_offsets
+
1
,
0
,
sizeof
(
uint32_t
)
*
8
,
stream
);
}
else
{
cub
::
DeviceSegmentedRadixSort
::
SortPairsDescending
(
d_temp_storage
,
temp_storage_bytes
,
sel_vals
,
sel_sorted_vals
,
sel_idx
,
sel_sorted_idx
,
n_iteration
*
k
,
n_iteration
,
d_offsets
,
d_offsets
+
1
,
0
,
sizeof
(
uint32_t
)
*
8
,
stream
);
cudaMalloc
(
&
d_temp_storage
,
temp_storage_bytes
);
cub
::
DeviceSegmentedRadixSort
::
SortPairsDescending
(
d_temp_storage
,
temp_storage_bytes
,
sel_vals
,
sel_sorted_vals
,
sel_idx
,
sel_sorted_idx
,
n_iteration
*
k
,
n_iteration
,
d_offsets
,
d_offsets
+
1
,
0
,
sizeof
(
uint32_t
)
*
8
,
stream
);
}
CHECK_CUDA
(
cudaFree
(
d_offsets
));
CHECK_CUDA
(
cudaFree
(
d_temp_storage
));
final_idx
=
sel_sorted_idx
;
}
// scatter to output (strided write)
{
dim3
block
(
BLOCK_SIZE
);
dim3
grid
((
k
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
,
n_iteration
);
op
::
topk
::
cuda
::
scatter_to_output
<
Tdata
><<<
grid
,
block
,
0
,
stream
>>>
(
input
,
final_idx
,
values_output
,
indices_output
,
n_iteration
,
k
,
input_ndim
,
dim
,
input_shape_cuda
,
input_strides_cuda
,
output_shape_cuda
,
output_strides_cuda
);
}
CHECK_CUDA
(
cudaGetLastError
());
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
values_output
,
void
*
indices_output
,
const
void
*
input
,
size_t
k
,
size_t
dim
,
bool
largest
,
bool
sorted
,
void
*
stream_
)
const
{
cudaStream_t
stream
=
(
cudaStream_t
)
stream_
;
constexpr
int
ITEMS
=
4
;
#define CALCULATE_TOPK(BLOCK_SIZE, Tdata) \
launchKernel<BLOCK_SIZE, ITEMS, Tdata>( \
_info, \
(Tdata *)values_output, (int32_t *)indices_output, (const Tdata *)input, \
k, dim, largest, sorted, \
stream, workspace, workspace_size)
#define CALCULATE_TOPK_WITH_BLOCK_SIZE(BLOCK_SIZE) \
{ \
if (_info.dtype == INFINI_DTYPE_BF16) \
return CALCULATE_TOPK(BLOCK_SIZE, __nv_bfloat16); \
else if (_info.dtype == INFINI_DTYPE_F16) \
return CALCULATE_TOPK(BLOCK_SIZE, half); \
else if (_info.dtype == INFINI_DTYPE_F32) \
return CALCULATE_TOPK(BLOCK_SIZE, float); \
else \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
>=
256
)
{
CALCULATE_TOPK_WITH_BLOCK_SIZE
(
256
)
}
else
{
return
INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED
;
}
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::topk::nvidia
src/infiniop/ops/topk/nvidia/topk_nvidia.cuh
deleted
100644 → 0
View file @
6ab911c3
#ifndef __TOPK_NVIDIA_H__
#define __TOPK_NVIDIA_H__
#include "../topk_desc.h"
DESCRIPTOR
(
nvidia
);
#endif // __TOPK_NVIDIA_H__
src/infiniop/ops/topk/operator.cc
deleted
100644 → 0
View file @
6ab911c3
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/topk.h"
#include <vector>
#ifdef ENABLE_CPU_API
#include "cpu/topk_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#include "nvidia/topk_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
#include "metax/topk_metax.h"
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/topk_kunlun.h"
#endif
#ifdef ENABLE_MOORE_API
#include "moore/topk_moore.h"
#endif
__INFINI_C
infiniStatus_t
infiniopCreateTopKDescriptor
(
infiniopHandle_t
handle
,
infiniopTopKDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
values_output_desc
,
infiniopTensorDescriptor_t
indices_output_desc
,
infiniopTensorDescriptor_t
input_desc
,
size_t
k
,
size_t
dim
,
bool
largest
,
bool
sorted
)
{
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::topk::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::topk::NAMESPACE::Descriptor **>(desc_ptr), \
values_output_desc, \
indices_output_desc, \
input_desc, \
k, \
dim, \
largest, \
sorted)
switch
(
handle
->
device
)
{
#ifdef ENABLE_CPU_API
CREATE
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_NVIDIA_API
CREATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
CREATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
CREATE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_METAX_API
CREATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_KUNLUN_API
CREATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
#ifdef ENABLE_MOORE_API
CREATE
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CREATE
}
__INFINI_C
infiniStatus_t
infiniopGetTopKWorkspaceSize
(
infiniopTopKDescriptor_t
desc
,
size_t
*
size
)
{
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::topk::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU_API
GET
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_NVIDIA_API
GET
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
GET
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
GET
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_METAX_API
GET
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_KUNLUN_API
GET
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
#ifdef ENABLE_MOORE_API
GET
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef GET
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__INFINI_C
infiniStatus_t
infiniopTopK
(
infiniopTopKDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
values_output
,
void
*
indices_output
,
const
void
*
input
,
size_t
k
,
size_t
dim
,
bool
largest
,
bool
sorted
,
void
*
stream
)
{
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::topk::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, values_output, indices_output, input, k, dim, largest, sorted, stream)
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU_API
CALCULATE
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_NVIDIA_API
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
CALCULATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
CALCULATE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_METAX_API
CALCULATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_KUNLUN_API
CALCULATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
#ifdef ENABLE_MOORE_API
CALCULATE
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CALCULATE
}
__INFINI_C
infiniStatus_t
infiniopDestroyTopKDescriptor
(
infiniopTopKDescriptor_t
desc
)
{
#define DELETE(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<const op::topk::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU_API
DELETE
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_NVIDIA_API
DELETE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
DELETE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
DELETE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_METAX_API
DELETE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_KUNLUN_API
DELETE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
#ifdef ENABLE_MOORE_API
DELETE
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef DELETE
}
Prev
1
…
4
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment