Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
c3018b13
Unverified
Commit
c3018b13
authored
Apr 29, 2022
by
eqy
Committed by
GitHub
Apr 29, 2022
Browse files
[transformer][pipeline parallel] fix typo in test (#1370)
* fix typo * Update test_pipeline_parallel_fwd_bwd.py
parent
2b7d280b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
3 deletions
+3
-3
tests/L0/run_transformer/test_pipeline_parallel_fwd_bwd.py
tests/L0/run_transformer/test_pipeline_parallel_fwd_bwd.py
+3
-3
No files found.
tests/L0/run_transformer/test_pipeline_parallel_fwd_bwd.py
View file @
c3018b13
...
@@ -45,7 +45,7 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase):
...
@@ -45,7 +45,7 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase):
forward_only
:
bool
,
forward_only
:
bool
,
fwd_bwd_func
:
FwdStepFunc
,
fwd_bwd_func
:
FwdStepFunc
,
pipeline_model_parallel_world_size
:
Optional
[
int
],
pipeline_model_parallel_world_size
:
Optional
[
int
],
vr
ia
tual_pipeline_model_parallel_size
:
Optional
[
int
],
v
i
rtual_pipeline_model_parallel_size
:
Optional
[
int
],
)
->
None
:
)
->
None
:
for
dtype
,
deallocate_pipeline_outputs
in
itertools
.
product
(
for
dtype
,
deallocate_pipeline_outputs
in
itertools
.
product
(
[
torch
.
float32
]
+
_get_autocast_dtypes
(),
(
True
,
False
),
[
torch
.
float32
]
+
_get_autocast_dtypes
(),
(
True
,
False
),
...
@@ -67,7 +67,7 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase):
...
@@ -67,7 +67,7 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase):
parallel_state
.
initialize_model_parallel
(
parallel_state
.
initialize_model_parallel
(
tensor_model_parallel_size_
=
tensor_model_parallel_world_size
,
tensor_model_parallel_size_
=
tensor_model_parallel_world_size
,
pipeline_model_parallel_size_
=
pipeline_model_parallel_world_size
,
pipeline_model_parallel_size_
=
pipeline_model_parallel_world_size
,
virtual_pipeline_model_parallel_size_
=
vr
ia
tual_pipeline_model_parallel_size
,
virtual_pipeline_model_parallel_size_
=
v
i
rtual_pipeline_model_parallel_size
,
)
)
pp_utils
.
_reconfigure_microbatch_calculator
(
pp_utils
.
_reconfigure_microbatch_calculator
(
rank
=
parallel_state
.
get_tensor_model_parallel_rank
(),
rank
=
parallel_state
.
get_tensor_model_parallel_rank
(),
...
@@ -88,7 +88,7 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase):
...
@@ -88,7 +88,7 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase):
model
=
build_model
(
model
=
build_model
(
testing_utils
.
model_provider_func
,
testing_utils
.
model_provider_func
,
wrap_with_ddp
=
True
,
wrap_with_ddp
=
True
,
virtual_pipeline_model_parallel_size
=
vr
ia
tual_pipeline_model_parallel_size
,
virtual_pipeline_model_parallel_size
=
v
i
rtual_pipeline_model_parallel_size
,
hidden_size
=
PipelineParallelForwardBackwardTest
.
HIDDEN_SIZE
,
hidden_size
=
PipelineParallelForwardBackwardTest
.
HIDDEN_SIZE
,
)
)
_param_groups
=
_get_params_for_weight_decay_optimization
(
model
)
_param_groups
=
_get_params_for_weight_decay_optimization
(
model
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment