Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
a15aa367
Commit
a15aa367
authored
Nov 28, 2025
by
Zhao Shijie
Committed by
zhaoshijie
Nov 29, 2025
Browse files
issue/563 Add metax support for topkrouter
parent
cad2d45a
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
120 additions
and
4 deletions
+120
-4
src/infiniop/ops/topkrouter/cuda/kernel.cuh
src/infiniop/ops/topkrouter/cuda/kernel.cuh
+4
-4
src/infiniop/ops/topkrouter/metax/topkrouter_metax.h
src/infiniop/ops/topkrouter/metax/topkrouter_metax.h
+8
-0
src/infiniop/ops/topkrouter/metax/topkrouter_metax.maca
src/infiniop/ops/topkrouter/metax/topkrouter_metax.maca
+93
-0
src/infiniop/ops/topkrouter/operator.cc
src/infiniop/ops/topkrouter/operator.cc
+15
-0
No files found.
src/infiniop/ops/topkrouter/cuda/kernel.cuh
View file @
a15aa367
...
@@ -6,16 +6,16 @@
...
@@ -6,16 +6,16 @@
#include <cub/block/block_reduce.cuh>
#include <cub/block/block_reduce.cuh>
#include <cub/block/block_store.cuh>
#include <cub/block/block_store.cuh>
#include <cub/cub.cuh>
#include <cub/cub.cuh>
#include <cuda_bf16.h>
//
#include <cuda_bf16.h>
#include <cuda_fp16.h>
//
#include <cuda_fp16.h>
#include <cuda_runtime.h>
//
#include <cuda_runtime.h>
template
<
typename
T
>
template
<
typename
T
>
inline
__device__
float
exp_func
(
T
x
)
{
inline
__device__
float
exp_func
(
T
x
)
{
float
data
;
float
data
;
if
constexpr
(
std
::
is_same_v
<
T
,
float
>
)
{
if
constexpr
(
std
::
is_same_v
<
T
,
float
>
)
{
data
=
x
;
data
=
x
;
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
__nv
_bfloat16
>
)
{
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
cuda
_bfloat16
>
)
{
data
=
__bfloat162float
(
x
);
data
=
__bfloat162float
(
x
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
half
>
)
{
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
half
>
)
{
data
=
__half2float
(
x
);
data
=
__half2float
(
x
);
...
...
src/infiniop/ops/topkrouter/metax/topkrouter_metax.h
0 → 100644
View file @
a15aa367
#ifndef __TOPKROUTER_METAX_H__
#define __TOPKROUTER_METAX_H__
#include "../topkrouter.h"
DESCRIPTOR
(
metax
)
#endif
src/infiniop/ops/topkrouter/metax/topkrouter_metax.maca
0 → 100644
View file @
a15aa367
#include "../../../devices/metax/metax_common.h"
#include "../../../devices/metax/metax_kernel_common.h"
#include "../cuda/kernel.cuh"
#include "topkrouter_metax.h"
#include <cub/block/block_reduce.cuh>
namespace op::topkrouter::metax {
struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t correction_bias_desc) {
auto result = TopkrouterInfo::create(x_desc);
CHECK_RESULT(result);
auto info = result.take();
if (info.x_strides[1] != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
std::move(info),
0,
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
namespace {
template <int BLOCK_SIZE = 128>
infiniStatus_t launch_topkrouter(float *d_values_out, int *d_indices_out, const void *d_input, const float *d_correction_bias,
const float routed_scaling_factor, const size_t N, const size_t width, const size_t topk, infiniDtype_t xtype,
hcStream_t stream) {
const int block_threads = BLOCK_SIZE;
dim3 blocks(N);
dim3 threads(block_threads);
if (xtype == INFINI_DTYPE_F32) {
topkrouter_kernel<float, BLOCK_SIZE><<<blocks, threads, 0, stream>>>(d_values_out, d_indices_out, (float *)d_input, d_correction_bias, routed_scaling_factor, N, width, topk);
} else if (xtype == INFINI_DTYPE_F16) {
topkrouter_kernel<half, BLOCK_SIZE><<<blocks, threads, 0, stream>>>(d_values_out, d_indices_out, (half *)d_input, d_correction_bias, routed_scaling_factor, N, width, topk);
} else if (xtype == INFINI_DTYPE_BF16) {
topkrouter_kernel<cuda_bfloat16, BLOCK_SIZE><<<blocks, threads, 0, stream>>>(d_values_out, d_indices_out, (cuda_bfloat16 *)d_input, d_correction_bias, routed_scaling_factor, N, width, topk);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
}; // namespace
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
float *values,
int *indices,
const void *x,
const float *correction_bias,
const float routed_scaling_factor,
const size_t topk,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
size_t N = _info.N;
size_t width = _info.width; // 256
// size_t n_routed_experts = 256;
// size_t n_group = 8;
// size_t topk_group = 4;
auto cuda_stream = reinterpret_cast<hcStream_t>(stream);
if (256 == width) {
launch_topkrouter<256>(values, indices, x, correction_bias, routed_scaling_factor, N, width, topk, _info.xtype, cuda_stream);
} else {
return INFINI_STATUS_BAD_PARAM;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::topkrouter::metax
src/infiniop/ops/topkrouter/operator.cc
View file @
a15aa367
...
@@ -8,6 +8,9 @@
...
@@ -8,6 +8,9 @@
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
#include "nvidia/topkrouter_nvidia.cuh"
#include "nvidia/topkrouter_nvidia.cuh"
#endif
#endif
#ifdef ENABLE_METAX_API
#include "metax/topkrouter_metax.h"
#endif
__C
infiniStatus_t
infiniopCreateTopkrouterDescriptor
(
infiniopHandle_t
handle
,
infiniopTopkrouterDescriptor_t
*
desc_ptr
,
__C
infiniStatus_t
infiniopCreateTopkrouterDescriptor
(
infiniopHandle_t
handle
,
infiniopTopkrouterDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
x_desc
,
infiniopTensorDescriptor_t
x_desc
,
...
@@ -26,6 +29,9 @@ __C infiniStatus_t infiniopCreateTopkrouterDescriptor(infiniopHandle_t handle, i
...
@@ -26,6 +29,9 @@ __C infiniStatus_t infiniopCreateTopkrouterDescriptor(infiniopHandle_t handle, i
#endif
#endif
#ifdef ENABLE_QY_API
#ifdef ENABLE_QY_API
CREATE
(
INFINI_DEVICE_QY
,
nvidia
);
CREATE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_METAX_API
CREATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#endif
}
}
...
@@ -49,6 +55,9 @@ __C infiniStatus_t infiniopGetTopkrouterWorkspaceSize(infiniopTopkrouterDescript
...
@@ -49,6 +55,9 @@ __C infiniStatus_t infiniopGetTopkrouterWorkspaceSize(infiniopTopkrouterDescript
#endif
#endif
#ifdef ENABLE_QY_API
#ifdef ENABLE_QY_API
GET
(
INFINI_DEVICE_QY
,
nvidia
);
GET
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_METAX_API
GET
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#endif
}
}
...
@@ -75,6 +84,9 @@ __C infiniStatus_t infiniopTopkrouter(infiniopTopkrouterDescriptor_t desc, void
...
@@ -75,6 +84,9 @@ __C infiniStatus_t infiniopTopkrouter(infiniopTopkrouterDescriptor_t desc, void
#endif
#endif
#ifdef ENABLE_QY_API
#ifdef ENABLE_QY_API
CALCULATE
(
INFINI_DEVICE_QY
,
nvidia
);
CALCULATE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_METAX_API
CALCULATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#endif
}
}
...
@@ -98,6 +110,9 @@ __C infiniStatus_t infiniopDestroyTopkrouterDescriptor(infiniopTopkrouterDescrip
...
@@ -98,6 +110,9 @@ __C infiniStatus_t infiniopDestroyTopkrouterDescriptor(infiniopTopkrouterDescrip
#endif
#endif
#ifdef ENABLE_QY_API
#ifdef ENABLE_QY_API
DESTROY
(
INFINI_DEVICE_QY
,
nvidia
);
DESTROY
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_METAX_API
DESTROY
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#endif
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment