Commit 54e084f7 authored by hoshi-hiyouga's avatar hoshi-hiyouga Committed by simon-mo
Browse files

[Bugfix] torchrun compatibility (#14899)


Signed-off-by: default avatarhiyouga <hiyouga@buaa.edu.cn>
Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
Co-authored-by: default avataryoukaichao <youkaichao@gmail.com>
parent 9e8f089d
......@@ -904,7 +904,9 @@ class ModelConfig:
else:
total_num_hidden_layers = getattr(self.hf_text_config,
"num_hidden_layers", 0)
pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size
# the layout order is: DP x PP x TP
pp_rank = (parallel_config.rank // parallel_config.tensor_parallel_size
) % parallel_config.pipeline_parallel_size
pp_size = parallel_config.pipeline_parallel_size
start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
return start, end
......
......@@ -897,9 +897,22 @@ def initialize_model_parallel(
get_world_group().device_group)
data_parallel_size = 1
has_external_dp = False
from vllm.config import get_current_vllm_config
config = get_current_vllm_config()
if config is not None:
if config.parallel_config.world_size != world_size:
# detect external data parallelism.
# dp in vllm means all dp instances need to run together.
# if the world size does not match, it means this dp is external,
# and the dp instances can run independently, e.g. in rlhf workflow
# from https://github.com/volcengine/verl .
# in that case, we treat the rest dimensions as if they are
# data parallel, and create a dummy dp group that is not used.
data_parallel_size = world_size // (pipeline_model_parallel_size *
tensor_model_parallel_size)
has_external_dp = True
else:
data_parallel_size = config.parallel_config.data_parallel_size
# the layout order is: DP x PP x TP
......@@ -940,6 +953,12 @@ def initialize_model_parallel(
2).reshape(-1,
data_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
if has_external_dp:
# create a dummy dp group that is not used actually,
# since this dp is external.
# a dummy dp group means every rank is a group itself.
# this way, no communication is needed, no memory is wasted.
group_ranks = [[x] for x in range(world_size)]
_DP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment