Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8e09b370
Unverified
Commit
8e09b370
authored
Apr 18, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Apr 17, 2025
Browse files
Sgl kernel fused_moe_gate support n_shared_experts (#5440)
parent
53dcf388
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
140 additions
and
38 deletions
+140
-38
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+2
-1
sgl-kernel/csrc/moe/moe_fused_gate.cu
sgl-kernel/csrc/moe/moe_fused_gate.cu
+81
-28
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+8
-2
sgl-kernel/python/sgl_kernel/moe.py
sgl-kernel/python/sgl_kernel/moe.py
+18
-2
sgl-kernel/tests/test_moe_fused_gate.py
sgl-kernel/tests/test_moe_fused_gate.py
+31
-5
No files found.
sgl-kernel/csrc/common_extension.cc
View file @
8e09b370
...
@@ -146,7 +146,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -146,7 +146,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
impl
(
"topk_softmax"
,
torch
::
kCUDA
,
&
topk_softmax
);
m
.
impl
(
"topk_softmax"
,
torch
::
kCUDA
,
&
topk_softmax
);
m
.
def
(
m
.
def
(
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk) -> "
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int "
"n_share_experts_fusion, float routed_scaling_factor) -> "
"(Tensor[])"
);
"(Tensor[])"
);
m
.
impl
(
"moe_fused_gate"
,
torch
::
kCUDA
,
&
moe_fused_gate
);
m
.
impl
(
"moe_fused_gate"
,
torch
::
kCUDA
,
&
moe_fused_gate
);
...
...
sgl-kernel/csrc/moe/moe_fused_gate.cu
View file @
8e09b370
...
@@ -57,6 +57,8 @@ __device__ void moe_fused_gate_impl(
...
@@ -57,6 +57,8 @@ __device__ void moe_fused_gate_impl(
int64_t
num_rows
,
int64_t
num_rows
,
int64_t
topk_group
,
int64_t
topk_group
,
int64_t
topk
,
int64_t
topk
,
int64_t
n_share_experts_fusion
,
double
routed_scaling_factor
,
Params
params
)
{
Params
params
)
{
int
tidx
=
threadIdx
.
x
;
int
tidx
=
threadIdx
.
x
;
int64_t
thread_row
=
int64_t
thread_row
=
...
@@ -65,6 +67,9 @@ __device__ void moe_fused_gate_impl(
...
@@ -65,6 +67,9 @@ __device__ void moe_fused_gate_impl(
return
;
return
;
}
}
// Calculate topk_excluding_share_expert_fusion from topk
int64_t
topk_excluding_share_expert_fusion
=
topk
-
(
n_share_experts_fusion
>
0
?
1
:
0
);
// Cast pointers to type T:
// Cast pointers to type T:
auto
*
input_ptr
=
reinterpret_cast
<
T
*>
(
input
);
auto
*
input_ptr
=
reinterpret_cast
<
T
*>
(
input
);
auto
*
bias_ptr
=
reinterpret_cast
<
T
*>
(
bias
);
auto
*
bias_ptr
=
reinterpret_cast
<
T
*>
(
bias
);
...
@@ -163,7 +168,7 @@ __device__ void moe_fused_gate_impl(
...
@@ -163,7 +168,7 @@ __device__ void moe_fused_gate_impl(
////////////////////// Topk //////////////////////
////////////////////// Topk //////////////////////
float
output_sum
=
0.0
f
;
float
output_sum
=
0.0
f
;
for
(
int
k_idx
=
0
;
k_idx
<
topk
;
++
k_idx
)
{
for
(
int
k_idx
=
0
;
k_idx
<
topk
_excluding_share_expert_fusion
;
++
k_idx
)
{
// local argmax
// local argmax
T
max_val
=
bias_chunk
[
0
];
T
max_val
=
bias_chunk
[
0
];
int
expert
=
first_elt_read_by_thread
;
int
expert
=
first_elt_read_by_thread
;
...
@@ -181,7 +186,7 @@ __device__ void moe_fused_gate_impl(
...
@@ -181,7 +186,7 @@ __device__ void moe_fused_gate_impl(
max_val
=
static_cast
<
T
>
(
-
FLT_MAX
);
max_val
=
static_cast
<
T
>
(
-
FLT_MAX
);
}
}
// argmax reduce
// argmax reduce
#pragma unroll
#pragma unroll
for
(
int
mask
=
params
.
THREADS_PER_ROW
/
2
;
mask
>
0
;
mask
/=
2
)
{
for
(
int
mask
=
params
.
THREADS_PER_ROW
/
2
;
mask
>
0
;
mask
/=
2
)
{
T
other_max
=
T
other_max
=
...
@@ -195,36 +200,46 @@ __device__ void moe_fused_gate_impl(
...
@@ -195,36 +200,46 @@ __device__ void moe_fused_gate_impl(
}
}
}
}
if
(
k_idx
<
topk
)
{
int
thread_to_clear_in_group
=
expert
/
params
.
VPT
;
int
thread_to_clear_in_group
=
expert
/
params
.
VPT
;
int64_t
idx
=
topk
*
thread_row
+
k_idx
;
int64_t
idx
=
topk
*
thread_row
+
k_idx
;
if
(
thread_group_idx
==
thread_to_clear_in_group
)
{
if
(
thread_group_idx
==
thread_to_clear_in_group
)
{
int
expert_to_clear_in_thread
=
expert
%
params
.
VPT
;
int
expert_to_clear_in_thread
=
expert
%
params
.
VPT
;
// clear the max value in the thread
// clear the max value in the thread
bias_chunk
[
expert_to_clear_in_thread
]
=
static_cast
<
T
>
(
-
FLT_MAX
);
bias_chunk
[
expert_to_clear_in_thread
]
=
static_cast
<
T
>
(
-
FLT_MAX
);
// store output
// store output
output_ptr
[
idx
]
=
static_cast
<
float
>
(
row_chunk
[
expert_to_clear_in_thread
]);
output_ptr
[
idx
]
=
static_cast
<
float
>
(
row_chunk
[
expert_to_clear_in_thread
]);
indices_ptr
[
idx
]
=
static_cast
<
int32_t
>
(
expert
);
indices_ptr
[
idx
]
=
static_cast
<
int32_t
>
(
expert
);
}
}
// accumulate sum
// accumulate sum for all elements
if
(
thread_group_idx
==
0
)
{
if
(
thread_group_idx
==
0
)
{
output_sum
+=
output_ptr
[
idx
];
output_sum
+=
output_ptr
[
idx
];
}
}
}
__syncthreads
();
__syncthreads
();
}
}
if
(
thread_group_idx
==
0
&&
n_share_experts_fusion
>
0
)
{
int64_t
last_idx
=
topk
*
thread_row
+
topk_excluding_share_expert_fusion
;
// Use round-robin to select expert
int64_t
expert_offset
=
thread_row
%
n_share_experts_fusion
;
indices_ptr
[
last_idx
]
=
static_cast
<
int32_t
>
(
params
.
NUM_EXPERTS
+
expert_offset
);
// Set the weight to the sum of all weights divided by routed_scaling_factor
output_ptr
[
last_idx
]
=
output_sum
/
routed_scaling_factor
;
}
__syncthreads
();
////////////////////// Rescale Output //////////////////////
////////////////////// Rescale Output //////////////////////
if
(
thread_group_idx
==
0
)
{
if
(
thread_group_idx
==
0
)
{
#pragma unroll
#pragma unroll
for
(
int
ii
=
0
;
ii
<
topk
;
++
ii
)
{
for
(
int
ii
=
0
;
ii
<
topk
;
++
ii
)
{
int64_t
const
idx
=
topk
*
thread_row
+
ii
;
int64_t
const
idx
=
topk
*
thread_row
+
ii
;
output_ptr
[
idx
]
=
static_cast
<
float
>
(
static_cast
<
T
>
(
output_ptr
[
idx
]
)
/
static_cast
<
T
>
(
output_sum
))
;
output_ptr
[
idx
]
=
output_ptr
[
idx
]
/
output_sum
;
}
}
}
}
}
}
...
@@ -257,9 +272,21 @@ __global__ void moe_fused_gate_kernel(
...
@@ -257,9 +272,21 @@ __global__ void moe_fused_gate_kernel(
int32_t
*
indices_ptr
,
int32_t
*
indices_ptr
,
int64_t
num_rows
,
int64_t
num_rows
,
int64_t
topk_group
,
int64_t
topk_group
,
int64_t
topk
)
{
int64_t
topk
,
int64_t
n_share_experts_fusion
,
double
routed_scaling_factor
)
{
KernelParams
<
VPT
,
NUM_EXPERTS
,
THREADS_PER_ROW
,
ROWS_PER_WARP
,
ROWS_PER_CTA
,
WARPS_PER_CTA
>
params
;
KernelParams
<
VPT
,
NUM_EXPERTS
,
THREADS_PER_ROW
,
ROWS_PER_WARP
,
ROWS_PER_CTA
,
WARPS_PER_CTA
>
params
;
moe_fused_gate_impl
<
T
>
(
input
,
bias
,
output_ptr
,
indices_ptr
,
num_rows
,
topk_group
,
topk
,
params
);
moe_fused_gate_impl
<
T
>
(
input
,
bias
,
output_ptr
,
indices_ptr
,
num_rows
,
topk_group
,
topk
,
n_share_experts_fusion
,
routed_scaling_factor
,
params
);
}
}
// Macro to compute compile-time constants and launch the kernel.
// Macro to compute compile-time constants and launch the kernel.
...
@@ -277,7 +304,9 @@ __global__ void moe_fused_gate_kernel(
...
@@ -277,7 +304,9 @@ __global__ void moe_fused_gate_kernel(
indices.data_ptr<int32_t>(), \
indices.data_ptr<int32_t>(), \
num_rows, \
num_rows, \
topk_group, \
topk_group, \
topk); \
topk, \
n_share_experts_fusion, \
routed_scaling_factor); \
dispatched = true; \
dispatched = true; \
} while (0)
} while (0)
...
@@ -303,7 +332,9 @@ __global__ void moe_fused_gate_kernel_dynamic(
...
@@ -303,7 +332,9 @@ __global__ void moe_fused_gate_kernel_dynamic(
int64_t
num_experts
,
int64_t
num_experts
,
int64_t
num_expert_group
,
int64_t
num_expert_group
,
int64_t
topk_group
,
int64_t
topk_group
,
int64_t
topk
)
{
int64_t
topk
,
int64_t
n_share_experts_fusion
,
double
routed_scaling_factor
)
{
KernelParamsDynamic
params
;
KernelParamsDynamic
params
;
params
.
NUM_EXPERTS
=
num_experts
;
// e.g, for deepseek v3, this is 256
params
.
NUM_EXPERTS
=
num_experts
;
// e.g, for deepseek v3, this is 256
params
.
VPT
=
num_experts
/
num_expert_group
;
// e.g., for deepseek v3, this is 256 / 8 = 32
params
.
VPT
=
num_experts
/
num_expert_group
;
// e.g., for deepseek v3, this is 256 / 8 = 32
...
@@ -312,14 +343,30 @@ __global__ void moe_fused_gate_kernel_dynamic(
...
@@ -312,14 +343,30 @@ __global__ void moe_fused_gate_kernel_dynamic(
params
.
ROWS_PER_WARP
=
std
::
max
<
int64_t
>
(
1
,
WARP_SIZE
/
num_expert_group
);
// WARP_SIZE is fixed as 32
params
.
ROWS_PER_WARP
=
std
::
max
<
int64_t
>
(
1
,
WARP_SIZE
/
num_expert_group
);
// WARP_SIZE is fixed as 32
params
.
ROWS_PER_CTA
=
params
.
WARPS_PER_CTA
*
params
.
ROWS_PER_WARP
;
params
.
ROWS_PER_CTA
=
params
.
WARPS_PER_CTA
*
params
.
ROWS_PER_WARP
;
moe_fused_gate_impl
<
T
>
(
input
,
bias
,
output_ptr
,
indices_ptr
,
num_rows
,
topk_group
,
topk
,
params
);
moe_fused_gate_impl
<
T
>
(
input
,
bias
,
output_ptr
,
indices_ptr
,
num_rows
,
topk_group
,
topk
,
n_share_experts_fusion
,
routed_scaling_factor
,
params
);
}
}
//------------------------------------------------------------------------------
//------------------------------------------------------------------------------
// Host Launcher Function
// Host Launcher Function
//------------------------------------------------------------------------------
//------------------------------------------------------------------------------
std
::
vector
<
at
::
Tensor
>
std
::
vector
<
at
::
Tensor
>
moe_fused_gate
(
moe_fused_gate
(
at
::
Tensor
&
input
,
at
::
Tensor
&
bias
,
int64_t
num_expert_group
,
int64_t
topk_group
,
int64_t
topk
)
{
at
::
Tensor
&
input
,
at
::
Tensor
&
bias
,
int64_t
num_expert_group
,
int64_t
topk_group
,
int64_t
topk
,
int64_t
n_share_experts_fusion
,
double
routed_scaling_factor
)
{
int64_t
num_rows
=
input
.
size
(
0
);
int64_t
num_rows
=
input
.
size
(
0
);
int32_t
num_experts
=
input
.
size
(
1
);
int32_t
num_experts
=
input
.
size
(
1
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
);
...
@@ -416,7 +463,9 @@ moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, in
...
@@ -416,7 +463,9 @@ moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, in
num_experts
,
num_experts
,
num_expert_group
,
num_expert_group
,
topk_group
,
topk_group
,
topk
);
topk
,
n_share_experts_fusion
,
routed_scaling_factor
);
}
else
if
(
input
.
scalar_type
()
==
at
::
kHalf
)
{
}
else
if
(
input
.
scalar_type
()
==
at
::
kHalf
)
{
moe_fused_gate_kernel_dynamic
<
float16_t
><<<
num_blocks
,
block_dim
,
0
,
stream
>>>
(
moe_fused_gate_kernel_dynamic
<
float16_t
><<<
num_blocks
,
block_dim
,
0
,
stream
>>>
(
input
.
data_ptr
(),
input
.
data_ptr
(),
...
@@ -427,7 +476,9 @@ moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, in
...
@@ -427,7 +476,9 @@ moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, in
num_experts
,
num_experts
,
num_expert_group
,
num_expert_group
,
topk_group
,
topk_group
,
topk
);
topk
,
n_share_experts_fusion
,
routed_scaling_factor
);
}
else
if
(
input
.
scalar_type
()
==
at
::
kFloat
)
{
}
else
if
(
input
.
scalar_type
()
==
at
::
kFloat
)
{
moe_fused_gate_kernel_dynamic
<
float32_t
><<<
num_blocks
,
block_dim
,
0
,
stream
>>>
(
moe_fused_gate_kernel_dynamic
<
float32_t
><<<
num_blocks
,
block_dim
,
0
,
stream
>>>
(
input
.
data_ptr
(),
input
.
data_ptr
(),
...
@@ -438,7 +489,9 @@ moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, in
...
@@ -438,7 +489,9 @@ moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, in
num_experts
,
num_experts
,
num_expert_group
,
num_expert_group
,
topk_group
,
topk_group
,
topk
);
topk
,
n_share_experts_fusion
,
routed_scaling_factor
);
}
else
{
}
else
{
TORCH_CHECK
(
false
,
"Unsupported data type for moe_fused_gate"
);
TORCH_CHECK
(
false
,
"Unsupported data type for moe_fused_gate"
);
}
}
...
...
sgl-kernel/include/sgl_kernel_ops.h
View file @
8e09b370
...
@@ -200,8 +200,14 @@ void topk_softmax(
...
@@ -200,8 +200,14 @@ void topk_softmax(
torch
::
Tensor
&
token_expert_indices
,
torch
::
Tensor
&
token_expert_indices
,
torch
::
Tensor
&
gating_output
);
torch
::
Tensor
&
gating_output
);
std
::
vector
<
at
::
Tensor
>
std
::
vector
<
at
::
Tensor
>
moe_fused_gate
(
moe_fused_gate
(
at
::
Tensor
&
input
,
at
::
Tensor
&
bias
,
int64_t
num_expert_group
,
int64_t
topk_group
,
int64_t
topk
);
at
::
Tensor
&
input
,
at
::
Tensor
&
bias
,
int64_t
num_expert_group
,
int64_t
topk_group
,
int64_t
topk
,
int64_t
n_share_experts_fusion
,
double
routed_scaling_factor
);
/*
/*
* From csrc/speculative
* From csrc/speculative
...
...
sgl-kernel/python/sgl_kernel/moe.py
View file @
8e09b370
...
@@ -34,13 +34,29 @@ def topk_softmax(
...
@@ -34,13 +34,29 @@ def topk_softmax(
)
)
def
moe_fused_gate
(
input_tensor
,
bias
,
num_expert_group
,
topk_group
,
topk
):
def
moe_fused_gate
(
input_tensor
,
bias
,
num_expert_group
,
topk_group
,
topk
,
n_share_experts_fusion
=
0
,
routed_scaling_factor
=
0
,
):
# This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
# This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
# it split group of expert into num_expert_group, and use top2 expert weight sum in each group
# it split group of expert into num_expert_group, and use top2 expert weight sum in each group
# as the group weight to select exerpt groups and then select topk experts within the selected groups
# as the group weight to select exerpt groups and then select topk experts within the selected groups
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limitted for now.
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limitted for now.
# for non-supported case, we suggestion to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
# for non-supported case, we suggestion to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
# n_share_experts_fusion: if > 0, the last expert will be replaced with a round-robin shared expert
# routed_scaling_factor: if > 0, the last expert will be scaled by this factor
return
torch
.
ops
.
sgl_kernel
.
moe_fused_gate
.
default
(
return
torch
.
ops
.
sgl_kernel
.
moe_fused_gate
.
default
(
input_tensor
,
bias
,
num_expert_group
,
topk_group
,
topk
input_tensor
,
bias
,
num_expert_group
,
topk_group
,
topk
,
n_share_experts_fusion
,
routed_scaling_factor
,
)
)
sgl-kernel/tests/test_moe_fused_gate.py
View file @
8e09b370
...
@@ -19,13 +19,15 @@ from sglang.srt.layers.moe.topk import biased_grouped_topk
...
@@ -19,13 +19,15 @@ from sglang.srt.layers.moe.topk import biased_grouped_topk
(
512
,
16
,
8
,
16
),
(
512
,
16
,
8
,
16
),
],
],
)
)
def
test_moe_fused_gate_combined
(
seq_length
,
dtype
,
params
):
@
pytest
.
mark
.
parametrize
(
"n_share_experts_fusion"
,
[
0
,
1
,
8
,
16
])
def
test_moe_fused_gate_combined
(
seq_length
,
dtype
,
params
,
n_share_experts_fusion
):
num_experts
,
num_expert_group
,
topk_group
,
topk
=
params
num_experts
,
num_expert_group
,
topk_group
,
topk
=
params
torch
.
manual_seed
(
seq_length
)
torch
.
manual_seed
(
seq_length
)
tensor
=
torch
.
rand
((
seq_length
,
num_experts
)).
to
(
dtype
).
cuda
()
tensor
=
torch
.
rand
((
seq_length
,
num_experts
)).
to
(
dtype
).
cuda
()
scores
=
tensor
.
clone
()
scores
=
tensor
.
clone
()
bias
=
torch
.
rand
(
num_experts
).
to
(
dtype
).
cuda
()
bias
=
torch
.
rand
(
num_experts
).
to
(
dtype
).
cuda
()
topk
=
topk
+
min
(
1
,
n_share_experts_fusion
)
output
,
indices
=
moe_fused_gate
(
output
,
indices
=
moe_fused_gate
(
tensor
,
tensor
,
...
@@ -33,6 +35,8 @@ def test_moe_fused_gate_combined(seq_length, dtype, params):
...
@@ -33,6 +35,8 @@ def test_moe_fused_gate_combined(seq_length, dtype, params):
num_expert_group
=
num_expert_group
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
,
topk_group
=
topk_group
,
topk
=
topk
,
topk
=
topk
,
n_share_experts_fusion
=
n_share_experts_fusion
,
routed_scaling_factor
=
2.5
,
)
)
ref_output
,
ref_indices
=
biased_grouped_topk
(
ref_output
,
ref_indices
=
biased_grouped_topk
(
scores
,
scores
,
...
@@ -43,8 +47,30 @@ def test_moe_fused_gate_combined(seq_length, dtype, params):
...
@@ -43,8 +47,30 @@ def test_moe_fused_gate_combined(seq_length, dtype, params):
num_expert_group
=
num_expert_group
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
,
topk_group
=
topk_group
,
compiled
=
False
,
compiled
=
False
,
n_share_experts_fusion
=
n_share_experts_fusion
,
)
)
# When n_share_experts_fusion > 0, ignore the comparison of the last topk dimension
if
n_share_experts_fusion
>
0
:
original_indices
=
indices
.
clone
()
original_ref_indices
=
ref_indices
.
clone
()
indices
=
indices
[:,
:
-
1
]
ref_indices
=
ref_indices
[:,
:
-
1
]
valid_min
=
num_experts
valid_max
=
num_experts
+
n_share_experts_fusion
shared_indices
=
original_indices
[:,
-
1
]
shared_ref_indices
=
original_ref_indices
[:,
-
1
]
if
shared_indices
is
not
None
:
assert
torch
.
all
(
(
shared_indices
>=
valid_min
)
&
(
shared_indices
<
valid_max
)
),
f
"Shared expert indices out of range: found values outside [
{
valid_min
}
,
{
valid_max
}
)"
if
shared_ref_indices
is
not
None
:
assert
torch
.
all
(
(
shared_ref_indices
>=
valid_min
)
&
(
shared_ref_indices
<
valid_max
)
),
f
"Shared expert reference indices out of range: found values outside [
{
valid_min
}
,
{
valid_max
}
)"
idx_check
=
torch
.
allclose
(
idx_check
=
torch
.
allclose
(
ref_indices
.
sort
()[
0
].
to
(
torch
.
int32
),
ref_indices
.
sort
()[
0
].
to
(
torch
.
int32
),
indices
.
sort
()[
0
].
to
(
torch
.
int32
),
indices
.
sort
()[
0
].
to
(
torch
.
int32
),
...
@@ -54,17 +80,17 @@ def test_moe_fused_gate_combined(seq_length, dtype, params):
...
@@ -54,17 +80,17 @@ def test_moe_fused_gate_combined(seq_length, dtype, params):
output_check
=
torch
.
allclose
(
output_check
=
torch
.
allclose
(
ref_output
.
sort
()[
0
].
to
(
torch
.
float32
),
ref_output
.
sort
()[
0
].
to
(
torch
.
float32
),
output
.
sort
()[
0
].
to
(
torch
.
float32
),
output
.
sort
()[
0
].
to
(
torch
.
float32
),
rtol
=
1e-0
4
,
rtol
=
1e-0
2
,
atol
=
1e-0
5
,
atol
=
1e-0
3
,
)
)
assert
idx_check
,
(
assert
idx_check
,
(
f
"Indices mismatch at seq_length
{
seq_length
}
, dtype
{
dtype
}
, "
f
"Indices mismatch at seq_length
{
seq_length
}
, dtype
{
dtype
}
, "
f
"params
{
params
}
"
f
"params
{
params
}
, n_share_experts_fusion
{
n_share_experts_fusion
}
"
)
)
assert
output_check
,
(
assert
output_check
,
(
f
"Output mismatch at seq_length
{
seq_length
}
, dtype
{
dtype
}
, "
f
"Output mismatch at seq_length
{
seq_length
}
, dtype
{
dtype
}
, "
f
"params
{
params
}
"
f
"params
{
params
}
, n_share_experts_fusion
{
n_share_experts_fusion
}
"
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment