Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
ed94d0bb
Unverified
Commit
ed94d0bb
authored
Dec 13, 2021
by
eqy
Committed by
GitHub
Dec 14, 2021
Browse files
check size in kth microbatch (#1247)
parent
0e25fcc4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
28 additions
and
10 deletions
+28
-10
apex/transformer/pipeline_parallel/utils.py
apex/transformer/pipeline_parallel/utils.py
+9
-1
tests/L0/run_transformer/run_dynamic_batchsize_test.py
tests/L0/run_transformer/run_dynamic_batchsize_test.py
+1
-1
tests/L0/run_transformer/run_pipeline_parallel_test.py
tests/L0/run_transformer/run_pipeline_parallel_test.py
+18
-8
No files found.
apex/transformer/pipeline_parallel/utils.py
View file @
ed94d0bb
...
@@ -126,7 +126,15 @@ def get_kth_microbatch(batch: List[torch.Tensor], k: int) -> List[torch.Tensor]:
...
@@ -126,7 +126,15 @@ def get_kth_microbatch(batch: List[torch.Tensor], k: int) -> List[torch.Tensor]:
`a local minibatch` consists of `global_batch_size / data_parallel_size` samples.
`a local minibatch` consists of `global_batch_size / data_parallel_size` samples.
"""
"""
micro_batch_size
=
get_micro_batch_size
()
micro_batch_size
=
get_micro_batch_size
()
return
[
x
[
k
*
micro_batch_size
:(
k
+
1
)
*
micro_batch_size
]
for
x
in
batch
]
start
=
k
*
micro_batch_size
end
=
start
+
micro_batch_size
microbatch
=
list
()
for
x
in
batch
:
size
=
x
.
size
(
0
)
assert
size
>
start
and
size
>=
end
microbatch
.
append
(
x
[
start
:
end
])
assert
len
(
microbatch
)
>
0
return
microbatch
def
get_autoresume
():
def
get_autoresume
():
...
...
tests/L0/run_transformer/run_dynamic_batchsize_test.py
View file @
ed94d0bb
...
@@ -31,7 +31,7 @@ _logger = get_transformer_logger("pipeline_parallel_test")
...
@@ -31,7 +31,7 @@ _logger = get_transformer_logger("pipeline_parallel_test")
# note(mkozuki): To see if local batch size increases, uncomment the line below
# note(mkozuki): To see if local batch size increases, uncomment the line below
# _logger.setLevel("INFO")
# _logger.setLevel("INFO")
global_vars
.
set_global_variables
(
global_vars
.
set_global_variables
(
args_defaults
=
{
"global_batch_size"
:
512
,
"rampup_batch_size"
:
[
32
,
32
,
1000
],},
args_defaults
=
{
"global_batch_size"
:
512
,
"rampup_batch_size"
:
[
64
,
64
,
1000
],},
ignore_unknown_args
=
True
,
ignore_unknown_args
=
True
,
)
)
...
...
tests/L0/run_transformer/run_pipeline_parallel_test.py
View file @
ed94d0bb
...
@@ -9,7 +9,7 @@ from apex.transformer.pipeline_parallel.schedules.common import build_model
...
@@ -9,7 +9,7 @@ from apex.transformer.pipeline_parallel.schedules.common import build_model
from
apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining
import
forward_backward_no_pipelining
from
apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining
import
forward_backward_no_pipelining
from
apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving
import
_forward_backward_pipelining_with_interleaving
from
apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving
import
_forward_backward_pipelining_with_interleaving
from
apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving
import
forward_backward_pipelining_without_interleaving
from
apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving
import
forward_backward_pipelining_without_interleaving
from
apex.transformer.pipeline_parallel.utils
import
setup
_microbatch_calculator
from
apex.transformer.pipeline_parallel.utils
import
_reconfigure
_microbatch_calculator
from
apex.transformer.pipeline_parallel.utils
import
update_num_microbatches
from
apex.transformer.pipeline_parallel.utils
import
update_num_microbatches
from
apex.transformer.testing
import
global_vars
from
apex.transformer.testing
import
global_vars
from
apex.transformer.testing.commons
import
TEST_SUCCESS_MESSAGE
from
apex.transformer.testing.commons
import
TEST_SUCCESS_MESSAGE
...
@@ -37,6 +37,7 @@ fwd_bwd_functions = {
...
@@ -37,6 +37,7 @@ fwd_bwd_functions = {
# TODO (mkozuki): Add a case with `autocast` and `GradScaler`.
# TODO (mkozuki): Add a case with `autocast` and `GradScaler`.
# Run forward & backward for one minibatch.
# Run forward & backward for one minibatch.
def
forward_backward_func_template
(
def
forward_backward_func_template
(
args
,
name
:
str
,
name
:
str
,
forward_backward_func
,
forward_backward_func
,
pipeline_model_parallel_size
:
int
,
pipeline_model_parallel_size
:
int
,
...
@@ -49,12 +50,26 @@ def forward_backward_func_template(
...
@@ -49,12 +50,26 @@ def forward_backward_func_template(
# pipeline_model_parallel_size>1. So use pipeline_model_parallel_size as
# pipeline_model_parallel_size>1. So use pipeline_model_parallel_size as
# tensor_model_parallel_size and set pipeline_model_parallel_size to 1.
# tensor_model_parallel_size and set pipeline_model_parallel_size to 1.
parallel_state
.
initialize_model_parallel
(
1
,
1
,
None
)
parallel_state
.
initialize_model_parallel
(
1
,
1
,
None
)
_reconfigure_microbatch_calculator
(
args
.
rank
,
args
.
rampup_batch_size
,
args
.
global_batch_size
,
args
.
micro_batch_size
,
parallel_state
.
get_data_parallel_world_size
(),
)
else
:
else
:
# NOTE (mkozuki): `virtual_pipeline_model_parallel_size` is necessary to enable interleaving scheduling
# NOTE (mkozuki): `virtual_pipeline_model_parallel_size` is necessary to enable interleaving scheduling
# In megatron, `args.virtual_pipeline_model_parallel_size` is computed in megatron/arguments.py and
# In megatron, `args.virtual_pipeline_model_parallel_size` is computed in megatron/arguments.py and
# used ubiquitously but this test uses custom model so it's safe to abuse.
# used ubiquitously but this test uses custom model so it's safe to abuse.
parallel_state
.
initialize_model_parallel
(
parallel_state
.
initialize_model_parallel
(
1
,
pipeline_model_parallel_size
,
virtual_pipeline_model_parallel_size
)
1
,
pipeline_model_parallel_size
,
virtual_pipeline_model_parallel_size
)
_reconfigure_microbatch_calculator
(
args
.
rank
,
args
.
rampup_batch_size
,
args
.
global_batch_size
,
args
.
micro_batch_size
,
1
,
# args.data_parallel_size,
)
if
virtual_pipeline_model_parallel_size
is
not
None
:
if
virtual_pipeline_model_parallel_size
is
not
None
:
# Check the experimental warning message
# Check the experimental warning message
get_forward_backward_func
(
virtual_pipeline_model_parallel_size
,
pipeline_model_parallel_size
)
get_forward_backward_func
(
virtual_pipeline_model_parallel_size
,
pipeline_model_parallel_size
)
...
@@ -100,13 +115,7 @@ if __name__ == "__main__":
...
@@ -100,13 +115,7 @@ if __name__ == "__main__":
args
=
global_vars
.
get_args
()
args
=
global_vars
.
get_args
()
batch_size
=
args
.
global_batch_size
batch_size
=
args
.
global_batch_size
micro_batch_size
=
args
.
micro_batch_size
micro_batch_size
=
args
.
micro_batch_size
setup_microbatch_calculator
(
args
.
rank
,
args
.
rampup_batch_size
,
args
.
global_batch_size
,
args
.
micro_batch_size
,
1
,
# args.data_parallel_size,
)
for
forward_only
in
(
True
,
False
):
for
forward_only
in
(
True
,
False
):
for
name
,
forward_backward_func
in
fwd_bwd_functions
.
items
():
for
name
,
forward_backward_func
in
fwd_bwd_functions
.
items
():
n_tests
+=
1
n_tests
+=
1
...
@@ -114,6 +123,7 @@ if __name__ == "__main__":
...
@@ -114,6 +123,7 @@ if __name__ == "__main__":
pipeline_model_parallel_size
=
world_size
pipeline_model_parallel_size
=
world_size
try
:
try
:
forward_backward_func_template
(
forward_backward_func_template
(
args
,
name
,
name
,
forward_backward_func
,
forward_backward_func
,
pipeline_model_parallel_size
,
pipeline_model_parallel_size
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment