Unverified Commit c3018b13 authored by eqy's avatar eqy Committed by GitHub
Browse files

[transformer][pipeline parallel] fix typo in test (#1370)

* fix typo

* Update test_pipeline_parallel_fwd_bwd.py
parent 2b7d280b
......@@ -45,7 +45,7 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase):
forward_only: bool,
fwd_bwd_func: FwdStepFunc,
pipeline_model_parallel_world_size: Optional[int],
vriatual_pipeline_model_parallel_size: Optional[int],
virtual_pipeline_model_parallel_size: Optional[int],
) -> None:
for dtype, deallocate_pipeline_outputs in itertools.product(
[torch.float32] + _get_autocast_dtypes(), (True, False),
......@@ -67,7 +67,7 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size,
pipeline_model_parallel_size_=pipeline_model_parallel_world_size,
virtual_pipeline_model_parallel_size_=vriatual_pipeline_model_parallel_size,
virtual_pipeline_model_parallel_size_=virtual_pipeline_model_parallel_size,
)
pp_utils._reconfigure_microbatch_calculator(
rank=parallel_state.get_tensor_model_parallel_rank(),
......@@ -88,7 +88,7 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase):
model = build_model(
testing_utils.model_provider_func,
wrap_with_ddp=True,
virtual_pipeline_model_parallel_size=vriatual_pipeline_model_parallel_size,
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
hidden_size=PipelineParallelForwardBackwardTest.HIDDEN_SIZE,
)
_param_groups = _get_params_for_weight_decay_optimization(model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment