Unverified Commit ed94d0bb authored by eqy's avatar eqy Committed by GitHub
Browse files

check size in kth microbatch (#1247)

parent 0e25fcc4
...@@ -126,7 +126,15 @@ def get_kth_microbatch(batch: List[torch.Tensor], k: int) -> List[torch.Tensor]: ...@@ -126,7 +126,15 @@ def get_kth_microbatch(batch: List[torch.Tensor], k: int) -> List[torch.Tensor]:
`a local minibatch` consists of `global_batch_size / data_parallel_size` samples. `a local minibatch` consists of `global_batch_size / data_parallel_size` samples.
""" """
micro_batch_size = get_micro_batch_size() micro_batch_size = get_micro_batch_size()
return [x[k * micro_batch_size:(k + 1) * micro_batch_size] for x in batch] start = k * micro_batch_size
end = start + micro_batch_size
microbatch = list()
for x in batch:
size = x.size(0)
assert size > start and size >= end
microbatch.append(x[start:end])
assert len(microbatch) > 0
return microbatch
def get_autoresume(): def get_autoresume():
......
...@@ -31,7 +31,7 @@ _logger = get_transformer_logger("pipeline_parallel_test") ...@@ -31,7 +31,7 @@ _logger = get_transformer_logger("pipeline_parallel_test")
# note(mkozuki): To see if local batch size increases, uncomment the line below # note(mkozuki): To see if local batch size increases, uncomment the line below
# _logger.setLevel("INFO") # _logger.setLevel("INFO")
global_vars.set_global_variables( global_vars.set_global_variables(
args_defaults={"global_batch_size": 512, "rampup_batch_size": [32, 32, 1000],}, args_defaults={"global_batch_size": 512, "rampup_batch_size": [64, 64, 1000],},
ignore_unknown_args=True, ignore_unknown_args=True,
) )
......
...@@ -9,7 +9,7 @@ from apex.transformer.pipeline_parallel.schedules.common import build_model ...@@ -9,7 +9,7 @@ from apex.transformer.pipeline_parallel.schedules.common import build_model
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import forward_backward_no_pipelining from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import forward_backward_no_pipelining
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import _forward_backward_pipelining_with_interleaving from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import _forward_backward_pipelining_with_interleaving
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import forward_backward_pipelining_without_interleaving from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import forward_backward_pipelining_without_interleaving
from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator from apex.transformer.pipeline_parallel.utils import _reconfigure_microbatch_calculator
from apex.transformer.pipeline_parallel.utils import update_num_microbatches from apex.transformer.pipeline_parallel.utils import update_num_microbatches
from apex.transformer.testing import global_vars from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
...@@ -37,6 +37,7 @@ fwd_bwd_functions = { ...@@ -37,6 +37,7 @@ fwd_bwd_functions = {
# TODO (mkozuki): Add a case with `autocast` and `GradScaler`. # TODO (mkozuki): Add a case with `autocast` and `GradScaler`.
# Run forward & backward for one minibatch. # Run forward & backward for one minibatch.
def forward_backward_func_template( def forward_backward_func_template(
args,
name: str, name: str,
forward_backward_func, forward_backward_func,
pipeline_model_parallel_size: int, pipeline_model_parallel_size: int,
...@@ -49,12 +50,26 @@ def forward_backward_func_template( ...@@ -49,12 +50,26 @@ def forward_backward_func_template(
# pipeline_model_parallel_size>1. So use pipeline_model_parallel_size as # pipeline_model_parallel_size>1. So use pipeline_model_parallel_size as
# tensor_model_parallel_size and set pipeline_model_parallel_size to 1. # tensor_model_parallel_size and set pipeline_model_parallel_size to 1.
parallel_state.initialize_model_parallel(1, 1, None) parallel_state.initialize_model_parallel(1, 1, None)
_reconfigure_microbatch_calculator(
args.rank,
args.rampup_batch_size,
args.global_batch_size,
args.micro_batch_size,
parallel_state.get_data_parallel_world_size(),
)
else: else:
# NOTE (mkozuki): `virtual_pipeline_model_parallel_size` is necessary to enable interleaving scheduling # NOTE (mkozuki): `virtual_pipeline_model_parallel_size` is necessary to enable interleaving scheduling
# In megatron, `args.virtual_pipeline_model_parallel_size` is computed in megatron/arguments.py and # In megatron, `args.virtual_pipeline_model_parallel_size` is computed in megatron/arguments.py and
# used ubiquitously but this test uses custom model so it's safe to abuse. # used ubiquitously but this test uses custom model so it's safe to abuse.
parallel_state.initialize_model_parallel( parallel_state.initialize_model_parallel(
1, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size) 1, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size)
_reconfigure_microbatch_calculator(
args.rank,
args.rampup_batch_size,
args.global_batch_size,
args.micro_batch_size,
1, # args.data_parallel_size,
)
if virtual_pipeline_model_parallel_size is not None: if virtual_pipeline_model_parallel_size is not None:
# Check the experimental warning message # Check the experimental warning message
get_forward_backward_func(virtual_pipeline_model_parallel_size, pipeline_model_parallel_size) get_forward_backward_func(virtual_pipeline_model_parallel_size, pipeline_model_parallel_size)
...@@ -100,13 +115,7 @@ if __name__ == "__main__": ...@@ -100,13 +115,7 @@ if __name__ == "__main__":
args = global_vars.get_args() args = global_vars.get_args()
batch_size = args.global_batch_size batch_size = args.global_batch_size
micro_batch_size = args.micro_batch_size micro_batch_size = args.micro_batch_size
setup_microbatch_calculator(
args.rank,
args.rampup_batch_size,
args.global_batch_size,
args.micro_batch_size,
1, # args.data_parallel_size,
)
for forward_only in (True, False): for forward_only in (True, False):
for name, forward_backward_func in fwd_bwd_functions.items(): for name, forward_backward_func in fwd_bwd_functions.items():
n_tests += 1 n_tests += 1
...@@ -114,6 +123,7 @@ if __name__ == "__main__": ...@@ -114,6 +123,7 @@ if __name__ == "__main__":
pipeline_model_parallel_size = world_size pipeline_model_parallel_size = world_size
try: try:
forward_backward_func_template( forward_backward_func_template(
args,
name, name,
forward_backward_func, forward_backward_func,
pipeline_model_parallel_size, pipeline_model_parallel_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment