Unverified Commit 433c4a49 authored by ZSL98's avatar ZSL98 Committed by GitHub
Browse files

Make vllm compatible with verl (#12824)


Co-authored-by: default avatarzhangshulai <zhangshulai@bytedance.com>
parent ef533d25
...@@ -1024,13 +1024,6 @@ def initialize_model_parallel( ...@@ -1024,13 +1024,6 @@ def initialize_model_parallel(
backend = backend or torch.distributed.get_backend( backend = backend or torch.distributed.get_backend(
get_world_group().device_group) get_world_group().device_group)
if (world_size
!= tensor_model_parallel_size * pipeline_model_parallel_size):
raise RuntimeError(
f"world_size ({world_size}) is not equal to "
f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")
# Build the tensor model-parallel groups. # Build the tensor model-parallel groups.
num_tensor_model_parallel_groups: int = (world_size // num_tensor_model_parallel_groups: int = (world_size //
tensor_model_parallel_size) tensor_model_parallel_size)
......
...@@ -101,7 +101,7 @@ class ExecutorWithExternalLauncher(UniProcExecutor): ...@@ -101,7 +101,7 @@ class ExecutorWithExternalLauncher(UniProcExecutor):
# - MASTER_PORT # - MASTER_PORT
distributed_init_method = "env://" distributed_init_method = "env://"
rank = int(os.environ["RANK"]) rank = int(os.environ["RANK"])
local_rank = rank local_rank = int(os.environ["LOCAL_RANK"])
is_driver_worker = True is_driver_worker = True
kwargs = dict( kwargs = dict(
vllm_config=self.vllm_config, vllm_config=self.vllm_config,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment