Unverified Commit c3018b13 authored by eqy's avatar eqy Committed by GitHub
Browse files

[transformer][pipeline parallel] fix typo in test (#1370)

* fix typo

* Update test_pipeline_parallel_fwd_bwd.py
parent 2b7d280b
...@@ -45,7 +45,7 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase): ...@@ -45,7 +45,7 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase):
forward_only: bool, forward_only: bool,
fwd_bwd_func: FwdStepFunc, fwd_bwd_func: FwdStepFunc,
pipeline_model_parallel_world_size: Optional[int], pipeline_model_parallel_world_size: Optional[int],
vriatual_pipeline_model_parallel_size: Optional[int], virtual_pipeline_model_parallel_size: Optional[int],
) -> None: ) -> None:
for dtype, deallocate_pipeline_outputs in itertools.product( for dtype, deallocate_pipeline_outputs in itertools.product(
[torch.float32] + _get_autocast_dtypes(), (True, False), [torch.float32] + _get_autocast_dtypes(), (True, False),
...@@ -67,7 +67,7 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase): ...@@ -67,7 +67,7 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase):
parallel_state.initialize_model_parallel( parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size, tensor_model_parallel_size_=tensor_model_parallel_world_size,
pipeline_model_parallel_size_=pipeline_model_parallel_world_size, pipeline_model_parallel_size_=pipeline_model_parallel_world_size,
virtual_pipeline_model_parallel_size_=vriatual_pipeline_model_parallel_size, virtual_pipeline_model_parallel_size_=virtual_pipeline_model_parallel_size,
) )
pp_utils._reconfigure_microbatch_calculator( pp_utils._reconfigure_microbatch_calculator(
rank=parallel_state.get_tensor_model_parallel_rank(), rank=parallel_state.get_tensor_model_parallel_rank(),
...@@ -88,7 +88,7 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase): ...@@ -88,7 +88,7 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase):
model = build_model( model = build_model(
testing_utils.model_provider_func, testing_utils.model_provider_func,
wrap_with_ddp=True, wrap_with_ddp=True,
virtual_pipeline_model_parallel_size=vriatual_pipeline_model_parallel_size, virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
hidden_size=PipelineParallelForwardBackwardTest.HIDDEN_SIZE, hidden_size=PipelineParallelForwardBackwardTest.HIDDEN_SIZE,
) )
_param_groups = _get_params_for_weight_decay_optimization(model) _param_groups = _get_params_for_weight_decay_optimization(model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment