"src/vscode:/vscode.git/clone" did not exist on "b785ddb654e4be3ae0066e231734754bdb2a191c"
Unverified Commit 35336133 authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

[POC] Support Megatron-LM's `rampup_batch_size` argument (#1212)

* init logging use

* fix

* clean up

* fp32 p2p comm

* init

* Dynamic global batch size with `MegatronPretrainingSampler`

I couldn't make this script work with `MegatronPretrainingRandomSampler` because the random sampler seems to have some requirement for
global batch size, total number of samples, local minibatch size, etc. which I'm not familiar with for now

* revive original pipeline parallel test

* update MULTIGPU_TEST: add dynamic batchsize test

* run MegatronPretrainingRandomSampler

* fix comment

* fix

* update

* cosmetic

* add note

* Apply 2 suggestion(s) to 2 file(s)

* change following https://github.com/NVIDIA/apex/pull/1210

* fix
parent 25bfcb91
import logging
# May help avoid undefined symbol errors https://pytorch.org/cppdocs/notes/faq.html#undefined-symbol-errors-from-pytorch-aten
import torch
if torch.distributed.is_available():
from . import parallel
......@@ -18,3 +21,19 @@ from . import optimizers
from . import normalization
from . import pyprof
from . import transformer
# Logging utilities mainly for apex.transformer module
class RankInfoFormatter(logging.Formatter):
def format(self, record):
from apex.transformer.parallel_state import get_rank_info
record.rank_info = get_rank_info()
return super().format(record)
_library_root_logger = logging.getLogger(__name__)
handler = logging.StreamHandler()
handler.setFormatter(RankInfoFormatter("%(asctime)s - %(name)s - %(levelname)s - %(rank_info)s - %(message)s"))
_library_root_logger.addHandler(handler)
_library_root_logger.propagate = False
from apex.transformer._data._batchsampler import MegatronPretrainingRandomSampler
from apex.transformer._data._batchsampler import MegatronPretrainingSampler
__all__ = [
"MegatronPretrainingRandomSampler",
"MegatronPretrainingSampler",
]
"""BatchSampler implementations for POC of dynamic batch size or rampup_batch_size support.
Implementations are based on https://github.com/NVIDIA/Megatron-LM/blob/bcd605f8570ebeeb0436c115ebbfafc3c5a40ae5/megatron/data/data_samplers.py.
""" # NOQA
import abc
import torch
__all__ = [
"MegatronPretrainingSampler",
"MegatronPretrainingRandomSampler",
]
class _Base:
"""Base class for Megatron style BatchSampler."""
@abc.abstractmethod
def __len__(self) -> int:
...
@abc.abstractmethod
def __iter__(self):
...
@property
@abc.abstractmethod
def local_minibatch_size(self) -> int:
...
@local_minibatch_size.setter
@abc.abstractclassmethod
def local_minibatch_size(self) -> None:
...
class MegatronPretrainingSampler(_Base):
def __init__(
self,
total_samples: int,
consumed_samples: int,
local_minibatch_size: int,
data_parallel_rank: int,
data_parallel_size: int,
drop_last: bool = True,
):
# Sanity checks.
if total_samples <= 0:
raise RuntimeError('no sample to consume: {}'.format(self.total_samples))
if consumed_samples >= total_samples:
raise RuntimeError('no samples left to consume: {}, {}'.format(self.consumed_samples, self.total_samples))
if local_minibatch_size <= 0:
raise RuntimeError(f"local minibatch size must be greater than 0: {local_minibatch_size}")
if data_parallel_size <= 0:
raise RuntimeError(f"data parallel size must be greater than 0: {data_parallel_size}")
if data_parallel_rank >= data_parallel_size:
raise RuntimeError('data_parallel_rank should be smaller than data size: {}, {}'.format(self.data_parallel_rank, data_parallel_size))
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self._local_minibatch_size = local_minibatch_size
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.local_minibatch_times_data_parallel_size = self._local_minibatch_size * data_parallel_size
self.drop_last = drop_last
def __len__(self):
return self.total_samples
def get_start_end_idx(self):
start_idx = self.data_parallel_rank * self.local_minibatch_size
end_idx = start_idx + self.local_minibatch_size
return start_idx, end_idx
@property
def local_minibatch_size(self) -> int:
return self._local_minibatch_size
@local_minibatch_size.setter
def local_minibatch_size(self, new_local_minibatch_size) -> None:
self._local_minibatch_size = new_local_minibatch_size
self.local_minibatch_times_data_parallel_size = self._local_minibatch_size * self.data_parallel_size
def __iter__(self):
batch = []
# Last batch will be dropped if drop_last is not set False
for idx in range(self.consumed_samples, self.total_samples):
batch.append(idx)
if len(batch) == self.local_minibatch_size:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]
batch = []
# Check the last partial batch and see drop_last is set
if len(batch) > 0 and not self.drop_last:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]
class MegatronPretrainingRandomSampler(_Base):
"""Megatron style Random Batch Sampler.
Major difference is that `__iter__` yields a local minibatch, not a microbatch.
A local minibatch consists of `global_batch_size / data_parallel_size`
Args:
total_samples: The number of data samples, i.e. ``len(dataset)``.
consumed_samples: The number of samples already consumed in pretraining.
local_minibatch_size: The number of data in each batch returned from `__iter__`. Basically
`local_minibatch_size = global_batch_size / data_parallel_size`.
data_parallel_rank:
data_parallel_size:
"""
def __init__(
self,
total_samples: int,
consumed_samples: int,
local_minibatch_size: int,
data_parallel_rank: int,
data_parallel_size: int,
) -> None:
if total_samples <= 0:
raise ValueError(f"no sample to consume: total_samples of {total_samples}")
if local_minibatch_size <= 0:
raise ValueError(f"Invalid local_minibatch_size: {local_minibatch_size}")
if data_parallel_size <= 0:
raise ValueError(f"Invalid data_parallel_size: {data_parallel_size}")
if data_parallel_rank >= data_parallel_size:
raise ValueError(
f"data_parallel_rank should be smaller than data parallel size: {data_parallel_rank} < {data_parallel_size}"
)
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self._local_minibatch_size = local_minibatch_size
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.local_minibatch_times_data_parallel_size = self._local_minibatch_size * self.data_parallel_size
self.last_batch_size = self.total_samples % self.local_minibatch_times_data_parallel_size
def __len__(self) -> int:
return self.total_samples
@property
def local_minibatch_size(self) -> int:
return self._local_minibatch_size
@local_minibatch_size.setter
def local_minibatch_size(self, new_local_minibatch_size) -> None:
self._local_minibatch_size = new_local_minibatch_size
self.local_minibatch_times_data_parallel_size = self._local_minibatch_size * self.data_parallel_size
def __iter__(self):
active_total_samples = self.total_samples - self.last_batch_size
self.epoch = self.consumed_samples // active_total_samples
current_epoch_samples = self.consumed_samples % active_total_samples
# note(mkozuki): might be better to uncomment
# assert current_epoch_samples % (self.data_parallel_size * apex.transformer.pipeline_parallel.utils.get_micro_batch_size()) == 0
# data sharding and random sampling
bucket_size = (self.total_samples // self.local_minibatch_times_data_parallel_size) * self.local_minibatch_size
bucket_offset = current_epoch_samples // self.data_parallel_size
start_idx = self.data_parallel_rank * bucket_size
g = torch.Generator()
g.manual_seed(self.epoch)
random_idx = torch.randperm(bucket_size, generator=g).tolist()
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
batch = []
# Last batch if not complete will be dropped.
for idx in idx_range:
batch.append(idx)
if len(batch) == self.local_minibatch_size:
self.consumed_samples += self.local_minibatch_times_data_parallel_size
yield batch
batch = []
from typing import Optional
import logging
import os
import threading
def get_transformer_logger(name: str) -> logging.Logger:
name_wo_ext = os.path.splitext(name)[0]
return logging.getLogger(name_wo_ext)
def set_logging_level(verbosity) -> None:
"""Change logging severity.
Args:
verbosity
"""
from apex import _library_root_logger
_library_root_logger.setLevel(verbosity)
......@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model and data parallel groups."""
from typing import Tuple
import torch
# TODO (mkozuki): Consider dissecting utils as this utils import is here
......@@ -164,6 +166,18 @@ def initialize_model_parallel(
if rank in ranks:
_EMBEDDING_GLOBAL_RANKS = embedding_ranks
def get_rank_info() -> Tuple[int, int, int]:
"""Returns a tuple of (tensor, pipeline, data)-parallel-rank for logger."""
if model_parallel_is_initialized():
return (
get_tensor_model_parallel_rank(),
get_pipeline_model_parallel_rank(),
# get_virtual_pipeline_model_parallel_rank(),
get_data_parallel_rank(),
)
return (0, 0, 0)
def model_parallel_is_initialized():
"""Check if model and data parallel groups are initialized."""
if _TENSOR_MODEL_PARALLEL_GROUP is None or _PIPELINE_MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None:
......
......@@ -16,6 +16,7 @@
from functools import reduce
import operator
from typing import Union, Optional, Tuple
import warnings
import torch
......@@ -68,8 +69,6 @@ def _run_p2pops(
req.wait()
# NOTE (mkozuki): Leaving `params_dytpe` as it is for future development in PyTorch, especially APEX O2 style AMP.
# But as of v1.10, basically all tensors are torch.float32 except for output tensors of `autocast` compatible layers.
def _communicate(
tensor_send_next: Optional[torch.Tensor],
tensor_send_prev: Optional[torch.Tensor],
......@@ -118,14 +117,21 @@ def _communicate(
tensor_chunk_shape = (reduce(operator.mul, tensor_shape, 1) // parallel_state.get_tensor_model_parallel_world_size(),)
else:
tensor_chunk_shape = tensor_shape
dtype = params_dtype or torch.float
if fp32_residual_connection:
dtype = torch.float
# NOTE(mkozuki): In PyTorch AMP, i.e. `torch.cuda.amp.autocast` context, activation tensors can be either FP32,
# FP16, or BF16 and there's no way to tell the dtypes of tensors on different devices in general.
# It might be possible if we restrict model architecture.
# dtype = params_dtype or torch.float
# if fp32_residual_connection:
# dtype = torch.float
# if dtype_ is not None:
# dtype = dtype_
# requires_grad = False
if dtype_ != torch.float32 or params_dtype is not None:
if torch.distributed.get_rank() == 0:
warnings.warn("Tensor P2P communications are executed in FP32")
dtype = torch.float32
requires_grad = True
if dtype_ is not None:
dtype = dtype_
requires_grad = False
if recv_prev:
tensor_recv_prev = torch.empty(
......
......@@ -9,6 +9,13 @@ from apex.transformer.pipeline_parallel.utils import get_kth_microbatch
from apex.transformer.pipeline_parallel.schedules.common import Batch, FwdStepFunc
from apex.transformer.pipeline_parallel.schedules.common import forward_step
from apex.transformer.pipeline_parallel.schedules.common import backward_step
from apex.transformer.log_util import get_transformer_logger
_all__ = ["forward_backward_no_pipelining"]
_logger = get_transformer_logger(__name__)
@contextmanager
......@@ -19,9 +26,6 @@ def placeholder_handler():
pass
# TODO (mkozuki): Confirm this will be used or not.
# TODO (mkozuki): Fix if necessary. Currently I'm seeing failure if `not forward_only` and
# the last `backward_step` seems to fail. However, note the possibility of my test script is wrong.
def forward_backward_no_pipelining(
forward_step_func: FwdStepFunc,
batch: Batch,
......@@ -64,18 +68,24 @@ def forward_backward_no_pipelining(
num_micro_batches = get_num_microbatches()
with context_handler():
for i in range(num_micro_batches - 1):
_logger.info(f"Iter {i} of {num_micro_batches - 1}")
cur_micro_batch = get_kth_microbatch(batch, i)
_logger.debug("Call `forward_step`")
output_tensor = forward_step(
forward_step_func, cur_micro_batch, model, input_tensor, losses_reduced)
if not forward_only:
_logger.debug("Call `backward_step`")
backward_step(input_tensor, output_tensor, output_tensor_grad)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
_logger.info("Cooldown")
_logger.debug("Call `forward_step`")
output_tensor = forward_step(
forward_step_func, get_kth_microbatch(batch, num_micro_batches - 1), model, input_tensor, losses_reduced
)
if not forward_only:
_logger.debug("Call `backward_step`")
backward_step(input_tensor, output_tensor, output_tensor_grad)
return losses_reduced
......@@ -9,7 +9,13 @@ from apex.transformer.pipeline_parallel.schedules.common import backward_step
from apex.transformer.pipeline_parallel.schedules.common import forward_step
from apex.transformer.pipeline_parallel.utils import get_kth_microbatch
from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.utils import rank_print
from apex.transformer.log_util import get_transformer_logger
__all__ = ["_forward_backward_pipelining_with_interleaving"]
_logger = get_transformer_logger(__name__)
# TODO (mkozuki): Reduce cyclomatic complexity
......@@ -82,13 +88,11 @@ def _forward_backward_pipelining_with_interleaving(
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
num_microbatches_remaining = num_microbatches - num_warmup_microbatches
# TODO (mkozuki): Remove once debug gets done
# rank_print(
# f"num_microbatches: {num_microbatches}, "
# f"num_warmup_microbatches: {num_warmup_microbatches}, "
# f"num_microbatches_remaining: {num_microbatches_remaining} -- "
# )
_logger.info(
f"num_microbatches: {num_microbatches}, "
f"num_warmup_microbatches: {num_warmup_microbatches}, "
f"num_microbatches_remaining: {num_microbatches_remaining}"
)
###################################################################################################################
# Helper function definitions.
......@@ -155,8 +159,9 @@ def _forward_backward_pipelining_with_interleaving(
###################################################################################################################
parallel_state.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append(p2p_communication.recv_forward(tensor_shape=tensor_shape))
_logger.info("Warmup phase")
for k in range(num_warmup_microbatches):
# rank_print(f"warmup iter: {k}")
_logger.debug(f"warmup iter: {k} / {num_warmup_microbatches}")
output_tensor = forward_step_helper(k, curr_iters)
# Determine if tensor should be received from previous stage.
......@@ -167,20 +172,21 @@ def _forward_backward_pipelining_with_interleaving(
recv_prev = False
if k == (num_microbatches - 1):
recv_prev = False
_logger.debug(f"next fwd model chunk ID: {next_forward_model_chunk_id}, recv_prev: {recv_prev}")
# Don't send tensor downstream if on last stage.
if parallel_state.is_pipeline_last_stage():
_logger.debug("Pipeline last stage, not sending tensor downstream")
output_tensor = None
# rank_print(f"recv_prev: {recv_prev}")
# Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration).
if k == (num_warmup_microbatches - 1) and not forward_only and not all_warmup_microbatches:
input_tensor_grad = None
recv_next = True
# rank_print(f"recv_next: {recv_next}")
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
recv_next = False
_logger.debug("send fwd&bwd and receive fwd&bwd")
(
input_tensor,
output_tensor_grad,
......@@ -193,17 +199,17 @@ def _forward_backward_pipelining_with_interleaving(
)
output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
else:
# rank_print("send_forward_recv_forward start")
_logger.debug("send fwd and receive fwd")
input_tensor = p2p_communication.send_forward_recv_forward(output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape)
# rank_print("send_forward_recv_forward finish")
# rank_print("communication done")
input_tensors[next_forward_model_chunk_id].append(input_tensor)
###################################################################################################################
# Run 1F1B in steady state.
###################################################################################################################
_logger.info("Steady phase")
for k in range(num_microbatches_remaining):
# Forward pass.
_logger.debug(f" steady phase iter {k} / {num_microbatches_remaining}")
forward_k = k + num_warmup_microbatches
output_tensor = forward_step_helper(forward_k, curr_iters)
......@@ -223,6 +229,7 @@ def _forward_backward_pipelining_with_interleaving(
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
_logger.debug(f"fwd/bwd model chunk id: {forward_model_chunk_id}/{backward_model_chunk_id}")
if parallel_state.is_pipeline_first_stage():
input_tensor_grad = None
......@@ -258,6 +265,7 @@ def _forward_backward_pipelining_with_interleaving(
recv_prev = False
# Communicate tensors.
_logger.debug("send fwd&bwd and receive fwd&bwd")
(
input_tensor,
output_tensor_grad,
......@@ -279,10 +287,12 @@ def _forward_backward_pipelining_with_interleaving(
###################################################################################################################
# Run cooldown backward passes (flush out pipeline).
###################################################################################################################
_logger.info("Cooldown phase")
if not forward_only:
if all_warmup_microbatches:
output_tensor_grads[num_model_chunks - 1].append(p2p_communication.recv_backward(tensor_shape=tensor_shape))
for k in range(num_microbatches_remaining, num_microbatches):
_logger.debug(f"cooldown iter {k} in range({num_microbatches_remaining}, {num_microbatches})")
input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False)
recv_next = True
......
......@@ -10,7 +10,13 @@ from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.schedules.common import Batch, FwdStepFunc
from apex.transformer.pipeline_parallel.schedules.common import forward_step
from apex.transformer.pipeline_parallel.schedules.common import backward_step
from apex.transformer.utils import rank_print
from apex.transformer.log_util import get_transformer_logger
__all__ = ["forward_backward_pipelining_without_interleaving"]
_logger = get_transformer_logger(__name__)
def forward_backward_pipelining_without_interleaving(
......@@ -61,12 +67,10 @@ def forward_backward_pipelining_without_interleaving(
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
num_microbatches_remaining = num_microbatches - num_warmup_microbatches
# TODO (mkozuki): Remove once debug gets done
print(
f">>> rank: {torch.distributed.get_rank()}, "
_logger.info(
f"num_microbatches: {num_microbatches}, "
f"num_warmup_microbatches: {num_warmup_microbatches}, "
f"num_microbatches_remaining: {num_microbatches_remaining} -- "
f"num_microbatches_remaining: {num_microbatches_remaining}"
)
# Input, output tensors only need to be saved when doing backward passes
......@@ -80,52 +84,48 @@ def forward_backward_pipelining_without_interleaving(
###################################################################################################################
# Run warmup forward passes.
###################################################################################################################
# rank_print(f"warmup: {num_warmup_microbatches}")
_logger.info("Warmup")
for i in range(num_warmup_microbatches):
_logger.debug(f"warmup iter: {i} / {num_warmup_microbatches}")
_logger.debug("receive fwd")
input_tensor = p2p_communication.recv_forward(tensor_shape=tensor_shape)
cur_microbatch = get_kth_microbatch(batch, i)
output_tensor = forward_step(forward_step_func, cur_microbatch, model, input_tensor, losses_reduced)
_logger.debug("send fwd")
p2p_communication.send_forward(output_tensor, tensor_shape=tensor_shape)
if not forward_only:
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
# rank_print(f"warmup iter: {i + 1} / {num_warmup_microbatches}")
# rank_print("warmup done")
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
# rank_print(f"num microbatches remaining: {num_microbatches_remaining}")
if num_microbatches_remaining > 0:
# rank_print(f"recv_forward before steady state start")
_logger.debug("recv_forward before steady state start")
input_tensor = p2p_communication.recv_forward(tensor_shape=tensor_shape)
# rank_print(f"recv_forward before steady state done")
###################################################################################################################
# Run 1F1B in steady state.
###################################################################################################################
# rank_print(f"steady: {num_microbatches_remaining} iters")
_logger.info("Steady phase")
for i in range(num_microbatches_remaining):
# rank_print(f"steady: iter {i + 1} / {num_microbatches_remaining} iters")
# if not forward_only:
# rank_print(f"len(input_tensors) = {len(input_tensors)}, len(output_tensors) = {len(output_tensors)}")
_logger.debug(f"steady iter: {i} / {num_microbatches_remaining}")
last_iteration = i == (num_microbatches_remaining - 1)
cur_microbatch = get_kth_microbatch(batch, i + num_warmup_microbatches)
output_tensor = forward_step(forward_step_func, cur_microbatch, model, input_tensor, losses_reduced)
if forward_only:
# rank_print(f"steady, no backward: `send_forward` start")
_logger.debug("send fwd")
p2p_communication.send_forward(output_tensor, tensor_shape=tensor_shape)
if not last_iteration:
_logger.debug("receive fwd (last iteration)")
input_tensor = p2p_communication.recv_forward(tensor_shape=tensor_shape)
# rank_print(f"steady, no backward: `send_forward` finish")
else:
# rank_print("L.124 steady, backward: `send_forward_recv_backward` start")
_logger.debug("send fwd & receive bwd")
output_tensor_grad = p2p_communication.send_forward_recv_backward(output_tensor, tensor_shape=tensor_shape)
# rank_print("L.124 steady, backward: `send_forward_recv_backward` finish")
# Add input_tensor and output_tensor to end of list.
input_tensors.append(input_tensor)
......@@ -141,35 +141,30 @@ def forward_backward_pipelining_without_interleaving(
if last_iteration:
input_tensor = None
# rank_print(f"L.142 steady backward last iteration: `send_backward` start")
_logger.debug("send bwd")
p2p_communication.send_backward(input_tensor_grad, tensor_shape=tensor_shape)
# rank_print(f"L.142 steady backward last iteration: `send_backward` finish")
else:
# rank_print(f"L.146 steady backward: `send_backward_recv_forward` start")
_logger.debug("send bwd and receive fwd")
input_tensor = p2p_communication.send_backward_recv_forward(
input_tensor_grad, tensor_shape=tensor_shape)
# rank_print(f"L.146 steady backward: `send_backward_recv_forward` finish")
# rank_print(f"steady: exit")
###################################################################################################################
# Run cooldown backward passes.
###################################################################################################################
_logger.info("Cooldown phase")
if not forward_only:
# rank_print(f"cooldownk: {num_warmup_microbatches} iters")
for i in range(num_warmup_microbatches):
# rank_print(f"cooldown iter: {i + 1} / {num_warmup_microbatches}")
_logger.debug(f"cooldown iter: {i} / {num_warmup_microbatches}")
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
# rank_print(f"cooldown waiting for grad tensor")
_logger.debug("receive bwd")
output_tensor_grad = p2p_communication.recv_backward(tensor_shape=tensor_shape)
# rank_print(f"cooldown received grad tensor")
input_tensor_grad = backward_step(
input_tensor, output_tensor, output_tensor_grad
)
# rank_print(f"cooldown sending grad tensor")
_logger.debug("send bwd")
p2p_communication.send_backward(input_tensor_grad, tensor_shape=tensor_shape)
# rank_print(f"cooldownk exit")
return losses_reduced
......@@ -68,6 +68,22 @@ def setup_microbatch_calculator(
rank, rampup_batch_size, global_batch_size, micro_batch_size, data_parallel_size)
def _reconfigure_microbatch_calculator(
rank: int,
rampup_batch_size: Optional[List[int]],
global_batch_size: int,
micro_batch_size: int,
data_parallel_size: int,
) -> None:
if torch.distributed.get_rank() == 0:
import warnings
warnings.warn("This function is only for unittest")
global _GLOBAL_NUM_MICROBATCHES_CALCULATOR
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = build_num_microbatches_calculator(
rank, rampup_batch_size, global_batch_size, micro_batch_size, data_parallel_size)
def get_micro_batch_size():
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.micro_batch_size
......
......@@ -14,9 +14,11 @@
# limitations under the License.
import os
import random
from typing import Optional, Union, List
import numpy
import torch
import torch.nn as nn
from apex import transformer
from apex.transformer.testing import global_vars
......@@ -25,6 +27,40 @@ from apex.transformer.testing import global_vars
TEST_SUCCESS_MESSAGE = ">> passed the test :-)"
# note (mkozuki): `pre_process` and `post_process` are a placeholder until interleaving schedule test comes.
class MyLayer(nn.Module):
def __init__(self, hidden_size: int, pre_process: bool, post_process: bool):
super().__init__()
self.pre_process = pre_process
self.post_process = post_process
self.layer = nn.Linear(hidden_size, hidden_size)
def forward(self, x):
return self.layer(x)
class MyModel(nn.Module):
def __init__(self, hidden_size: int, pre_process: bool = False, post_process: bool = False) -> None:
super().__init__()
self.pre_process = pre_process
self.post_process = post_process
self.layer = MyLayer(hidden_size=hidden_size, pre_process=pre_process, post_process=post_process)
self.input_tensor = None
def set_input_tensor(self, input_tensor: Union[torch.Tensor, List[torch.Tensor]]) -> None:
self.input_tensor = input_tensor
def forward(self, x: Optional[torch.Tensor]) -> torch.Tensor:
if self.input_tensor is None:
return self.layer(x)
return self.layer(self.input_tensor)
def model_provider_func(hidden_size, pre_process, post_process) -> MyModel:
return MyModel(hidden_size, pre_process, post_process)
class IdentityLayer(torch.nn.Module):
def __init__(self, size, scale=1.0):
super(IdentityLayer, self).__init__()
......
......@@ -37,17 +37,27 @@ def get_args():
return _GLOBAL_ARGS
def get_num_microbatches():
def get_num_microbatches() -> int:
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get()
def get_current_global_batch_size():
def get_current_global_batch_size() -> int:
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get_current_global_batch_size()
def update_num_microbatches(consumed_samples, consistency_check=True):
_GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples,
consistency_check)
def update_num_microbatches(consumed_samples: int, *, consistency_check: bool = True) -> None:
"""Update the number of microbatches upon the number of consumed samples.
.. note::
This function has no effect unless ``rampup_batch_size`` is set.
Args:
consumed_samples: The number of consumed samples so far. Basically this is equal to
:math:`num_iter * global_batch_size`.
consistency_check: If :obj:`True`, sanity checks the consumed samples, i.e., check if
``consumed_samples`` is divisible by :math:`micro_batch_size \times data_parallel_size`.
"""
_GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples, consistency_check)
# def get_tokenizer():
......
......@@ -34,14 +34,3 @@ def gather_split_1d_tensor(tensor):
chunks = [gathered[i * numel:(i + 1) * numel] for i in range(world_size)]
torch.distributed.all_gather(chunks, tensor, group=parallel_state.get_tensor_model_parallel_group())
return gathered
# TODO(mkozuki): Rewrite this using `logging`.
def rank_print(msg):
"""Print the given msg with rank information"""
print(
f"tensor rank: {parallel_state.get_tensor_model_parallel_rank()}"
f"pipeline rank: {parallel_state.get_pipeline_model_parallel_rank()}, "
f"virtual pipeline rank: {parallel_state.get_virtual_pipeline_model_parallel_rank()}, "
f"data rank: {parallel_state.get_data_parallel_rank()} | {msg}"
)
from typing import Tuple, List
import torch
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.schedules.common import (
_get_params_for_weight_decay_optimization,
)
from apex.transformer.pipeline_parallel.schedules.common import build_model
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import (
_forward_backward_pipelining_with_interleaving,
)
from apex.transformer.pipeline_parallel.utils import average_losses_across_data_parallel_group
from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator
from apex.transformer.pipeline_parallel.utils import _reconfigure_microbatch_calculator
from apex.transformer.pipeline_parallel.utils import update_num_microbatches
from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
from apex.transformer.testing.commons import initialize_distributed
from apex.transformer.testing.commons import print_separator
from apex.transformer.log_util import get_transformer_logger, set_logging_level
from apex.transformer.testing.commons import model_provider_func
from apex.transformer._data import MegatronPretrainingRandomSampler
from apex.transformer._data import MegatronPretrainingSampler
# note(mkozuki): To see warmup, steady, cooldown iterations, uncomment the line below
# set_logging_level("INFO")
_logger = get_transformer_logger("pipeline_parallel_test")
# note(mkozuki): To see if local batch size increases, uncomment the line below
# _logger.setLevel("INFO")
global_vars.set_global_variables(
args_defaults={"global_batch_size": 512, "rampup_batch_size": [32, 32, 1000],},
ignore_unknown_args=True,
)
RAMPUP_BATCH_SIZE = []
NUM_ITERATIONS = 20
NUM_SAMPLES = 16384 // 2
batch_size, micro_batch_size = None, None
HIDDEN_SIZE = 16
def Dataset(num_samples: int) -> List[Tuple[torch.Tensor, torch.Tensor]]:
return [(torch.randn(HIDDEN_SIZE), torch.randn(HIDDEN_SIZE // 2)) for _ in range(num_samples)]
def process_batch(batch):
if isinstance(batch, (list, tuple)):
x = batch[0]
else:
x = batch
return x
def fwd_step_func(micro_batch, model):
x = process_batch(micro_batch)
y = model(x)
# note (mkozuki): I don't think this function is nice but I do think this is enough for now
# just to check the sanity of ported pipeline functions.
def loss_func(x):
loss = torch.sum(x)
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {"avg": averaged_loss}
return y, loss_func
# Run forward & backward with dynamic batch size.
def run_interleaved_with_dynamic_batch_size(
pipeline_model_parallel_size: int, forward_only: bool, BatchSamplerCls,
) -> None:
args = global_vars.get_args()
_reconfigure_microbatch_calculator(
args.rank,
args.rampup_batch_size,
args.global_batch_size,
args.micro_batch_size,
1, # args.data_parallel_size,
)
virtual_pipeline_model_parallel_size = 2
# NOTE (mkozuki): `virtual_pipeline_model_parallel_size` is a requisite for the interleaving scheduling
# In megatron, `args.virtual_pipeline_model_parallel_size` is computed in megatron/arguments.py and
# used ubiquitously but this test uses custom model so it's safe to abuse.
parallel_state.initialize_model_parallel(
1, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size
)
pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
print_separator(f"BatchSamplerCls: {BatchSamplerCls.__name__}, forward_only: {forward_only}")
model = build_model(
model_provider_func,
wrap_with_ddp=True,
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
hidden_size=HIDDEN_SIZE,
)
assert isinstance(model, list)
assert len(model) == virtual_pipeline_model_parallel_size
optimizer = torch.optim.Adam(_get_params_for_weight_decay_optimization(model))
initial_local_minibatch_size = get_num_microbatches() * micro_batch_size
dataset = Dataset(NUM_SAMPLES)
data_loader = torch.utils.data.DataLoader(
dataset,
batch_sampler=BatchSamplerCls(
NUM_SAMPLES,
0,
initial_local_minibatch_size,
parallel_state.get_data_parallel_rank(),
parallel_state.get_data_parallel_world_size(),
),
)
data_iter = iter(data_loader)
def get_num_samples(batch):
if isinstance(batch, torch.Tensor):
return len(batch)
assert isinstance(batch, (list, tuple))
return [get_num_samples(b) for b in batch]
tensor_shape = [micro_batch_size, HIDDEN_SIZE]
consumed_samples = 0
for i in range(NUM_ITERATIONS):
update_num_microbatches(consumed_samples, consistency_check=False)
local_batch_size = get_num_microbatches() * micro_batch_size
data_iter._index_sampler.local_minibatch_size = local_batch_size
local_mini_batch = next(data_iter)
_logger.info(
f"iter: {i} / {NUM_ITERATIONS} "
f"local batchsize: {get_num_samples(local_mini_batch)} "
f"consumed_samples: {consumed_samples} / {NUM_SAMPLES}"
)
_forward_backward_pipelining_with_interleaving(
fwd_step_func,
local_mini_batch,
model,
forward_only=forward_only,
tensor_shape=tensor_shape,
)
consumed_samples += (
parallel_state.get_data_parallel_world_size()
* get_num_microbatches()
* micro_batch_size
)
if not forward_only:
for m in model:
for p in m.parameters():
if p.grad is None:
raise RuntimeError("grad not found")
else:
optimizer.zero_grad(set_to_none=True)
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
if __name__ == "__main__":
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
n_tests = 0
failures = []
initialize_distributed()
world_size = torch.distributed.get_world_size()
args = global_vars.get_args()
batch_size = args.global_batch_size
micro_batch_size = args.micro_batch_size
setup_microbatch_calculator(
args.rank,
args.rampup_batch_size,
args.global_batch_size,
args.micro_batch_size,
1, # args.data_parallel_size,
)
for BatchSamplerCls in (MegatronPretrainingSampler, MegatronPretrainingRandomSampler):
for forward_only in (False, True):
n_tests += 1
pipeline_model_parallel_size = world_size
try:
run_interleaved_with_dynamic_batch_size(
pipeline_model_parallel_size, forward_only, BatchSamplerCls,
)
except Exception as e:
msg = (
f"\tforward_only: {forward_only}\n"
f"pipeline rank: {parallel_state.get_pipeline_model_parallel_rank()}, "
f"virtual pipeline rank: {parallel_state.get_virtual_pipeline_model_parallel_rank()}\n"
f"{str(e)}"
)
raise RuntimeError(msg)
finally:
parallel_state.destroy_model_parallel()
print_separator("TEST RESULT")
if failures:
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print("\n".join(failures))
msg = f"{len(failures)} / {n_tests} cases failed"
raise RuntimeError(msg)
else:
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print("### PASS!")
from functools import partial
import logging
from typing import List
import torch
......@@ -18,9 +19,11 @@ from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
from apex.transformer.testing.commons import initialize_distributed
from apex.transformer.testing.commons import print_separator
from apex.transformer.testing.standalone_gpt import gpt_model_provider
from apex.transformer.utils import rank_print
from apex.transformer.log_util import get_transformer_logger, set_logging_level
set_logging_level(logging.NOTSET)
_logger = get_transformer_logger("megatron_gpt_pipeline_test")
global_vars.set_global_variables()
N_VOCAB = 8192
......@@ -77,14 +80,14 @@ def run_gpt(pipeline_model_parallel_size, virtual_pipeline_model_parallel_size=N
model = build_model(
gpt_model_provider, True,
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size)
# rank_print("building model")
_logger.debug("building model")
assert isinstance(model, list)
assert len(model) == (1 or virtual_pipeline_model_parallel_size)
_param_groups = _get_params_for_weight_decay_optimization(model)
torch.optim.Adam(_param_groups)
if parallel_state.is_pipeline_last_stage():
# rank_print("checking `word_embeddings` existence")
_logger.debug("checking `word_embeddings` existence")
for m in model:
assert hasattr(m, "word_embeddings")
......@@ -93,19 +96,19 @@ def run_gpt(pipeline_model_parallel_size, virtual_pipeline_model_parallel_size=N
batch = generate_batch(args.global_batch_size, args.seq_length)
else:
batch = [generate_batch(args.global_batch_size, args.seq_length) for _ in range(virtual_pipeline_model_parallel_size)]
# rank_print("preparing batch")
_logger.debug("preparing batch")
if virtual_pipeline_model_parallel_size is None:
fwd_bwd_func = forward_backward_pipelining_without_interleaving
else:
fwd_bwd_func = _forward_backward_pipelining_with_interleaving
# rank_print(f"selecting forward_backward func: {fwd_bwd_func}")
_logger.debug(f"selecting forward_backward func: {fwd_bwd_func}")
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
# rank_print(f"`tensor_shape`: {tensor_shape}")
_logger.debug(f"`tensor_shape`: {tensor_shape}")
fwd_bwd_func(forward_step, batch, model, forward_only=forward_only, tensor_shape=tensor_shape)
# rank_print(TEST_SUCCESS_MESSAGE)
_logger.debug(TEST_SUCCESS_MESSAGE)
if __name__ == "__main__":
......@@ -126,7 +129,7 @@ if __name__ == "__main__":
# TODO(mkozuki): handle exception correctly, but for now, lazily commenting out as
# this won't get kicked by CI
except Exception as e:
# rank_print(str(e))
_logger.debug(str(e))
pass
finally:
parallel_state.destroy_model_parallel()
......@@ -3,6 +3,7 @@ from typing import Optional, Union, List
import torch
import torch.nn as nn
import apex
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel import get_forward_backward_func
from apex.transformer.pipeline_parallel.schedules.common import _get_params_for_weight_decay_optimization
......@@ -17,9 +18,11 @@ from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
from apex.transformer.testing.commons import initialize_distributed
from apex.transformer.testing.commons import print_separator
from apex.transformer.utils import rank_print
from apex.transformer.log_util import get_transformer_logger, set_logging_level
# set_logging_level("INFO")
_logger = get_transformer_logger("pipeline_parallel_test")
global_vars.set_global_variables()
......@@ -133,7 +136,6 @@ def forward_backward_func_template(
fwd_step_func, batch, model, forward_only=forward_only, tensor_shape=tensor_shape)
if not forward_only:
# rank_print("grad check")
for m in model:
for p in m.parameters():
if p.grad is None:
......
......@@ -10,6 +10,7 @@ DENY_TEST = [
]
MULTIGPU_TEST = [
"pipeline_parallel_test",
"dynamic_batchsize_test",
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment