Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2103b806
"vscode:/vscode.git/clone" did not exist on "06501db2e891bdb13c3504a9661b7ab0b517115f"
Unverified
Commit
2103b806
authored
May 28, 2025
by
Junrong Lin
Committed by
GitHub
May 27, 2025
Browse files
[CI] update verlengine ci to 4-gpu test (#6007)
parent
e806f708
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
323 additions
and
42 deletions
+323
-42
.github/workflows/pr-test.yml
.github/workflows/pr-test.yml
+1
-1
test/srt/run_suite.py
test/srt/run_suite.py
+2
-1
test/srt/test_verl_engine_2_gpu.py
test/srt/test_verl_engine_2_gpu.py
+276
-0
test/srt/test_verl_engine_4_gpu.py
test/srt/test_verl_engine_4_gpu.py
+44
-40
No files found.
.github/workflows/pr-test.yml
View file @
2103b806
...
@@ -103,7 +103,7 @@ jobs:
...
@@ -103,7 +103,7 @@ jobs:
bash scripts/ci_install_dependency.sh
bash scripts/ci_install_dependency.sh
-
name
:
Run test
-
name
:
Run test
timeout-minutes
:
2
0
timeout-minutes
:
3
0
run
:
|
run
:
|
cd test/srt
cd test/srt
python3 run_suite.py --suite per-commit-4-gpu
python3 run_suite.py --suite per-commit-4-gpu
...
...
test/srt/run_suite.py
View file @
2103b806
...
@@ -101,7 +101,7 @@ suites = {
...
@@ -101,7 +101,7 @@ suites = {
TestFile
(
"test_moe_ep.py"
,
181
),
TestFile
(
"test_moe_ep.py"
,
181
),
TestFile
(
"test_patch_torch.py"
,
19
),
TestFile
(
"test_patch_torch.py"
,
19
),
TestFile
(
"test_update_weights_from_distributed.py"
,
103
),
TestFile
(
"test_update_weights_from_distributed.py"
,
103
),
TestFile
(
"test_verl_engine.py"
,
64
),
TestFile
(
"test_verl_engine
_2_gpu
.py"
,
64
),
],
],
"per-commit-2-gpu-amd"
:
[
"per-commit-2-gpu-amd"
:
[
TestFile
(
"test_mla_tp.py"
,
170
),
TestFile
(
"test_mla_tp.py"
,
170
),
...
@@ -109,6 +109,7 @@ suites = {
...
@@ -109,6 +109,7 @@ suites = {
"per-commit-4-gpu"
:
[
"per-commit-4-gpu"
:
[
TestFile
(
"test_local_attn.py"
,
250
),
TestFile
(
"test_local_attn.py"
,
250
),
TestFile
(
"test_pp_single_node.py"
,
150
),
TestFile
(
"test_pp_single_node.py"
,
150
),
TestFile
(
"test_verl_engine_4_gpu.py"
,
64
),
],
],
"per-commit-8-gpu"
:
[
"per-commit-8-gpu"
:
[
# Disabled deepep tests temporarily because it takes too much time.
# Disabled deepep tests temporarily because it takes too much time.
...
...
test/srt/test_verl_engine_2_gpu.py
0 → 100644
View file @
2103b806
import
multiprocessing
import
multiprocessing
as
mp
import
os
import
random
import
traceback
import
unittest
from
multiprocessing
import
Process
import
torch
from
torch.distributed.device_mesh
import
init_device_mesh
from
torch.distributed.fsdp
import
CPUOffload
from
torch.distributed.fsdp
import
FullyShardedDataParallel
as
FSDP
from
torch.distributed.fsdp
import
MixedPrecision
from
torch.distributed.fsdp.api
import
(
ShardedStateDictConfig
,
ShardingStrategy
,
StateDictType
,
)
from
transformers
import
AutoModelForCausalLM
from
sglang.srt.entrypoints.verl_engine
import
VerlEngine
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.utils
import
is_port_available
from
sglang.test.runners
import
(
HFRunner
,
SRTRunner
,
check_close_model_outputs
,
get_dtype_str
,
)
from
sglang.test.test_utils
import
CustomTestCase
,
find_available_port
,
is_in_ci
_MAX_NEW_TOKENS
=
8
_PROMPTS
=
[
"1+1=2, 1+2=3, 1+3=4, 1+4=5, 1+5="
,
"1*1=1, 1*2=2, 1*3=3, 1*4=4, 1*5="
]
_TORCH_DTYPE
=
torch
.
float16
# Set to false to temporarily debug issues unrelated to weight update
_ENABLE_UPDATE_WEIGHTS
=
True
# _ENABLE_UPDATE_WEIGHTS = False
# TODO maybe we should add more other models? should we keep it in sync with test_generation_models.py?
ALL_MODELS
=
[
dict
(
model_path
=
"meta-llama/Llama-3.2-1B-Instruct"
),
dict
(
model_path
=
"Qwen/Qwen2-1.5B"
),
dict
(
model_path
=
"allenai/OLMo-1B-0724-hf"
),
dict
(
model_path
=
"allenai/OLMo-2-1124-7B-Instruct"
),
dict
(
model_path
=
"ibm-granite/granite-3.0-2b-instruct"
,
prefill_tolerance
=
0.22
,
decode_tolerance
=
0.22
,
),
]
class
TestVerlEngine
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
multiprocessing
.
set_start_method
(
"spawn"
)
def
assert_fragment_e2e_execution
(
self
,
index
:
int
,
model_path
:
str
,
mem_fraction_static
:
float
=
0.4
,
dp_size
:
int
=
1
,
tp_size
:
int
=
2
,
tight_memory
:
bool
=
False
,
prefill_tolerance
:
float
=
0.1
,
decode_tolerance
:
float
=
0.1
,
):
master_port
=
find_available_port
(
23456
)
print
(
f
"assert_fragment_e2e_execution START
{
index
=
}
{
model_path
=
}
"
)
processes
=
[]
output_reader
,
output_writer
=
mp
.
Pipe
(
duplex
=
False
)
world_size
=
dp_size
*
tp_size
for
rank
in
range
(
world_size
):
p
=
Process
(
target
=
_run_subprocess
,
kwargs
=
dict
(
rank
=
rank
,
dp_size
=
dp_size
,
tp_size
=
tp_size
,
master_port
=
master_port
,
output_writer
=
output_writer
,
model_path
=
model_path
,
mem_fraction_static
=
mem_fraction_static
,
tight_memory
=
tight_memory
,
prefill_tolerance
=
prefill_tolerance
,
decode_tolerance
=
decode_tolerance
,
),
)
p
.
start
()
processes
.
append
(
p
)
for
_
in
range
(
tp_size
):
self
.
assertTrue
(
output_reader
.
recv
(),
f
"Subprocess has error, please see logs above. (
{
index
=
}
{
model_path
=
}
)"
,
)
for
p
in
processes
:
p
.
join
()
def
test_ci_models
(
self
):
ci_models
=
[
random
.
choice
(
ALL_MODELS
)]
for
index
,
model_info
in
enumerate
(
ci_models
):
self
.
assert_fragment_e2e_execution
(
index
=
index
,
**
model_info
)
def
test_others
(
self
):
if
is_in_ci
():
return
for
index
,
model_info
in
enumerate
(
ALL_MODELS
):
self
.
assert_fragment_e2e_execution
(
index
=
index
,
**
model_info
)
# def test_adhoc(self):
# self.assert_fragment_e2e_execution(index=0, model_path="meta-llama/Llama-3.2-1B-Instruct")
def
_run_subprocess
(
rank
:
int
,
dp_size
:
int
,
tp_size
:
int
,
master_port
:
int
,
output_writer
,
model_path
:
str
,
mem_fraction_static
:
float
,
tight_memory
:
bool
,
prefill_tolerance
:
float
,
decode_tolerance
:
float
,
):
try
:
print
(
f
"subprocess[
{
rank
=
}
] Start
{
os
.
environ
.
get
(
'CUDA_VISIBLE_DEVICES'
)
=
}
"
)
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_PORT"
]
=
str
(
master_port
)
torch
.
distributed
.
init_process_group
(
rank
=
rank
,
world_size
=
dp_size
*
tp_size
)
torch
.
cuda
.
set_device
(
rank
)
base_gpu_id
=
rank
//
tp_size
*
tp_size
mesh_kwargs
=
dict
(
mesh_shape
=
(
dp_size
,
tp_size
,
1
),
mesh_dim_names
=
[
"dp"
,
"tp"
,
"pp"
]
)
inference_device_mesh_device
=
init_device_mesh
(
"cuda"
,
**
mesh_kwargs
)
inference_device_mesh_cpu
=
init_device_mesh
(
"cpu"
,
**
mesh_kwargs
)
print
(
f
"subprocess[
{
rank
=
}
,
{
base_gpu_id
=
}
]
{
inference_device_mesh_device
=
}
{
inference_device_mesh_cpu
=
}
"
)
# hf model is used for comparison
hf_model
=
AutoModelForCausalLM
.
from_pretrained
(
model_path
,
torch_dtype
=
_TORCH_DTYPE
,
trust_remote_code
=
True
).
cuda
()
hf_tokenizer
=
get_tokenizer
(
model_path
,
trust_remote_code
=
True
)
hf_outputs
=
HFRunner
.
forward_generation_raw
(
base_model
=
hf_model
,
prompts
=
_PROMPTS
,
max_new_tokens
=
_MAX_NEW_TOKENS
,
tokenizer
=
hf_tokenizer
,
lora_paths
=
None
,
torch_dtype
=
_TORCH_DTYPE
,
output_str_only
=
False
,
)
print
(
f
"subprocess[
{
rank
=
}
] call hf.forward
{
hf_outputs
=
}
"
,
flush
=
True
,
)
if
_ENABLE_UPDATE_WEIGHTS
:
if
tight_memory
:
hf_model
.
cpu
()
torch
.
cuda
.
empty_cache
()
# test update weights
print
(
f
"subprocess[
{
rank
=
}
] get_fsdp_state_dict"
,
flush
=
True
)
fsdp_state_dict
=
_get_fsdp_state_dict
(
hf_model
=
hf_model
,
world_size
=
dp_size
*
tp_size
)
engine
=
VerlEngine
(
model_path
=
model_path
,
load_format
=
"dummy"
if
_ENABLE_UPDATE_WEIGHTS
else
"auto"
,
mem_fraction_static
=
mem_fraction_static
,
random_seed
=
42
,
base_gpu_id
=
base_gpu_id
,
trust_remote_code
=
True
,
dtype
=
get_dtype_str
(
_TORCH_DTYPE
),
device_mesh_cpu
=
inference_device_mesh_cpu
[
"tp"
],
)
print
(
f
"subprocess[
{
rank
=
}
]
{
engine
=
}
"
,
flush
=
True
)
if
_ENABLE_UPDATE_WEIGHTS
:
print
(
f
"subprocess[
{
rank
=
}
] call update_weights_from_tensor"
,
flush
=
True
)
engine
.
update_weights_from_tensor
(
[(
k
,
v
)
for
k
,
v
in
fsdp_state_dict
.
items
()]
)
for
enable_batch
in
[
False
,
True
]:
if
enable_batch
:
fn
=
SRTRunner
.
batch_forward_generation_raw
else
:
fn
=
SRTRunner
.
forward_generation_raw
srt_outputs
=
fn
(
prompts
=
_PROMPTS
,
max_new_tokens
=
_MAX_NEW_TOKENS
,
lora_paths
=
None
,
engine
=
engine
,
)
print
(
f
"subprocess[
{
rank
=
}
] call srt.forward
{
enable_batch
=
}
{
srt_outputs
=
}
"
,
flush
=
True
,
)
check_close_model_outputs
(
hf_outputs
=
hf_outputs
,
srt_outputs
=
srt_outputs
,
prefill_tolerance
=
prefill_tolerance
,
decode_tolerance
=
decode_tolerance
,
rouge_l_tolerance
=
1
,
check_logprobs
=
not
enable_batch
,
debug_text
=
f
"
{
enable_batch
=
}
{
rank
=
}
"
,
)
execution_ok
=
True
except
Exception
as
e
:
print
(
f
"subprocess[
{
rank
=
}
] has error:
{
e
}
"
,
flush
=
True
)
traceback
.
print_exc
()
execution_ok
=
False
output_writer
.
send
(
execution_ok
)
output_writer
.
close
()
engine
.
shutdown
()
print
(
f
"subprocess[
{
rank
=
}
] end"
,
flush
=
True
)
# Adapted from https://github.com/volcengine/verl/blob/main/tests/rollout/run_fsdp_vllm.py
def
_get_fsdp_state_dict
(
hf_model
,
world_size
:
int
):
device_mesh
=
init_device_mesh
(
"cuda"
,
mesh_shape
=
(
world_size
,),
mesh_dim_names
=
[
"fsdp"
]
)
mixed_precision
=
MixedPrecision
(
param_dtype
=
torch
.
bfloat16
,
reduce_dtype
=
torch
.
float32
,
buffer_dtype
=
torch
.
float32
,
)
fsdp_model
=
FSDP
(
hf_model
,
use_orig_params
=
True
,
auto_wrap_policy
=
None
,
device_id
=
torch
.
cuda
.
current_device
(),
sharding_strategy
=
ShardingStrategy
.
FULL_SHARD
,
mixed_precision
=
mixed_precision
,
cpu_offload
=
CPUOffload
(
offload_params
=
False
),
sync_module_states
=
False
,
device_mesh
=
device_mesh
,
)
print
(
f
"
{
fsdp_model
=
}
"
)
FSDP
.
set_state_dict_type
(
fsdp_model
,
state_dict_type
=
StateDictType
.
SHARDED_STATE_DICT
,
state_dict_config
=
ShardedStateDictConfig
(),
)
return
fsdp_model
.
state_dict
()
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/test_verl_engine.py
→
test/srt/test_verl_engine
_4_gpu
.py
View file @
2103b806
...
@@ -38,35 +38,27 @@ _ENABLE_UPDATE_WEIGHTS = True
...
@@ -38,35 +38,27 @@ _ENABLE_UPDATE_WEIGHTS = True
# _ENABLE_UPDATE_WEIGHTS = False
# _ENABLE_UPDATE_WEIGHTS = False
# TODO maybe we should add more other models? should we keep it in sync with test_generation_models.py?
# TODO maybe we should add more other models? should we keep it in sync with test_generation_models.py?
CI_MODELS
=
[
ALL_MODELS
=
[
dict
(
model_path
=
"meta-llama/Llama-3.1-8B-Instruct"
),
dict
(
# Fail to run gemma-2-2b after transformers==4.48.3 -> 4.50.0
model_path
=
"Qwen/Qwen2.5-0.5B"
,
# dict(model_path="google/gemma-2-2b"),
dp_size
=
2
,
]
tp_size
=
2
,
# default to 2
ALL_OTHER_MODELS
=
[
),
dict
(
model_path
=
"meta-llama/Llama-3.2-1B-Instruct"
),
dict
(
model_path
=
"Qwen/Qwen2-1.5B"
),
dict
(
dict
(
model_path
=
"Qwen/Qwen2.5-14B-Instruct"
,
model_path
=
"Qwen/Qwen2.5-14B-Instruct"
,
mem_fraction_static
=
0.4
,
mem_fraction_static
=
0.7
,
tp_size
=
8
,
dp_size
=
2
,
tp_size
=
2
,
tight_memory
=
True
,
tight_memory
=
True
,
decode_tolerance
=
1.3
,
decode_tolerance
=
1.3
,
),
# test_generation_models.py same config (qwen + tp=8) gives 1.22 decode error
),
# test_generation_models.py same config (qwen + tp=8) gives 1.22 decode error
dict
(
model_path
=
"HuggingFaceTB/SmolLM-135M-Instruct"
,
tp_size
=
3
),
dict
(
model_path
=
"allenai/OLMo-1B-0724-hf"
),
dict
(
dict
(
model_path
=
"THUDM/glm-4-9b-chat"
,
model_path
=
"THUDM/glm-4-9b-chat"
,
mem_fraction_static
=
0.1
,
mem_fraction_static
=
0.5
,
tp_size
=
8
,
dp_size
=
2
,
tp_size
=
2
,
tight_memory
=
True
,
tight_memory
=
True
,
),
),
dict
(
model_path
=
"allenai/OLMo-2-1124-7B-Instruct"
),
dict
(
model_path
=
"ibm-granite/granite-3.0-2b-instruct"
,
prefill_tolerance
=
0.22
,
decode_tolerance
=
0.22
,
),
# Fail to run these models in test_generation_models.py, need to fix that first
# Fail to run these models in test_generation_models.py, need to fix that first
# dict(model_path="openai-community/gpt2"),
# dict(model_path="openai-community/gpt2"),
# dict(model_path="microsoft/Phi-3-small-8k-instruct"),
# dict(model_path="microsoft/Phi-3-small-8k-instruct"),
...
@@ -83,6 +75,7 @@ class TestVerlEngine(CustomTestCase):
...
@@ -83,6 +75,7 @@ class TestVerlEngine(CustomTestCase):
index
:
int
,
index
:
int
,
model_path
:
str
,
model_path
:
str
,
mem_fraction_static
:
float
=
0.4
,
mem_fraction_static
:
float
=
0.4
,
dp_size
:
int
=
1
,
tp_size
:
int
=
2
,
tp_size
:
int
=
2
,
tight_memory
:
bool
=
False
,
tight_memory
:
bool
=
False
,
prefill_tolerance
:
float
=
0.1
,
prefill_tolerance
:
float
=
0.1
,
...
@@ -94,11 +87,13 @@ class TestVerlEngine(CustomTestCase):
...
@@ -94,11 +87,13 @@ class TestVerlEngine(CustomTestCase):
processes
=
[]
processes
=
[]
output_reader
,
output_writer
=
mp
.
Pipe
(
duplex
=
False
)
output_reader
,
output_writer
=
mp
.
Pipe
(
duplex
=
False
)
for
tp_rank
in
range
(
tp_size
):
world_size
=
dp_size
*
tp_size
for
rank
in
range
(
world_size
):
p
=
Process
(
p
=
Process
(
target
=
_run_subprocess
,
target
=
_run_subprocess
,
kwargs
=
dict
(
kwargs
=
dict
(
tp_rank
=
tp_rank
,
rank
=
rank
,
dp_size
=
dp_size
,
tp_size
=
tp_size
,
tp_size
=
tp_size
,
master_port
=
master_port
,
master_port
=
master_port
,
output_writer
=
output_writer
,
output_writer
=
output_writer
,
...
@@ -122,7 +117,8 @@ class TestVerlEngine(CustomTestCase):
...
@@ -122,7 +117,8 @@ class TestVerlEngine(CustomTestCase):
p
.
join
()
p
.
join
()
def
test_ci_models
(
self
):
def
test_ci_models
(
self
):
for
index
,
model_info
in
enumerate
(
CI_MODELS
):
ci_models
=
[
random
.
choice
(
ALL_MODELS
)]
for
index
,
model_info
in
enumerate
(
ci_models
):
self
.
assert_fragment_e2e_execution
(
index
=
index
,
**
model_info
)
self
.
assert_fragment_e2e_execution
(
index
=
index
,
**
model_info
)
def
test_others
(
self
):
def
test_others
(
self
):
...
@@ -137,7 +133,8 @@ class TestVerlEngine(CustomTestCase):
...
@@ -137,7 +133,8 @@ class TestVerlEngine(CustomTestCase):
def
_run_subprocess
(
def
_run_subprocess
(
tp_rank
:
int
,
rank
:
int
,
dp_size
:
int
,
tp_size
:
int
,
tp_size
:
int
,
master_port
:
int
,
master_port
:
int
,
output_writer
,
output_writer
,
...
@@ -148,18 +145,22 @@ def _run_subprocess(
...
@@ -148,18 +145,22 @@ def _run_subprocess(
decode_tolerance
:
float
,
decode_tolerance
:
float
,
):
):
try
:
try
:
print
(
f
"subprocess[
{
tp_
rank
=
}
] Start
{
os
.
environ
.
get
(
'CUDA_VISIBLE_DEVICES'
)
=
}
"
)
print
(
f
"subprocess[
{
rank
=
}
] Start
{
os
.
environ
.
get
(
'CUDA_VISIBLE_DEVICES'
)
=
}
"
)
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_PORT"
]
=
str
(
master_port
)
os
.
environ
[
"MASTER_PORT"
]
=
str
(
master_port
)
torch
.
distributed
.
init_process_group
(
rank
=
tp_rank
,
world_size
=
tp_size
)
torch
.
distributed
.
init_process_group
(
rank
=
rank
,
world_size
=
dp_size
*
tp_size
)
torch
.
cuda
.
set_device
(
tp_rank
)
torch
.
cuda
.
set_device
(
rank
)
base_gpu_id
=
rank
//
tp_size
*
tp_size
mesh_kwargs
=
dict
(
mesh_shape
=
(
tp_size
,
1
),
mesh_dim_names
=
[
"tp"
,
"pp"
])
mesh_kwargs
=
dict
(
mesh_shape
=
(
dp_size
,
tp_size
,
1
),
mesh_dim_names
=
[
"dp"
,
"tp"
,
"pp"
]
)
inference_device_mesh_device
=
init_device_mesh
(
"cuda"
,
**
mesh_kwargs
)
inference_device_mesh_device
=
init_device_mesh
(
"cuda"
,
**
mesh_kwargs
)
inference_device_mesh_cpu
=
init_device_mesh
(
"cpu"
,
**
mesh_kwargs
)
inference_device_mesh_cpu
=
init_device_mesh
(
"cpu"
,
**
mesh_kwargs
)
print
(
print
(
f
"subprocess[
{
tp_
rank
=
}
]
{
inference_device_mesh_device
=
}
{
inference_device_mesh_cpu
=
}
"
f
"subprocess[
{
rank
=
}
,
{
base_gpu_id
=
}
]
{
inference_device_mesh_device
=
}
{
inference_device_mesh_cpu
=
}
"
)
)
# hf model is used for comparison
# hf model is used for comparison
...
@@ -178,7 +179,7 @@ def _run_subprocess(
...
@@ -178,7 +179,7 @@ def _run_subprocess(
output_str_only
=
False
,
output_str_only
=
False
,
)
)
print
(
print
(
f
"subprocess[
{
tp_
rank
=
}
] call hf.forward
{
hf_outputs
=
}
"
,
f
"subprocess[
{
rank
=
}
] call hf.forward
{
hf_outputs
=
}
"
,
flush
=
True
,
flush
=
True
,
)
)
...
@@ -188,22 +189,25 @@ def _run_subprocess(
...
@@ -188,22 +189,25 @@ def _run_subprocess(
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
# test update weights
# test update weights
print
(
f
"subprocess[
{
tp_rank
=
}
] get_fsdp_state_dict"
,
flush
=
True
)
print
(
f
"subprocess[
{
rank
=
}
] get_fsdp_state_dict"
,
flush
=
True
)
fsdp_state_dict
=
_get_fsdp_state_dict
(
hf_model
=
hf_model
,
tp_size
=
tp_size
)
fsdp_state_dict
=
_get_fsdp_state_dict
(
hf_model
=
hf_model
,
world_size
=
dp_size
*
tp_size
)
engine
=
VerlEngine
(
engine
=
VerlEngine
(
model_path
=
model_path
,
model_path
=
model_path
,
load_format
=
"dummy"
if
_ENABLE_UPDATE_WEIGHTS
else
"auto"
,
load_format
=
"dummy"
if
_ENABLE_UPDATE_WEIGHTS
else
"auto"
,
mem_fraction_static
=
mem_fraction_static
,
mem_fraction_static
=
mem_fraction_static
,
random_seed
=
42
,
random_seed
=
42
,
base_gpu_id
=
base_gpu_id
,
trust_remote_code
=
True
,
trust_remote_code
=
True
,
dtype
=
get_dtype_str
(
_TORCH_DTYPE
),
dtype
=
get_dtype_str
(
_TORCH_DTYPE
),
device_mesh_cpu
=
inference_device_mesh_cpu
[
"tp"
],
device_mesh_cpu
=
inference_device_mesh_cpu
[
"tp"
],
)
)
print
(
f
"subprocess[
{
tp_
rank
=
}
]
{
engine
=
}
"
,
flush
=
True
)
print
(
f
"subprocess[
{
rank
=
}
]
{
engine
=
}
"
,
flush
=
True
)
if
_ENABLE_UPDATE_WEIGHTS
:
if
_ENABLE_UPDATE_WEIGHTS
:
print
(
f
"subprocess[
{
tp_
rank
=
}
] call update_weights_from_tensor"
,
flush
=
True
)
print
(
f
"subprocess[
{
rank
=
}
] call update_weights_from_tensor"
,
flush
=
True
)
engine
.
update_weights_from_tensor
(
engine
.
update_weights_from_tensor
(
[(
k
,
v
)
for
k
,
v
in
fsdp_state_dict
.
items
()]
[(
k
,
v
)
for
k
,
v
in
fsdp_state_dict
.
items
()]
)
)
...
@@ -221,7 +225,7 @@ def _run_subprocess(
...
@@ -221,7 +225,7 @@ def _run_subprocess(
engine
=
engine
,
engine
=
engine
,
)
)
print
(
print
(
f
"subprocess[
{
tp_
rank
=
}
] call srt.forward
{
enable_batch
=
}
{
srt_outputs
=
}
"
,
f
"subprocess[
{
rank
=
}
] call srt.forward
{
enable_batch
=
}
{
srt_outputs
=
}
"
,
flush
=
True
,
flush
=
True
,
)
)
...
@@ -232,13 +236,13 @@ def _run_subprocess(
...
@@ -232,13 +236,13 @@ def _run_subprocess(
decode_tolerance
=
decode_tolerance
,
decode_tolerance
=
decode_tolerance
,
rouge_l_tolerance
=
1
,
rouge_l_tolerance
=
1
,
check_logprobs
=
not
enable_batch
,
check_logprobs
=
not
enable_batch
,
debug_text
=
f
"
{
enable_batch
=
}
{
tp_
rank
=
}
"
,
debug_text
=
f
"
{
enable_batch
=
}
{
rank
=
}
"
,
)
)
execution_ok
=
True
execution_ok
=
True
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"subprocess[
{
tp_
rank
=
}
] has error:
{
e
}
"
,
flush
=
True
)
print
(
f
"subprocess[
{
rank
=
}
] has error:
{
e
}
"
,
flush
=
True
)
traceback
.
print_exc
()
traceback
.
print_exc
()
execution_ok
=
False
execution_ok
=
False
...
@@ -246,13 +250,13 @@ def _run_subprocess(
...
@@ -246,13 +250,13 @@ def _run_subprocess(
output_writer
.
close
()
output_writer
.
close
()
engine
.
shutdown
()
engine
.
shutdown
()
print
(
f
"subprocess[
{
tp_
rank
=
}
] end"
,
flush
=
True
)
print
(
f
"subprocess[
{
rank
=
}
] end"
,
flush
=
True
)
# Adapted from https://github.com/volcengine/verl/blob/main/tests/rollout/run_fsdp_vllm.py
# Adapted from https://github.com/volcengine/verl/blob/main/tests/rollout/run_fsdp_vllm.py
def
_get_fsdp_state_dict
(
hf_model
,
tp
_size
:
int
):
def
_get_fsdp_state_dict
(
hf_model
,
world
_size
:
int
):
device_mesh
=
init_device_mesh
(
device_mesh
=
init_device_mesh
(
"cuda"
,
mesh_shape
=
(
tp
_size
,),
mesh_dim_names
=
[
"fsdp"
]
"cuda"
,
mesh_shape
=
(
world
_size
,),
mesh_dim_names
=
[
"fsdp"
]
)
)
mixed_precision
=
MixedPrecision
(
mixed_precision
=
MixedPrecision
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment