Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8a828666
Unverified
Commit
8a828666
authored
May 07, 2025
by
Jinyan Chen
Committed by
GitHub
May 06, 2025
Browse files
Add DeepEP to CI PR Test (#5655)
Co-authored-by:
Jinyan Chen
<
jinyanc@nvidia.com
>
parent
aff584fa
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
1607 additions
and
3 deletions
+1607
-3
.github/workflows/pr-test.yml
.github/workflows/pr-test.yml
+3
-3
.github/workflows/release-docker-deepep.yml
.github/workflows/release-docker-deepep.yml
+36
-0
python/sglang/test/test_deepep_utils.py
python/sglang/test/test_deepep_utils.py
+219
-0
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+1
-0
scripts/ci_install_dependency_8_gpu.sh
scripts/ci_install_dependency_8_gpu.sh
+122
-0
test/srt/run_suite.py
test/srt/run_suite.py
+3
-0
test/srt/test_deepep_internode.py
test/srt/test_deepep_internode.py
+445
-0
test/srt/test_deepep_intranode.py
test/srt/test_deepep_intranode.py
+379
-0
test/srt/test_deepep_low_latency.py
test/srt/test_deepep_low_latency.py
+325
-0
test/srt/test_moe_deepep_eval_accuracy_large.py
test/srt/test_moe_deepep_eval_accuracy_large.py
+74
-0
No files found.
.github/workflows/pr-test.yml
View file @
8a828666
...
...
@@ -97,7 +97,7 @@ jobs:
-
name
:
Install dependencies
run
:
|
bash scripts/ci_install_dependency.sh
bash scripts/ci_install_dependency
_8_gpu
.sh
-
name
:
Run test
timeout-minutes
:
40
...
...
@@ -259,9 +259,9 @@ jobs:
finish
:
if
:
always()
needs
:
[
unit-test-frontend
,
unit-test-backend-1-gpu
,
unit-test-backend-2-gpu
,
unit-test-frontend
,
unit-test-backend-1-gpu
,
unit-test-backend-2-gpu
,
unit-test-backend-8-gpu
,
performance-test-1-gpu-part-1
,
performance-test-1-gpu-part-2
,
performance-test-2-gpu
,
accuracy-test-1-gpu
,
accuracy-test-2-gpu
accuracy-test-1-gpu
,
accuracy-test-2-gpu
,
]
runs-on
:
ubuntu-latest
steps
:
...
...
.github/workflows/release-docker-deepep.yml
0 → 100644
View file @
8a828666
name
:
Build DeepEP Docker Image
on
:
workflow_dispatch
:
schedule
:
-
cron
:
'
0
0
*
*
*'
jobs
:
build-dev
:
if
:
${{ github.repository == 'sgl-project/sglang' }}
runs-on
:
ubuntu-22.04
steps
:
-
name
:
Checkout repository
uses
:
actions/checkout@v4
-
name
:
Free disk space
uses
:
jlumbroso/free-disk-space@main
with
:
tool-cache
:
false
docker-images
:
false
android
:
true
dotnet
:
true
haskell
:
true
large-packages
:
true
swap-storage
:
false
-
name
:
Login to Docker Hub
uses
:
docker/login-action@v2
with
:
username
:
${{ secrets.DOCKERHUB_USERNAME }}
password
:
${{ secrets.DOCKERHUB_TOKEN }}
-
name
:
Build and Push DeepEP Image
run
:
|
docker build . -f docker/Dockerfile.deepep -t lmsysorg/sglang:deepep --no-cache
docker push lmsysorg/sglang:deepep
python/sglang/test/test_deepep_utils.py
0 → 100644
View file @
8a828666
# Copy from deepseek-ai/DeepEP/tests/test_utils.py
import
os
import
sys
from
typing
import
Optional
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
def
init_dist
(
local_rank
:
int
,
num_local_ranks
:
int
):
# NOTES: you may rewrite this function with your own cluster settings
ip
=
os
.
getenv
(
"MASTER_ADDR"
,
"127.0.0.1"
)
port
=
int
(
os
.
getenv
(
"MASTER_PORT"
,
"8361"
))
num_nodes
=
int
(
os
.
getenv
(
"WORLD_SIZE"
,
1
))
node_rank
=
int
(
os
.
getenv
(
"RANK"
,
0
))
assert
(
num_local_ranks
<
8
and
num_nodes
==
1
)
or
num_local_ranks
==
8
dist
.
init_process_group
(
backend
=
"nccl"
,
init_method
=
f
"tcp://
{
ip
}
:
{
port
}
"
,
world_size
=
num_nodes
*
num_local_ranks
,
rank
=
node_rank
*
num_local_ranks
+
local_rank
,
)
torch
.
set_default_dtype
(
torch
.
bfloat16
)
torch
.
set_default_device
(
"cuda"
)
torch
.
cuda
.
set_device
(
local_rank
)
return
(
dist
.
get_rank
(),
dist
.
get_world_size
(),
dist
.
new_group
(
list
(
range
(
num_local_ranks
*
num_nodes
))),
)
def
calc_diff
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
):
x
,
y
=
x
.
double
()
+
1
,
y
.
double
()
+
1
denominator
=
(
x
*
x
+
y
*
y
).
sum
()
sim
=
2
*
(
x
*
y
).
sum
()
/
denominator
return
(
1
-
sim
).
item
()
def
per_token_cast_to_fp8
(
x
:
torch
.
Tensor
):
assert
x
.
dim
()
==
2
and
x
.
size
(
1
)
%
128
==
0
m
,
n
=
x
.
shape
x_view
=
x
.
view
(
m
,
-
1
,
128
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
2
).
view
(
m
,
-
1
).
clamp
(
1e-4
)
return
(
x_view
*
(
448.0
/
x_amax
.
unsqueeze
(
2
))).
to
(
torch
.
float8_e4m3fn
).
view
(
m
,
n
),
(
x_amax
/
448.0
).
view
(
m
,
-
1
)
def
per_token_cast_back
(
x_fp8
:
torch
.
Tensor
,
x_scales
:
torch
.
Tensor
):
x_fp32
=
x_fp8
.
to
(
torch
.
float32
).
view
(
x_fp8
.
size
(
0
),
-
1
,
128
)
x_scales
=
x_scales
.
view
(
x_fp8
.
size
(
0
),
-
1
,
1
)
return
(
x_fp32
*
x_scales
).
view
(
x_fp8
.
shape
).
to
(
torch
.
bfloat16
)
def
inplace_unique
(
x
:
torch
.
Tensor
,
num_slots
:
int
):
assert
x
.
dim
()
==
2
mask
=
x
<
0
x_padded
=
x
.
masked_fill
(
mask
,
num_slots
)
bin_count
=
torch
.
zeros
((
x
.
size
(
0
),
num_slots
+
1
),
dtype
=
x
.
dtype
,
device
=
x
.
device
)
bin_count
.
scatter_add_
(
1
,
x_padded
,
torch
.
ones_like
(
x_padded
))
bin_count
=
bin_count
[:,
:
num_slots
]
sorted_bin_count
,
sorted_bin_idx
=
torch
.
sort
(
bin_count
,
dim
=-
1
,
descending
=
True
)
sorted_bin_idx
.
masked_fill_
(
sorted_bin_count
==
0
,
-
1
)
sorted_bin_idx
=
torch
.
sort
(
sorted_bin_idx
,
descending
=
True
,
dim
=-
1
).
values
x
[:,
:].
fill_
(
-
1
)
valid_len
=
min
(
num_slots
,
x
.
size
(
1
))
x
[:,
:
valid_len
]
=
sorted_bin_idx
[:,
:
valid_len
]
def
create_grouped_scores
(
scores
:
torch
.
Tensor
,
group_idx
:
torch
.
Tensor
,
num_groups
:
int
):
num_tokens
,
num_experts
=
scores
.
shape
scores
=
scores
.
view
(
num_tokens
,
num_groups
,
-
1
)
mask
=
torch
.
zeros
((
num_tokens
,
num_groups
),
dtype
=
torch
.
bool
,
device
=
scores
.
device
)
mask
=
mask
.
scatter_
(
1
,
group_idx
,
True
).
unsqueeze
(
-
1
).
expand_as
(
scores
)
return
(
scores
*
mask
).
view
(
num_tokens
,
num_experts
)
def
bench
(
fn
,
num_warmups
:
int
=
20
,
num_tests
:
int
=
30
,
post_fn
=
None
):
# Flush L2 cache with 256 MB data
torch
.
cuda
.
synchronize
()
cache
=
torch
.
empty
(
int
(
256e6
//
4
),
dtype
=
torch
.
int
,
device
=
"cuda"
)
# Warmup
for
_
in
range
(
num_warmups
):
fn
()
# Flush L2
cache
.
zero_
()
# Testing
start_events
=
[
torch
.
cuda
.
Event
(
enable_timing
=
True
)
for
_
in
range
(
num_tests
)]
end_events
=
[
torch
.
cuda
.
Event
(
enable_timing
=
True
)
for
_
in
range
(
num_tests
)]
for
i
in
range
(
num_tests
):
# Record
start_events
[
i
].
record
()
fn
()
end_events
[
i
].
record
()
if
post_fn
is
not
None
:
post_fn
()
torch
.
cuda
.
synchronize
()
times
=
np
.
array
(
[
s
.
elapsed_time
(
e
)
/
1e3
for
s
,
e
in
zip
(
start_events
,
end_events
)]
)[
1
:]
return
np
.
average
(
times
),
np
.
min
(
times
),
np
.
max
(
times
)
class
empty_suppress
:
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
*
_
):
pass
class
suppress_stdout_stderr
:
def
__enter__
(
self
):
self
.
outnull_file
=
open
(
os
.
devnull
,
"w"
)
self
.
errnull_file
=
open
(
os
.
devnull
,
"w"
)
self
.
old_stdout_fileno_undup
=
sys
.
stdout
.
fileno
()
self
.
old_stderr_fileno_undup
=
sys
.
stderr
.
fileno
()
self
.
old_stdout_fileno
=
os
.
dup
(
sys
.
stdout
.
fileno
())
self
.
old_stderr_fileno
=
os
.
dup
(
sys
.
stderr
.
fileno
())
self
.
old_stdout
=
sys
.
stdout
self
.
old_stderr
=
sys
.
stderr
os
.
dup2
(
self
.
outnull_file
.
fileno
(),
self
.
old_stdout_fileno_undup
)
os
.
dup2
(
self
.
errnull_file
.
fileno
(),
self
.
old_stderr_fileno_undup
)
sys
.
stdout
=
self
.
outnull_file
sys
.
stderr
=
self
.
errnull_file
return
self
def
__exit__
(
self
,
*
_
):
sys
.
stdout
=
self
.
old_stdout
sys
.
stderr
=
self
.
old_stderr
os
.
dup2
(
self
.
old_stdout_fileno
,
self
.
old_stdout_fileno_undup
)
os
.
dup2
(
self
.
old_stderr_fileno
,
self
.
old_stderr_fileno_undup
)
os
.
close
(
self
.
old_stdout_fileno
)
os
.
close
(
self
.
old_stderr_fileno
)
self
.
outnull_file
.
close
()
self
.
errnull_file
.
close
()
def
bench_kineto
(
fn
,
kernel_names
,
num_tests
:
int
=
30
,
suppress_kineto_output
:
bool
=
False
,
trace_path
:
Optional
[
str
]
=
None
,
barrier_comm_profiling
:
bool
=
False
,
):
# Profile
suppress
=
suppress_stdout_stderr
if
suppress_kineto_output
else
empty_suppress
with
suppress
():
schedule
=
torch
.
profiler
.
schedule
(
wait
=
0
,
warmup
=
1
,
active
=
1
,
repeat
=
1
)
with
torch
.
profiler
.
profile
(
activities
=
[
torch
.
profiler
.
ProfilerActivity
.
CUDA
],
schedule
=
schedule
)
as
prof
:
for
i
in
range
(
2
):
# NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead
if
barrier_comm_profiling
:
lhs
=
torch
.
randn
((
8192
,
8192
),
dtype
=
torch
.
float
,
device
=
"cuda"
)
rhs
=
torch
.
randn
((
8192
,
8192
),
dtype
=
torch
.
float
,
device
=
"cuda"
)
lhs
@
rhs
dist
.
all_reduce
(
torch
.
ones
(
1
,
dtype
=
torch
.
float
,
device
=
"cuda"
))
for
_
in
range
(
num_tests
):
fn
()
prof
.
step
()
# Parse the profiling table
assert
isinstance
(
kernel_names
,
str
)
or
isinstance
(
kernel_names
,
tuple
)
is_tupled
=
isinstance
(
kernel_names
,
tuple
)
prof_lines
=
(
prof
.
key_averages
()
.
table
(
sort_by
=
"cuda_time_total"
,
max_name_column_width
=
100
)
.
split
(
"
\n
"
)
)
kernel_names
=
(
kernel_names
,)
if
isinstance
(
kernel_names
,
str
)
else
kernel_names
assert
all
([
isinstance
(
name
,
str
)
for
name
in
kernel_names
])
for
name
in
kernel_names
:
assert
(
sum
([
name
in
line
for
line
in
prof_lines
])
==
1
),
f
"Errors of the kernel
{
name
}
in the profiling table"
# Save chrome traces
if
trace_path
is
not
None
:
prof
.
export_chrome_trace
(
trace_path
)
# Return average kernel times
units
=
{
"ms"
:
1e3
,
"us"
:
1e6
}
kernel_times
=
[]
for
name
in
kernel_names
:
for
line
in
prof_lines
:
if
name
in
line
:
time_str
=
line
.
split
()[
-
2
]
for
unit
,
scale
in
units
.
items
():
if
unit
in
time_str
:
kernel_times
.
append
(
float
(
time_str
.
replace
(
unit
,
""
))
/
scale
)
break
break
return
tuple
(
kernel_times
)
if
is_tupled
else
kernel_times
[
0
]
def
hash_tensor
(
t
:
torch
.
Tensor
):
return
t
.
view
(
torch
.
int64
).
sum
().
item
()
python/sglang/test/test_utils.py
View file @
8a828666
...
...
@@ -66,6 +66,7 @@ DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION = (
)
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST
=
"Alibaba-NLP/gte-Qwen2-1.5B-instruct"
DEFAULT_REASONING_MODEL_NAME_FOR_TEST
=
"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
DEFAULT_DEEPPEP_MODEL_NAME_FOR_TEST
=
"deepseek-ai/DeepSeek-V3-0324"
DEFAULT_AWQ_MOE_MODEL_NAME_FOR_TEST
=
(
"hugging-quants/Mixtral-8x7B-Instruct-v0.1-AWQ-INT4"
)
...
...
scripts/ci_install_dependency_8_gpu.sh
0 → 100755
View file @
8a828666
#!/bin/bash
# Install the dependency in CI.
set
-euxo
pipefail
export
GDRCOPY_HOME
=
/usr/src/gdrdrv-2.4.4/
export
NVSHMEM_DIR
=
/opt/nvshmem/install
export
LD_LIBRARY_PATH
=
"
${
NVSHMEM_DIR
}
/lib:
$LD_LIBRARY_PATH
"
export
PATH
=
"
${
NVSHMEM_DIR
}
/bin:
$PATH
"
export
CUDA_HOME
=
/usr/local/cuda
SCRIPT_DIR
=
"
$(
cd
"
$(
dirname
"
${
BASH_SOURCE
[0]
}
"
)
"
&&
pwd
)
"
bash
"
${
SCRIPT_DIR
}
/killall_sglang.sh"
# Clean up existing installations
pip uninstall
-y
flashinfer flashinfer_python sgl-kernel sglang vllm deepep
||
true
pip cache purge
rm
-rf
/root/.cache/flashinfer
if
[
-d
"lmms-eval"
]
;
then
rm
-rf
lmms-eval
fi
rm
-rf
/root/.cache/deepep
rm
-rf
/usr/local/lib/python3.10/dist-packages/flashinfer
*
rm
-rf
/usr/local/lib/python3.10/dist-packages/sgl_kernel
*
rm
-rf
/usr/local/lib/python3.10/dist-packages/deepep
*
dpkg
-r
gdrcopy gdrcopy-tests libgdrapi gdrdrv-dkms
||
true
rm
-rf
/opt/gdrcopy
rm
-rf
/usr/local/lib/libgdrapi
*
rm
-rf
/usr/local/include/gdrapi.h
rm
-rf
/opt/nvshmem
rm
-rf
/usr/local/lib/libnvshmem
*
rm
-rf
/usr/local/include/nvshmem
*
# Update pip
pip
install
--upgrade
pip
# Install sgl-kernel
pip
install
sgl-kernel
==
0.1.1
--no-cache-dir
# Install the main package
pip
install
-e
"python[all]"
# Install additional dependencies
pip
install
torch_memory_saver
pip
install
transformers
==
4.51.0 sentence_transformers accelerate peft pandas datasets timm
torchaudio
==
2.6.0
# For compling xgrammar kernels
pip
install
cuda-python nvidia-cuda-nvrtc-cu12
# For lmms_evals evaluating MMMU
git clone
--branch
v0.3.3
--depth
1 https://github.com/EvolvingLMMs-Lab/lmms-eval.git
pip
install
-e
lmms-eval/
# Install FlashMLA for attention backend tests
pip
install
git+https://github.com/deepseek-ai/FlashMLA.git
# Install system dependencies
# apt-get update && apt-get install -y libibverbs-dev infiniband-diags libmlx5-1 rdma-core openssh-server perftest ibverbs-providers libibumad3 libibverbs1 libnl-3-200 libnl-route-3-200 librdmacm1 rdma-core-dev infiniband-diags-dev libibverbs-dev libibverbs-utils librdmacm-dev librdmacm-utils ibverbs-utils rdma-core-utils
apt
install
curl wget git
sudo
libibverbs-dev
-y
apt
install
-y
rdma-core infiniband-diags openssh-server perftest ibverbs-providers libibumad3 libibverbs1 libnl-3-200 libnl-route-3-200 librdmacm1
curl https://bootstrap.pypa.io/get-pip.py
-o
get-pip.py
&&
python3 get-pip.py
wget https://github.com/Kitware/CMake/releases/download/v3.27.4/cmake-3.27.4-linux-x86_64.sh
chmod
+x cmake-3.27.4-linux-x86_64.sh
./cmake-3.27.4-linux-x86_64.sh
--skip-license
--prefix
=
/usr/local
rm
cmake-3.27.4-linux-x86_64.sh
# Install GDRCopy
mkdir
-p
/opt/gdrcopy
mkdir
-p
/opt/nvshmem
cd
/opt/gdrcopy
git clone https://github.com/NVIDIA/gdrcopy.git
.
git checkout v2.4.4
apt update
apt
install
-y
nvidia-dkms-535
apt
install
-y
build-essential devscripts debhelper fakeroot pkg-config dkms
apt
install
-y
check libsubunit0 libsubunit-dev
cd
packages
CUDA
=
/usr/local/cuda ./build-deb-packages.sh
dpkg
-i
gdrdrv-dkms_
*
.deb
dpkg
-i
libgdrapi_
*
.deb
dpkg
-i
gdrcopy-tests_
*
.deb
dpkg
-i
gdrcopy_
*
.deb
if
[
!
-e
"/usr/lib/x86_64-linux-gnu/libmlx5.so"
]
;
then
ln
-s
/usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so
fi
apt-get update
&&
apt-get
install
-y
libfabric-dev
# Clone DeepEP
git clone https://github.com/deepseek-ai/DeepEP.git /root/.cache/deepep
# Install NVSHMEM
cd
/opt/nvshmem
wget https://developer.download.nvidia.com/compute/redist/nvshmem/3.2.5/source/nvshmem_src_3.2.5-1.txz
tar
-xf
nvshmem_src_3.2.5-1.txz
mv
nvshmem_src nvshmem
cd
nvshmem
git apply /root/.cache/deepep/third-party/nvshmem.patch
NVSHMEM_SHMEM_SUPPORT
=
0
\
NVSHMEM_UCX_SUPPORT
=
0
\
NVSHMEM_USE_NCCL
=
0
\
NVSHMEM_MPI_SUPPORT
=
0
\
NVSHMEM_IBGDA_SUPPORT
=
1
\
NVSHMEM_PMIX_SUPPORT
=
0
\
NVSHMEM_TIMEOUT_DEVICE_POLLING
=
0
\
NVSHMEM_USE_GDRCOPY
=
1
\
cmake
-S
.
-B
build/
-DCMAKE_INSTALL_PREFIX
=
/opt/nvshmem/install
-DCMAKE_CUDA_ARCHITECTURES
=
90
cd
build
make
-j
$(
nproc
)
install
# Install DeepEP
cd
/root/.cache/deepep
&&
python3 setup.py
install
# Verify configuration
echo
"=== NCCL Configuration ==="
nvidia-smi topo
-m
nvidia-smi nvlink
-s
echo
"=== Verify GDRCOPY ==="
gdrcopy_copybw
echo
"=== Verify NVSHMEM ==="
nvshmem-info
-a
# /opt/nvshmem/bin/perftest/device/pt-to-pt/shmem_put_bw
test/srt/run_suite.py
View file @
8a828666
...
...
@@ -96,6 +96,9 @@ suites = {
TestFile
(
"test_verl_engine.py"
,
64
),
],
"per-commit-8-gpu"
:
[
TestFile
(
"test_deepep_intranode.py"
,
50
),
TestFile
(
"test_deepep_low_latency.py"
,
50
),
TestFile
(
"test_moe_deepep_eval_accuracy_large.py"
,
250
),
TestFile
(
"test_local_attn.py"
,
250
),
TestFile
(
"test_full_deepseek_v3.py"
,
250
),
TestFile
(
"test_fa3.py"
,
30
),
...
...
test/srt/test_deepep_internode.py
0 → 100644
View file @
8a828666
# Copy from deepseek-ai/DeepEP/tests/test_internode.py
import
os
import
time
# noinspection PyUnresolvedReferences
import
deep_ep
# Test compatibility with low latency functions
import
test_deepep_low_latency
import
torch
import
torch.distributed
as
dist
from
sglang.test.test_deepep_utils
import
(
bench
,
calc_diff
,
create_grouped_scores
,
init_dist
,
inplace_unique
,
per_token_cast_back
,
per_token_cast_to_fp8
,
)
def
test_main
(
num_sms
:
int
,
local_rank
:
int
,
num_local_ranks
:
int
,
num_ranks
:
int
,
num_nodes
:
int
,
rank
:
int
,
buffer
:
deep_ep
.
Buffer
,
group
:
dist
.
ProcessGroup
,
):
# Settings
num_tokens
,
hidden
,
num_topk_groups
,
num_topk
,
num_experts
=
(
4096
,
7168
,
min
(
num_nodes
,
4
),
8
,
(
256
//
num_ranks
)
*
num_ranks
,
)
assert
num_experts
%
num_ranks
==
0
and
num_local_ranks
==
8
if
local_rank
==
0
:
print
(
f
"[config] num_tokens=
{
num_tokens
}
, hidden=
{
hidden
}
, num_topk_groups=
{
num_topk_groups
}
, num_topk=
{
num_topk
}
"
,
flush
=
True
,
)
# Random data
x
=
torch
.
ones
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
*
rank
x_pure_rand
=
torch
.
randn
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
x_e4m3
=
per_token_cast_to_fp8
(
x
)
scores
=
(
torch
.
randn
((
num_tokens
,
num_experts
),
dtype
=
torch
.
float32
,
device
=
"cuda"
).
abs
()
+
1
)
group_scores
=
scores
.
view
(
num_tokens
,
num_nodes
,
-
1
).
amax
(
dim
=-
1
)
group_idx
=
torch
.
topk
(
group_scores
,
k
=
num_topk_groups
,
dim
=-
1
,
sorted
=
False
).
indices
masked_scores
=
create_grouped_scores
(
scores
,
group_idx
,
num_nodes
)
topk_idx
=
torch
.
topk
(
masked_scores
,
num_topk
,
dim
=-
1
,
largest
=
True
,
sorted
=
False
)[
1
]
topk_weights
=
(
torch
.
ones
((
num_tokens
,
num_topk
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
*
rank
)
topk_weights_pure_rand
=
torch
.
randn
(
(
num_tokens
,
num_topk
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
rank_idx
=
topk_idx
//
(
num_experts
//
num_ranks
)
rank_idx
.
masked_fill_
(
topk_idx
==
-
1
,
-
1
)
inplace_unique
(
rank_idx
,
num_ranks
)
rdma_rank_idx
=
rank_idx
//
num_local_ranks
rdma_rank_idx
.
masked_fill_
(
rank_idx
==
-
1
,
-
1
)
inplace_unique
(
rdma_rank_idx
,
num_nodes
)
# RDMA dispatch counts
rdma_idx
=
topk_idx
//
(
num_experts
//
num_nodes
)
rdma_idx
.
masked_fill_
(
topk_idx
==
-
1
,
-
1
)
inplace_unique
(
rdma_idx
,
num_nodes
)
num_rdma_token_sent
=
rdma_idx
.
ne
(
-
1
).
sum
().
item
()
# Expert meta
num_tokens_per_expert
=
torch
.
zeros
((
num_experts
,),
dtype
=
torch
.
int
,
device
=
"cuda"
)
for
i
in
range
(
num_experts
):
num_tokens_per_expert
[
i
]
=
(
topk_idx
==
i
).
sum
()
gbl_num_tokens_per_expert
=
num_tokens_per_expert
.
clone
()
dist
.
all_reduce
(
gbl_num_tokens_per_expert
,
group
=
group
)
# Rank layout meta
num_tokens_per_rank
=
torch
.
empty
((
num_ranks
,),
dtype
=
torch
.
int
,
device
=
"cuda"
)
num_tokens_per_rdma_rank
=
torch
.
empty
((
num_nodes
,),
dtype
=
torch
.
int
,
device
=
"cuda"
)
token_idx_in_rank
=
torch
.
full
(
(
num_ranks
,
num_tokens
),
-
1
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
for
i
in
range
(
num_ranks
):
num_tokens_per_rank
[
i
]
=
(
rank_idx
==
i
).
sum
()
token_sel
=
(
rank_idx
==
i
).
max
(
dim
=-
1
)[
0
]
count
=
token_sel
.
sum
().
item
()
tokens
=
torch
.
sort
(
token_sel
.
to
(
torch
.
int
),
descending
=
True
)[
1
]
tokens
[:
count
]
=
torch
.
sort
(
tokens
[:
count
])[
0
]
token_idx_in_rank
[
i
][
tokens
[:
count
]]
=
torch
.
arange
(
count
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
for
i
in
range
(
num_nodes
):
num_tokens_per_rdma_rank
[
i
]
=
(
rdma_rank_idx
==
i
).
sum
()
token_idx_in_rank
=
token_idx_in_rank
.
T
.
contiguous
().
to
(
torch
.
int
)
is_token_in_rank
=
token_idx_in_rank
>=
0
gbl_num_tokens_per_rank
=
num_tokens_per_rank
.
clone
()
dist
.
all_reduce
(
gbl_num_tokens_per_rank
,
group
=
group
)
(
ref_num_tokens_per_rank
,
ref_num_tokens_per_rdma_rank
,
ref_num_tokens_per_expert
,
ref_is_token_in_rank
,
_
,
)
=
buffer
.
get_dispatch_layout
(
topk_idx
,
num_experts
)
assert
torch
.
allclose
(
ref_num_tokens_per_rank
,
num_tokens_per_rank
)
assert
torch
.
allclose
(
ref_num_tokens_per_rdma_rank
,
num_tokens_per_rdma_rank
)
assert
torch
.
allclose
(
ref_num_tokens_per_expert
,
num_tokens_per_expert
)
assert
torch
.
allclose
(
ref_is_token_in_rank
,
is_token_in_rank
)
t
=
bench
(
lambda
:
buffer
.
get_dispatch_layout
(
topk_idx
,
num_experts
))[
0
]
if
local_rank
==
0
:
print
(
f
"[layout] Kernel performance:
{
t
*
1000
:.
3
f
}
ms"
,
flush
=
True
)
print
(
""
,
flush
=
True
)
group
.
barrier
()
time
.
sleep
(
1
)
# Config
rdma_buffer_size
,
nvl_buffer_size
=
128
,
(
720
if
num_ranks
in
(
144
,
160
)
else
512
)
config
=
deep_ep
.
Config
(
num_sms
,
8
,
nvl_buffer_size
,
16
,
rdma_buffer_size
)
# Test dispatch
# noinspection PyShadowingNames
def
check_data
(
check_x
,
recv_gbl_rank_prefix_sum
):
assert
torch
.
allclose
(
check_x
.
amin
(
dim
=
1
),
check_x
.
amax
(
dim
=
1
))
check_start
=
0
for
i
in
range
(
num_ranks
):
check_end
=
recv_gbl_rank_prefix_sum
[
i
].
item
()
assert
(
check_x
[
check_start
:
check_end
,
:].
int
()
-
i
).
sum
().
item
()
==
0
check_start
=
check_end
for
previous_mode
in
(
False
,
True
):
for
async_mode
in
(
False
,
True
):
for
current_x
in
(
x_pure_rand
,
x
,
x_e4m3
):
for
with_topk
in
(
False
,
True
):
if
local_rank
==
0
:
print
(
f
'[testing] Running with
{
"FP8"
if
isinstance
(
current_x
,
tuple
)
else
"BF16"
}
,
{
"with"
if
with_topk
else
"without"
}
top-k (async=
{
async_mode
}
, previous=
{
previous_mode
}
) ...'
,
flush
=
True
,
end
=
""
,
)
dispatch_args
=
{
"x"
:
current_x
,
"num_tokens_per_rank"
:
num_tokens_per_rank
,
"num_tokens_per_rdma_rank"
:
num_tokens_per_rdma_rank
,
"is_token_in_rank"
:
is_token_in_rank
,
"num_tokens_per_expert"
:
num_tokens_per_expert
,
"config"
:
config
,
"async_finish"
:
async_mode
,
}
if
with_topk
:
dispatch_args
.
update
(
{
"topk_idx"
:
topk_idx
,
"topk_weights"
:
(
topk_weights_pure_rand
if
current_x
is
x_pure_rand
else
topk_weights
),
}
)
if
previous_mode
:
dispatch_args
.
update
({
"previous_event"
:
buffer
.
capture
()})
(
recv_x
,
recv_topk_idx
,
recv_topk_weights
,
recv_num_tokens_per_expert_list
,
handle
,
event
,
)
=
buffer
.
dispatch
(
**
dispatch_args
)
event
.
current_stream_wait
()
if
async_mode
else
()
recv_x
=
(
per_token_cast_back
(
*
recv_x
)
if
isinstance
(
recv_x
,
tuple
)
else
recv_x
)
# Checks
recv_gbl_rank_prefix_sum
=
handle
[
-
4
]
assert
gbl_num_tokens_per_rank
[
rank
].
item
()
==
recv_x
.
size
(
0
),
f
"
{
gbl_num_tokens_per_rank
[
rank
].
item
()
}
!=
{
recv_x
.
size
(
0
)
}
"
assert
(
gbl_num_tokens_per_expert
.
view
(
num_ranks
,
-
1
)[
rank
].
tolist
()
==
recv_num_tokens_per_expert_list
)
if
current_x
is
not
x_pure_rand
:
check_data
(
recv_x
,
recv_gbl_rank_prefix_sum
)
if
with_topk
:
# Check `topk_idx`
assert
(
recv_topk_idx
.
eq
(
-
1
)
|
(
(
recv_topk_idx
>=
0
)
&
(
recv_topk_idx
<
(
num_experts
//
num_ranks
))
)
).
sum
().
item
()
==
recv_topk_idx
.
numel
()
for
i
,
count
in
enumerate
(
recv_num_tokens_per_expert_list
):
assert
recv_topk_idx
.
eq
(
i
).
sum
().
item
()
==
count
# Check `topk_weights`
if
current_x
is
not
x_pure_rand
:
recv_topk_weights
[
recv_topk_idx
.
eq
(
-
1
)]
=
(
recv_topk_weights
.
amax
(
dim
=
1
,
keepdim
=
True
).
expand_as
(
recv_topk_weights
)[
recv_topk_idx
.
eq
(
-
1
)]
)
check_data
(
recv_topk_weights
,
recv_gbl_rank_prefix_sum
)
# Test cached dispatch (must without top-k staffs)
if
not
with_topk
:
dispatch_args
=
{
"x"
:
current_x
,
"handle"
:
handle
,
"config"
:
config
,
"async_finish"
:
async_mode
,
}
if
previous_mode
:
dispatch_args
.
update
({
"previous_event"
:
buffer
.
capture
()})
recv_x
,
_
,
_
,
_
,
_
,
event
=
buffer
.
dispatch
(
**
dispatch_args
)
event
.
current_stream_wait
()
if
async_mode
else
()
recv_x
=
(
per_token_cast_back
(
*
recv_x
)
if
isinstance
(
recv_x
,
tuple
)
else
recv_x
)
if
current_x
is
not
x_pure_rand
:
check_data
(
recv_x
,
recv_gbl_rank_prefix_sum
)
# Test combine
combine_args
=
{
"x"
:
recv_x
,
"handle"
:
handle
,
"config"
:
config
,
"async_finish"
:
async_mode
,
}
if
with_topk
:
combine_args
.
update
({
"topk_weights"
:
recv_topk_weights
})
if
previous_mode
:
dispatch_args
.
update
({
"previous_event"
:
buffer
.
capture
()})
combined_x
,
combined_topk_weights
,
event
=
buffer
.
combine
(
**
combine_args
)
event
.
current_stream_wait
()
if
async_mode
else
()
check_x
=
combined_x
.
float
()
/
is_token_in_rank
.
sum
(
dim
=
1
).
unsqueeze
(
1
)
ref_x
=
x_pure_rand
if
current_x
is
x_pure_rand
else
x
assert
calc_diff
(
check_x
,
ref_x
)
<
5e-6
if
with_topk
:
check_topk_weights
=
(
combined_topk_weights
if
(
current_x
is
x_pure_rand
)
else
(
combined_topk_weights
/
is_token_in_rank
.
sum
(
dim
=
1
).
unsqueeze
(
1
)
)
)
ref_topk_weights
=
(
topk_weights_pure_rand
if
current_x
is
x_pure_rand
else
topk_weights
)
assert
calc_diff
(
check_topk_weights
,
ref_topk_weights
)
<
1e-9
# For later tuning
dispatch_bf16_rdma_send_bytes
=
num_rdma_token_sent
*
hidden
*
2
dispatch_bf16_nvl_recv_bytes
=
recv_x
.
numel
()
*
2
combine_bf16_nvl_send_bytes
=
dispatch_bf16_nvl_recv_bytes
combine_bf16_rdma_recv_bytes
=
dispatch_bf16_rdma_send_bytes
if
local_rank
==
0
:
print
(
" passed"
,
flush
=
True
)
if
local_rank
==
0
:
print
(
""
,
flush
=
True
)
# Tune dispatch performance
best_dispatch_results
=
None
fp8_factor
=
(
1
+
4
/
128
)
/
2
for
current_x
in
(
x_e4m3
,
x
):
best_time
,
best_results
=
1e10
,
None
rdma_send_bytes
=
(
(
dispatch_bf16_rdma_send_bytes
*
fp8_factor
)
if
isinstance
(
current_x
,
tuple
)
else
dispatch_bf16_rdma_send_bytes
)
nvl_recv_bytes
=
(
(
dispatch_bf16_nvl_recv_bytes
*
fp8_factor
)
if
isinstance
(
current_x
,
tuple
)
else
dispatch_bf16_nvl_recv_bytes
)
for
nvl_chunk_size
in
range
(
4
,
33
,
4
):
for
rdma_chunk_size
in
range
(
4
,
33
,
4
):
config
=
deep_ep
.
Config
(
num_sms
,
nvl_chunk_size
,
nvl_buffer_size
,
rdma_chunk_size
,
rdma_buffer_size
,
)
tune_args
=
{
"x"
:
current_x
,
"handle"
:
handle
,
"config"
:
config
}
t
=
bench
(
lambda
:
buffer
.
dispatch
(
**
tune_args
))[
0
]
if
t
<
best_time
:
best_time
,
best_results
=
t
,
(
num_sms
,
nvl_chunk_size
,
rdma_chunk_size
,
)
if
local_rank
==
0
:
print
(
f
"[tuning] SMs
{
num_sms
}
, NVL chunk
{
nvl_chunk_size
}
, RDMA chunk
{
rdma_chunk_size
}
:
{
rdma_send_bytes
/
1e9
/
t
:.
2
f
}
GB/s (RDMA),
{
nvl_recv_bytes
/
1e9
/
t
:.
2
f
}
GB/s (NVL) "
,
flush
=
True
,
)
if
local_rank
==
0
:
print
(
f
'[tuning] Best dispatch (
{
"FP8"
if
isinstance
(
current_x
,
tuple
)
else
"BF16"
}
): SMs
{
best_results
[
0
]
}
, NVL chunk
{
best_results
[
1
]
}
, RDMA chunk
{
best_results
[
2
]
}
:
{
rdma_send_bytes
/
1e9
/
best_time
:.
2
f
}
GB/s (RDMA),
{
nvl_recv_bytes
/
1e9
/
best_time
:.
2
f
}
GB/s (NVL)'
,
flush
=
True
,
)
print
(
""
,
flush
=
True
)
if
isinstance
(
current_x
,
tuple
):
# Gather FP8 the best config from rank 0
best_dispatch_results
=
torch
.
tensor
(
[
best_results
[
0
],
best_results
[
1
],
best_results
[
2
]],
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
all_best_fp8_results_list
=
[
torch
.
zeros_like
(
best_dispatch_results
)
for
_
in
range
(
torch
.
distributed
.
get_world_size
())
]
dist
.
all_gather
(
all_best_fp8_results_list
,
best_dispatch_results
,
group
=
group
)
best_dispatch_results
=
all_best_fp8_results_list
[
0
].
tolist
()
dispatch_config
=
deep_ep
.
Config
(
best_dispatch_results
[
0
],
best_dispatch_results
[
1
],
nvl_buffer_size
,
best_dispatch_results
[
2
],
rdma_buffer_size
,
)
dispatch_args
=
{
"x"
:
x
,
"num_tokens_per_rank"
:
num_tokens_per_rank
,
"num_tokens_per_rdma_rank"
:
num_tokens_per_rdma_rank
,
"is_token_in_rank"
:
is_token_in_rank
,
"num_tokens_per_expert"
:
num_tokens_per_expert
,
"config"
:
dispatch_config
if
dispatch_config
is
not
None
else
config
,
}
recv_x
,
_
,
_
,
_
,
handle
,
_
=
buffer
.
dispatch
(
**
dispatch_args
)
# Tune combine performance
best_time
,
best_results
=
1e10
,
None
for
nvl_chunk_size
in
range
(
1
,
5
,
1
):
for
rdma_chunk_size
in
range
(
8
,
33
,
4
):
config
=
deep_ep
.
Config
(
num_sms
,
nvl_chunk_size
,
nvl_buffer_size
,
rdma_chunk_size
,
rdma_buffer_size
,
)
tune_args
=
{
"x"
:
recv_x
,
"handle"
:
handle
,
"config"
:
config
}
t
=
bench
(
lambda
:
buffer
.
combine
(
**
tune_args
))[
0
]
if
local_rank
==
0
:
print
(
f
"[tuning] SMs
{
num_sms
}
, NVL chunk
{
nvl_chunk_size
}
, RDMA chunk
{
rdma_chunk_size
}
:
{
combine_bf16_rdma_recv_bytes
/
1e9
/
t
:.
2
f
}
GB/s (RDMA),
{
combine_bf16_nvl_send_bytes
/
1e9
/
t
:.
2
f
}
GB/s (NVL) "
,
flush
=
True
,
)
if
t
<
best_time
:
best_time
,
best_results
=
t
,
(
num_sms
,
nvl_chunk_size
,
rdma_chunk_size
,
)
if
local_rank
==
0
:
print
(
f
"[tuning] Best combine: SMs
{
best_results
[
0
]
}
, NVL chunk
{
best_results
[
1
]
}
, RDMA chunk
{
best_results
[
2
]
}
:
{
combine_bf16_rdma_recv_bytes
/
1e9
/
best_time
:.
2
f
}
GB/s (RDMA),
{
combine_bf16_nvl_send_bytes
/
1e9
/
best_time
:.
2
f
}
GB/s (NVL)"
,
flush
=
True
,
)
print
(
""
,
flush
=
True
)
# noinspection PyUnboundLocalVariable
def
test_loop
(
local_rank
:
int
,
num_local_ranks
:
int
):
num_nodes
=
int
(
os
.
getenv
(
"WORLD_SIZE"
,
1
))
rank
,
num_ranks
,
group
=
init_dist
(
local_rank
,
num_local_ranks
)
test_ll_compatibility
=
False
if
test_ll_compatibility
:
ll_num_tokens
,
ll_hidden
,
ll_num_experts
,
ll_num_topk
=
16
,
5120
,
256
,
9
buffer
=
deep_ep
.
Buffer
(
group
,
int
(
1e9
),
int
(
1e9
),
low_latency_mode
=
test_ll_compatibility
,
num_qps_per_rank
=
(
ll_num_experts
//
num_ranks
if
test_ll_compatibility
else
1
),
)
assert
num_local_ranks
==
8
and
num_ranks
>
8
torch
.
manual_seed
(
rank
)
for
i
in
(
24
,):
test_main
(
i
,
local_rank
,
num_local_ranks
,
num_ranks
,
num_nodes
,
rank
,
buffer
,
group
)
if
local_rank
==
0
:
print
(
""
,
flush
=
True
)
# Test compatibility with low latency functions
if
test_ll_compatibility
:
buffer
.
clean_low_latency_buffer
(
ll_num_tokens
,
ll_hidden
,
ll_num_experts
)
test_deepep_low_latency
.
test_main
(
ll_num_tokens
,
ll_hidden
,
ll_num_experts
,
ll_num_topk
,
rank
,
num_ranks
,
group
,
buffer
,
seed
=
1
,
)
if
__name__
==
"__main__"
:
num_processes
=
8
torch
.
multiprocessing
.
spawn
(
test_loop
,
args
=
(
num_processes
,),
nprocs
=
num_processes
)
test/srt/test_deepep_intranode.py
0 → 100644
View file @
8a828666
# Copy from deepseek-ai/DeepEP/tests/test_intranode.py
import
os
import
time
# noinspection PyUnresolvedReferences
import
deep_ep
# Test compatibility with low latency functions
import
test_deepep_low_latency
import
torch
import
torch.distributed
as
dist
from
sglang.test.test_deepep_utils
import
(
bench
,
calc_diff
,
init_dist
,
inplace_unique
,
per_token_cast_back
,
per_token_cast_to_fp8
,
)
def
test_main
(
num_sms
:
int
,
local_rank
:
int
,
num_ranks
:
int
,
rank
:
int
,
buffer
:
deep_ep
.
Buffer
,
group
:
dist
.
ProcessGroup
,
):
# Settings
num_tokens
,
hidden
,
num_topk
,
num_experts
=
(
4096
,
7168
,
8
,
(
256
//
num_ranks
)
*
num_ranks
,
)
assert
num_experts
%
num_ranks
==
0
if
local_rank
==
0
:
print
(
f
"[config] num_tokens=
{
num_tokens
}
, hidden=
{
hidden
}
, num_topk=
{
num_topk
}
"
,
flush
=
True
,
)
# Random data
x
=
torch
.
ones
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
*
rank
x_pure_rand
=
torch
.
randn
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
x_e4m3
=
per_token_cast_to_fp8
(
x
)
scores
=
(
torch
.
randn
((
num_tokens
,
num_experts
),
dtype
=
torch
.
float32
,
device
=
"cuda"
).
abs
()
+
1
)
topk_idx
=
torch
.
topk
(
scores
,
num_topk
,
dim
=-
1
,
largest
=
True
,
sorted
=
False
)[
1
]
topk_weights
=
(
torch
.
ones
((
num_tokens
,
num_topk
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
*
rank
)
topk_weights_pure_rand
=
torch
.
randn
(
(
num_tokens
,
num_topk
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
rank_idx
=
topk_idx
//
(
num_experts
//
num_ranks
)
rank_idx
.
masked_fill_
(
topk_idx
==
-
1
,
-
1
)
inplace_unique
(
rank_idx
,
num_ranks
)
# Expert meta
num_tokens_per_expert
=
torch
.
zeros
((
num_experts
,),
dtype
=
torch
.
int
,
device
=
"cuda"
)
for
i
in
range
(
num_experts
):
num_tokens_per_expert
[
i
]
=
(
topk_idx
==
i
).
sum
()
gbl_num_tokens_per_expert
=
num_tokens_per_expert
.
clone
()
dist
.
all_reduce
(
gbl_num_tokens_per_expert
,
group
=
group
)
# Rank layout meta
num_tokens_per_rank
=
torch
.
empty
((
num_ranks
,),
dtype
=
torch
.
int
,
device
=
"cuda"
)
token_idx_in_rank
=
torch
.
full
(
(
num_ranks
,
num_tokens
),
-
1
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
for
i
in
range
(
num_ranks
):
num_tokens_per_rank
[
i
]
=
(
rank_idx
==
i
).
sum
()
token_sel
=
(
rank_idx
==
i
).
max
(
dim
=-
1
)[
0
]
count
=
token_sel
.
sum
().
item
()
tokens
=
torch
.
sort
(
token_sel
.
to
(
torch
.
int
),
descending
=
True
)[
1
]
tokens
[:
count
]
=
torch
.
sort
(
tokens
[:
count
])[
0
]
token_idx_in_rank
[
i
][
tokens
[:
count
]]
=
torch
.
arange
(
count
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
token_idx_in_rank
=
token_idx_in_rank
.
T
.
contiguous
().
to
(
torch
.
int
)
is_token_in_rank
=
token_idx_in_rank
>=
0
gbl_num_tokens_per_rank
=
num_tokens_per_rank
.
clone
()
dist
.
all_reduce
(
gbl_num_tokens_per_rank
,
group
=
group
)
ref_num_tokens_per_rank
,
_
,
ref_num_tokens_per_expert
,
ref_is_token_in_rank
,
_
=
(
buffer
.
get_dispatch_layout
(
topk_idx
,
num_experts
)
)
assert
torch
.
allclose
(
ref_num_tokens_per_rank
,
num_tokens_per_rank
)
assert
torch
.
allclose
(
ref_num_tokens_per_expert
,
num_tokens_per_expert
)
assert
torch
.
allclose
(
ref_is_token_in_rank
,
is_token_in_rank
)
t
=
bench
(
lambda
:
buffer
.
get_dispatch_layout
(
topk_idx
,
num_experts
))[
0
]
if
local_rank
==
0
:
print
(
f
"[layout] Kernel performance:
{
t
*
1000
:.
3
f
}
ms"
,
flush
=
True
)
print
(
""
,
flush
=
True
)
group
.
barrier
()
time
.
sleep
(
1
)
# Config
nvl_buffer_size
=
256
config
=
deep_ep
.
Config
(
num_sms
,
8
,
nvl_buffer_size
)
# Test dispatch
# noinspection PyShadowingNames
def
check_data
(
check_x
,
rank_prefix_matrix
):
assert
torch
.
allclose
(
check_x
.
amin
(
dim
=
1
),
check_x
.
amax
(
dim
=
1
))
check_start
=
0
for
i
in
range
(
num_ranks
):
check_end
=
rank_prefix_matrix
[
i
][
rank
].
item
()
assert
(
check_x
[
check_start
:
check_end
,
:].
int
()
-
i
).
sum
().
item
()
==
0
check_start
=
check_end
for
previous_mode
in
(
False
,
True
):
for
async_mode
in
(
False
,
True
):
for
current_x
in
(
x_pure_rand
,
x
,
x_e4m3
):
for
with_topk
in
(
False
,
True
):
if
local_rank
==
0
:
print
(
f
'[testing] Running with
{
"FP8"
if
isinstance
(
current_x
,
tuple
)
else
"BF16"
}
,
{
"with"
if
with_topk
else
"without"
}
top-k (async=
{
async_mode
}
, previous=
{
previous_mode
}
) ...'
,
flush
=
True
,
end
=
""
,
)
dispatch_args
=
{
"x"
:
current_x
,
"num_tokens_per_rank"
:
num_tokens_per_rank
,
"is_token_in_rank"
:
is_token_in_rank
,
"num_tokens_per_expert"
:
num_tokens_per_expert
,
"config"
:
config
,
"async_finish"
:
async_mode
,
}
if
with_topk
:
dispatch_args
.
update
(
{
"topk_idx"
:
topk_idx
,
"topk_weights"
:
(
topk_weights_pure_rand
if
current_x
is
x_pure_rand
else
topk_weights
),
}
)
if
previous_mode
:
dispatch_args
.
update
({
"previous_event"
:
buffer
.
capture
()})
(
recv_x
,
recv_topk_idx
,
recv_topk_weights
,
recv_num_tokens_per_expert_list
,
handle
,
event
,
)
=
buffer
.
dispatch
(
**
dispatch_args
)
event
.
current_stream_wait
()
if
async_mode
else
()
recv_x
=
(
per_token_cast_back
(
*
recv_x
)
if
isinstance
(
recv_x
,
tuple
)
else
recv_x
)
# Checks
rank_prefix_matrix
=
handle
[
0
]
assert
gbl_num_tokens_per_rank
[
rank
].
item
()
==
recv_x
.
size
(
0
),
f
"
{
gbl_num_tokens_per_rank
[
rank
].
item
()
}
!=
{
recv_x
.
size
(
0
)
}
"
assert
(
gbl_num_tokens_per_expert
.
view
(
num_ranks
,
-
1
)[
rank
].
tolist
()
==
recv_num_tokens_per_expert_list
)
if
current_x
is
not
x_pure_rand
:
check_data
(
recv_x
,
rank_prefix_matrix
)
if
with_topk
:
# Check `topk_idx`
assert
(
recv_topk_idx
.
eq
(
-
1
)
|
(
(
recv_topk_idx
>=
0
)
&
(
recv_topk_idx
<
(
num_experts
//
num_ranks
))
)
).
sum
().
item
()
==
recv_topk_idx
.
numel
()
for
i
,
count
in
enumerate
(
recv_num_tokens_per_expert_list
):
assert
recv_topk_idx
.
eq
(
i
).
sum
().
item
()
==
count
# Check `topk_weights`
if
current_x
is
not
x_pure_rand
:
recv_topk_weights
[
recv_topk_idx
.
eq
(
-
1
)]
=
(
recv_topk_weights
.
amax
(
dim
=
1
,
keepdim
=
True
).
expand_as
(
recv_topk_weights
)[
recv_topk_idx
.
eq
(
-
1
)]
)
check_data
(
recv_topk_weights
,
rank_prefix_matrix
)
# Test cached dispatch (must without top-k staffs)
if
not
with_topk
:
dispatch_args
=
{
"x"
:
current_x
,
"handle"
:
handle
,
"config"
:
config
,
"async_finish"
:
async_mode
,
}
if
previous_mode
:
dispatch_args
.
update
({
"previous_event"
:
buffer
.
capture
()})
recv_x
,
_
,
_
,
_
,
_
,
event
=
buffer
.
dispatch
(
**
dispatch_args
)
event
.
current_stream_wait
()
if
async_mode
else
()
recv_x
=
(
per_token_cast_back
(
*
recv_x
)
if
isinstance
(
recv_x
,
tuple
)
else
recv_x
)
if
current_x
is
not
x_pure_rand
:
check_data
(
recv_x
,
rank_prefix_matrix
)
# Test combine
combine_args
=
{
"x"
:
recv_x
,
"handle"
:
handle
,
"config"
:
config
,
"async_finish"
:
async_mode
,
}
if
with_topk
:
combine_args
.
update
({
"topk_weights"
:
recv_topk_weights
})
if
previous_mode
:
dispatch_args
.
update
({
"previous_event"
:
buffer
.
capture
()})
combined_x
,
combined_topk_weights
,
event
=
buffer
.
combine
(
**
combine_args
)
event
.
current_stream_wait
()
if
async_mode
else
()
check_x
=
combined_x
.
float
()
/
is_token_in_rank
.
sum
(
dim
=
1
).
unsqueeze
(
1
)
ref_x
=
x_pure_rand
if
current_x
is
x_pure_rand
else
x
assert
calc_diff
(
check_x
,
ref_x
)
<
5e-6
if
with_topk
:
check_topk_weights
=
(
combined_topk_weights
if
(
current_x
is
x_pure_rand
)
else
(
combined_topk_weights
/
is_token_in_rank
.
sum
(
dim
=
1
).
unsqueeze
(
1
)
)
)
ref_topk_weights
=
(
topk_weights_pure_rand
if
current_x
is
x_pure_rand
else
topk_weights
)
assert
calc_diff
(
check_topk_weights
,
ref_topk_weights
)
<
1e-9
# For later tuning
dispatch_bf16_nvl_recv_bytes
=
recv_x
.
numel
()
*
2
combine_bf16_nvl_send_bytes
=
dispatch_bf16_nvl_recv_bytes
if
local_rank
==
0
:
print
(
" passed"
,
flush
=
True
)
if
local_rank
==
0
:
print
(
""
,
flush
=
True
)
# Tune dispatch performance
best_dispatch_results
=
None
fp8_factor
=
(
1
+
4
/
128
)
/
2
for
current_x
in
(
x_e4m3
,
x
):
best_time
,
best_results
=
1e10
,
None
nvl_recv_bytes
=
(
(
dispatch_bf16_nvl_recv_bytes
*
fp8_factor
)
if
isinstance
(
current_x
,
tuple
)
else
dispatch_bf16_nvl_recv_bytes
)
for
nvl_chunk_size
in
range
(
4
,
33
,
4
):
config
=
deep_ep
.
Config
(
num_sms
,
nvl_chunk_size
,
nvl_buffer_size
)
tune_args
=
{
"x"
:
current_x
,
"handle"
:
handle
,
"config"
:
config
}
t
=
bench
(
lambda
:
buffer
.
dispatch
(
**
tune_args
))[
0
]
if
t
<
best_time
:
best_time
,
best_results
=
t
,
(
num_sms
,
nvl_chunk_size
)
if
local_rank
==
0
:
print
(
f
"[tuning] SMs
{
num_sms
}
, NVL chunk
{
nvl_chunk_size
}
:
{
nvl_recv_bytes
/
1e9
/
t
:.
2
f
}
GB/s (NVL) "
,
flush
=
True
,
)
if
local_rank
==
0
:
print
(
f
'[tuning] Best dispatch (
{
"FP8"
if
isinstance
(
current_x
,
tuple
)
else
"BF16"
}
): SMs
{
best_results
[
0
]
}
, NVL chunk
{
best_results
[
1
]
}
,
{
nvl_recv_bytes
/
1e9
/
best_time
:.
2
f
}
GB/s (NVL)'
,
flush
=
True
,
)
print
(
""
,
flush
=
True
)
if
isinstance
(
current_x
,
tuple
):
# Gather FP8 the best config from rank 0
best_dispatch_results
=
torch
.
tensor
(
[
best_results
[
0
],
best_results
[
1
]],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
all_best_fp8_results_list
=
[
torch
.
zeros_like
(
best_dispatch_results
)
for
_
in
range
(
torch
.
distributed
.
get_world_size
())
]
dist
.
all_gather
(
all_best_fp8_results_list
,
best_dispatch_results
,
group
=
group
)
best_dispatch_results
=
all_best_fp8_results_list
[
0
].
tolist
()
dispatch_config
=
deep_ep
.
Config
(
best_dispatch_results
[
0
],
best_dispatch_results
[
1
],
nvl_buffer_size
)
dispatch_args
=
{
"x"
:
x
,
"num_tokens_per_rank"
:
num_tokens_per_rank
,
"is_token_in_rank"
:
is_token_in_rank
,
"num_tokens_per_expert"
:
num_tokens_per_expert
,
"config"
:
dispatch_config
if
dispatch_config
is
not
None
else
config
,
}
recv_x
,
_
,
_
,
_
,
handle
,
_
=
buffer
.
dispatch
(
**
dispatch_args
)
# Tune combine performance
best_time
,
best_results
=
1e10
,
None
for
nvl_chunk_size
in
range
(
1
,
7
,
1
):
config
=
deep_ep
.
Config
(
num_sms
,
nvl_chunk_size
,
nvl_buffer_size
)
tune_args
=
{
"x"
:
recv_x
,
"handle"
:
handle
,
"config"
:
config
}
t
=
bench
(
lambda
:
buffer
.
combine
(
**
tune_args
))[
0
]
if
local_rank
==
0
:
print
(
f
"[tuning] SMs
{
num_sms
}
, NVL chunk
{
nvl_chunk_size
}
:
{
combine_bf16_nvl_send_bytes
/
1e9
/
t
:.
2
f
}
GB/s (NVL) "
,
flush
=
True
,
)
if
t
<
best_time
:
best_time
,
best_results
=
t
,
(
num_sms
,
nvl_chunk_size
)
if
local_rank
==
0
:
print
(
f
"[tuning] Best combine: SMs
{
best_results
[
0
]
}
, NVL chunk
{
best_results
[
1
]
}
:
{
combine_bf16_nvl_send_bytes
/
1e9
/
best_time
:.
2
f
}
GB/s (NVL)"
,
flush
=
True
,
)
print
(
""
,
flush
=
True
)
# noinspection PyUnboundLocalVariable
def
test_loop
(
local_rank
:
int
,
num_local_ranks
:
int
):
rank
,
num_ranks
,
group
=
init_dist
(
local_rank
,
num_local_ranks
)
test_ll_compatibility
,
num_rdma_bytes
=
False
,
0
if
test_ll_compatibility
:
ll_num_tokens
,
ll_hidden
,
ll_num_experts
,
ll_num_topk
=
16
,
5120
,
256
,
9
num_rdma_bytes
=
deep_ep
.
Buffer
.
get_low_latency_rdma_size_hint
(
ll_num_tokens
,
ll_hidden
,
num_ranks
,
ll_num_experts
)
buffer
=
deep_ep
.
Buffer
(
group
,
int
(
1e9
),
num_rdma_bytes
,
low_latency_mode
=
test_ll_compatibility
,
num_qps_per_rank
=
(
ll_num_experts
//
num_ranks
if
test_ll_compatibility
else
1
),
)
torch
.
manual_seed
(
rank
)
for
i
in
(
24
,):
test_main
(
i
,
local_rank
,
num_ranks
,
rank
,
buffer
,
group
)
if
local_rank
==
0
:
print
(
""
,
flush
=
True
)
# Test compatibility with low latency functions
if
test_ll_compatibility
:
buffer
.
clean_low_latency_buffer
(
ll_num_tokens
,
ll_hidden
,
ll_num_experts
)
test_deepep_low_latency
.
test_main
(
ll_num_tokens
,
ll_hidden
,
ll_num_experts
,
ll_num_topk
,
rank
,
num_ranks
,
group
,
buffer
,
seed
=
1
,
)
if
__name__
==
"__main__"
:
num_processes
=
8
torch
.
multiprocessing
.
spawn
(
test_loop
,
args
=
(
num_processes
,),
nprocs
=
num_processes
)
test/srt/test_deepep_low_latency.py
0 → 100644
View file @
8a828666
# Copy from deepseek-ai/DeepEP/tests/test_low_latency.py
import
random
from
functools
import
partial
import
deep_ep
import
torch
import
torch.distributed
as
dist
from
sglang.test.test_deepep_utils
import
(
bench
,
bench_kineto
,
calc_diff
,
hash_tensor
,
init_dist
,
per_token_cast_back
,
)
def
test_main
(
num_tokens
:
int
,
hidden
:
int
,
num_experts
:
int
,
num_topk
:
int
,
rank
:
int
,
num_ranks
:
int
,
group
:
dist
.
ProcessGroup
,
buffer
:
deep_ep
.
Buffer
,
seed
:
int
=
0
,
):
torch
.
manual_seed
(
seed
+
rank
)
random
.
seed
(
seed
+
rank
)
assert
num_experts
%
num_ranks
==
0
num_local_experts
=
num_experts
//
num_ranks
# NOTES: the integers greater than 256 exceeds the BF16 precision limit
rank_offset
=
128
assert
(
num_ranks
-
rank_offset
<
257
),
"Too many ranks (exceeding test precision limit)"
x
=
torch
.
ones
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
*
(
rank
-
rank_offset
)
x
[:,
-
128
:]
=
torch
.
arange
(
num_tokens
,
device
=
"cuda"
).
to
(
torch
.
bfloat16
).
view
(
-
1
,
1
)
scores
=
(
torch
.
randn
((
num_tokens
,
num_experts
),
dtype
=
torch
.
float32
,
device
=
"cuda"
).
abs
()
+
1
)
topk_idx
=
torch
.
topk
(
scores
,
num_topk
,
dim
=-
1
,
largest
=
True
,
sorted
=
True
)[
1
]
topk_weights
=
torch
.
randn
(
(
num_tokens
,
num_topk
),
dtype
=
torch
.
float32
,
device
=
"cuda"
).
abs
()
# Randomly mask some positions
for
i
in
range
(
10
):
topk_idx
[
random
.
randint
(
0
,
num_tokens
-
1
),
random
.
randint
(
0
,
num_topk
-
1
)]
=
(
-
1
)
# Check dispatch correctness
do_check
=
True
hash_value
,
num_times
=
0
,
0
for
return_recv_hook
in
(
False
,
True
):
for
dispatch_use_fp8
in
(
False
,
True
):
num_times
+=
1
for
i
in
range
((
num_times
%
2
)
+
1
):
packed_recv_x
,
packed_recv_count
,
handle
,
event
,
hook
=
(
buffer
.
low_latency_dispatch
(
x
,
topk_idx
,
num_tokens
,
num_experts
,
use_fp8
=
dispatch_use_fp8
,
async_finish
=
not
return_recv_hook
,
return_recv_hook
=
return_recv_hook
,
)
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
packed_recv_x
=
(
(
packed_recv_x
[
0
],
packed_recv_x
[
1
].
contiguous
())
if
dispatch_use_fp8
else
packed_recv_x
)
simulated_gemm_x
=
(
per_token_cast_back
(
packed_recv_x
[
0
].
view
(
-
1
,
hidden
),
packed_recv_x
[
1
].
view
(
-
1
,
hidden
//
128
),
).
view
(
packed_recv_x
[
0
].
shape
)
if
dispatch_use_fp8
else
packed_recv_x
.
clone
()
)
all_topk_idx
=
torch
.
empty
(
(
num_ranks
,
num_tokens
,
num_topk
),
dtype
=
topk_idx
.
dtype
,
device
=
"cuda"
)
dist
.
all_gather_into_tensor
(
all_topk_idx
,
topk_idx
,
group
=
group
)
for
i
in
range
(
num_local_experts
if
do_check
else
0
):
expert_id
=
rank
*
num_local_experts
+
i
recv_x
=
(
per_token_cast_back
(
packed_recv_x
[
0
][
i
],
packed_recv_x
[
1
][
i
])
if
dispatch_use_fp8
else
packed_recv_x
[
i
]
)
recv_count
,
recv_src_info
,
recv_layout_range
=
(
packed_recv_count
[
i
],
handle
[
0
][
i
],
handle
[
1
][
i
],
)
# Check expert indices
int_mask
=
(
2
**
32
)
-
1
num_valid_tokens
=
recv_count
.
item
()
assert
(
num_valid_tokens
==
(
recv_layout_range
&
int_mask
).
sum
().
item
()
),
f
"
{
num_valid_tokens
}
!=
{
recv_layout_range
&
int_mask
}
.sum().item()"
assert
(
num_valid_tokens
==
(
all_topk_idx
==
expert_id
).
sum
().
item
()
),
f
"
{
num_valid_tokens
}
!=
{
(
all_topk_idx
==
expert_id
).
sum
().
item
()
}
"
# Check received data
recv_x
=
recv_x
[:
num_valid_tokens
]
recv_x_amin
=
recv_x
[:,
:
-
128
].
amin
(
dim
=-
1
)
recv_src_info
=
recv_src_info
[:
num_valid_tokens
]
assert
torch
.
equal
(
recv_x_amin
,
recv_x
[:,
:
-
128
].
amax
(
dim
=-
1
))
assert
(
recv_x
[:,
-
128
:]
-
recv_src_info
.
view
(
-
1
,
1
)
%
num_tokens
).
sum
().
item
()
==
0
for
j
in
range
(
num_ranks
):
begin_idx
,
count
=
(
recv_layout_range
[
j
]
>>
32
).
item
(),
(
recv_layout_range
[
j
]
&
int_mask
).
item
()
assert
(
recv_x_amin
==
j
-
rank_offset
).
sum
().
item
()
==
(
all_topk_idx
[
j
]
==
expert_id
).
sum
().
item
()
assert
(
recv_x
[
begin_idx
:
begin_idx
+
count
][:
-
128
]
-
j
).
sum
().
item
()
==
0
if
dispatch_use_fp8
:
hash_value
^=
hash_tensor
(
packed_recv_x
[
0
][
i
,
:
num_valid_tokens
])
hash_value
^=
hash_tensor
(
packed_recv_x
[
1
][
i
,
:
num_valid_tokens
])
else
:
hash_value
^=
hash_tensor
(
packed_recv_x
[
i
,
:
num_valid_tokens
])
# Check combine correctness
for
zero_copy
in
(
False
,
True
):
if
zero_copy
:
buffer
.
get_next_low_latency_combine_buffer
(
handle
)[
:,
:,
:
]
=
simulated_gemm_x
out
=
torch
.
empty
(
(
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
topk_weights
,
handle
,
async_finish
=
not
return_recv_hook
,
zero_copy
=
zero_copy
,
return_recv_hook
=
return_recv_hook
,
out
=
out
,
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
if
do_check
:
diff
=
calc_diff
(
x
*
topk_weights
.
masked_fill
(
topk_idx
==
-
1
,
0
)
.
sum
(
dim
=
1
)
.
view
(
-
1
,
1
),
combined_x
,
)
assert
torch
.
isnan
(
combined_x
).
sum
().
item
()
==
0
assert
diff
<
1e-5
,
f
"Error:
{
diff
=
}
,
{
zero_copy
=
}
"
hash_value
^=
hash_tensor
(
combined_x
)
def
create_test_cast_with_outliers
(
num_outliers
):
tmp
=
torch
.
randn
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
tmp
/=
tmp
.
abs
().
amax
(
dim
=
1
).
view
(
-
1
,
1
)
assert
tmp
.
abs
().
amax
().
item
()
<=
1
# Create some amax outliers
for
i
in
range
(
num_outliers
):
tmp
[
random
.
randint
(
0
,
num_tokens
-
1
)]
*=
1e3
return
tmp
# noinspection PyShadowingNames
def
large_gemm_with_hook
(
hook
):
mat_0
=
torch
.
randn
((
8192
,
8192
),
dtype
=
torch
.
float
)
mat_1
=
torch
.
randn
((
8192
,
8192
),
dtype
=
torch
.
float
)
mat_0
@
mat_1
hook
()
# noinspection PyShadowingNames
def
test_func
(
zero_copy
:
bool
,
return_recv_hook
:
bool
):
recv_x
,
recv_count
,
handle
,
event
,
hook
=
buffer
.
low_latency_dispatch
(
x
,
topk_idx
,
num_tokens
,
num_experts
,
async_finish
=
False
,
return_recv_hook
=
return_recv_hook
,
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
if
zero_copy
:
buffer
.
get_next_low_latency_combine_buffer
(
handle
)[
:,
:,
:
]
=
simulated_gemm_x
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
topk_weights
,
handle
,
zero_copy
=
zero_copy
,
return_recv_hook
=
return_recv_hook
,
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
# Calculate bandwidth
num_fp8_bytes
,
num_bf16_bytes
=
(
hidden
+
hidden
/
128
*
4
+
16
),
hidden
*
2
num_dispatch_comm_bytes
,
num_combine_comm_bytes
=
0
,
0
for
i
in
range
(
num_tokens
):
num_selections
=
(
topk_idx
[
i
]
!=
-
1
).
sum
().
item
()
num_dispatch_comm_bytes
+=
num_fp8_bytes
*
num_selections
num_combine_comm_bytes
+=
num_bf16_bytes
*
num_selections
# Dispatch + combine testing
avg_t
,
min_t
,
max_t
=
bench
(
partial
(
test_func
,
zero_copy
=
False
,
return_recv_hook
=
False
)
)
print
(
f
"[rank
{
rank
}
] Dispatch + combine bandwidth:
{
(
num_dispatch_comm_bytes
+
num_combine_comm_bytes
)
/
1e9
/
avg_t
:.
2
f
}
GB/s, "
f
"avg_t=
{
avg_t
*
1e6
:.
2
f
}
us, min_t=
{
min_t
*
1e6
:.
2
f
}
us, max_t=
{
max_t
*
1e6
:.
2
f
}
us"
,
flush
=
True
,
)
# Separate profiling
for
return_recv_hook
in
(
False
,
True
):
group
.
barrier
()
dispatch_t
,
combine_t
=
bench_kineto
(
partial
(
test_func
,
zero_copy
=
True
,
return_recv_hook
=
return_recv_hook
),
kernel_names
=
(
"dispatch"
,
"combine"
),
barrier_comm_profiling
=
True
,
suppress_kineto_output
=
True
,
)
if
not
return_recv_hook
:
print
(
f
"[rank
{
rank
}
] Dispatch bandwidth:
{
num_dispatch_comm_bytes
/
1e9
/
dispatch_t
:.
2
f
}
GB/s, avg_t=
{
dispatch_t
*
1e6
:.
2
f
}
us | "
f
"Combine bandwidth:
{
num_combine_comm_bytes
/
1e9
/
combine_t
:.
2
f
}
GB/s, avg_t=
{
combine_t
*
1e6
:.
2
f
}
us"
,
flush
=
True
,
)
else
:
print
(
f
"[rank
{
rank
}
] Dispatch send/recv time:
{
dispatch_t
*
2
*
1e6
:.
2
f
}
us | "
f
"Combine send/recv time:
{
combine_t
*
2
*
1e6
:.
2
f
}
us"
,
flush
=
True
,
)
return
hash_value
# noinspection PyUnboundLocalVariable
def
test_loop
(
local_rank
:
int
,
num_local_ranks
:
int
):
rank
,
num_ranks
,
group
=
init_dist
(
local_rank
,
num_local_ranks
)
num_tokens
,
hidden
,
num_topk
,
num_experts
=
128
,
7168
,
8
,
288
num_rdma_bytes
=
deep_ep
.
Buffer
.
get_low_latency_rdma_size_hint
(
num_tokens
,
hidden
,
num_ranks
,
num_experts
)
if
local_rank
==
0
:
print
(
f
"Allocating buffer size:
{
num_rdma_bytes
/
1e6
}
MB ..."
,
flush
=
True
)
buffer
=
deep_ep
.
Buffer
(
group
,
num_rdma_bytes
=
num_rdma_bytes
,
low_latency_mode
=
True
,
num_qps_per_rank
=
num_experts
//
num_ranks
,
)
test_main
(
num_tokens
,
hidden
,
num_experts
,
num_topk
,
rank
,
num_ranks
,
group
,
buffer
,
seed
=
1
,
)
do_pressure_test
=
False
for
seed
in
range
(
int
(
1e9
)
if
do_pressure_test
else
0
):
if
local_rank
==
0
:
print
(
f
"Testing with seed
{
seed
}
..."
,
flush
=
True
)
ref_hash
=
test_main
(
num_tokens
,
hidden
,
num_experts
,
num_topk
,
rank
,
num_ranks
,
group
,
buffer
,
seed
=
seed
,
)
for
i
in
range
(
20
):
assert
(
test_main
(
num_tokens
,
hidden
,
num_experts
,
num_topk
,
rank
,
num_ranks
,
group
,
buffer
,
seed
=
seed
,
)
==
ref_hash
),
f
"Error: seed=
{
seed
}
"
if
__name__
==
"__main__"
:
# TODO: you may modify NUMA binding for less CPU overhead
num_processes
=
8
torch
.
multiprocessing
.
spawn
(
test_loop
,
args
=
(
num_processes
,),
nprocs
=
num_processes
)
test/srt/test_moe_deepep_eval_accuracy_large.py
0 → 100644
View file @
8a828666
"""
Usage:
python -m unittest test_moe_deepep_eval_accuracy_large.TestMoEDeepEPEvalAccuracyLarge.test_mmlu
"""
import
unittest
from
types
import
SimpleNamespace
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.few_shot_gsm8k
import
run_eval
as
run_eval_few_shot_gsm8k
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_DEEPPEP_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
popen_launch_server
,
)
class
TestMoEDeepEPEvalAccuracyLarge
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_DEEPPEP_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--trust-remote-code"
,
"--tp"
,
"8"
,
"--enable-deepep-moe"
,
"--cuda-graph-max-bs"
,
"128"
,
],
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_gsm8k
(
self
):
args
=
SimpleNamespace
(
num_shots
=
8
,
data_path
=
None
,
num_questions
=
200
,
parallel
=
64
,
max_new_tokens
=
512
,
host
=
"http://127.0.0.1"
,
port
=
int
(
self
.
base_url
.
split
(
":"
)[
-
1
]),
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
print
(
f
"Eval accuracy of GSM8K:
{
metrics
=
}
"
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.93
)
def
test_mmlu
(
self
):
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"mmlu"
,
num_examples
=
64
,
num_threads
=
32
,
)
metrics
=
run_eval
(
args
)
print
(
f
"Eval accuracy of MMLU:
{
metrics
=
}
"
)
self
.
assertGreater
(
metrics
[
"score"
],
0.87
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment