Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e3e0bc50
Unverified
Commit
e3e0bc50
authored
Mar 01, 2025
by
fzyzcjy
Committed by
GitHub
Feb 28, 2025
Browse files
[Feature] SPMD for SGLang + Verl (#3852)
parent
bac414ab
Changes
19
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
890 additions
and
202 deletions
+890
-202
.github/workflows/pr-test.yml
.github/workflows/pr-test.yml
+6
-0
examples/runtime/engine/offline_batch_inference_torchrun.py
examples/runtime/engine/offline_batch_inference_torchrun.py
+81
-0
python/sglang/srt/entrypoints/engine.py
python/sglang/srt/entrypoints/engine.py
+15
-4
python/sglang/srt/entrypoints/verl_engine.py
python/sglang/srt/entrypoints/verl_engine.py
+145
-0
python/sglang/srt/managers/data_parallel_controller.py
python/sglang/srt/managers/data_parallel_controller.py
+6
-2
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+2
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+3
-2
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+4
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+42
-3
python/sglang/srt/models/gemma.py
python/sglang/srt/models/gemma.py
+0
-6
python/sglang/srt/models/gemma2.py
python/sglang/srt/models/gemma2.py
+0
-7
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+8
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+0
-1
python/sglang/test/runners.py
python/sglang/test/runners.py
+231
-133
python/sglang/test/test_programs.py
python/sglang/test/test_programs.py
+1
-1
test/lang/test_srt_backend.py
test/lang/test_srt_backend.py
+1
-1
test/srt/models/test_generation_models.py
test/srt/models/test_generation_models.py
+20
-41
test/srt/test_update_weights_from_tensor.py
test/srt/test_update_weights_from_tensor.py
+28
-0
test/srt/test_verl_engine.py
test/srt/test_verl_engine.py
+297
-0
No files found.
.github/workflows/pr-test.yml
View file @
e3e0bc50
...
...
@@ -149,6 +149,12 @@ jobs:
cd test/srt
python3 test_update_weights_from_distributed.py
-
name
:
Test VerlEngine
timeout-minutes
:
10
run
:
|
cd test/srt
python3 test_verl_engine.py
-
name
:
Test expert parallelism (EP=2)
timeout-minutes
:
10
run
:
|
...
...
examples/runtime/engine/offline_batch_inference_torchrun.py
0 → 100644
View file @
e3e0bc50
import
datetime
import
os
import
sys
from
torch.distributed.device_mesh
import
init_device_mesh
from
sglang.srt.entrypoints.verl_engine
import
VerlEngine
def
run
():
"""
Example command:
```
torchrun --nproc_per_node=8 offline_batch_inference_torchrun.py
```
"""
local_rank
=
int
(
os
.
environ
[
"LOCAL_RANK"
])
rank
=
int
(
os
.
environ
[
"RANK"
])
world_size
=
int
(
os
.
environ
[
"WORLD_SIZE"
])
def
_log
(
text
):
t
=
datetime
.
datetime
.
now
().
strftime
(
"%H:%M:%S"
)
print
(
f
"[
{
t
}
] [rank=
{
rank
}
]
{
text
}
"
)
_log
(
f
'start
{
local_rank
=
}
{
rank
=
}
{
world_size
=
}
{
sys
.
argv
=
}
{
os
.
environ
.
get
(
"CUDA_VISIBLE_DEVICES"
)
}
'
)
tp_size
=
4
dp_size
=
2
assert
world_size
==
tp_size
*
dp_size
device_mesh_kwargs
=
dict
(
mesh_shape
=
(
tp_size
,
dp_size
,
1
),
mesh_dim_names
=
[
"tp"
,
"dp"
,
"pp"
]
)
device_mesh_cpu
=
init_device_mesh
(
"cpu"
,
**
device_mesh_kwargs
)
_log
(
f
"
{
device_mesh_cpu
=
}
"
)
tp_rank
=
device_mesh_cpu
.
get_local_rank
(
"tp"
)
dp_rank
=
device_mesh_cpu
.
get_local_rank
(
"dp"
)
_log
(
f
"
{
tp_rank
=
}
{
tp_size
=
}
;
{
dp_rank
=
}
{
dp_size
=
}
"
)
model_name
,
mem_fraction_static
=
"meta-llama/Llama-3.2-1B-Instruct"
,
0.1
# model_name, mem_fraction_static = "meta-llama/Llama-3.1-70B-Instruct", 0.9 # test large models
# model_name, mem_fraction_static = "deepseek-ai/DeepSeek-V2-Lite", 0.8
for
k
in
[
"TORCHELASTIC_USE_AGENT_STORE"
]:
if
k
in
os
.
environ
:
del
os
.
environ
[
k
]
fragment
=
VerlEngine
(
model_path
=
model_name
,
mem_fraction_static
=
mem_fraction_static
,
device_mesh_cpu
=
device_mesh_cpu
[
"tp"
],
base_gpu_id
=
dp_rank
,
gpu_id_step
=
dp_size
,
port
=
30000
,
# for DeepSeek-V2-Lite + DP Attention
# enable_dp_attention=True, port=30000 + dp_rank * 100,
)
_log
(
f
"
{
fragment
=
}
"
)
prompt_all
=
[
[
"1+1=2, 1+2=3, 1+3=4, 1+4="
,
"9-1=8, 8-1=7, 7-1="
],
[
"2*1=2, 2*2=4, 2*3="
,
"8/2=4, 6/2="
],
]
prompt
=
prompt_all
[
dp_rank
]
output
=
fragment
.
generate
(
prompt
=
prompt
,
sampling_params
=
dict
(
max_new_tokens
=
16
,
temperature
=
0.0
),
)
_log
(
f
"
{
prompt
=
}
{
output
=
}
"
)
fragment
.
shutdown
()
_log
(
f
"End script"
)
if
__name__
==
"__main__"
:
run
()
python/sglang/srt/entrypoints/engine.py
View file @
e3e0bc50
...
...
@@ -271,10 +271,18 @@ class Engine:
self
.
tokenizer_manager
.
update_weights_from_distributed
(
obj
,
None
)
)
def
update_weights_from_tensor
(
self
,
named_tensors
:
List
[
Tuple
[
str
,
torch
.
Tensor
]]):
"""Update weights from distributed source."""
def
update_weights_from_tensor
(
self
,
named_tensors
:
List
[
Tuple
[
str
,
torch
.
Tensor
]],
load_format
:
Optional
[
str
]
=
None
,
flush_cache
:
bool
=
True
,
):
"""Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be true
to avoid duplicated operations such as clearing cache."""
obj
=
UpdateWeightsFromTensorReqInput
(
serialized_named_tensors
=
MultiprocessingSerializer
.
serialize
(
named_tensors
)
serialized_named_tensors
=
MultiprocessingSerializer
.
serialize
(
named_tensors
),
load_format
=
load_format
,
flush_cache
=
flush_cache
,
)
loop
=
asyncio
.
get_event_loop
()
return
loop
.
run_until_complete
(
...
...
@@ -384,7 +392,10 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic
)
for
tp_rank
in
tp_rank_range
:
reader
,
writer
=
mp
.
Pipe
(
duplex
=
False
)
gpu_id
=
server_args
.
base_gpu_id
+
tp_rank
%
tp_size_per_node
gpu_id
=
(
server_args
.
base_gpu_id
+
(
tp_rank
%
tp_size_per_node
)
*
server_args
.
gpu_id_step
)
proc
=
mp
.
Process
(
target
=
run_scheduler_process
,
args
=
(
server_args
,
port_args
,
gpu_id
,
tp_rank
,
None
,
writer
),
...
...
python/sglang/srt/entrypoints/verl_engine.py
0 → 100644
View file @
e3e0bc50
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
os
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch.distributed
as
dist
from
torch.distributed.tensor
import
DeviceMesh
,
DTensor
from
sglang.srt.model_executor.model_runner
import
LocalSerializedTensor
from
sglang.srt.server
import
Engine
from
sglang.srt.utils
import
MultiprocessingSerializer
,
broadcast_pyobj
class
VerlEngine
:
def
__init__
(
self
,
device_mesh_cpu
:
DeviceMesh
,
nnodes
:
int
=
1
,
**
kwargs
,
):
self
.
_device_mesh_cpu
=
device_mesh_cpu
self
.
_tp_rank
=
device_mesh_cpu
.
get_local_rank
()
self
.
_tp_size
=
device_mesh_cpu
.
size
()
tp_size_per_node
=
self
.
_tp_size
//
nnodes
node_rank
=
self
.
_tp_rank
//
tp_size_per_node
first_rank_in_node
=
self
.
_tp_rank
%
tp_size_per_node
==
0
if
first_rank_in_node
:
os
.
environ
[
"SGLANG_BLOCK_NONZERO_RANK_CHILDREN"
]
=
"0"
self
.
_engine
=
Engine
(
**
kwargs
,
tp_size
=
self
.
_tp_size
,
node_rank
=
node_rank
,
nnodes
=
nnodes
)
else
:
self
.
_engine
=
None
dist
.
barrier
(
group
=
self
.
_device_mesh_cpu
.
get_group
())
def
generate
(
self
,
# The input prompt. It can be a single prompt or a batch of prompts.
prompt
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
,
sampling_params
:
Optional
[
Union
[
List
[
Dict
],
Dict
]]
=
None
,
# The token ids for text; one can either specify text or input_ids.
input_ids
:
Optional
[
Union
[
List
[
List
[
int
]],
List
[
int
]]]
=
None
,
# The image input. It can be a file name, a url, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image.
image_data
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
,
return_logprob
:
Optional
[
Union
[
List
[
bool
],
bool
]]
=
False
,
logprob_start_len
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
top_logprobs_num
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
lora_path
:
Optional
[
List
[
Optional
[
str
]]]
=
None
,
custom_logit_processor
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
,
)
->
Dict
:
"""
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
Please refer to `GenerateReqInput` for the documentation.
"""
if
self
.
_tp_rank
==
0
:
output
=
self
.
_engine
.
generate
(
prompt
=
prompt
,
sampling_params
=
sampling_params
,
input_ids
=
input_ids
,
image_data
=
image_data
,
return_logprob
=
return_logprob
,
logprob_start_len
=
logprob_start_len
,
top_logprobs_num
=
top_logprobs_num
,
lora_path
=
lora_path
,
custom_logit_processor
=
custom_logit_processor
,
)
else
:
output
=
None
# Most naive implementation, can extract tensor and send via gloo if too slow
[
output
]
=
broadcast_pyobj
(
data
=
[
output
],
rank
=
self
.
_tp_rank
,
dist_group
=
self
.
_device_mesh_cpu
.
get_group
(),
src
=
self
.
_device_mesh_cpu
.
mesh
[
0
].
item
(),
)
return
output
def
update_weights_from_tensor
(
self
,
named_tensors
:
List
[
Tuple
[
str
,
torch
.
Tensor
]],
load_format
:
Optional
[
str
]
=
None
,
):
# Most naive implementation, can optimize a lot if it is bottleneck
for
tensor_index
,
(
name
,
tensor
)
in
enumerate
(
named_tensors
):
serialized_tensor
=
MultiprocessingSerializer
.
serialize
(
_preprocess_tensor_for_update_weights
(
tensor
)
)
if
self
.
_tp_rank
==
0
:
gathered_serialized_tensors
=
[
None
for
_
in
range
(
self
.
_tp_size
)]
else
:
gathered_serialized_tensors
=
None
dist
.
gather_object
(
obj
=
serialized_tensor
,
object_gather_list
=
gathered_serialized_tensors
,
dst
=
self
.
_device_mesh_cpu
.
mesh
.
tolist
()[
0
],
group
=
self
.
_device_mesh_cpu
.
get_group
(),
)
if
self
.
_tp_rank
==
0
:
self
.
_engine
.
update_weights_from_tensor
(
named_tensors
=
[
(
name
,
LocalSerializedTensor
(
values
=
gathered_serialized_tensors
),
)
],
load_format
=
load_format
,
flush_cache
=
tensor_index
==
len
(
named_tensors
)
-
1
,
)
def
release_memory_occupation
(
self
):
if
self
.
_tp_rank
==
0
:
self
.
_engine
.
release_memory_occupation
()
def
resume_memory_occupation
(
self
):
if
self
.
_tp_rank
==
0
:
self
.
_engine
.
resume_memory_occupation
()
def
shutdown
(
self
):
if
self
.
_engine
is
not
None
:
self
.
_engine
.
shutdown
()
def
_preprocess_tensor_for_update_weights
(
tensor
:
torch
.
Tensor
):
if
isinstance
(
tensor
,
DTensor
):
return
tensor
.
full_tensor
()
return
tensor
python/sglang/srt/managers/data_parallel_controller.py
View file @
e3e0bc50
...
...
@@ -121,7 +121,7 @@ class DataParallelController:
args
=
(
server_args
,
tmp_port_args
,
base_gpu_id
,
dp_rank
),
)
threads
.
append
(
thread
)
base_gpu_id
+=
server_args
.
tp_size
base_gpu_id
+=
server_args
.
tp_size
*
server_args
.
gpu_id_step
# Free all sockets before starting the threads to launch TP workers
for
sock
in
sockets
:
...
...
@@ -177,7 +177,11 @@ class DataParallelController:
rank_port_args
.
nccl_port
=
port_args
.
nccl_port
reader
,
writer
=
mp
.
Pipe
(
duplex
=
False
)
gpu_id
=
server_args
.
base_gpu_id
+
base_gpu_id
+
tp_rank
%
tp_size_per_node
gpu_id
=
(
server_args
.
base_gpu_id
+
base_gpu_id
+
(
tp_rank
%
tp_size_per_node
)
*
server_args
.
gpu_id_step
)
proc
=
mp
.
Process
(
target
=
run_scheduler_process
,
args
=
(
server_args
,
rank_port_args
,
gpu_id
,
tp_rank
,
dp_rank
,
writer
),
...
...
python/sglang/srt/managers/io_struct.py
View file @
e3e0bc50
...
...
@@ -449,6 +449,8 @@ class UpdateWeightsFromDistributedReqOutput:
@
dataclass
class
UpdateWeightsFromTensorReqInput
:
serialized_named_tensors
:
bytes
# indeed Dict[str, torch.Tensor]
load_format
:
Optional
[
str
]
flush_cache
:
bool
@
dataclass
...
...
python/sglang/srt/managers/scheduler.py
View file @
e3e0bc50
...
...
@@ -1760,8 +1760,9 @@ class Scheduler:
success
,
message
=
self
.
tp_worker
.
update_weights_from_tensor
(
recv_req
)
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
if
success
:
flash_cache_success
=
self
.
flush_cache
()
assert
flash_cache_success
,
"Cache flush failed after updating weights"
if
recv_req
.
flush_cache
:
flash_cache_success
=
self
.
flush_cache
()
assert
flash_cache_success
,
"Cache flush failed after updating weights"
else
:
logger
.
error
(
message
)
return
UpdateWeightsFromTensorReqOutput
(
success
,
message
)
...
...
python/sglang/srt/managers/tp_worker.py
View file @
e3e0bc50
...
...
@@ -205,7 +205,10 @@ class TpModelWorker:
def
update_weights_from_tensor
(
self
,
recv_req
:
UpdateWeightsFromTensorReqInput
):
success
,
message
=
self
.
model_runner
.
update_weights_from_tensor
(
MultiprocessingSerializer
.
deserialize
(
recv_req
.
serialized_named_tensors
)
named_tensors
=
MultiprocessingSerializer
.
deserialize
(
recv_req
.
serialized_named_tensors
),
load_format
=
recv_req
.
load_format
,
)
return
success
,
message
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
e3e0bc50
...
...
@@ -17,7 +17,8 @@ import gc
import
json
import
logging
import
time
from
typing
import
List
,
Optional
,
Tuple
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch.distributed
as
dist
...
...
@@ -56,10 +57,12 @@ from sglang.srt.mem_cache.memory_pool import (
from
sglang.srt.model_executor.cuda_graph_runner
import
CudaGraphRunner
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader
import
get_model
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.torch_memory_saver_adapter
import
TorchMemorySaverAdapter
from
sglang.srt.utils
import
(
MultiprocessingSerializer
,
enable_show_time_cost
,
get_available_gpu_memory
,
init_custom_process_group
,
...
...
@@ -514,8 +517,21 @@ class ModelRunner:
logger
.
error
(
error_msg
)
return
False
,
error_msg
def
update_weights_from_tensor
(
self
,
named_tensors
:
List
[
Tuple
[
str
,
torch
.
Tensor
]]):
self
.
model
.
load_weights
(
named_tensors
)
def
update_weights_from_tensor
(
self
,
named_tensors
:
List
[
Tuple
[
str
,
Union
[
torch
.
Tensor
,
"LocalSerializedTensor"
]]],
load_format
:
Optional
[
str
]
=
None
,
):
named_tensors
=
[
(
name
,
_unwrap_tensor
(
tensor
,
tp_rank
=
self
.
tp_rank
))
for
name
,
tensor
in
named_tensors
]
if
load_format
==
"direct"
:
_model_load_weights_direct
(
self
.
model
,
named_tensors
)
elif
load_format
is
None
:
self
.
model
.
load_weights
(
named_tensors
)
else
:
raise
NotImplementedError
(
f
"Unknown load_format=
{
load_format
}
"
)
return
True
,
"Success"
def
get_weights_by_name
(
...
...
@@ -836,3 +852,26 @@ class ModelRunner:
if
rope_scaling
is
None
:
return
False
return
rope_scaling
.
get
(
"type"
,
None
)
==
"mrope"
def
_model_load_weights_direct
(
model
,
named_tensors
:
List
[
Tuple
[
str
,
torch
.
Tensor
]]):
params_dict
=
dict
(
model
.
named_parameters
())
for
name
,
tensor
in
named_tensors
:
default_weight_loader
(
params_dict
[
name
],
tensor
)
def
_unwrap_tensor
(
tensor
,
tp_rank
):
if
isinstance
(
tensor
,
LocalSerializedTensor
):
return
tensor
.
get
(
tp_rank
)
return
tensor
@
dataclass
class
LocalSerializedTensor
:
"""torch.Tensor that gets serialized by MultiprocessingSerializer (which only serializes a pointer and not the data).
The i-th element in the list corresponds to i-th rank's GPU."""
values
:
List
[
bytes
]
def
get
(
self
,
rank
:
int
):
return
MultiprocessingSerializer
.
deserialize
(
self
.
values
[
rank
])
python/sglang/srt/models/gemma.py
View file @
e3e0bc50
...
...
@@ -336,12 +336,6 @@ class GemmaForCausalLM(nn.Module):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
unloaded_params
=
params_dict
.
keys
()
-
loaded_params
if
unloaded_params
:
raise
RuntimeError
(
"Some weights are not initialized from checkpoints: "
f
"
{
unloaded_params
}
"
)
EntryClass
=
GemmaForCausalLM
python/sglang/srt/models/gemma2.py
View file @
e3e0bc50
...
...
@@ -437,12 +437,5 @@ class Gemma2ForCausalLM(nn.Module):
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
unloaded_params
=
params_dict
.
keys
()
-
loaded_params
if
unloaded_params
:
raise
RuntimeError
(
"Some weights are not initialized from checkpoints: "
f
"
{
unloaded_params
}
"
)
EntryClass
=
Gemma2ForCausalLM
python/sglang/srt/server_args.py
View file @
e3e0bc50
...
...
@@ -82,6 +82,7 @@ class ServerArgs:
dist_timeout
:
Optional
[
int
]
=
None
# timeout for torch.distributed
download_dir
:
Optional
[
str
]
=
None
base_gpu_id
:
int
=
0
gpu_id_step
:
int
=
1
# Logging
log_level
:
str
=
"info"
...
...
@@ -552,6 +553,12 @@ class ServerArgs:
default
=
ServerArgs
.
base_gpu_id
,
help
=
"The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine."
,
)
parser
.
add_argument
(
"--gpu-id-step"
,
type
=
int
,
default
=
ServerArgs
.
gpu_id_step
,
help
=
"The delta between consecutive GPU IDs that are used. For example, setting it to 2 will use GPU 0,2,4,..."
,
)
# Logging
parser
.
add_argument
(
...
...
@@ -957,6 +964,7 @@ class ServerArgs:
and
(
self
.
lora_paths
is
None
or
self
.
disable_radix_cache
)
),
"compatibility of lora and cuda graph and radix attention is in progress"
assert
self
.
base_gpu_id
>=
0
,
"base_gpu_id must be non-negative"
assert
self
.
gpu_id_step
>=
1
,
"gpu_id_step must be positive"
if
isinstance
(
self
.
lora_paths
,
list
):
lora_paths
=
self
.
lora_paths
...
...
python/sglang/srt/utils.py
View file @
e3e0bc50
...
...
@@ -1386,7 +1386,6 @@ def get_ip() -> str:
def
get_open_port
()
->
int
:
port
=
os
.
getenv
(
"SGLANG_PORT"
)
if
port
is
not
None
:
while
True
:
...
...
python/sglang/test/runners.py
View file @
e3e0bc50
...
...
@@ -21,9 +21,9 @@ import torch
import
torch.nn.functional
as
F
from
transformers
import
AutoModelForCausalLM
from
sglang.srt.entrypoints.engine
import
Engine
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.test.test_utils
import
DEFAULT_PORT_FOR_SRT_TEST_RUNNER
from
sglang.srt.server
import
Engine
from
sglang.test.test_utils
import
DEFAULT_PORT_FOR_SRT_TEST_RUNNER
,
calculate_rouge_l
DEFAULT_PROMPTS
=
[
"Apple is red. Banana is Yellow. "
*
800
+
"Apple is"
,
...
...
@@ -95,9 +95,11 @@ class HFRunner:
torch_dtype
:
torch
.
dtype
,
model_type
:
str
=
"generation"
,
output_str_only
:
bool
=
False
,
trust_remote_code
:
bool
=
False
,
):
self
.
model_type
=
model_type
self
.
output_str_only
=
output_str_only
self
.
trust_remote_code
=
trust_remote_code
self
.
in_queue
=
mp
.
Queue
()
self
.
out_queue
=
mp
.
Queue
()
...
...
@@ -130,7 +132,7 @@ class HFRunner:
self
.
base_model
=
AutoModelForCausalLM
.
from_pretrained
(
model_path
,
torch_dtype
=
torch_dtype
,
trust_remote_code
=
Fals
e
,
trust_remote_code
=
self
.
trust_remote_cod
e
,
low_cpu_mem_usage
=
True
,
).
cuda
()
elif
self
.
model_type
==
"embedding"
:
...
...
@@ -147,7 +149,11 @@ class HFRunner:
).
cuda
()
else
:
raise
Exception
(
f
"Unrecognized model type
{
self
.
model_type
}
"
)
self
.
tokenizer
=
get_tokenizer
(
model_path
,
torch_dtype
=
torch
.
dtype
)
self
.
tokenizer
=
get_tokenizer
(
model_path
,
torch_dtype
=
torch
.
dtype
,
trust_remote_code
=
self
.
trust_remote_code
,
)
# Run forward
while
True
:
...
...
@@ -157,74 +163,15 @@ class HFRunner:
if
prompts
is
not
None
:
if
self
.
model_type
==
"generation"
:
output_strs
=
[]
top_input_logprobs
=
[]
top_output_logprobs
=
[]
for
i
,
p
in
enumerate
(
prompts
):
if
isinstance
(
p
,
str
):
input_ids
=
self
.
tokenizer
.
encode
(
p
,
return_tensors
=
"pt"
).
cuda
()
else
:
input_ids
=
torch
.
tensor
([
p
],
device
=
"cuda"
)
if
lora_paths
is
not
None
and
lora_paths
[
i
]
is
not
None
:
from
peft
import
PeftModel
self
.
model
=
PeftModel
.
from_pretrained
(
self
.
base_model
,
lora_paths
[
i
],
torch_dtype
=
torch_dtype
,
is_trainable
=
False
,
)
else
:
self
.
model
=
self
.
base_model
outputs
=
self
.
model
.
generate
(
input_ids
,
do_sample
=
False
,
temperature
=
None
,
top_p
=
None
,
max_new_tokens
=
max_new_tokens
,
return_dict_in_generate
=
True
,
output_scores
=
(
not
self
.
output_str_only
),
)
text
=
self
.
tokenizer
.
decode
(
outputs
[
0
][
0
][
len
(
input_ids
[
0
])
:],
skip_special_tokens
=
True
)
# Check if the text is empty or only whitespace.
if
not
text
.
strip
():
raise
ValueError
(
"Received an empty text response. Please verify your input or model configuration."
)
output_strs
.
append
(
text
)
if
not
self
.
output_str_only
:
# outputs.scores: (num_token, 1, vocab_size)
top_output_logprobs
.
append
(
[
get_top_logprobs
(
logits
[
0
],
NUM_TOP_LOGPROBS
).
tolist
()
for
logits
in
outputs
.
scores
]
)
del
outputs
input_logits
=
self
.
model
.
forward
(
input_ids
).
logits
[
0
]
top_input_logprobs
.
append
(
get_top_logprobs
(
input_logits
,
NUM_TOP_LOGPROBS
).
tolist
()
)
del
input_logits
out_queue
.
put
(
ModelOutput
(
output_strs
=
output_strs
,
top_input_logprobs
=
top_input_logprobs
,
top_output_logprobs
=
top_output_logprobs
,
self
.
forward_generation_raw
(
prompts
=
prompts
,
max_new_tokens
=
max_new_tokens
,
base_model
=
self
.
base_model
,
tokenizer
=
self
.
tokenizer
,
lora_paths
=
lora_paths
,
torch_dtype
=
torch_dtype
,
output_str_only
=
self
.
output_str_only
,
)
)
...
...
@@ -269,6 +216,79 @@ class HFRunner:
self
.
model_proc
.
terminate
()
self
.
in_queue
=
self
.
out_queue
=
None
@
staticmethod
def
forward_generation_raw
(
prompts
:
Union
[
List
[
str
],
List
[
torch
.
Tensor
]],
max_new_tokens
,
base_model
,
tokenizer
,
lora_paths
,
torch_dtype
:
torch
.
dtype
,
output_str_only
:
bool
,
)
->
ModelOutput
:
output_strs
=
[]
top_input_logprobs
=
[]
top_output_logprobs
=
[]
for
i
,
p
in
enumerate
(
prompts
):
if
isinstance
(
p
,
str
):
input_ids
=
tokenizer
.
encode
(
p
,
return_tensors
=
"pt"
).
cuda
()
else
:
input_ids
=
torch
.
tensor
([
p
],
device
=
"cuda"
)
if
lora_paths
is
not
None
and
lora_paths
[
i
]
is
not
None
:
from
peft
import
PeftModel
model
=
PeftModel
.
from_pretrained
(
base_model
,
lora_paths
[
i
],
torch_dtype
=
torch_dtype
,
is_trainable
=
False
,
)
else
:
model
=
base_model
outputs
=
model
.
generate
(
input_ids
,
do_sample
=
False
,
temperature
=
None
,
top_p
=
None
,
max_new_tokens
=
max_new_tokens
,
return_dict_in_generate
=
True
,
output_scores
=
(
not
output_str_only
),
)
text
=
tokenizer
.
decode
(
outputs
[
0
][
0
][
len
(
input_ids
[
0
])
:],
skip_special_tokens
=
True
)
# Check if the text is empty or only whitespace.
if
not
text
.
strip
():
raise
ValueError
(
"Received an empty text response. Please verify your input or model configuration."
)
output_strs
.
append
(
text
)
if
not
output_str_only
:
# outputs.scores: (num_token, 1, vocab_size)
top_output_logprobs
.
append
(
[
get_top_logprobs
(
logits
[
0
],
NUM_TOP_LOGPROBS
).
tolist
()
for
logits
in
outputs
.
scores
]
)
del
outputs
input_logits
=
model
.
forward
(
input_ids
).
logits
[
0
]
top_input_logprobs
.
append
(
get_top_logprobs
(
input_logits
,
NUM_TOP_LOGPROBS
).
tolist
()
)
del
input_logits
return
ModelOutput
(
output_strs
=
output_strs
,
top_input_logprobs
=
top_input_logprobs
,
top_output_logprobs
=
top_output_logprobs
,
)
class
SRTRunner
:
def
__init__
(
...
...
@@ -284,6 +304,7 @@ class SRTRunner:
disable_cuda_graph
:
bool
=
False
,
disable_radix_cache
:
bool
=
False
,
mem_fraction_static
:
float
=
0.65
,
trust_remote_code
:
bool
=
False
,
):
self
.
model_type
=
model_type
self
.
is_generation
=
model_type
==
"generation"
...
...
@@ -293,7 +314,7 @@ class SRTRunner:
dtype
=
get_dtype_str
(
torch_dtype
),
port
=
port
,
mem_fraction_static
=
mem_fraction_static
,
trust_remote_code
=
Fals
e
,
trust_remote_code
=
trust_remote_cod
e
,
is_embedding
=
not
self
.
is_generation
,
lora_paths
=
lora_paths
,
max_loras_per_batch
=
max_loras_per_batch
,
...
...
@@ -301,7 +322,7 @@ class SRTRunner:
disable_cuda_graph
=
disable_cuda_graph
,
disable_radix_cache
=
disable_radix_cache
,
)
self
.
tokenizer
=
get_tokenizer
(
model_path
)
self
.
tokenizer
=
get_tokenizer
(
model_path
,
trust_remote_code
=
trust_remote_code
)
def
forward
(
self
,
...
...
@@ -310,54 +331,11 @@ class SRTRunner:
lora_paths
=
None
,
):
if
self
.
is_generation
:
# the return value contains logprobs from prefill
output_strs
=
[]
top_input_logprobs
=
[]
top_output_logprobs
=
[]
sampling_params
=
{
"max_new_tokens"
:
max_new_tokens
,
"temperature"
:
0
}
for
i
,
prompt
in
enumerate
(
prompts
):
response
=
self
.
engine
.
generate
(
prompt
,
lora_path
=
lora_paths
[
i
]
if
lora_paths
else
None
,
sampling_params
=
sampling_params
,
return_logprob
=
True
,
logprob_start_len
=
0
,
top_logprobs_num
=
NUM_TOP_LOGPROBS
,
)
text
=
response
[
"text"
]
# Check if the text is empty or only whitespace.
if
not
text
.
strip
():
raise
ValueError
(
"Received an empty text response. Please verify your input or model configuration."
)
output_strs
.
append
(
text
)
top_input_logprobs
.
append
(
[
[
tup
[
0
]
for
tup
in
x
[:
NUM_TOP_LOGPROBS
]]
for
x
in
response
[
"meta_info"
][
"input_top_logprobs"
][
1
:]
]
+
[
[
tup
[
0
]
for
tup
in
response
[
"meta_info"
][
"output_top_logprobs"
][
0
][
:
NUM_TOP_LOGPROBS
]
]
]
)
top_output_logprobs
.
append
(
[
[
tup
[
0
]
for
tup
in
x
[:
NUM_TOP_LOGPROBS
]]
for
x
in
response
[
"meta_info"
][
"output_top_logprobs"
]
]
)
return
ModelOutput
(
output_strs
=
output_strs
,
top_input_logprobs
=
top_input_logprobs
,
top_output_logprobs
=
top_output_logprobs
,
return
self
.
forward_generation_raw
(
prompts
=
prompts
,
max_new_tokens
=
max_new_tokens
,
lora_paths
=
lora_paths
,
engine
=
self
.
engine
,
)
else
:
response
=
self
.
engine
.
encode
(
prompts
)
...
...
@@ -379,18 +357,11 @@ class SRTRunner:
only return output strings and no logprobs
"""
if
self
.
is_generation
:
# the return value contains logprobs from prefill
output_strs
=
[]
sampling_params
=
{
"max_new_tokens"
:
max_new_tokens
,
"temperature"
:
0
}
response
=
self
.
engine
.
generate
(
prompts
,
lora_path
=
lora_paths
if
lora_paths
else
None
,
sampling_params
=
sampling_params
,
)
output_strs
=
[
r
[
"text"
]
for
r
in
response
]
return
ModelOutput
(
output_strs
=
output_strs
,
return
self
.
batch_forward_generation_raw
(
prompts
=
prompts
,
max_new_tokens
=
max_new_tokens
,
lora_paths
=
lora_paths
,
engine
=
self
.
engine
,
)
else
:
response
=
self
.
engine
.
encode
(
prompts
)
...
...
@@ -408,6 +379,84 @@ class SRTRunner:
self
.
engine
.
shutdown
()
del
self
.
engine
@
staticmethod
def
forward_generation_raw
(
prompts
:
Union
[
List
[
str
],
List
[
torch
.
Tensor
]],
max_new_tokens
,
lora_paths
,
engine
,
):
# the return value contains logprobs from prefill
output_strs
=
[]
top_input_logprobs
=
[]
top_output_logprobs
=
[]
sampling_params
=
{
"max_new_tokens"
:
max_new_tokens
,
"temperature"
:
0
}
for
i
,
prompt
in
enumerate
(
prompts
):
response
=
engine
.
generate
(
prompt
,
lora_path
=
lora_paths
[
i
]
if
lora_paths
else
None
,
sampling_params
=
sampling_params
,
return_logprob
=
True
,
logprob_start_len
=
0
,
top_logprobs_num
=
NUM_TOP_LOGPROBS
,
)
text
=
response
[
"text"
]
# Check if the text is empty or only whitespace.
if
not
text
.
strip
():
raise
ValueError
(
"Received an empty text response. Please verify your input or model configuration."
)
output_strs
.
append
(
text
)
top_input_logprobs
.
append
(
[
[
tup
[
0
]
for
tup
in
x
[:
NUM_TOP_LOGPROBS
]]
for
x
in
response
[
"meta_info"
][
"input_top_logprobs"
][
1
:]
]
+
[
[
tup
[
0
]
for
tup
in
response
[
"meta_info"
][
"output_top_logprobs"
][
0
][
:
NUM_TOP_LOGPROBS
]
]
]
)
top_output_logprobs
.
append
(
[
[
tup
[
0
]
for
tup
in
x
[:
NUM_TOP_LOGPROBS
]]
for
x
in
response
[
"meta_info"
][
"output_top_logprobs"
]
]
)
return
ModelOutput
(
output_strs
=
output_strs
,
top_input_logprobs
=
top_input_logprobs
,
top_output_logprobs
=
top_output_logprobs
,
)
@
staticmethod
def
batch_forward_generation_raw
(
prompts
:
Union
[
List
[
str
],
List
[
torch
.
Tensor
]],
max_new_tokens
,
lora_paths
,
engine
,
):
# the return value contains logprobs from prefill
output_strs
=
[]
sampling_params
=
{
"max_new_tokens"
:
max_new_tokens
,
"temperature"
:
0
}
response
=
engine
.
generate
(
prompts
,
lora_path
=
lora_paths
if
lora_paths
else
None
,
sampling_params
=
sampling_params
,
)
output_strs
=
[
r
[
"text"
]
for
r
in
response
]
return
ModelOutput
(
output_strs
=
output_strs
,
)
def
monkey_patch_gemma2_sdpa
():
"""
...
...
@@ -422,3 +471,52 @@ def monkey_patch_gemma2_sdpa():
return
config
setattr
(
Gemma2PreTrainedModel
,
"_check_and_enable_sdpa"
,
_check_and_enable_sdpa
)
def
check_close_model_outputs
(
hf_outputs
:
ModelOutput
,
srt_outputs
:
ModelOutput
,
prefill_tolerance
:
float
,
decode_tolerance
:
float
,
rouge_l_tolerance
:
float
,
debug_text
:
str
=
""
,
check_logprobs
:
bool
=
True
,
):
# Compare output strings
print
(
f
"
{
hf_outputs
.
output_strs
=
}
"
)
print
(
f
"
{
srt_outputs
.
output_strs
=
}
"
)
rouge_l_scores
=
calculate_rouge_l
(
hf_outputs
.
output_strs
,
srt_outputs
.
output_strs
)
print
(
f
"
{
rouge_l_scores
=
}
"
)
assert
all
(
score
>=
rouge_l_tolerance
for
score
in
rouge_l_scores
),
f
"Not all ROUGE-L scores are greater than rouge_l_tolerance=
{
rouge_l_tolerance
}
"
if
check_logprobs
:
for
i
in
range
(
len
(
hf_outputs
.
output_strs
)):
# Compare input logprobs
hf_logprobs
=
torch
.
Tensor
(
hf_outputs
.
top_input_logprobs
[
i
])
srt_logprobs
=
torch
.
Tensor
(
srt_outputs
.
top_input_logprobs
[
i
])
input_len
=
hf_logprobs
.
shape
[
0
]
print
(
"prefill logprobs max_diff"
,
torch
.
max
(
abs
(
hf_logprobs
-
srt_logprobs
))
)
if
input_len
<=
100
:
assert
torch
.
all
(
abs
(
hf_logprobs
-
srt_logprobs
)
<
prefill_tolerance
),
(
f
"prefill logprobs are not all close with
{
debug_text
}
"
f
"prefill_tolerance=
{
prefill_tolerance
}
."
f
"
{
hf_logprobs
=
}
,
{
srt_logprobs
=
}
"
)
# Compare output logprobs
hf_logprobs
=
torch
.
Tensor
(
hf_outputs
.
top_output_logprobs
[
i
])
srt_logprobs
=
torch
.
Tensor
(
srt_outputs
.
top_output_logprobs
[
i
])
print
(
"decode logprobs max_diff"
,
torch
.
max
(
abs
(
hf_logprobs
-
srt_logprobs
))
)
if
input_len
<=
100
:
assert
torch
.
all
(
abs
(
hf_logprobs
-
srt_logprobs
)
<
decode_tolerance
),
(
f
"decode logprobs are not all close with
{
debug_text
}
"
f
"decode_tolerance=
{
decode_tolerance
}
."
f
"
{
hf_logprobs
=
}
,
{
srt_logprobs
=
}
"
)
python/sglang/test/test_programs.py
View file @
e3e0bc50
...
...
@@ -536,7 +536,7 @@ def test_hellaswag_select():
# Compute accuracy
accuracy_gen
=
np
.
mean
(
np
.
array
(
preds_gen
)
==
np
.
array
(
labels
))
print
(
f
"
{
accuracy
=
}
,
{
accuracy_gen
=
}
"
)
assert
np
.
abs
(
accuracy_gen
-
accuracy
)
<
0.
05
assert
np
.
abs
(
accuracy_gen
-
accuracy
)
<
0.
1
assert
np
.
abs
(
latency_gen
-
latency
)
<
1
return
accuracy
,
latency
...
...
test/lang/test_srt_backend.py
View file @
e3e0bc50
...
...
@@ -74,7 +74,7 @@ class TestSRTBackend(unittest.TestCase):
# Run twice to capture more bugs
for
_
in
range
(
2
):
accuracy
,
latency
=
test_hellaswag_select
()
self
.
assertGreater
(
accuracy
,
0.6
9
)
self
.
assertGreater
(
accuracy
,
0.6
5
)
def
test_gen_min_new_tokens
(
self
):
test_gen_min_new_tokens
()
...
...
test/srt/models/test_generation_models.py
View file @
e3e0bc50
...
...
@@ -27,8 +27,13 @@ from typing import List
import
torch
from
sglang.test.runners
import
DEFAULT_PROMPTS
,
HFRunner
,
SRTRunner
from
sglang.test.test_utils
import
calculate_rouge_l
,
is_in_ci
from
sglang.test.runners
import
(
DEFAULT_PROMPTS
,
HFRunner
,
SRTRunner
,
check_close_model_outputs
,
)
from
sglang.test.test_utils
import
is_in_ci
@
dataclasses
.
dataclass
...
...
@@ -39,6 +44,7 @@ class ModelCase:
decode_tolerance
:
float
=
5e-2
rouge_l_tolerance
:
float
=
1
skip_long_prompt
:
bool
=
False
trust_remote_code
:
bool
=
False
# Popular models that run on the CI
...
...
@@ -53,7 +59,9 @@ ALL_OTHER_MODELS = [
ModelCase
(
"Qwen/Qwen2.5-14B-Instruct"
),
ModelCase
(
"HuggingFaceTB/SmolLM-135M-Instruct"
,
skip_long_prompt
=
True
),
ModelCase
(
"allenai/OLMo-1B-0724-hf"
,
decode_tolerance
=
8e-2
,
skip_long_prompt
=
True
),
ModelCase
(
"THUDM/glm-4-9b-chat"
),
ModelCase
(
"THUDM/glm-4-9b-chat"
,
tp_size
=
2
,
trust_remote_code
=
True
,
skip_long_prompt
=
True
),
ModelCase
(
"openai-community/gpt2"
),
ModelCase
(
"microsoft/Phi-3-small-8k-instruct"
),
ModelCase
(
"allenai/OLMo-2-1124-7B-Instruct"
,
skip_long_prompt
=
True
),
...
...
@@ -87,6 +95,7 @@ class TestGenerationModels(unittest.TestCase):
model_path
,
torch_dtype
=
torch_dtype
,
model_type
=
"generation"
,
trust_remote_code
=
model_case
.
trust_remote_code
,
)
as
hf_runner
:
hf_outputs
=
hf_runner
.
forward
(
prompts
,
max_new_tokens
=
max_new_tokens
)
...
...
@@ -95,48 +104,18 @@ class TestGenerationModels(unittest.TestCase):
tp_size
=
model_case
.
tp_size
,
torch_dtype
=
torch_dtype
,
model_type
=
"generation"
,
trust_remote_code
=
model_case
.
trust_remote_code
,
)
as
srt_runner
:
srt_outputs
=
srt_runner
.
forward
(
prompts
,
max_new_tokens
=
max_new_tokens
)
for
i
in
range
(
len
(
prompts
)):
# Compare input logprobs
hf_logprobs
=
torch
.
Tensor
(
hf_outputs
.
top_input_logprobs
[
i
])
srt_logprobs
=
torch
.
Tensor
(
srt_outputs
.
top_input_logprobs
[
i
])
input_len
=
hf_logprobs
.
shape
[
0
]
print
(
"prefill logprobs max_diff"
,
torch
.
max
(
abs
(
hf_logprobs
-
srt_logprobs
))
)
if
input_len
<=
100
:
assert
torch
.
all
(
abs
(
hf_logprobs
-
srt_logprobs
)
<
prefill_tolerance
),
(
f
"prefill logprobs are not all close with model_path=
{
model_path
}
prompts=
{
prompts
}
"
f
"prefill_tolerance=
{
prefill_tolerance
}
."
f
"
{
hf_logprobs
=
}
,
{
srt_logprobs
=
}
"
)
# Compare output logprobs
hf_logprobs
=
torch
.
Tensor
(
hf_outputs
.
top_output_logprobs
[
i
])
srt_logprobs
=
torch
.
Tensor
(
srt_outputs
.
top_output_logprobs
[
i
])
print
(
"decode logprobs max_diff"
,
torch
.
max
(
abs
(
hf_logprobs
-
srt_logprobs
))
)
if
input_len
<=
100
:
assert
torch
.
all
(
abs
(
hf_logprobs
-
srt_logprobs
)
<
decode_tolerance
),
(
f
"decode logprobs are not all close with model_path=
{
model_path
}
prompts=
{
prompts
}
"
f
"decode_tolerance=
{
decode_tolerance
}
."
f
"
{
hf_logprobs
=
}
,
{
srt_logprobs
=
}
"
)
# Compare output strings
print
(
f
"
{
hf_outputs
.
output_strs
=
}
"
)
print
(
f
"
{
srt_outputs
.
output_strs
=
}
"
)
rouge_l_scores
=
calculate_rouge_l
(
hf_outputs
.
output_strs
,
srt_outputs
.
output_strs
check_close_model_outputs
(
hf_outputs
=
hf_outputs
,
srt_outputs
=
srt_outputs
,
prefill_tolerance
=
model_case
.
prefill_tolerance
,
decode_tolerance
=
model_case
.
decode_tolerance
,
rouge_l_tolerance
=
model_case
.
rouge_l_tolerance
,
debug_text
=
f
"model_path=
{
model_path
}
prompts=
{
prompts
}
"
,
)
print
(
f
"
{
rouge_l_scores
=
}
"
)
assert
all
(
score
>=
rouge_l_tolerance
for
score
in
rouge_l_scores
),
f
"Not all ROUGE-L scores are greater than rouge_l_tolerance=
{
rouge_l_tolerance
}
"
def
test_ci_models
(
self
):
for
model_case
in
CI_MODELS
:
...
...
test/srt/test_update_weights_from_tensor.py
View file @
e3e0bc50
...
...
@@ -26,6 +26,34 @@ class TestUpdateWeightsFromTensor(unittest.TestCase):
engine
.
shutdown
()
def
test_update_weights_from_tensor_load_format_direct
(
self
):
engine
=
sgl
.
Engine
(
model_path
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
)
write_param_names
=
[
f
"model.layers.
{
i
}
.self_attn.qkv_proj.weight"
for
i
in
range
(
6
,
16
)
]
read_param_names
=
[
f
"model.layers.
{
i
}
.self_attn.k_proj.weight"
for
i
in
range
(
6
,
16
)
]
_check_param
(
engine
,
read_param_names
[
0
],
[
-
0.0198
,
0.0227
,
0.0168
,
0.0232
,
-
0.0178
]
)
new_tensor
=
torch
.
full
((
3072
,
2048
),
1.5
)
engine
.
update_weights_from_tensor
(
[
(
write_param_name
,
new_tensor
.
clone
())
for
write_param_name
in
write_param_names
],
load_format
=
"direct"
,
)
for
read_param_name
in
read_param_names
[:
3
]:
_check_param
(
engine
,
read_param_name
,
[
1.5
]
*
5
)
engine
.
shutdown
()
def
_check_param
(
engine
,
param_name
,
expect_values
):
actual_values
=
torch
.
tensor
(
engine
.
get_weights_by_name
(
param_name
))[
0
,
:
5
]
...
...
test/srt/test_verl_engine.py
0 → 100644
View file @
e3e0bc50
import
multiprocessing
import
multiprocessing
as
mp
import
os
import
random
import
traceback
import
unittest
from
multiprocessing
import
Process
import
torch
from
torch.distributed.device_mesh
import
init_device_mesh
from
torch.distributed.fsdp
import
CPUOffload
from
torch.distributed.fsdp
import
FullyShardedDataParallel
as
FSDP
from
torch.distributed.fsdp
import
MixedPrecision
from
torch.distributed.fsdp.api
import
(
ShardedStateDictConfig
,
ShardingStrategy
,
StateDictType
,
)
from
transformers
import
AutoModelForCausalLM
from
sglang.srt.entrypoints.verl_engine
import
VerlEngine
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.utils
import
is_port_available
from
sglang.test.runners
import
(
HFRunner
,
SRTRunner
,
check_close_model_outputs
,
get_dtype_str
,
)
from
sglang.test.test_utils
import
is_in_ci
_MAX_NEW_TOKENS
=
8
_PROMPTS
=
[
"1+1=2, 1+2=3, 1+3=4, 1+4=5, 1+5="
,
"1*1=1, 1*2=2, 1*3=3, 1*4=4, 1*5="
]
_TORCH_DTYPE
=
torch
.
float16
# Set to false to temporarily debug issues unrelated to weight update
_ENABLE_UPDATE_WEIGHTS
=
True
# _ENABLE_UPDATE_WEIGHTS = False
# TODO maybe we should add more other models? should we keep it in sync with test_generation_models.py?
CI_MODELS
=
[
dict
(
model_path
=
"meta-llama/Llama-3.1-8B-Instruct"
),
dict
(
model_path
=
"google/gemma-2-2b"
),
]
ALL_OTHER_MODELS
=
[
dict
(
model_path
=
"meta-llama/Llama-3.2-1B-Instruct"
),
dict
(
model_path
=
"Qwen/Qwen2-1.5B"
),
dict
(
model_path
=
"Qwen/Qwen2.5-14B-Instruct"
,
mem_fraction_static
=
0.4
,
tp_size
=
8
,
tight_memory
=
True
,
decode_tolerance
=
1.3
,
),
# test_generation_models.py same config (qwen + tp=8) gives 1.22 decode error
dict
(
model_path
=
"HuggingFaceTB/SmolLM-135M-Instruct"
,
tp_size
=
3
),
dict
(
model_path
=
"allenai/OLMo-1B-0724-hf"
),
dict
(
model_path
=
"THUDM/glm-4-9b-chat"
,
mem_fraction_static
=
0.1
,
tp_size
=
8
,
tight_memory
=
True
,
),
dict
(
model_path
=
"allenai/OLMo-2-1124-7B-Instruct"
),
dict
(
model_path
=
"ibm-granite/granite-3.0-2b-instruct"
,
prefill_tolerance
=
0.22
,
decode_tolerance
=
0.22
,
),
# Fail to run these models in test_generation_models.py, need to fix that first
# dict(model_path="openai-community/gpt2"),
# dict(model_path="microsoft/Phi-3-small-8k-instruct"),
]
class
TestVerlEngine
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
multiprocessing
.
set_start_method
(
"spawn"
)
def
assert_fragment_e2e_execution
(
self
,
index
:
int
,
model_path
:
str
,
mem_fraction_static
:
float
=
0.4
,
tp_size
:
int
=
2
,
tight_memory
:
bool
=
False
,
prefill_tolerance
:
float
=
0.1
,
decode_tolerance
:
float
=
0.1
,
):
master_port
=
find_available_port
(
23456
)
print
(
f
"assert_fragment_e2e_execution START
{
index
=
}
{
model_path
=
}
"
)
processes
=
[]
output_reader
,
output_writer
=
mp
.
Pipe
(
duplex
=
False
)
for
tp_rank
in
range
(
tp_size
):
p
=
Process
(
target
=
_run_subprocess
,
kwargs
=
dict
(
tp_rank
=
tp_rank
,
tp_size
=
tp_size
,
master_port
=
master_port
,
output_writer
=
output_writer
,
model_path
=
model_path
,
mem_fraction_static
=
mem_fraction_static
,
tight_memory
=
tight_memory
,
prefill_tolerance
=
prefill_tolerance
,
decode_tolerance
=
decode_tolerance
,
),
)
p
.
start
()
processes
.
append
(
p
)
for
_
in
range
(
tp_size
):
self
.
assertTrue
(
output_reader
.
recv
(),
f
"Subprocess has error, please see logs above. (
{
index
=
}
{
model_path
=
}
)"
,
)
for
p
in
processes
:
p
.
join
()
def
test_ci_models
(
self
):
for
index
,
model_info
in
enumerate
(
CI_MODELS
):
self
.
assert_fragment_e2e_execution
(
index
=
index
,
**
model_info
)
def
test_others
(
self
):
if
is_in_ci
():
return
for
index
,
model_info
in
enumerate
(
ALL_OTHER_MODELS
):
self
.
assert_fragment_e2e_execution
(
index
=
index
,
**
model_info
)
# def test_adhoc(self):
# self.assert_fragment_e2e_execution(index=0, model_path="meta-llama/Llama-3.2-1B-Instruct")
def
_run_subprocess
(
tp_rank
:
int
,
tp_size
:
int
,
master_port
:
int
,
output_writer
,
model_path
:
str
,
mem_fraction_static
:
float
,
tight_memory
:
bool
,
prefill_tolerance
:
float
,
decode_tolerance
:
float
,
):
try
:
print
(
f
"subprocess[
{
tp_rank
=
}
] Start
{
os
.
environ
.
get
(
'CUDA_VISIBLE_DEVICES'
)
=
}
"
)
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_PORT"
]
=
str
(
master_port
)
torch
.
distributed
.
init_process_group
(
rank
=
tp_rank
,
world_size
=
tp_size
)
torch
.
cuda
.
set_device
(
tp_rank
)
mesh_kwargs
=
dict
(
mesh_shape
=
(
tp_size
,
1
),
mesh_dim_names
=
[
"tp"
,
"pp"
])
inference_device_mesh_device
=
init_device_mesh
(
"cuda"
,
**
mesh_kwargs
)
inference_device_mesh_cpu
=
init_device_mesh
(
"cpu"
,
**
mesh_kwargs
)
print
(
f
"subprocess[
{
tp_rank
=
}
]
{
inference_device_mesh_device
=
}
{
inference_device_mesh_cpu
=
}
"
)
# hf model is used for comparison
hf_model
=
AutoModelForCausalLM
.
from_pretrained
(
model_path
,
torch_dtype
=
_TORCH_DTYPE
,
trust_remote_code
=
True
).
cuda
()
hf_tokenizer
=
get_tokenizer
(
model_path
,
trust_remote_code
=
True
)
hf_outputs
=
HFRunner
.
forward_generation_raw
(
prompts
=
_PROMPTS
,
max_new_tokens
=
_MAX_NEW_TOKENS
,
base_model
=
hf_model
,
tokenizer
=
hf_tokenizer
,
lora_paths
=
None
,
torch_dtype
=
_TORCH_DTYPE
,
output_str_only
=
False
,
)
print
(
f
"subprocess[
{
tp_rank
=
}
] call hf.forward
{
hf_outputs
=
}
"
,
flush
=
True
,
)
if
_ENABLE_UPDATE_WEIGHTS
:
if
tight_memory
:
hf_model
.
cpu
()
torch
.
cuda
.
empty_cache
()
# test update weights
print
(
f
"subprocess[
{
tp_rank
=
}
] get_fsdp_state_dict"
,
flush
=
True
)
fsdp_state_dict
=
_get_fsdp_state_dict
(
hf_model
=
hf_model
,
tp_size
=
tp_size
)
engine
=
VerlEngine
(
model_path
=
model_path
,
load_format
=
"dummy"
if
_ENABLE_UPDATE_WEIGHTS
else
"auto"
,
mem_fraction_static
=
mem_fraction_static
,
random_seed
=
42
,
trust_remote_code
=
True
,
dtype
=
get_dtype_str
(
_TORCH_DTYPE
),
device_mesh_cpu
=
inference_device_mesh_cpu
[
"tp"
],
)
print
(
f
"subprocess[
{
tp_rank
=
}
]
{
engine
=
}
"
,
flush
=
True
)
if
_ENABLE_UPDATE_WEIGHTS
:
print
(
f
"subprocess[
{
tp_rank
=
}
] call update_weights_from_tensor"
,
flush
=
True
)
engine
.
update_weights_from_tensor
(
[(
k
,
v
)
for
k
,
v
in
fsdp_state_dict
.
items
()]
)
for
enable_batch
in
[
False
,
True
]:
if
enable_batch
:
fn
=
SRTRunner
.
batch_forward_generation_raw
else
:
fn
=
SRTRunner
.
forward_generation_raw
srt_outputs
=
fn
(
prompts
=
_PROMPTS
,
max_new_tokens
=
_MAX_NEW_TOKENS
,
lora_paths
=
None
,
engine
=
engine
,
)
print
(
f
"subprocess[
{
tp_rank
=
}
] call srt.forward
{
enable_batch
=
}
{
srt_outputs
=
}
"
,
flush
=
True
,
)
check_close_model_outputs
(
hf_outputs
=
hf_outputs
,
srt_outputs
=
srt_outputs
,
prefill_tolerance
=
prefill_tolerance
,
decode_tolerance
=
decode_tolerance
,
rouge_l_tolerance
=
1
,
check_logprobs
=
not
enable_batch
,
debug_text
=
f
"
{
enable_batch
=
}
{
tp_rank
=
}
"
,
)
execution_ok
=
True
except
Exception
as
e
:
print
(
f
"subprocess[
{
tp_rank
=
}
] has error:
{
e
}
"
,
flush
=
True
)
traceback
.
print_exc
()
execution_ok
=
False
output_writer
.
send
(
execution_ok
)
output_writer
.
close
()
engine
.
shutdown
()
print
(
f
"subprocess[
{
tp_rank
=
}
] end"
,
flush
=
True
)
# Adapted from https://github.com/volcengine/verl/blob/main/tests/rollout/run_fsdp_vllm.py
def
_get_fsdp_state_dict
(
hf_model
,
tp_size
:
int
):
device_mesh
=
init_device_mesh
(
"cuda"
,
mesh_shape
=
(
tp_size
,),
mesh_dim_names
=
[
"fsdp"
]
)
mixed_precision
=
MixedPrecision
(
param_dtype
=
torch
.
bfloat16
,
reduce_dtype
=
torch
.
float32
,
buffer_dtype
=
torch
.
float32
,
)
fsdp_model
=
FSDP
(
hf_model
,
use_orig_params
=
True
,
auto_wrap_policy
=
None
,
device_id
=
torch
.
cuda
.
current_device
(),
sharding_strategy
=
ShardingStrategy
.
FULL_SHARD
,
mixed_precision
=
mixed_precision
,
cpu_offload
=
CPUOffload
(
offload_params
=
False
),
sync_module_states
=
False
,
device_mesh
=
device_mesh
,
)
print
(
f
"
{
fsdp_model
=
}
"
)
FSDP
.
set_state_dict_type
(
fsdp_model
,
state_dict_type
=
StateDictType
.
SHARDED_STATE_DICT
,
state_dict_config
=
ShardedStateDictConfig
(),
)
return
fsdp_model
.
state_dict
()
# TODO Ask: this is extracted from PortArgs.init_new, is it allowed to extract it, i.e. touch that old code
def
find_available_port
(
base_port
:
int
):
port
=
base_port
+
random
.
randint
(
100
,
1000
)
while
True
:
if
is_port_available
(
port
):
return
port
if
port
<
60000
:
port
+=
42
else
:
port
-=
43
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment