"vscode:/vscode.git/clone" did not exist on "f00cd6efbd00b0273f58c393a617415b5d1d410e"
Unverified Commit 59978d5e authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

[transformer] Update `build_model` function to support encoder&decoder model (#1307)

* update build_model to support enc&dec model

* fix typo: cur_sargs -> cur_args

* enc&dec path: correctly update pre/post process
parent 47c269b6
......@@ -26,6 +26,7 @@ def build_model(
model_provider_func: Callable[[Any, Dict[str, Any]], torch.nn.Module],
wrap_with_ddp: bool = True,
virtual_pipeline_model_parallel_size: Optional[int] = None,
model_type: ModelType = ModelType.encoder_or_decoder,
*args: Any,
**kwargs: Any,
) -> List[torch.nn.Module]:
......@@ -39,6 +40,7 @@ def build_model(
wrap_with_ddp: If :obj:`True`, wrap the instantiated model
with `torch.nn.parallel.distributed.DistributedDataParallel`, a.k.a. `DDP`.
virtual_pipeline_model_parallel_size: Specify when using interleaving scheduling pipeline model parallel.
model_type:
*args: arguments for model provider func
**kwargs: Keyword arguments for model provider func
......@@ -67,13 +69,39 @@ def build_model(
else:
cur_args = args
cur_kwargs = kwargs
pre_process = parallel_state.is_pipeline_first_stage()
post_process = parallel_state.is_pipeline_last_stage()
cur_kwargs.update({
"pre_process": pre_process,
"post_process": post_process,
})
model = model_provider_func(*cur_args, **cur_kwargs)
if model_type == ModelType.encoder_or_decoder:
pre_process = parallel_state.is_pipeline_first_stage()
post_process = parallel_state.is_pipeline_last_stage()
cur_kwargs.update({
"pre_process": pre_process,
"post_process": post_process,
})
model = model_provider_func(*cur_args, **cur_kwargs)
elif model_type == ModelType.encoder_and_decoder:
pre_process = parallel_state.is_pipeline_first_stage()
post_process = parallel_state.is_pipeline_last_stage()
# `add_encoder` & `add_decoder` logic.
add_encoder, add_decoder = True, True
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
split_rank = parallel_state.get_pipeline_model_parallel_split_rank()
if split_rank is None:
raise RuntimeError(
"Split rank needs to be specified for model with both encoder and decoder."
)
rank = parallel_state.get_pipeline_model_parallel_rank()
world_size = parallel_state.get_pipeline_model_parallel_world_size()
pre_process = rank == 0 or rank == split_rank
post_process = rank == (split_rank - 1) or rank == (world_size - 1)
add_encoder = parallel_state.is_pipeline_stage_before_split()
add_decoder = parallel_state.is_pipeline_stage_after_split()
cur_kwargs.update({
"pre_process": pre_process,
"post_process": post_process,
"add_encoder": add_encoder,
"add_decoder": add_decoder,
})
model = model_provider_func(*cur_args, **cur_kwargs)
model.model_type = model_type
if not isinstance(model, list):
model = [model]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment