Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
ac4aae48
Unverified
Commit
ac4aae48
authored
Dec 01, 2025
by
Shijie
Committed by
GitHub
Dec 01, 2025
Browse files
Merge branch 'main' into dev_topkrouter
parents
a15aa367
2f3f4076
Changes
33
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
316 additions
and
88 deletions
+316
-88
include/infinicore/ops/matmul.hpp
include/infinicore/ops/matmul.hpp
+2
-2
python/infinicore/__init__.py
python/infinicore/__init__.py
+2
-0
python/infinicore/device.py
python/infinicore/device.py
+16
-10
python/infinicore/nn/functional/rope.py
python/infinicore/nn/functional/rope.py
+2
-2
python/infinicore/ops/matmul.py
python/infinicore/ops/matmul.py
+3
-3
python/infinicore/tensor.py
python/infinicore/tensor.py
+27
-18
src/infinicore/nn/embedding.cc
src/infinicore/nn/embedding.cc
+3
-3
src/infinicore/nn/linear.cc
src/infinicore/nn/linear.cc
+2
-2
src/infinicore/nn/parameter.cc
src/infinicore/nn/parameter.cc
+8
-2
src/infinicore/nn/rmsnorm.cc
src/infinicore/nn/rmsnorm.cc
+0
-4
src/infinicore/nn/rope.cc
src/infinicore/nn/rope.cc
+4
-22
src/infinicore/ops/matmul/matmul.cc
src/infinicore/ops/matmul/matmul.cc
+4
-4
src/infinicore/pybind11/context.hpp
src/infinicore/pybind11/context.hpp
+1
-1
src/infinicore/pybind11/infinicore.cc
src/infinicore/pybind11/infinicore.cc
+3
-0
src/infinicore/pybind11/nn.hpp
src/infinicore/pybind11/nn.hpp
+15
-0
src/infinicore/pybind11/ops/matmul.hpp
src/infinicore/pybind11/ops/matmul.hpp
+2
-0
src/infinicore/pybind11/ops/rope.hpp
src/infinicore/pybind11/ops/rope.hpp
+0
-5
src/infinicore/tensor/copy.cc
src/infinicore/tensor/copy.cc
+27
-8
src/infinicore/utils.hpp
src/infinicore/utils.hpp
+6
-2
src/infiniop/ops/topkrouter/kunlun/kernel.h
src/infiniop/ops/topkrouter/kunlun/kernel.h
+189
-0
No files found.
include/infinicore/ops/matmul.hpp
View file @
ac4aae48
...
...
@@ -5,7 +5,7 @@
namespace
infinicore
::
op
{
Tensor
matmul
(
Tensor
a
,
Tensor
b
);
void
matmul_
(
Tensor
c
,
Tensor
a
,
Tensor
b
);
Tensor
matmul
(
Tensor
a
,
Tensor
b
,
float
alpha
=
1.0
f
);
void
matmul_
(
Tensor
c
,
Tensor
a
,
Tensor
b
,
float
alpha
=
1.0
f
);
}
// namespace infinicore::op
python/infinicore/__init__.py
View file @
ac4aae48
import
contextlib
import
infinicore.context
as
context
import
infinicore.nn
as
nn
# Import context functions
...
...
@@ -60,6 +61,7 @@ from infinicore.tensor import (
__all__
=
[
# Modules.
"context"
,
"nn"
,
# Classes.
"device"
,
...
...
python/infinicore/device.py
View file @
ac4aae48
...
...
@@ -2,16 +2,20 @@ from infinicore.lib import _infinicore
class
device
:
def
__init__
(
self
,
type
=
None
,
index
=
None
):
if
type
is
None
:
type
=
"cpu"
# Public attributes describing the device
type
:
str
index
:
int
_underlying
:
_infinicore
.
Device
def
__init__
(
self
,
type
=
None
,
index
=
None
):
if
isinstance
(
type
,
device
):
self
.
type
=
type
.
type
self
.
index
=
type
.
index
return
if
type
is
None
:
type
=
"cpu"
if
":"
in
type
:
if
index
is
not
None
:
raise
ValueError
(
...
...
@@ -22,12 +26,14 @@ class device:
index
=
int
(
index
)
self
.
type
=
type
self
.
index
=
index
_type
,
_index
=
device
.
_to_infinicore_device
(
type
,
index
if
index
else
0
)
self
.
_underlying
=
_infinicore
.
Device
(
_type
,
_index
)
self
.
index
=
index
if
index
else
0
def
__getattr__
(
self
,
name
):
# Lazily construct and cache an attribute.
# such as, self._underlying .
_type
,
_index
=
device
.
_to_infinicore_device
(
self
.
type
,
self
.
index
)
setattr
(
self
,
name
,
_infinicore
.
Device
(
_type
,
_index
))
return
getattr
(
self
,
name
)
def
__repr__
(
self
):
return
f
"device(type='
{
self
.
type
}
'
{
f
', index=
{
self
.
index
}
' if self.index is not None else ''
}
)"
...
...
python/infinicore/nn/functional/rope.py
View file @
ac4aae48
...
...
@@ -5,8 +5,8 @@ from infinicore.tensor import Tensor
class
RopeAlgo
:
r
"""Different types of RoPE algorithms."""
GPT_J
=
_infinicore
.
Algo
.
GPT_J
GPT_NEOX
=
_infinicore
.
Algo
.
GPT_NEOX
GPT_J
=
_infinicore
.
RoPE
Algo
.
GPT_J
GPT_NEOX
=
_infinicore
.
RoPE
Algo
.
GPT_NEOX
def
rope
(
...
...
python/infinicore/ops/matmul.py
View file @
ac4aae48
...
...
@@ -2,10 +2,10 @@ from infinicore.lib import _infinicore
from
infinicore.tensor
import
Tensor
def
matmul
(
input
,
other
,
*
,
out
=
None
):
def
matmul
(
input
,
other
,
*
,
alpha
=
1.0
,
out
=
None
):
if
out
is
None
:
return
Tensor
(
_infinicore
.
matmul
(
input
.
_underlying
,
other
.
_underlying
))
return
Tensor
(
_infinicore
.
matmul
(
input
.
_underlying
,
other
.
_underlying
,
alpha
))
_infinicore
.
matmul_
(
out
.
_underlying
,
input
.
_underlying
,
other
.
_underlying
)
_infinicore
.
matmul_
(
out
.
_underlying
,
input
.
_underlying
,
other
.
_underlying
,
alpha
)
return
out
python/infinicore/tensor.py
View file @
ac4aae48
...
...
@@ -14,30 +14,35 @@ from .utils import (
class
Tensor
:
# Public attributes describing the device
_underlying
:
_infinicore
.
Tensor
_torch_ref
:
"torch.Tensor"
# noqa: F821
shape
:
list
[
int
]
dtype
:
infinicore
.
dtype
device
:
infinicore
.
device
def
__init__
(
self
,
underlying
,
*
,
_torch_ref
=
None
):
"""An internal method. Please do not use this directly."""
self
.
_underlying
=
underlying
self
.
_dtype
=
infinicore
.
dtype
(
self
.
_underlying
.
dtype
)
self
.
_device
=
infinicore
.
device
.
_from_infinicore_device
(
self
.
_underlying
.
device
)
self
.
_torch_ref
=
_torch_ref
@
property
def
shape
(
self
):
return
self
.
_underlying
.
shape
@
property
def
dtype
(
self
):
return
self
.
_dtype
@
property
def
device
(
self
):
return
self
.
_device
def
__getattr__
(
self
,
name
):
# Lazily construct and cache an attribute.
# such as, self.shape, self.dtype, self.device .
if
name
==
"shape"
:
setattr
(
self
,
name
,
getattr
(
self
.
_underlying
,
name
))
elif
name
==
"dtype"
:
setattr
(
self
,
name
,
infinicore
.
dtype
(
getattr
(
self
.
_underlying
,
name
)))
elif
name
==
"device"
:
setattr
(
self
,
name
,
infinicore
.
device
.
_from_infinicore_device
(
getattr
(
self
.
_underlying
,
name
)
),
)
return
getattr
(
self
,
name
)
@
property
def
ndim
(
self
):
...
...
@@ -101,6 +106,10 @@ class Tensor:
def
__add__
(
self
,
other
):
return
infinicore
.
add
(
self
,
other
)
def
__iadd__
(
self
,
other
):
infinicore
.
add
(
self
,
other
,
out
=
self
)
return
self
def
__matmul__
(
self
,
other
):
return
infinicore
.
matmul
(
self
,
other
)
...
...
src/infinicore/nn/embedding.cc
View file @
ac4aae48
...
...
@@ -36,9 +36,9 @@ Embedding::Embedding(size_t num_embeddings,
// This would require a slice operation
}
spdlog
::
debug
(
"Created Embedding module: num_embeddings={}, embedding_dim={}, dtype={}, padding_idx={}"
,
num_embeddings
,
embedding_dim
,
static_cast
<
int
>
(
dtype_
),
padding_idx_
.
has_value
()
?
std
::
to_string
(
padding_idx_
.
value
())
:
"None"
);
SPDLOG_DEBUG
(
"Created Embedding module: num_embeddings={}, embedding_dim={}, dtype={}, padding_idx={}"
,
num_embeddings
,
embedding_dim
,
static_cast
<
int
>
(
dtype_
),
padding_idx_
.
has_value
()
?
std
::
to_string
(
padding_idx_
.
value
())
:
"None"
);
}
Tensor
Embedding
::
forward
(
const
Tensor
&
indices
)
const
{
...
...
src/infinicore/nn/linear.cc
View file @
ac4aae48
...
...
@@ -22,8 +22,8 @@ Linear::Linear(size_t in_features, size_t out_features, bool bias, const DataTyp
bias_
=
Parameter
();
// Default constructed empty parameter
}
spdlog
::
debug
(
"Created Linear module: in_features={}, out_features={}, bias={}, dtype={}"
,
in_features
,
out_features
,
bias
,
static_cast
<
int
>
(
dtype_
));
SPDLOG_DEBUG
(
"Created Linear module: in_features={}, out_features={}, bias={}, dtype={}"
,
in_features
,
out_features
,
bias
,
static_cast
<
int
>
(
dtype_
));
}
Tensor
Linear
::
compute_linear
(
Tensor
&
input
)
const
{
...
...
src/infinicore/nn/parameter.cc
View file @
ac4aae48
...
...
@@ -19,7 +19,13 @@ Parameter::Parameter(
void
Parameter
::
load_blob
(
const
void
*
data
)
{
auto
buffer
=
Tensor
::
empty
(
impl_
->
shape
(),
impl_
->
dtype
(),
Device
(
Device
::
Type
::
CPU
,
0
),
true
);
std
::
memcpy
(
buffer
->
data
(),
data
,
buffer
->
nbytes
());
infinicore
::
context
::
memcpyH2D
(
impl_
->
data
(),
buffer
->
data
(),
buffer
->
nbytes
());
infinicore
::
context
::
syncStream
();
// If parameter is on CPU, use direct memcpy; otherwise use H2D
if
(
impl_
->
device
().
getType
()
==
Device
::
Type
::
CPU
)
{
infinicore
::
context
::
memcpyH2H
(
impl_
->
data
(),
buffer
->
data
(),
buffer
->
nbytes
());
}
else
{
infinicore
::
context
::
memcpyH2D
(
impl_
->
data
(),
buffer
->
data
(),
buffer
->
nbytes
());
infinicore
::
context
::
syncStream
();
}
}
}
// namespace infinicore::nn
src/infinicore/nn/rmsnorm.cc
View file @
ac4aae48
#include "infinicore/nn/rmsnorm.hpp"
#include "infinicore/ops.hpp"
#include <cmath>
#include <spdlog/spdlog.h>
#include <stdexcept>
namespace
infinicore
::
nn
{
...
...
@@ -19,9 +18,6 @@ RMSNorm::RMSNorm(size_t normalized_shape, double eps, const DataType &dtype, con
// Initialize weight to ones (standard practice for RMSNorm)
auto
ones_tensor
=
Tensor
::
ones
({
normalized_shape
},
dtype_
,
device
);
weight_
->
copy_from
(
ones_tensor
);
spdlog
::
debug
(
"Created RMSNorm module: normalized_shape={}, eps={}, dtype={}"
,
normalized_shape
,
eps
,
static_cast
<
int
>
(
dtype_
));
}
Tensor
RMSNorm
::
forward
(
const
Tensor
&
x
)
const
{
...
...
src/infinicore/nn/rope.cc
View file @
ac4aae48
...
...
@@ -4,7 +4,6 @@
#include <algorithm>
#include <cmath>
#include <functional>
#include <spdlog/spdlog.h>
#include <stdexcept>
namespace
infinicore
::
nn
{
...
...
@@ -20,7 +19,6 @@ RoPE::RoPE(size_t head_dim,
theta_
(
theta
),
algo_
(
algo
),
dtype_
(
dtype
)
{
if
(
head_dim
%
2
!=
0
)
{
throw
std
::
invalid_argument
(
"head_dim must be even for RoPE, got "
+
std
::
to_string
(
head_dim
));
}
...
...
@@ -29,9 +27,6 @@ RoPE::RoPE(size_t head_dim,
// Initialize cache tables
initialize_cache
();
spdlog
::
debug
(
"Created RoPE module: head_dim={}, max_seq_len={}, theta={}, algo={}, dtype={}"
,
head_dim
,
max_seq_len
,
theta
,
static_cast
<
int
>
(
algo
),
static_cast
<
int
>
(
dtype_
));
}
void
RoPE
::
initialize_cache
()
{
...
...
@@ -42,9 +37,8 @@ void RoPE::initialize_cache() {
INFINICORE_NN_BUFFER_INIT
(
cos_cache
,
({
max_seq_len_
,
cache_dim
},
dtype_
,
device_
));
// Pre-compute sin and cos values
// The frequency calculation differs based on algorithm:
// - GPT_J: pairs are (2j, 2j+1) for cache entry j, frequency for dimension 2j is theta^(-2j/head_dim)
// - GPT_NEOX: pairs are (j, j+head_dim/2) for cache entry j, frequency for dimension j is theta^(-j/head_dim)
// Frequency generation always uses GPT-J style (theta^(-2j/head_dim)).
// The rotation algorithm (algo_) controls how dimensions are paired in the kernel.
// Compute on CPU first, then copy to device
auto
cpu_device
=
Device
(
Device
::
Type
::
CPU
,
0
);
...
...
@@ -55,20 +49,8 @@ void RoPE::initialize_cache() {
for
(
size_t
pos
=
0
;
pos
<
max_seq_len_
;
pos
++
)
{
for
(
size_t
j
=
0
;
j
<
cache_dim
;
j
++
)
{
// Compute inverse frequency based on algorithm
double
inv_freq
;
if
(
algo_
==
Algo
::
GPT_J
)
{
// GPT_J: pairs are (2j, 2j+1) for cache entry j
// Frequency for pair j: theta^(-2j/head_dim)
inv_freq
=
1.0
/
std
::
pow
(
theta_
,
2.0
*
static_cast
<
double
>
(
j
)
/
static_cast
<
double
>
(
head_dim_
));
}
else
if
(
algo_
==
Algo
::
GPT_NEOX
)
{
// GPT_NEOX: pairs are (j, j+head_dim/2) for cache entry j
// Frequency for pair j (corresponding to dimension j): theta^(-j/head_dim)
inv_freq
=
1.0
/
std
::
pow
(
theta_
,
static_cast
<
double
>
(
j
)
/
static_cast
<
double
>
(
head_dim_
));
}
else
{
throw
std
::
runtime_error
(
"Unsupported RoPE algorithm: "
+
std
::
to_string
(
static_cast
<
int
>
(
algo_
)));
}
// GPT-J style inverse frequency: theta^(-2j/head_dim)
double
inv_freq
=
1.0
/
std
::
pow
(
theta_
,
2.0
*
static_cast
<
double
>
(
j
)
/
static_cast
<
double
>
(
head_dim_
));
// Compute angle: position * inverse_frequency
double
angle
=
static_cast
<
double
>
(
pos
)
*
inv_freq
;
...
...
src/infinicore/ops/matmul/matmul.cc
View file @
ac4aae48
...
...
@@ -3,11 +3,11 @@
namespace
infinicore
::
op
{
Tensor
matmul
(
Tensor
a
,
Tensor
b
)
{
return
gemm
(
a
,
b
,
1.0
f
,
0.0
f
);
Tensor
matmul
(
Tensor
a
,
Tensor
b
,
float
alpha
)
{
return
gemm
(
a
,
b
,
alpha
,
0.0
f
);
}
void
matmul_
(
Tensor
c
,
Tensor
a
,
Tensor
b
)
{
Gemm
::
execute
(
c
,
a
,
b
,
1.0
f
,
0.0
f
);
void
matmul_
(
Tensor
c
,
Tensor
a
,
Tensor
b
,
float
alpha
)
{
Gemm
::
execute
(
c
,
a
,
b
,
alpha
,
0.0
f
);
}
}
// namespace infinicore::op
src/infinicore/pybind11/context.hpp
View file @
ac4aae48
...
...
@@ -26,4 +26,4 @@ inline void bind(py::module &m) {
m
.
def
(
"sync_device"
,
&
syncDevice
,
"Synchronize the current device"
);
}
}
// namespace infinicore::context
\ No newline at end of file
}
// namespace infinicore::context
src/infinicore/pybind11/infinicore.cc
View file @
ac4aae48
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "../utils.hpp"
#include "context.hpp"
#include "device.hpp"
#include "device_event.hpp"
#include "dtype.hpp"
#include "nn.hpp"
#include "ops.hpp"
#include "tensor.hpp"
...
...
@@ -17,6 +19,7 @@ PYBIND11_MODULE(_infinicore, m) {
dtype
::
bind
(
m
);
ops
::
bind
(
m
);
tensor
::
bind
(
m
);
pybind11_nn
::
bind
(
m
);
}
}
// namespace infinicore
src/infinicore/pybind11/nn.hpp
0 → 100644
View file @
ac4aae48
#pragma once
#include <pybind11/pybind11.h>
#include "nn/rope.hpp"
namespace
py
=
pybind11
;
namespace
infinicore
::
pybind11_nn
{
inline
void
bind
(
py
::
module
&
m
)
{
bind_rope
(
m
);
}
}
// namespace infinicore::pybind11_nn
src/infinicore/pybind11/ops/matmul.hpp
View file @
ac4aae48
...
...
@@ -13,6 +13,7 @@ inline void bind_matmul(py::module &m) {
&
op
::
matmul
,
py
::
arg
(
"a"
),
py
::
arg
(
"b"
),
py
::
arg
(
"alpha"
)
=
1.0
f
,
R"doc(Matrix multiplication of two tensors.)doc"
);
m
.
def
(
"matmul_"
,
...
...
@@ -20,6 +21,7 @@ inline void bind_matmul(py::module &m) {
py
::
arg
(
"c"
),
py
::
arg
(
"a"
),
py
::
arg
(
"b"
),
py
::
arg
(
"alpha"
)
=
1.0
f
,
R"doc(In-place matrix multiplication.)doc"
);
}
...
...
src/infinicore/pybind11/ops/rope.hpp
View file @
ac4aae48
...
...
@@ -9,11 +9,6 @@ namespace py = pybind11;
namespace
infinicore
::
ops
{
inline
void
bind_rope
(
py
::
module
&
m
)
{
py
::
enum_
<
infinicore
::
nn
::
RoPE
::
Algo
>
(
m
,
"Algo"
)
.
value
(
"GPT_J"
,
infinicore
::
nn
::
RoPE
::
Algo
::
GPT_J
)
.
value
(
"GPT_NEOX"
,
infinicore
::
nn
::
RoPE
::
Algo
::
GPT_NEOX
);
m
.
def
(
"rope"
,
&
op
::
rope
,
py
::
arg
(
"x"
),
...
...
src/infinicore/tensor/copy.cc
View file @
ac4aae48
...
...
@@ -3,14 +3,15 @@
#include "infinicore/ops.hpp"
#include "infinicore/tensor.hpp"
#include <spdlog/spdlog.h>
#include <algorithm>
#include <cstring>
#include <iostream>
namespace
infinicore
{
Tensor
TensorImpl
::
to
(
Device
device
)
const
{
if
(
device
==
data_
.
memory
->
device
())
{
return
Tensor
(
const_cast
<
TensorImpl
*>
(
this
)
->
shared_from_this
());
}
else
{
std
::
shared_ptr
<
TensorImpl
>
_t
=
empty
(
meta_
.
shape
,
meta_
.
dtype
,
device
,
true
);
std
::
shared_ptr
<
TensorImpl
>
_t
=
empty
(
meta_
.
shape
,
meta_
.
dtype
,
device
);
_t
->
copy_from
(
Tensor
(
const_cast
<
TensorImpl
*>
(
this
)
->
shared_from_this
()));
return
Tensor
(
_t
);
}
...
...
@@ -20,26 +21,44 @@ void TensorImpl::copy_from(Tensor src) {
if
(
src
->
shape
()
!=
this
->
shape
())
{
throw
std
::
runtime_error
(
"Cannot copy from tensor with different shape"
);
}
if
(
this
->
device
().
getType
()
==
src
->
device
().
getType
())
{
op
::
rearrange_
(
Tensor
(
const_cast
<
TensorImpl
*>
(
this
)
->
shared_from_this
()),
src
);
if
(
this
->
device
()
==
src
->
device
())
{
// If both tensors are contiguous, use direct memcpy (much faster and avoids rearrange issues)
if
(
this
->
is_contiguous
()
&&
src
->
is_contiguous
())
{
// Use nbytes() to get the actual tensor size
size_t
copy_size
=
std
::
min
(
this
->
nbytes
(),
src
->
nbytes
());
// For CPU-to-CPU copies, use regular memcpy. For device-to-device, use D2D memcpy
if
(
this
->
device
().
getType
()
==
Device
::
Type
::
CPU
)
{
context
::
memcpyH2H
(
this
->
data
(),
src
->
data
(),
copy_size
);
}
else
{
context
::
memcpyD2D
(
this
->
data
(),
src
->
data
(),
copy_size
);
}
}
else
{
op
::
rearrange_
(
Tensor
(
const_cast
<
TensorImpl
*>
(
this
)
->
shared_from_this
()),
src
);
}
}
else
{
if
(
!
src
->
is_contiguous
())
{
src
=
src
->
contiguous
();
}
// Use nbytes() to get the actual tensor size, not the full memory size
size_t
copy_size
=
std
::
min
(
this
->
nbytes
(),
src
->
nbytes
());
if
(
this
->
device
().
getType
()
==
Device
::
Type
::
CPU
)
{
if
(
this
->
is_contiguous
())
{
context
::
memcpyD2H
(
this
->
data
(),
src
->
data
(),
this
->
data_
.
memory
->
size
()
);
context
::
memcpyD2H
(
this
->
data
(),
src
->
data
(),
copy_
size
);
}
else
{
auto
local_src
=
Tensor
::
empty
(
this
->
shape
(),
this
->
dtype
(),
this
->
device
());
context
::
memcpyD2H
(
local_src
->
data
(),
src
->
data
(),
this
->
data_
.
memory
->
size
());
op
::
rearrange_
(
Tensor
(
const_cast
<
TensorImpl
*>
(
this
)
->
shared_from_this
()),
local_src
);
}
}
else
if
(
src
->
device
().
getType
()
==
Device
::
Type
::
CPU
)
{
if
(
this
->
is_contiguous
())
{
context
::
memcpyH2D
(
this
->
data
(),
src
->
data
(),
this
->
data_
.
memory
->
size
()
);
context
::
memcpyH2D
(
this
->
data
(),
src
->
data
(),
copy_
size
);
}
else
{
auto
local_src
=
Tensor
::
empty
(
this
->
shape
(),
this
->
dtype
(),
this
->
device
());
context
::
memcpyH2D
(
local_src
->
data
(),
src
->
data
(),
this
->
data_
.
memory
->
size
()
);
context
::
memcpyH2D
(
local_src
->
data
(),
src
->
data
(),
copy_
size
);
op
::
rearrange_
(
Tensor
(
const_cast
<
TensorImpl
*>
(
this
)
->
shared_from_this
()),
local_src
);
}
}
...
...
src/infinicore/utils.hpp
View file @
ac4aae48
...
...
@@ -13,6 +13,10 @@ inline struct SpdlogInitializer {
}
else
{
spdlog
::
cfg
::
load_env_levels
(
"INFINICORE_LOG_LEVEL"
);
}
// Set pattern for logging
// Using SPDLOG_* macros enables source location support (%s and %#)
// Format: [timestamp] [level] [file:line] message
spdlog
::
set_pattern
(
"[%Y-%m-%d %H:%M:%S.%e] [%^%l%$] [%s:%#] %v"
);
}
}
spdlog_initializer
;
...
...
@@ -21,9 +25,9 @@ inline struct SpdlogInitializer {
#define INFINICORE_CHECK_ERROR(call) \
do { \
spdlog::debug
("Entering `" #call "` at `" __FILE__ ":" STRINGIZE(__LINE__) "`."); \
SPDLOG_DEBUG
("Entering `" #call "` at `" __FILE__ ":" STRINGIZE(__LINE__) "`.");
\
infiniStatus_t ret = (call); \
spdlog::debug
("Exiting `" #call "` at `" __FILE__ ":" STRINGIZE(__LINE__) "`."); \
SPDLOG_DEBUG
("Exiting `" #call "` at `" __FILE__ ":" STRINGIZE(__LINE__) "`.");
\
if (ret != INFINI_STATUS_SUCCESS) { \
throw std::runtime_error(#call " failed with error: " + std::string(infini_status_string(ret))); \
} \
...
...
src/infiniop/ops/topkrouter/kunlun/kernel.h
0 → 100644
View file @
ac4aae48
#ifndef __TOPKROUTER_KUNLUN_KERNEL_H__
#define __TOPKROUTER_KUNLUN_KERNEL_H__
#include "../../../devices/kunlun/kunlun_kernel_common.h"
#include "../../../sort/kunlun/heap.h"
#include <float.h>
using
namespace
device
::
kunlun
::
kernel
;
template
<
typename
T
>
inline
__device__
float
expf_
(
T
x
)
{
float
data
;
if
constexpr
(
std
::
is_same_v
<
T
,
float
>
)
{
data
=
x
;
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
bfloat16_t
>
)
{
data
=
__bfloat162float
(
x
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
half
>
)
{
data
=
__half2float
(
x
);
}
return
exp
(
data
);
}
template
<
typename
T
>
inline
__device__
float
sigmoidf_
(
T
x
)
{
return
1.0
f
/
(
1.0
f
+
expf_
<
T
>
(
-
x
));
}
template
<
typename
T
,
typename
TID
>
inline
__device__
void
descending_sort
(
T
*
x
,
TID
*
idx
,
int32_t
n
)
{
make_lm_min_heap
(
x
,
idx
,
n
);
mfence_lm
();
sort_lm_min_heap
(
x
,
idx
,
n
);
mfence_lm
();
}
template
<
typename
T
,
int32_t
BLOCK_THREADS
=
64
,
int32_t
MAX_EXPERTS
=
256
,
int32_t
N_GROUPS
=
8
,
int32_t
TOPK_GROUP
=
4
,
int32_t
TOPK_PER_GROUP
=
2
>
__global__
void
topkrouter_kernel
(
float
*
values_topk
,
// 输出数据, 形状[N, topk]
int32_t
*
indices_topk
,
// 输出索引, 形状[N, topk]
const
T
*
input
,
// 输入数据 [N, n_experts]
const
float
*
d_correction_bias
,
// 输入数据 [n_experts]
const
float
routed_scaling_factor
,
const
int32_t
N
,
// N tokens
const
int32_t
n_experts
,
// n_experts <= MAX_EXPERTS
const
int32_t
topk
)
{
const
int32_t
block_idx
=
cluster_id
();
if
(
block_idx
>=
N
)
{
return
;
}
const
int32_t
thread_idx
=
core_id
();
const
int32_t
GROUP_SIZE
=
n_experts
/
N_GROUPS
;
// 32 in DeepSeek-V3
__shared__
T
input_shm
[
MAX_EXPERTS
];
// input shm for i-th token, total N
__shared__
float
correction_bias_sm
[
MAX_EXPERTS
];
// Copy data into SM
if
(
thread_idx
==
0
)
{
GM2SM_ASYNC
(
input
+
block_idx
*
n_experts
,
input_shm
,
n_experts
*
sizeof
(
T
));
GM2SM_ASYNC
(
d_correction_bias
,
correction_bias_sm
,
n_experts
*
sizeof
(
float
));
}
sync_cluster
();
// Calculate sigmoid scores and add bias
__shared__
float
scores
[
MAX_EXPERTS
];
__shared__
float
scores_with_bias_shm
[
MAX_EXPERTS
];
for
(
int32_t
i
=
thread_idx
;
i
<
n_experts
;
i
+=
BLOCK_THREADS
)
{
float
v
=
sigmoidf_
<
T
>
(
input_shm
[
i
]);
scores
[
i
]
=
v
;
scores_with_bias_shm
[
i
]
=
v
+
correction_bias_sm
[
i
];
}
sync_cluster
();
// 按N_GROUPS分组,每组统计TOPK_PER_GROUP最大分数和
__shared__
float
values_grouped_topk_shm
[
N_GROUPS
];
if
(
thread_idx
<
N_GROUPS
)
{
int32_t
base
=
thread_idx
*
GROUP_SIZE
;
float
tmp
[
TOPK_PER_GROUP
];
// 初始化为负无穷,便于找topk
#pragma unroll
for
(
int32_t
k
=
0
;
k
<
TOPK_PER_GROUP
;
++
k
)
{
tmp
[
k
]
=
-
FLT_MAX
;
}
// 维护一个TOPK_PER_GROUP大小的降序队列
for
(
int32_t
i
=
0
;
i
<
GROUP_SIZE
;
++
i
)
{
float
val
=
scores_with_bias_shm
[
base
+
i
];
// 插入到队列
if
(
val
>
tmp
[
TOPK_PER_GROUP
-
1
])
{
int
pos
=
TOPK_PER_GROUP
-
1
;
while
(
pos
>
0
&&
val
>
tmp
[
pos
-
1
])
{
tmp
[
pos
]
=
tmp
[
pos
-
1
];
--
pos
;
}
tmp
[
pos
]
=
val
;
}
}
float
group_sum
=
0.
f
;
for
(
int32_t
k
=
0
;
k
<
TOPK_PER_GROUP
;
++
k
)
{
group_sum
+=
tmp
[
k
];
}
values_grouped_topk_shm
[
thread_idx
]
=
group_sum
;
}
sync_cluster
();
// Select TOPK_GROUP in N_GROUPS according to sum of TOPK_PER_GROUP values in each group
__shared__
int32_t
indices_group
[
TOPK_GROUP
];
if
(
thread_idx
==
0
)
{
float
values_group
[
TOPK_GROUP
];
int32_t
indices_tmp
[
TOPK_GROUP
];
// 初始化为负无穷和-1
#pragma unroll
for
(
int32_t
k
=
0
;
k
<
TOPK_GROUP
;
++
k
)
{
values_group
[
k
]
=
-
FLT_MAX
;
indices_tmp
[
k
]
=
-
1
;
}
for
(
int32_t
i
=
0
;
i
<
N_GROUPS
;
i
++
)
{
float
val
=
values_grouped_topk_shm
[
i
];
if
(
val
>
values_group
[
TOPK_GROUP
-
1
])
{
int32_t
pos
=
TOPK_GROUP
-
1
;
while
(
pos
>
0
&&
val
>
values_group
[
pos
-
1
])
{
values_group
[
pos
]
=
values_group
[
pos
-
1
];
indices_tmp
[
pos
]
=
indices_tmp
[
pos
-
1
];
pos
--
;
}
values_group
[
pos
]
=
val
;
indices_tmp
[
pos
]
=
i
;
}
}
// 写入共享内存
#pragma unroll
for
(
int32_t
k
=
0
;
k
<
TOPK_GROUP
;
++
k
)
{
indices_group
[
k
]
=
indices_tmp
[
k
];
}
}
sync_cluster
();
// 拷贝被选中的group的数据 values_group_select和 indices_group_select
__shared__
float
values_group_select
[
MAX_EXPERTS
];
__shared__
int32_t
indices_group_select
[
MAX_EXPERTS
];
if
(
thread_idx
<
TOPK_GROUP
)
{
int32_t
group_id
=
indices_group
[
thread_idx
];
// 用于本线程复制group数据的临时buffer
float
local_buffer
[
GROUP_SIZE
];
// 拷贝选中group的所有分数到local_buffer
__builtin_memcpy
(
local_buffer
,
scores_with_bias_shm
+
group_id
*
GROUP_SIZE
,
GROUP_SIZE
*
sizeof
(
float
));
mfence_lm
();
// 写回到共享内存选取buffer,对齐排列
__builtin_memcpy
(
values_group_select
+
thread_idx
*
GROUP_SIZE
,
local_buffer
,
GROUP_SIZE
*
sizeof
(
float
));
// 记录原始索引
for
(
int32_t
i
=
0
;
i
<
GROUP_SIZE
;
i
++
)
{
indices_group_select
[
thread_idx
*
GROUP_SIZE
+
i
]
=
group_id
*
GROUP_SIZE
+
i
;
}
}
sync_cluster
();
// Global topk and copy to GM
if
(
thread_idx
==
0
)
{
int32_t
len
=
GROUP_SIZE
*
TOPK_GROUP
;
float
values
[
len
];
int32_t
indices
[
len
];
// COPY to LM
__builtin_memcpy
(
values
,
values_group_select
,
len
*
sizeof
(
float
));
__builtin_memcpy
(
indices
,
indices_group_select
,
len
*
sizeof
(
int32_t
));
mfence_lm
();
// Sort
descending_sort
<
float
,
int32_t
>
(
values
,
indices
,
len
);
// Last scaling
float
sum
=
1e-9
f
;
for
(
int32_t
k
=
0
;
k
<
topk
;
k
++
)
{
int32_t
idx
=
indices
[
k
];
sum
+=
scores
[
idx
];
}
for
(
int32_t
k
=
0
;
k
<
topk
;
k
++
)
{
int32_t
idx
=
indices
[
k
];
values
[
k
]
=
routed_scaling_factor
*
scores
[
idx
]
/
sum
;
}
mfence_lm
();
// COPY to GM
LM2GM_ASYNC
(
values
,
values_topk
,
topk
*
sizeof
(
float
));
LM2GM_ASYNC
(
indices
,
indices_topk
,
topk
*
sizeof
(
int32_t
));
}
sync_cluster
();
}
#endif // __TOPKROUTER_KUNLUN_KERNEL_H__
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment