Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
EasyR1
Commits
2369eb2b
Commit
2369eb2b
authored
Apr 21, 2025
by
chenych
Browse files
update
parent
ac9d2b05
Changes
43
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
53 additions
and
46 deletions
+53
-46
verl/workers/rollout/config.py
verl/workers/rollout/config.py
+4
-2
verl/workers/rollout/vllm_rollout_spmd.py
verl/workers/rollout/vllm_rollout_spmd.py
+30
-25
verl/workers/sharding_manager/fsdp_vllm.py
verl/workers/sharding_manager/fsdp_vllm.py
+19
-19
No files found.
verl/workers/rollout/config.py
View file @
2369eb2b
...
@@ -16,7 +16,7 @@ Rollout config
...
@@ -16,7 +16,7 @@ Rollout config
"""
"""
from
dataclasses
import
asdict
,
dataclass
,
field
from
dataclasses
import
asdict
,
dataclass
,
field
from
typing
import
Any
,
Dict
from
typing
import
Any
,
Dict
,
Optional
@
dataclass
@
dataclass
...
@@ -26,6 +26,7 @@ class RolloutConfig:
...
@@ -26,6 +26,7 @@ class RolloutConfig:
temperature
:
float
=
1.0
temperature
:
float
=
1.0
top_p
:
float
=
1.0
top_p
:
float
=
1.0
top_k
:
int
=
-
1
top_k
:
int
=
-
1
seed
:
int
=
1
limit_images
:
int
=
0
limit_images
:
int
=
0
dtype
:
str
=
"bf16"
dtype
:
str
=
"bf16"
gpu_memory_utilization
:
float
=
0.6
gpu_memory_utilization
:
float
=
0.6
...
@@ -33,13 +34,14 @@ class RolloutConfig:
...
@@ -33,13 +34,14 @@ class RolloutConfig:
enforce_eager
:
bool
=
False
enforce_eager
:
bool
=
False
enable_chunked_prefill
:
bool
=
False
# only for v0 engine
enable_chunked_prefill
:
bool
=
False
# only for v0 engine
tensor_parallel_size
:
int
=
2
tensor_parallel_size
:
int
=
2
max_model_len
:
Optional
[
int
]
=
None
max_num_batched_tokens
:
int
=
8192
max_num_batched_tokens
:
int
=
8192
max_num_seqs
:
int
=
1024
disable_log_stats
:
bool
=
True
disable_log_stats
:
bool
=
True
val_override_config
:
Dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
val_override_config
:
Dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
"""auto keys"""
"""auto keys"""
prompt_length
:
int
=
field
(
default
=-
1
,
init
=
False
)
prompt_length
:
int
=
field
(
default
=-
1
,
init
=
False
)
response_length
:
int
=
field
(
default
=-
1
,
init
=
False
)
response_length
:
int
=
field
(
default
=-
1
,
init
=
False
)
trust_remote_code
:
bool
=
field
(
default
=
False
,
init
=
False
)
def
to_dict
(
self
):
def
to_dict
(
self
):
return
asdict
(
self
)
return
asdict
(
self
)
verl/workers/rollout/vllm_rollout_spmd.py
View file @
2369eb2b
...
@@ -11,16 +11,10 @@
...
@@ -11,16 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""
The vllm_rollout that can be applied in different backend
When working with FSDP:
- Use DTensor weight loader (recommended) or HF weight loader
- Utilize state_dict from the FSDP to synchronize the weights among tp ranks in vLLM
"""
import
os
import
os
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
Any
,
List
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -31,6 +25,7 @@ from vllm import LLM, RequestOutput, SamplingParams
...
@@ -31,6 +25,7 @@ from vllm import LLM, RequestOutput, SamplingParams
from
...protocol
import
DataProto
from
...protocol
import
DataProto
from
...utils
import
torch_functional
as
VF
from
...utils
import
torch_functional
as
VF
from
...utils.tokenizer
import
get_processor
from
...utils.torch_dtypes
import
PrecisionType
from
...utils.torch_dtypes
import
PrecisionType
from
.base
import
BaseRollout
from
.base
import
BaseRollout
from
.config
import
RolloutConfig
from
.config
import
RolloutConfig
...
@@ -43,6 +38,15 @@ def _repeat_interleave(value: Union[torch.Tensor, np.ndarray], repeats: int) ->
...
@@ -43,6 +38,15 @@ def _repeat_interleave(value: Union[torch.Tensor, np.ndarray], repeats: int) ->
return
np
.
repeat
(
value
,
repeats
,
axis
=
0
)
return
np
.
repeat
(
value
,
repeats
,
axis
=
0
)
def
_get_logit_bias
(
model_path
:
str
,
trust_remote_code
:
bool
)
->
Optional
[
Dict
[
int
,
float
]]:
processor
=
get_processor
(
model_path
,
trust_remote_code
=
trust_remote_code
)
if
processor
is
not
None
and
hasattr
(
processor
,
"image_token"
):
image_token_id
=
processor
.
tokenizer
.
convert_tokens_to_ids
(
processor
.
image_token
)
return
{
image_token_id
:
-
100
}
else
:
return
None
class
vLLMRollout
(
BaseRollout
):
class
vLLMRollout
(
BaseRollout
):
def
__init__
(
self
,
model_path
:
str
,
config
:
RolloutConfig
,
tokenizer
:
PreTrainedTokenizer
):
def
__init__
(
self
,
model_path
:
str
,
config
:
RolloutConfig
,
tokenizer
:
PreTrainedTokenizer
):
"""A vLLM rollout. It requires the module is supported by the vllm.
"""A vLLM rollout. It requires the module is supported by the vllm.
...
@@ -62,33 +66,38 @@ class vLLMRollout(BaseRollout):
...
@@ -62,33 +66,38 @@ class vLLMRollout(BaseRollout):
if
config
.
max_num_batched_tokens
<
config
.
prompt_length
+
config
.
response_length
:
if
config
.
max_num_batched_tokens
<
config
.
prompt_length
+
config
.
response_length
:
raise
ValueError
(
"max_num_batched_tokens should be greater than prompt_length + response_length."
)
raise
ValueError
(
"max_num_batched_tokens should be greater than prompt_length + response_length."
)
vllm_init_kwargs
=
{}
if
config
.
limit_images
>
0
:
vllm_init_kwargs
=
{
"limit_mm_per_prompt"
:
{
"image"
:
config
.
limit_images
}}
self
.
inference_engine
=
LLM
(
self
.
inference_engine
=
LLM
(
model
=
model_path
,
model
=
model_path
,
skip_tokenizer_init
=
False
,
skip_tokenizer_init
=
False
,
tensor_parallel_size
=
config
.
tensor_parallel_size
,
trust_remote_code
=
config
.
trust_remote_code
,
load_format
=
"dummy"
,
dtype
=
PrecisionType
.
to_str
(
PrecisionType
.
to_dtype
(
config
.
dtype
)),
dtype
=
PrecisionType
.
to_str
(
PrecisionType
.
to_dtype
(
config
.
dtype
)),
seed
=
config
.
seed
,
max_model_len
=
config
.
max_model_len
or
config
.
prompt_length
+
config
.
response_length
,
distributed_executor_backend
=
"external_launcher"
,
tensor_parallel_size
=
config
.
tensor_parallel_size
,
gpu_memory_utilization
=
config
.
gpu_memory_utilization
,
gpu_memory_utilization
=
config
.
gpu_memory_utilization
,
enforce_eager
=
config
.
enforce_eager
,
max_model_len
=
config
.
prompt_length
+
config
.
response_length
,
max_num_batched_tokens
=
config
.
max_num_batched_tokens
,
max_num_batched_tokens
=
config
.
max_num_batched_tokens
,
en
able_
sleep_mode
=
False
,
dis
able_
log_stats
=
config
.
disable_log_stats
,
distributed_executor_backend
=
"external_launch
er
"
,
enforce_eager
=
config
.
enforce_eag
er
,
disable_custom_all_reduce
=
True
,
disable_custom_all_reduce
=
True
,
limit_mm_per_prompt
=
{
"image"
:
config
.
limit_images
}
if
config
.
limit_images
>
0
else
None
,
disable_mm_preprocessor_cache
=
True
,
disable_mm_preprocessor_cache
=
True
,
disable_log_stats
=
config
.
disable_log_stats
,
enable_chunked_prefill
=
config
.
enable_chunked_prefill
,
enable_chunked_prefill
=
config
.
enable_chunked_prefill
,
seed
=
self
.
rank
//
config
.
tensor_parallel_siz
e
,
#
dp rank
enable_sleep_mode
=
Fals
e
,
#
nv True rocm False
**
vllm_init_kwargs
,
# swap_space=20
,
)
)
# Offload vllm model to reduce peak memory usage
# Offload vllm model to reduce peak memory usage
# self.inference_engine.sleep(level=1)
# self.inference_engine.sleep(level=1)
## TODO DCU 怎么释放显存
sampling_kwargs
=
{
"max_tokens"
:
config
.
response_length
,
"detokenize"
:
False
}
# self.inference_engine.offload_model_weights()
sampling_kwargs
=
{
"max_tokens"
:
config
.
response_length
,
"detokenize"
:
False
,
"logit_bias"
:
_get_logit_bias
(
model_path
,
trust_remote_code
=
config
.
trust_remote_code
),
}
default_sampling_params
=
SamplingParams
()
default_sampling_params
=
SamplingParams
()
for
key
in
config
.
to_dict
().
keys
():
for
key
in
config
.
to_dict
().
keys
():
if
hasattr
(
default_sampling_params
,
key
):
if
hasattr
(
default_sampling_params
,
key
):
...
@@ -152,10 +161,6 @@ class vLLMRollout(BaseRollout):
...
@@ -152,10 +161,6 @@ class vLLMRollout(BaseRollout):
input_ids
=
_repeat_interleave
(
input_ids
,
self
.
sampling_params
.
n
)
input_ids
=
_repeat_interleave
(
input_ids
,
self
.
sampling_params
.
n
)
attention_mask
=
_repeat_interleave
(
attention_mask
,
self
.
sampling_params
.
n
)
attention_mask
=
_repeat_interleave
(
attention_mask
,
self
.
sampling_params
.
n
)
position_ids
=
_repeat_interleave
(
position_ids
,
self
.
sampling_params
.
n
)
position_ids
=
_repeat_interleave
(
position_ids
,
self
.
sampling_params
.
n
)
if
"multi_modal_inputs"
in
non_tensor_batch
.
keys
():
non_tensor_batch
[
"multi_modal_inputs"
]
=
_repeat_interleave
(
non_tensor_batch
[
"multi_modal_inputs"
],
self
.
sampling_params
.
n
)
sequence_ids
=
torch
.
cat
([
input_ids
,
response_ids
],
dim
=-
1
)
sequence_ids
=
torch
.
cat
([
input_ids
,
response_ids
],
dim
=-
1
)
response_length
=
response_ids
.
size
(
1
)
response_length
=
response_ids
.
size
(
1
)
...
...
verl/workers/sharding_manager/fsdp_vllm.py
View file @
2369eb2b
...
@@ -12,14 +12,14 @@
...
@@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
warnings
import
inspect
from
typing
import
Dict
,
Iterable
,
Tuple
,
Union
from
typing
import
Dict
,
Iterable
,
Tuple
,
Union
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch.distributed._tensor
import
DTensor
from
torch.distributed._tensor
import
DTensor
from
torch.distributed.checkpoint.state_dict
import
get_model_state_dict
from
torch.distributed.device_mesh
import
DeviceMesh
from
torch.distributed.device_mesh
import
DeviceMesh
from
torch.distributed.fsdp.api
import
ShardedStateDictConfig
,
StateDictType
from
torch.distributed.fsdp.fully_sharded_data_parallel
import
FullyShardedDataParallel
as
FSDP
from
torch.distributed.fsdp.fully_sharded_data_parallel
import
FullyShardedDataParallel
as
FSDP
from
vllm
import
LLM
from
vllm
import
LLM
from
vllm.distributed
import
parallel_state
as
vllm_ps
from
vllm.distributed
import
parallel_state
as
vllm_ps
...
@@ -34,18 +34,11 @@ class FSDPVLLMShardingManager(BaseShardingManager):
...
@@ -34,18 +34,11 @@ class FSDPVLLMShardingManager(BaseShardingManager):
self
,
self
,
module
:
FSDP
,
module
:
FSDP
,
inference_engine
:
LLM
,
inference_engine
:
LLM
,
device_mesh
:
DeviceMesh
=
None
,
device_mesh
:
DeviceMesh
,
):
):
self
.
module
=
module
self
.
module
=
module
self
.
inference_engine
=
inference_engine
self
.
inference_engine
=
inference_engine
self
.
device_mesh
=
device_mesh
self
.
device_mesh
=
device_mesh
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
"ignore"
)
FSDP
.
set_state_dict_type
(
self
.
module
,
state_dict_type
=
StateDictType
.
SHARDED_STATE_DICT
,
state_dict_config
=
ShardedStateDictConfig
(),
)
self
.
world_size
=
dist
.
get_world_size
()
self
.
world_size
=
dist
.
get_world_size
()
self
.
tp_size
=
vllm_ps
.
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
vllm_ps
.
get_tensor_model_parallel_world_size
()
...
@@ -59,13 +52,10 @@ class FSDPVLLMShardingManager(BaseShardingManager):
...
@@ -59,13 +52,10 @@ class FSDPVLLMShardingManager(BaseShardingManager):
# Note that torch_random_states may be different on each dp rank
# Note that torch_random_states may be different on each dp rank
self
.
torch_random_states
=
torch
.
cuda
.
get_rng_state
()
self
.
torch_random_states
=
torch
.
cuda
.
get_rng_state
()
# get a random rng states
# get a random rng states
if
self
.
device_mesh
is
not
None
:
gen_dp_rank
=
self
.
device_mesh
[
"dp"
].
get_local_rank
()
gen_dp_rank
=
self
.
device_mesh
[
"dp"
].
get_local_rank
()
torch
.
cuda
.
manual_seed
(
gen_dp_rank
+
1000
)
# make sure all tp ranks have the same random states
torch
.
cuda
.
manual_seed
(
gen_dp_rank
+
1000
)
# make sure all tp ranks have the same random states
self
.
gen_random_states
=
torch
.
cuda
.
get_rng_state
()
self
.
gen_random_states
=
torch
.
cuda
.
get_rng_state
()
torch
.
cuda
.
set_rng_state
(
self
.
torch_random_states
)
torch
.
cuda
.
set_rng_state
(
self
.
torch_random_states
)
else
:
self
.
gen_random_states
=
None
def
_make_weight_iterator
(
def
_make_weight_iterator
(
self
,
actor_weights
:
Dict
[
str
,
Union
[
torch
.
Tensor
,
DTensor
]]
self
,
actor_weights
:
Dict
[
str
,
Union
[
torch
.
Tensor
,
DTensor
]]
...
@@ -83,16 +73,24 @@ class FSDPVLLMShardingManager(BaseShardingManager):
...
@@ -83,16 +73,24 @@ class FSDPVLLMShardingManager(BaseShardingManager):
# vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103
# vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
print_gpu_memory_usage
(
"Before state_dict() in sharding manager"
)
print_gpu_memory_usage
(
"Before state_dict() in sharding manager"
)
actor_weights
=
self
.
module
.
state_dict
(
)
actor_weights
=
get_model_state_dict
(
self
.
module
)
print_gpu_memory_usage
(
"After state_dict() in sharding manager"
)
print_gpu_memory_usage
(
"After state_dict() in sharding manager"
)
if
"tags"
in
inspect
.
signature
(
self
.
inference_engine
.
wake_up
).
parameters
:
self
.
inference_engine
.
wake_up
(
tags
=
[
"weights"
])
else
:
self
.
inference_engine
.
wake_up
()
self
.
inference_engine
.
wake_up
()
model
=
self
.
inference_engine
.
llm_engine
.
model_executor
.
driver_worker
.
worker
.
model_runner
.
model
model
=
self
.
inference_engine
.
llm_engine
.
model_executor
.
driver_worker
.
worker
.
model_runner
.
model
model
.
load_weights
(
self
.
_make_weight_iterator
(
actor_weights
))
model
.
load_weights
(
self
.
_make_weight_iterator
(
actor_weights
))
print_gpu_memory_usage
(
"After sync model weights in sharding manager"
)
print_gpu_memory_usage
(
"After sync model weights in sharding manager"
)
del
actor_weights
del
actor_weights
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
if
"tags"
in
inspect
.
signature
(
self
.
inference_engine
.
wake_up
).
parameters
:
self
.
inference_engine
.
wake_up
(
tags
=
[
"kv_cache"
])
print_gpu_memory_usage
(
"After del state_dict and empty_cache in sharding manager"
)
print_gpu_memory_usage
(
"After del state_dict and empty_cache in sharding manager"
)
# important: need to manually set the random states of each tp to be identical.
# important: need to manually set the random states of each tp to be identical.
if
self
.
device_mesh
is
not
None
:
if
self
.
device_mesh
is
not
None
:
...
@@ -103,6 +101,8 @@ class FSDPVLLMShardingManager(BaseShardingManager):
...
@@ -103,6 +101,8 @@ class FSDPVLLMShardingManager(BaseShardingManager):
print_gpu_memory_usage
(
"Before vllm offload in sharding manager"
)
print_gpu_memory_usage
(
"Before vllm offload in sharding manager"
)
free_bytes_before_sleep
=
torch
.
cuda
.
mem_get_info
()[
0
]
free_bytes_before_sleep
=
torch
.
cuda
.
mem_get_info
()[
0
]
# self.inference_engine.sleep(level=1)
# self.inference_engine.sleep(level=1)
## rocm
# self.inference_engine.offload_model_weights()
free_bytes_after_sleep
=
torch
.
cuda
.
mem_get_info
()[
0
]
free_bytes_after_sleep
=
torch
.
cuda
.
mem_get_info
()[
0
]
self
.
freed_bytes
=
free_bytes_after_sleep
-
free_bytes_before_sleep
self
.
freed_bytes
=
free_bytes_after_sleep
-
free_bytes_before_sleep
print_gpu_memory_usage
(
"After vllm offload in sharding manager"
)
print_gpu_memory_usage
(
"After vllm offload in sharding manager"
)
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment