Unverified Commit 35336133 authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

[POC] Support Megatron-LM's `rampup_batch_size` argument (#1212)

* init logging use

* fix

* clean up

* fp32 p2p comm

* init

* Dynamic global batch size with `MegatronPretrainingSampler`

I couldn't make this script work with `MegatronPretrainingRandomSampler` because the random sampler seems to have some requirement for
global batch size, total number of samples, local minibatch size, etc. which I'm not familiar with for now

* revive original pipeline parallel test

* update MULTIGPU_TEST: add dynamic batchsize test

* run MegatronPretrainingRandomSampler

* fix comment

* fix

* update

* cosmetic

* add note

* Apply 2 suggestion(s) to 2 file(s)

* change following https://github.com/NVIDIA/apex/pull/1210

* fix
parent 25bfcb91
import logging
# May help avoid undefined symbol errors https://pytorch.org/cppdocs/notes/faq.html#undefined-symbol-errors-from-pytorch-aten # May help avoid undefined symbol errors https://pytorch.org/cppdocs/notes/faq.html#undefined-symbol-errors-from-pytorch-aten
import torch import torch
if torch.distributed.is_available(): if torch.distributed.is_available():
from . import parallel from . import parallel
...@@ -18,3 +21,19 @@ from . import optimizers ...@@ -18,3 +21,19 @@ from . import optimizers
from . import normalization from . import normalization
from . import pyprof from . import pyprof
from . import transformer from . import transformer
# Logging utilities mainly for apex.transformer module
class RankInfoFormatter(logging.Formatter):
def format(self, record):
from apex.transformer.parallel_state import get_rank_info
record.rank_info = get_rank_info()
return super().format(record)
_library_root_logger = logging.getLogger(__name__)
handler = logging.StreamHandler()
handler.setFormatter(RankInfoFormatter("%(asctime)s - %(name)s - %(levelname)s - %(rank_info)s - %(message)s"))
_library_root_logger.addHandler(handler)
_library_root_logger.propagate = False
from apex.transformer._data._batchsampler import MegatronPretrainingRandomSampler
from apex.transformer._data._batchsampler import MegatronPretrainingSampler
__all__ = [
"MegatronPretrainingRandomSampler",
"MegatronPretrainingSampler",
]
"""BatchSampler implementations for POC of dynamic batch size or rampup_batch_size support.
Implementations are based on https://github.com/NVIDIA/Megatron-LM/blob/bcd605f8570ebeeb0436c115ebbfafc3c5a40ae5/megatron/data/data_samplers.py.
""" # NOQA
import abc
import torch
__all__ = [
"MegatronPretrainingSampler",
"MegatronPretrainingRandomSampler",
]
class _Base:
"""Base class for Megatron style BatchSampler."""
@abc.abstractmethod
def __len__(self) -> int:
...
@abc.abstractmethod
def __iter__(self):
...
@property
@abc.abstractmethod
def local_minibatch_size(self) -> int:
...
@local_minibatch_size.setter
@abc.abstractclassmethod
def local_minibatch_size(self) -> None:
...
class MegatronPretrainingSampler(_Base):
def __init__(
self,
total_samples: int,
consumed_samples: int,
local_minibatch_size: int,
data_parallel_rank: int,
data_parallel_size: int,
drop_last: bool = True,
):
# Sanity checks.
if total_samples <= 0:
raise RuntimeError('no sample to consume: {}'.format(self.total_samples))
if consumed_samples >= total_samples:
raise RuntimeError('no samples left to consume: {}, {}'.format(self.consumed_samples, self.total_samples))
if local_minibatch_size <= 0:
raise RuntimeError(f"local minibatch size must be greater than 0: {local_minibatch_size}")
if data_parallel_size <= 0:
raise RuntimeError(f"data parallel size must be greater than 0: {data_parallel_size}")
if data_parallel_rank >= data_parallel_size:
raise RuntimeError('data_parallel_rank should be smaller than data size: {}, {}'.format(self.data_parallel_rank, data_parallel_size))
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self._local_minibatch_size = local_minibatch_size
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.local_minibatch_times_data_parallel_size = self._local_minibatch_size * data_parallel_size
self.drop_last = drop_last
def __len__(self):
return self.total_samples
def get_start_end_idx(self):
start_idx = self.data_parallel_rank * self.local_minibatch_size
end_idx = start_idx + self.local_minibatch_size
return start_idx, end_idx
@property
def local_minibatch_size(self) -> int:
return self._local_minibatch_size
@local_minibatch_size.setter
def local_minibatch_size(self, new_local_minibatch_size) -> None:
self._local_minibatch_size = new_local_minibatch_size
self.local_minibatch_times_data_parallel_size = self._local_minibatch_size * self.data_parallel_size
def __iter__(self):
batch = []
# Last batch will be dropped if drop_last is not set False
for idx in range(self.consumed_samples, self.total_samples):
batch.append(idx)
if len(batch) == self.local_minibatch_size:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]
batch = []
# Check the last partial batch and see drop_last is set
if len(batch) > 0 and not self.drop_last:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]
class MegatronPretrainingRandomSampler(_Base):
"""Megatron style Random Batch Sampler.
Major difference is that `__iter__` yields a local minibatch, not a microbatch.
A local minibatch consists of `global_batch_size / data_parallel_size`
Args:
total_samples: The number of data samples, i.e. ``len(dataset)``.
consumed_samples: The number of samples already consumed in pretraining.
local_minibatch_size: The number of data in each batch returned from `__iter__`. Basically
`local_minibatch_size = global_batch_size / data_parallel_size`.
data_parallel_rank:
data_parallel_size:
"""
def __init__(
self,
total_samples: int,
consumed_samples: int,
local_minibatch_size: int,
data_parallel_rank: int,
data_parallel_size: int,
) -> None:
if total_samples <= 0:
raise ValueError(f"no sample to consume: total_samples of {total_samples}")
if local_minibatch_size <= 0:
raise ValueError(f"Invalid local_minibatch_size: {local_minibatch_size}")
if data_parallel_size <= 0:
raise ValueError(f"Invalid data_parallel_size: {data_parallel_size}")
if data_parallel_rank >= data_parallel_size:
raise ValueError(
f"data_parallel_rank should be smaller than data parallel size: {data_parallel_rank} < {data_parallel_size}"
)
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self._local_minibatch_size = local_minibatch_size
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.local_minibatch_times_data_parallel_size = self._local_minibatch_size * self.data_parallel_size
self.last_batch_size = self.total_samples % self.local_minibatch_times_data_parallel_size
def __len__(self) -> int:
return self.total_samples
@property
def local_minibatch_size(self) -> int:
return self._local_minibatch_size
@local_minibatch_size.setter
def local_minibatch_size(self, new_local_minibatch_size) -> None:
self._local_minibatch_size = new_local_minibatch_size
self.local_minibatch_times_data_parallel_size = self._local_minibatch_size * self.data_parallel_size
def __iter__(self):
active_total_samples = self.total_samples - self.last_batch_size
self.epoch = self.consumed_samples // active_total_samples
current_epoch_samples = self.consumed_samples % active_total_samples
# note(mkozuki): might be better to uncomment
# assert current_epoch_samples % (self.data_parallel_size * apex.transformer.pipeline_parallel.utils.get_micro_batch_size()) == 0
# data sharding and random sampling
bucket_size = (self.total_samples // self.local_minibatch_times_data_parallel_size) * self.local_minibatch_size
bucket_offset = current_epoch_samples // self.data_parallel_size
start_idx = self.data_parallel_rank * bucket_size
g = torch.Generator()
g.manual_seed(self.epoch)
random_idx = torch.randperm(bucket_size, generator=g).tolist()
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
batch = []
# Last batch if not complete will be dropped.
for idx in idx_range:
batch.append(idx)
if len(batch) == self.local_minibatch_size:
self.consumed_samples += self.local_minibatch_times_data_parallel_size
yield batch
batch = []
from typing import Optional
import logging
import os
import threading
def get_transformer_logger(name: str) -> logging.Logger:
name_wo_ext = os.path.splitext(name)[0]
return logging.getLogger(name_wo_ext)
def set_logging_level(verbosity) -> None:
"""Change logging severity.
Args:
verbosity
"""
from apex import _library_root_logger
_library_root_logger.setLevel(verbosity)
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Model and data parallel groups.""" """Model and data parallel groups."""
from typing import Tuple
import torch import torch
# TODO (mkozuki): Consider dissecting utils as this utils import is here # TODO (mkozuki): Consider dissecting utils as this utils import is here
...@@ -164,6 +166,18 @@ def initialize_model_parallel( ...@@ -164,6 +166,18 @@ def initialize_model_parallel(
if rank in ranks: if rank in ranks:
_EMBEDDING_GLOBAL_RANKS = embedding_ranks _EMBEDDING_GLOBAL_RANKS = embedding_ranks
def get_rank_info() -> Tuple[int, int, int]:
"""Returns a tuple of (tensor, pipeline, data)-parallel-rank for logger."""
if model_parallel_is_initialized():
return (
get_tensor_model_parallel_rank(),
get_pipeline_model_parallel_rank(),
# get_virtual_pipeline_model_parallel_rank(),
get_data_parallel_rank(),
)
return (0, 0, 0)
def model_parallel_is_initialized(): def model_parallel_is_initialized():
"""Check if model and data parallel groups are initialized.""" """Check if model and data parallel groups are initialized."""
if _TENSOR_MODEL_PARALLEL_GROUP is None or _PIPELINE_MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None: if _TENSOR_MODEL_PARALLEL_GROUP is None or _PIPELINE_MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None:
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
from functools import reduce from functools import reduce
import operator import operator
from typing import Union, Optional, Tuple from typing import Union, Optional, Tuple
import warnings
import torch import torch
...@@ -68,8 +69,6 @@ def _run_p2pops( ...@@ -68,8 +69,6 @@ def _run_p2pops(
req.wait() req.wait()
# NOTE (mkozuki): Leaving `params_dytpe` as it is for future development in PyTorch, especially APEX O2 style AMP.
# But as of v1.10, basically all tensors are torch.float32 except for output tensors of `autocast` compatible layers.
def _communicate( def _communicate(
tensor_send_next: Optional[torch.Tensor], tensor_send_next: Optional[torch.Tensor],
tensor_send_prev: Optional[torch.Tensor], tensor_send_prev: Optional[torch.Tensor],
...@@ -118,14 +117,21 @@ def _communicate( ...@@ -118,14 +117,21 @@ def _communicate(
tensor_chunk_shape = (reduce(operator.mul, tensor_shape, 1) // parallel_state.get_tensor_model_parallel_world_size(),) tensor_chunk_shape = (reduce(operator.mul, tensor_shape, 1) // parallel_state.get_tensor_model_parallel_world_size(),)
else: else:
tensor_chunk_shape = tensor_shape tensor_chunk_shape = tensor_shape
dtype = params_dtype or torch.float
if fp32_residual_connection:
dtype = torch.float
# NOTE(mkozuki): In PyTorch AMP, i.e. `torch.cuda.amp.autocast` context, activation tensors can be either FP32,
# FP16, or BF16 and there's no way to tell the dtypes of tensors on different devices in general.
# It might be possible if we restrict model architecture.
# dtype = params_dtype or torch.float
# if fp32_residual_connection:
# dtype = torch.float
# if dtype_ is not None:
# dtype = dtype_
# requires_grad = False
if dtype_ != torch.float32 or params_dtype is not None:
if torch.distributed.get_rank() == 0:
warnings.warn("Tensor P2P communications are executed in FP32")
dtype = torch.float32
requires_grad = True requires_grad = True
if dtype_ is not None:
dtype = dtype_
requires_grad = False
if recv_prev: if recv_prev:
tensor_recv_prev = torch.empty( tensor_recv_prev = torch.empty(
......
...@@ -9,6 +9,13 @@ from apex.transformer.pipeline_parallel.utils import get_kth_microbatch ...@@ -9,6 +9,13 @@ from apex.transformer.pipeline_parallel.utils import get_kth_microbatch
from apex.transformer.pipeline_parallel.schedules.common import Batch, FwdStepFunc from apex.transformer.pipeline_parallel.schedules.common import Batch, FwdStepFunc
from apex.transformer.pipeline_parallel.schedules.common import forward_step from apex.transformer.pipeline_parallel.schedules.common import forward_step
from apex.transformer.pipeline_parallel.schedules.common import backward_step from apex.transformer.pipeline_parallel.schedules.common import backward_step
from apex.transformer.log_util import get_transformer_logger
_all__ = ["forward_backward_no_pipelining"]
_logger = get_transformer_logger(__name__)
@contextmanager @contextmanager
...@@ -19,9 +26,6 @@ def placeholder_handler(): ...@@ -19,9 +26,6 @@ def placeholder_handler():
pass pass
# TODO (mkozuki): Confirm this will be used or not.
# TODO (mkozuki): Fix if necessary. Currently I'm seeing failure if `not forward_only` and
# the last `backward_step` seems to fail. However, note the possibility of my test script is wrong.
def forward_backward_no_pipelining( def forward_backward_no_pipelining(
forward_step_func: FwdStepFunc, forward_step_func: FwdStepFunc,
batch: Batch, batch: Batch,
...@@ -64,18 +68,24 @@ def forward_backward_no_pipelining( ...@@ -64,18 +68,24 @@ def forward_backward_no_pipelining(
num_micro_batches = get_num_microbatches() num_micro_batches = get_num_microbatches()
with context_handler(): with context_handler():
for i in range(num_micro_batches - 1): for i in range(num_micro_batches - 1):
_logger.info(f"Iter {i} of {num_micro_batches - 1}")
cur_micro_batch = get_kth_microbatch(batch, i) cur_micro_batch = get_kth_microbatch(batch, i)
_logger.debug("Call `forward_step`")
output_tensor = forward_step( output_tensor = forward_step(
forward_step_func, cur_micro_batch, model, input_tensor, losses_reduced) forward_step_func, cur_micro_batch, model, input_tensor, losses_reduced)
if not forward_only: if not forward_only:
_logger.debug("Call `backward_step`")
backward_step(input_tensor, output_tensor, output_tensor_grad) backward_step(input_tensor, output_tensor, output_tensor_grad)
# Run computation for last microbatch out of context handler (want to # Run computation for last microbatch out of context handler (want to
# synchronize gradients). # synchronize gradients).
_logger.info("Cooldown")
_logger.debug("Call `forward_step`")
output_tensor = forward_step( output_tensor = forward_step(
forward_step_func, get_kth_microbatch(batch, num_micro_batches - 1), model, input_tensor, losses_reduced forward_step_func, get_kth_microbatch(batch, num_micro_batches - 1), model, input_tensor, losses_reduced
) )
if not forward_only: if not forward_only:
_logger.debug("Call `backward_step`")
backward_step(input_tensor, output_tensor, output_tensor_grad) backward_step(input_tensor, output_tensor, output_tensor_grad)
return losses_reduced return losses_reduced
...@@ -9,7 +9,13 @@ from apex.transformer.pipeline_parallel.schedules.common import backward_step ...@@ -9,7 +9,13 @@ from apex.transformer.pipeline_parallel.schedules.common import backward_step
from apex.transformer.pipeline_parallel.schedules.common import forward_step from apex.transformer.pipeline_parallel.schedules.common import forward_step
from apex.transformer.pipeline_parallel.utils import get_kth_microbatch from apex.transformer.pipeline_parallel.utils import get_kth_microbatch
from apex.transformer.pipeline_parallel.utils import get_num_microbatches from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.utils import rank_print from apex.transformer.log_util import get_transformer_logger
__all__ = ["_forward_backward_pipelining_with_interleaving"]
_logger = get_transformer_logger(__name__)
# TODO (mkozuki): Reduce cyclomatic complexity # TODO (mkozuki): Reduce cyclomatic complexity
...@@ -82,13 +88,11 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -82,13 +88,11 @@ def _forward_backward_pipelining_with_interleaving(
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
num_microbatches_remaining = num_microbatches - num_warmup_microbatches num_microbatches_remaining = num_microbatches - num_warmup_microbatches
_logger.info(
# TODO (mkozuki): Remove once debug gets done f"num_microbatches: {num_microbatches}, "
# rank_print( f"num_warmup_microbatches: {num_warmup_microbatches}, "
# f"num_microbatches: {num_microbatches}, " f"num_microbatches_remaining: {num_microbatches_remaining}"
# f"num_warmup_microbatches: {num_warmup_microbatches}, " )
# f"num_microbatches_remaining: {num_microbatches_remaining} -- "
# )
################################################################################################################### ###################################################################################################################
# Helper function definitions. # Helper function definitions.
...@@ -155,8 +159,9 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -155,8 +159,9 @@ def _forward_backward_pipelining_with_interleaving(
################################################################################################################### ###################################################################################################################
parallel_state.set_virtual_pipeline_model_parallel_rank(0) parallel_state.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append(p2p_communication.recv_forward(tensor_shape=tensor_shape)) input_tensors[0].append(p2p_communication.recv_forward(tensor_shape=tensor_shape))
_logger.info("Warmup phase")
for k in range(num_warmup_microbatches): for k in range(num_warmup_microbatches):
# rank_print(f"warmup iter: {k}") _logger.debug(f"warmup iter: {k} / {num_warmup_microbatches}")
output_tensor = forward_step_helper(k, curr_iters) output_tensor = forward_step_helper(k, curr_iters)
# Determine if tensor should be received from previous stage. # Determine if tensor should be received from previous stage.
...@@ -167,20 +172,21 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -167,20 +172,21 @@ def _forward_backward_pipelining_with_interleaving(
recv_prev = False recv_prev = False
if k == (num_microbatches - 1): if k == (num_microbatches - 1):
recv_prev = False recv_prev = False
_logger.debug(f"next fwd model chunk ID: {next_forward_model_chunk_id}, recv_prev: {recv_prev}")
# Don't send tensor downstream if on last stage. # Don't send tensor downstream if on last stage.
if parallel_state.is_pipeline_last_stage(): if parallel_state.is_pipeline_last_stage():
_logger.debug("Pipeline last stage, not sending tensor downstream")
output_tensor = None output_tensor = None
# rank_print(f"recv_prev: {recv_prev}")
# Send and receive tensors as appropriate (send tensors computed # Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration). # in this iteration; receive tensors for next iteration).
if k == (num_warmup_microbatches - 1) and not forward_only and not all_warmup_microbatches: if k == (num_warmup_microbatches - 1) and not forward_only and not all_warmup_microbatches:
input_tensor_grad = None input_tensor_grad = None
recv_next = True recv_next = True
# rank_print(f"recv_next: {recv_next}")
if parallel_state.is_pipeline_last_stage(ignore_virtual=True): if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
recv_next = False recv_next = False
_logger.debug("send fwd&bwd and receive fwd&bwd")
( (
input_tensor, input_tensor,
output_tensor_grad, output_tensor_grad,
...@@ -193,17 +199,17 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -193,17 +199,17 @@ def _forward_backward_pipelining_with_interleaving(
) )
output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad) output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
else: else:
# rank_print("send_forward_recv_forward start") _logger.debug("send fwd and receive fwd")
input_tensor = p2p_communication.send_forward_recv_forward(output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape) input_tensor = p2p_communication.send_forward_recv_forward(output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape)
# rank_print("send_forward_recv_forward finish")
# rank_print("communication done")
input_tensors[next_forward_model_chunk_id].append(input_tensor) input_tensors[next_forward_model_chunk_id].append(input_tensor)
################################################################################################################### ###################################################################################################################
# Run 1F1B in steady state. # Run 1F1B in steady state.
################################################################################################################### ###################################################################################################################
_logger.info("Steady phase")
for k in range(num_microbatches_remaining): for k in range(num_microbatches_remaining):
# Forward pass. # Forward pass.
_logger.debug(f" steady phase iter {k} / {num_microbatches_remaining}")
forward_k = k + num_warmup_microbatches forward_k = k + num_warmup_microbatches
output_tensor = forward_step_helper(forward_k, curr_iters) output_tensor = forward_step_helper(forward_k, curr_iters)
...@@ -223,6 +229,7 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -223,6 +229,7 @@ def _forward_backward_pipelining_with_interleaving(
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id) parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
_logger.debug(f"fwd/bwd model chunk id: {forward_model_chunk_id}/{backward_model_chunk_id}")
if parallel_state.is_pipeline_first_stage(): if parallel_state.is_pipeline_first_stage():
input_tensor_grad = None input_tensor_grad = None
...@@ -258,6 +265,7 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -258,6 +265,7 @@ def _forward_backward_pipelining_with_interleaving(
recv_prev = False recv_prev = False
# Communicate tensors. # Communicate tensors.
_logger.debug("send fwd&bwd and receive fwd&bwd")
( (
input_tensor, input_tensor,
output_tensor_grad, output_tensor_grad,
...@@ -279,10 +287,12 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -279,10 +287,12 @@ def _forward_backward_pipelining_with_interleaving(
################################################################################################################### ###################################################################################################################
# Run cooldown backward passes (flush out pipeline). # Run cooldown backward passes (flush out pipeline).
################################################################################################################### ###################################################################################################################
_logger.info("Cooldown phase")
if not forward_only: if not forward_only:
if all_warmup_microbatches: if all_warmup_microbatches:
output_tensor_grads[num_model_chunks - 1].append(p2p_communication.recv_backward(tensor_shape=tensor_shape)) output_tensor_grads[num_model_chunks - 1].append(p2p_communication.recv_backward(tensor_shape=tensor_shape))
for k in range(num_microbatches_remaining, num_microbatches): for k in range(num_microbatches_remaining, num_microbatches):
_logger.debug(f"cooldown iter {k} in range({num_microbatches_remaining}, {num_microbatches})")
input_tensor_grad = backward_step_helper(k) input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False) next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False)
recv_next = True recv_next = True
......
...@@ -10,7 +10,13 @@ from apex.transformer.pipeline_parallel.utils import get_num_microbatches ...@@ -10,7 +10,13 @@ from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.schedules.common import Batch, FwdStepFunc from apex.transformer.pipeline_parallel.schedules.common import Batch, FwdStepFunc
from apex.transformer.pipeline_parallel.schedules.common import forward_step from apex.transformer.pipeline_parallel.schedules.common import forward_step
from apex.transformer.pipeline_parallel.schedules.common import backward_step from apex.transformer.pipeline_parallel.schedules.common import backward_step
from apex.transformer.utils import rank_print from apex.transformer.log_util import get_transformer_logger
__all__ = ["forward_backward_pipelining_without_interleaving"]
_logger = get_transformer_logger(__name__)
def forward_backward_pipelining_without_interleaving( def forward_backward_pipelining_without_interleaving(
...@@ -61,12 +67,10 @@ def forward_backward_pipelining_without_interleaving( ...@@ -61,12 +67,10 @@ def forward_backward_pipelining_without_interleaving(
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
num_microbatches_remaining = num_microbatches - num_warmup_microbatches num_microbatches_remaining = num_microbatches - num_warmup_microbatches
# TODO (mkozuki): Remove once debug gets done _logger.info(
print(
f">>> rank: {torch.distributed.get_rank()}, "
f"num_microbatches: {num_microbatches}, " f"num_microbatches: {num_microbatches}, "
f"num_warmup_microbatches: {num_warmup_microbatches}, " f"num_warmup_microbatches: {num_warmup_microbatches}, "
f"num_microbatches_remaining: {num_microbatches_remaining} -- " f"num_microbatches_remaining: {num_microbatches_remaining}"
) )
# Input, output tensors only need to be saved when doing backward passes # Input, output tensors only need to be saved when doing backward passes
...@@ -80,52 +84,48 @@ def forward_backward_pipelining_without_interleaving( ...@@ -80,52 +84,48 @@ def forward_backward_pipelining_without_interleaving(
################################################################################################################### ###################################################################################################################
# Run warmup forward passes. # Run warmup forward passes.
################################################################################################################### ###################################################################################################################
# rank_print(f"warmup: {num_warmup_microbatches}") _logger.info("Warmup")
for i in range(num_warmup_microbatches): for i in range(num_warmup_microbatches):
_logger.debug(f"warmup iter: {i} / {num_warmup_microbatches}")
_logger.debug("receive fwd")
input_tensor = p2p_communication.recv_forward(tensor_shape=tensor_shape) input_tensor = p2p_communication.recv_forward(tensor_shape=tensor_shape)
cur_microbatch = get_kth_microbatch(batch, i) cur_microbatch = get_kth_microbatch(batch, i)
output_tensor = forward_step(forward_step_func, cur_microbatch, model, input_tensor, losses_reduced) output_tensor = forward_step(forward_step_func, cur_microbatch, model, input_tensor, losses_reduced)
_logger.debug("send fwd")
p2p_communication.send_forward(output_tensor, tensor_shape=tensor_shape) p2p_communication.send_forward(output_tensor, tensor_shape=tensor_shape)
if not forward_only: if not forward_only:
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
output_tensors.append(output_tensor) output_tensors.append(output_tensor)
# rank_print(f"warmup iter: {i + 1} / {num_warmup_microbatches}")
# rank_print("warmup done")
# Before running 1F1B, need to receive first forward tensor. # Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to # If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here. # receive this tensor here.
# rank_print(f"num microbatches remaining: {num_microbatches_remaining}")
if num_microbatches_remaining > 0: if num_microbatches_remaining > 0:
# rank_print(f"recv_forward before steady state start") _logger.debug("recv_forward before steady state start")
input_tensor = p2p_communication.recv_forward(tensor_shape=tensor_shape) input_tensor = p2p_communication.recv_forward(tensor_shape=tensor_shape)
# rank_print(f"recv_forward before steady state done")
################################################################################################################### ###################################################################################################################
# Run 1F1B in steady state. # Run 1F1B in steady state.
################################################################################################################### ###################################################################################################################
# rank_print(f"steady: {num_microbatches_remaining} iters") _logger.info("Steady phase")
for i in range(num_microbatches_remaining): for i in range(num_microbatches_remaining):
# rank_print(f"steady: iter {i + 1} / {num_microbatches_remaining} iters") _logger.debug(f"steady iter: {i} / {num_microbatches_remaining}")
# if not forward_only:
# rank_print(f"len(input_tensors) = {len(input_tensors)}, len(output_tensors) = {len(output_tensors)}")
last_iteration = i == (num_microbatches_remaining - 1) last_iteration = i == (num_microbatches_remaining - 1)
cur_microbatch = get_kth_microbatch(batch, i + num_warmup_microbatches) cur_microbatch = get_kth_microbatch(batch, i + num_warmup_microbatches)
output_tensor = forward_step(forward_step_func, cur_microbatch, model, input_tensor, losses_reduced) output_tensor = forward_step(forward_step_func, cur_microbatch, model, input_tensor, losses_reduced)
if forward_only: if forward_only:
# rank_print(f"steady, no backward: `send_forward` start") _logger.debug("send fwd")
p2p_communication.send_forward(output_tensor, tensor_shape=tensor_shape) p2p_communication.send_forward(output_tensor, tensor_shape=tensor_shape)
if not last_iteration: if not last_iteration:
_logger.debug("receive fwd (last iteration)")
input_tensor = p2p_communication.recv_forward(tensor_shape=tensor_shape) input_tensor = p2p_communication.recv_forward(tensor_shape=tensor_shape)
# rank_print(f"steady, no backward: `send_forward` finish")
else: else:
# rank_print("L.124 steady, backward: `send_forward_recv_backward` start") _logger.debug("send fwd & receive bwd")
output_tensor_grad = p2p_communication.send_forward_recv_backward(output_tensor, tensor_shape=tensor_shape) output_tensor_grad = p2p_communication.send_forward_recv_backward(output_tensor, tensor_shape=tensor_shape)
# rank_print("L.124 steady, backward: `send_forward_recv_backward` finish")
# Add input_tensor and output_tensor to end of list. # Add input_tensor and output_tensor to end of list.
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
...@@ -141,35 +141,30 @@ def forward_backward_pipelining_without_interleaving( ...@@ -141,35 +141,30 @@ def forward_backward_pipelining_without_interleaving(
if last_iteration: if last_iteration:
input_tensor = None input_tensor = None
# rank_print(f"L.142 steady backward last iteration: `send_backward` start") _logger.debug("send bwd")
p2p_communication.send_backward(input_tensor_grad, tensor_shape=tensor_shape) p2p_communication.send_backward(input_tensor_grad, tensor_shape=tensor_shape)
# rank_print(f"L.142 steady backward last iteration: `send_backward` finish")
else: else:
# rank_print(f"L.146 steady backward: `send_backward_recv_forward` start") _logger.debug("send bwd and receive fwd")
input_tensor = p2p_communication.send_backward_recv_forward( input_tensor = p2p_communication.send_backward_recv_forward(
input_tensor_grad, tensor_shape=tensor_shape) input_tensor_grad, tensor_shape=tensor_shape)
# rank_print(f"L.146 steady backward: `send_backward_recv_forward` finish")
# rank_print(f"steady: exit")
################################################################################################################### ###################################################################################################################
# Run cooldown backward passes. # Run cooldown backward passes.
################################################################################################################### ###################################################################################################################
_logger.info("Cooldown phase")
if not forward_only: if not forward_only:
# rank_print(f"cooldownk: {num_warmup_microbatches} iters")
for i in range(num_warmup_microbatches): for i in range(num_warmup_microbatches):
# rank_print(f"cooldown iter: {i + 1} / {num_warmup_microbatches}") _logger.debug(f"cooldown iter: {i} / {num_warmup_microbatches}")
input_tensor = input_tensors.pop(0) input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0) output_tensor = output_tensors.pop(0)
# rank_print(f"cooldown waiting for grad tensor") _logger.debug("receive bwd")
output_tensor_grad = p2p_communication.recv_backward(tensor_shape=tensor_shape) output_tensor_grad = p2p_communication.recv_backward(tensor_shape=tensor_shape)
# rank_print(f"cooldown received grad tensor")
input_tensor_grad = backward_step( input_tensor_grad = backward_step(
input_tensor, output_tensor, output_tensor_grad input_tensor, output_tensor, output_tensor_grad
) )
# rank_print(f"cooldown sending grad tensor") _logger.debug("send bwd")
p2p_communication.send_backward(input_tensor_grad, tensor_shape=tensor_shape) p2p_communication.send_backward(input_tensor_grad, tensor_shape=tensor_shape)
# rank_print(f"cooldownk exit")
return losses_reduced return losses_reduced
...@@ -68,6 +68,22 @@ def setup_microbatch_calculator( ...@@ -68,6 +68,22 @@ def setup_microbatch_calculator(
rank, rampup_batch_size, global_batch_size, micro_batch_size, data_parallel_size) rank, rampup_batch_size, global_batch_size, micro_batch_size, data_parallel_size)
def _reconfigure_microbatch_calculator(
rank: int,
rampup_batch_size: Optional[List[int]],
global_batch_size: int,
micro_batch_size: int,
data_parallel_size: int,
) -> None:
if torch.distributed.get_rank() == 0:
import warnings
warnings.warn("This function is only for unittest")
global _GLOBAL_NUM_MICROBATCHES_CALCULATOR
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = build_num_microbatches_calculator(
rank, rampup_batch_size, global_batch_size, micro_batch_size, data_parallel_size)
def get_micro_batch_size(): def get_micro_batch_size():
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.micro_batch_size return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.micro_batch_size
......
...@@ -14,9 +14,11 @@ ...@@ -14,9 +14,11 @@
# limitations under the License. # limitations under the License.
import os import os
import random import random
from typing import Optional, Union, List
import numpy import numpy
import torch import torch
import torch.nn as nn
from apex import transformer from apex import transformer
from apex.transformer.testing import global_vars from apex.transformer.testing import global_vars
...@@ -25,6 +27,40 @@ from apex.transformer.testing import global_vars ...@@ -25,6 +27,40 @@ from apex.transformer.testing import global_vars
TEST_SUCCESS_MESSAGE = ">> passed the test :-)" TEST_SUCCESS_MESSAGE = ">> passed the test :-)"
# note (mkozuki): `pre_process` and `post_process` are a placeholder until interleaving schedule test comes.
class MyLayer(nn.Module):
def __init__(self, hidden_size: int, pre_process: bool, post_process: bool):
super().__init__()
self.pre_process = pre_process
self.post_process = post_process
self.layer = nn.Linear(hidden_size, hidden_size)
def forward(self, x):
return self.layer(x)
class MyModel(nn.Module):
def __init__(self, hidden_size: int, pre_process: bool = False, post_process: bool = False) -> None:
super().__init__()
self.pre_process = pre_process
self.post_process = post_process
self.layer = MyLayer(hidden_size=hidden_size, pre_process=pre_process, post_process=post_process)
self.input_tensor = None
def set_input_tensor(self, input_tensor: Union[torch.Tensor, List[torch.Tensor]]) -> None:
self.input_tensor = input_tensor
def forward(self, x: Optional[torch.Tensor]) -> torch.Tensor:
if self.input_tensor is None:
return self.layer(x)
return self.layer(self.input_tensor)
def model_provider_func(hidden_size, pre_process, post_process) -> MyModel:
return MyModel(hidden_size, pre_process, post_process)
class IdentityLayer(torch.nn.Module): class IdentityLayer(torch.nn.Module):
def __init__(self, size, scale=1.0): def __init__(self, size, scale=1.0):
super(IdentityLayer, self).__init__() super(IdentityLayer, self).__init__()
......
...@@ -37,17 +37,27 @@ def get_args(): ...@@ -37,17 +37,27 @@ def get_args():
return _GLOBAL_ARGS return _GLOBAL_ARGS
def get_num_microbatches(): def get_num_microbatches() -> int:
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get() return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get()
def get_current_global_batch_size(): def get_current_global_batch_size() -> int:
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get_current_global_batch_size() return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get_current_global_batch_size()
def update_num_microbatches(consumed_samples, consistency_check=True): def update_num_microbatches(consumed_samples: int, *, consistency_check: bool = True) -> None:
_GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples, """Update the number of microbatches upon the number of consumed samples.
consistency_check)
.. note::
This function has no effect unless ``rampup_batch_size`` is set.
Args:
consumed_samples: The number of consumed samples so far. Basically this is equal to
:math:`num_iter * global_batch_size`.
consistency_check: If :obj:`True`, sanity checks the consumed samples, i.e., check if
``consumed_samples`` is divisible by :math:`micro_batch_size \times data_parallel_size`.
"""
_GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples, consistency_check)
# def get_tokenizer(): # def get_tokenizer():
......
...@@ -34,14 +34,3 @@ def gather_split_1d_tensor(tensor): ...@@ -34,14 +34,3 @@ def gather_split_1d_tensor(tensor):
chunks = [gathered[i * numel:(i + 1) * numel] for i in range(world_size)] chunks = [gathered[i * numel:(i + 1) * numel] for i in range(world_size)]
torch.distributed.all_gather(chunks, tensor, group=parallel_state.get_tensor_model_parallel_group()) torch.distributed.all_gather(chunks, tensor, group=parallel_state.get_tensor_model_parallel_group())
return gathered return gathered
# TODO(mkozuki): Rewrite this using `logging`.
def rank_print(msg):
"""Print the given msg with rank information"""
print(
f"tensor rank: {parallel_state.get_tensor_model_parallel_rank()}"
f"pipeline rank: {parallel_state.get_pipeline_model_parallel_rank()}, "
f"virtual pipeline rank: {parallel_state.get_virtual_pipeline_model_parallel_rank()}, "
f"data rank: {parallel_state.get_data_parallel_rank()} | {msg}"
)
from typing import Tuple, List
import torch
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.schedules.common import (
_get_params_for_weight_decay_optimization,
)
from apex.transformer.pipeline_parallel.schedules.common import build_model
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import (
_forward_backward_pipelining_with_interleaving,
)
from apex.transformer.pipeline_parallel.utils import average_losses_across_data_parallel_group
from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator
from apex.transformer.pipeline_parallel.utils import _reconfigure_microbatch_calculator
from apex.transformer.pipeline_parallel.utils import update_num_microbatches
from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
from apex.transformer.testing.commons import initialize_distributed
from apex.transformer.testing.commons import print_separator
from apex.transformer.log_util import get_transformer_logger, set_logging_level
from apex.transformer.testing.commons import model_provider_func
from apex.transformer._data import MegatronPretrainingRandomSampler
from apex.transformer._data import MegatronPretrainingSampler
# note(mkozuki): To see warmup, steady, cooldown iterations, uncomment the line below
# set_logging_level("INFO")
_logger = get_transformer_logger("pipeline_parallel_test")
# note(mkozuki): To see if local batch size increases, uncomment the line below
# _logger.setLevel("INFO")
global_vars.set_global_variables(
args_defaults={"global_batch_size": 512, "rampup_batch_size": [32, 32, 1000],},
ignore_unknown_args=True,
)
RAMPUP_BATCH_SIZE = []
NUM_ITERATIONS = 20
NUM_SAMPLES = 16384 // 2
batch_size, micro_batch_size = None, None
HIDDEN_SIZE = 16
def Dataset(num_samples: int) -> List[Tuple[torch.Tensor, torch.Tensor]]:
return [(torch.randn(HIDDEN_SIZE), torch.randn(HIDDEN_SIZE // 2)) for _ in range(num_samples)]
def process_batch(batch):
if isinstance(batch, (list, tuple)):
x = batch[0]
else:
x = batch
return x
def fwd_step_func(micro_batch, model):
x = process_batch(micro_batch)
y = model(x)
# note (mkozuki): I don't think this function is nice but I do think this is enough for now
# just to check the sanity of ported pipeline functions.
def loss_func(x):
loss = torch.sum(x)
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {"avg": averaged_loss}
return y, loss_func
# Run forward & backward with dynamic batch size.
def run_interleaved_with_dynamic_batch_size(
pipeline_model_parallel_size: int, forward_only: bool, BatchSamplerCls,
) -> None:
args = global_vars.get_args()
_reconfigure_microbatch_calculator(
args.rank,
args.rampup_batch_size,
args.global_batch_size,
args.micro_batch_size,
1, # args.data_parallel_size,
)
virtual_pipeline_model_parallel_size = 2
# NOTE (mkozuki): `virtual_pipeline_model_parallel_size` is a requisite for the interleaving scheduling
# In megatron, `args.virtual_pipeline_model_parallel_size` is computed in megatron/arguments.py and
# used ubiquitously but this test uses custom model so it's safe to abuse.
parallel_state.initialize_model_parallel(
1, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size
)
pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
print_separator(f"BatchSamplerCls: {BatchSamplerCls.__name__}, forward_only: {forward_only}")
model = build_model(
model_provider_func,
wrap_with_ddp=True,
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
hidden_size=HIDDEN_SIZE,
)
assert isinstance(model, list)
assert len(model) == virtual_pipeline_model_parallel_size
optimizer = torch.optim.Adam(_get_params_for_weight_decay_optimization(model))
initial_local_minibatch_size = get_num_microbatches() * micro_batch_size
dataset = Dataset(NUM_SAMPLES)
data_loader = torch.utils.data.DataLoader(
dataset,
batch_sampler=BatchSamplerCls(
NUM_SAMPLES,
0,
initial_local_minibatch_size,
parallel_state.get_data_parallel_rank(),
parallel_state.get_data_parallel_world_size(),
),
)
data_iter = iter(data_loader)
def get_num_samples(batch):
if isinstance(batch, torch.Tensor):
return len(batch)
assert isinstance(batch, (list, tuple))
return [get_num_samples(b) for b in batch]
tensor_shape = [micro_batch_size, HIDDEN_SIZE]
consumed_samples = 0
for i in range(NUM_ITERATIONS):
update_num_microbatches(consumed_samples, consistency_check=False)
local_batch_size = get_num_microbatches() * micro_batch_size
data_iter._index_sampler.local_minibatch_size = local_batch_size
local_mini_batch = next(data_iter)
_logger.info(
f"iter: {i} / {NUM_ITERATIONS} "
f"local batchsize: {get_num_samples(local_mini_batch)} "
f"consumed_samples: {consumed_samples} / {NUM_SAMPLES}"
)
_forward_backward_pipelining_with_interleaving(
fwd_step_func,
local_mini_batch,
model,
forward_only=forward_only,
tensor_shape=tensor_shape,
)
consumed_samples += (
parallel_state.get_data_parallel_world_size()
* get_num_microbatches()
* micro_batch_size
)
if not forward_only:
for m in model:
for p in m.parameters():
if p.grad is None:
raise RuntimeError("grad not found")
else:
optimizer.zero_grad(set_to_none=True)
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
if __name__ == "__main__":
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
n_tests = 0
failures = []
initialize_distributed()
world_size = torch.distributed.get_world_size()
args = global_vars.get_args()
batch_size = args.global_batch_size
micro_batch_size = args.micro_batch_size
setup_microbatch_calculator(
args.rank,
args.rampup_batch_size,
args.global_batch_size,
args.micro_batch_size,
1, # args.data_parallel_size,
)
for BatchSamplerCls in (MegatronPretrainingSampler, MegatronPretrainingRandomSampler):
for forward_only in (False, True):
n_tests += 1
pipeline_model_parallel_size = world_size
try:
run_interleaved_with_dynamic_batch_size(
pipeline_model_parallel_size, forward_only, BatchSamplerCls,
)
except Exception as e:
msg = (
f"\tforward_only: {forward_only}\n"
f"pipeline rank: {parallel_state.get_pipeline_model_parallel_rank()}, "
f"virtual pipeline rank: {parallel_state.get_virtual_pipeline_model_parallel_rank()}\n"
f"{str(e)}"
)
raise RuntimeError(msg)
finally:
parallel_state.destroy_model_parallel()
print_separator("TEST RESULT")
if failures:
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print("\n".join(failures))
msg = f"{len(failures)} / {n_tests} cases failed"
raise RuntimeError(msg)
else:
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print("### PASS!")
from functools import partial from functools import partial
import logging
from typing import List from typing import List
import torch import torch
...@@ -18,9 +19,11 @@ from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE ...@@ -18,9 +19,11 @@ from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
from apex.transformer.testing.commons import initialize_distributed from apex.transformer.testing.commons import initialize_distributed
from apex.transformer.testing.commons import print_separator from apex.transformer.testing.commons import print_separator
from apex.transformer.testing.standalone_gpt import gpt_model_provider from apex.transformer.testing.standalone_gpt import gpt_model_provider
from apex.transformer.utils import rank_print from apex.transformer.log_util import get_transformer_logger, set_logging_level
set_logging_level(logging.NOTSET)
_logger = get_transformer_logger("megatron_gpt_pipeline_test")
global_vars.set_global_variables() global_vars.set_global_variables()
N_VOCAB = 8192 N_VOCAB = 8192
...@@ -77,14 +80,14 @@ def run_gpt(pipeline_model_parallel_size, virtual_pipeline_model_parallel_size=N ...@@ -77,14 +80,14 @@ def run_gpt(pipeline_model_parallel_size, virtual_pipeline_model_parallel_size=N
model = build_model( model = build_model(
gpt_model_provider, True, gpt_model_provider, True,
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size) virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size)
# rank_print("building model") _logger.debug("building model")
assert isinstance(model, list) assert isinstance(model, list)
assert len(model) == (1 or virtual_pipeline_model_parallel_size) assert len(model) == (1 or virtual_pipeline_model_parallel_size)
_param_groups = _get_params_for_weight_decay_optimization(model) _param_groups = _get_params_for_weight_decay_optimization(model)
torch.optim.Adam(_param_groups) torch.optim.Adam(_param_groups)
if parallel_state.is_pipeline_last_stage(): if parallel_state.is_pipeline_last_stage():
# rank_print("checking `word_embeddings` existence") _logger.debug("checking `word_embeddings` existence")
for m in model: for m in model:
assert hasattr(m, "word_embeddings") assert hasattr(m, "word_embeddings")
...@@ -93,19 +96,19 @@ def run_gpt(pipeline_model_parallel_size, virtual_pipeline_model_parallel_size=N ...@@ -93,19 +96,19 @@ def run_gpt(pipeline_model_parallel_size, virtual_pipeline_model_parallel_size=N
batch = generate_batch(args.global_batch_size, args.seq_length) batch = generate_batch(args.global_batch_size, args.seq_length)
else: else:
batch = [generate_batch(args.global_batch_size, args.seq_length) for _ in range(virtual_pipeline_model_parallel_size)] batch = [generate_batch(args.global_batch_size, args.seq_length) for _ in range(virtual_pipeline_model_parallel_size)]
# rank_print("preparing batch") _logger.debug("preparing batch")
if virtual_pipeline_model_parallel_size is None: if virtual_pipeline_model_parallel_size is None:
fwd_bwd_func = forward_backward_pipelining_without_interleaving fwd_bwd_func = forward_backward_pipelining_without_interleaving
else: else:
fwd_bwd_func = _forward_backward_pipelining_with_interleaving fwd_bwd_func = _forward_backward_pipelining_with_interleaving
# rank_print(f"selecting forward_backward func: {fwd_bwd_func}") _logger.debug(f"selecting forward_backward func: {fwd_bwd_func}")
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size) tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
# rank_print(f"`tensor_shape`: {tensor_shape}") _logger.debug(f"`tensor_shape`: {tensor_shape}")
fwd_bwd_func(forward_step, batch, model, forward_only=forward_only, tensor_shape=tensor_shape) fwd_bwd_func(forward_step, batch, model, forward_only=forward_only, tensor_shape=tensor_shape)
# rank_print(TEST_SUCCESS_MESSAGE) _logger.debug(TEST_SUCCESS_MESSAGE)
if __name__ == "__main__": if __name__ == "__main__":
...@@ -126,7 +129,7 @@ if __name__ == "__main__": ...@@ -126,7 +129,7 @@ if __name__ == "__main__":
# TODO(mkozuki): handle exception correctly, but for now, lazily commenting out as # TODO(mkozuki): handle exception correctly, but for now, lazily commenting out as
# this won't get kicked by CI # this won't get kicked by CI
except Exception as e: except Exception as e:
# rank_print(str(e)) _logger.debug(str(e))
pass pass
finally: finally:
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
...@@ -3,6 +3,7 @@ from typing import Optional, Union, List ...@@ -3,6 +3,7 @@ from typing import Optional, Union, List
import torch import torch
import torch.nn as nn import torch.nn as nn
import apex
from apex.transformer import parallel_state from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel import get_forward_backward_func from apex.transformer.pipeline_parallel import get_forward_backward_func
from apex.transformer.pipeline_parallel.schedules.common import _get_params_for_weight_decay_optimization from apex.transformer.pipeline_parallel.schedules.common import _get_params_for_weight_decay_optimization
...@@ -17,9 +18,11 @@ from apex.transformer.testing import global_vars ...@@ -17,9 +18,11 @@ from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
from apex.transformer.testing.commons import initialize_distributed from apex.transformer.testing.commons import initialize_distributed
from apex.transformer.testing.commons import print_separator from apex.transformer.testing.commons import print_separator
from apex.transformer.utils import rank_print from apex.transformer.log_util import get_transformer_logger, set_logging_level
# set_logging_level("INFO")
_logger = get_transformer_logger("pipeline_parallel_test")
global_vars.set_global_variables() global_vars.set_global_variables()
...@@ -133,7 +136,6 @@ def forward_backward_func_template( ...@@ -133,7 +136,6 @@ def forward_backward_func_template(
fwd_step_func, batch, model, forward_only=forward_only, tensor_shape=tensor_shape) fwd_step_func, batch, model, forward_only=forward_only, tensor_shape=tensor_shape)
if not forward_only: if not forward_only:
# rank_print("grad check")
for m in model: for m in model:
for p in m.parameters(): for p in m.parameters():
if p.grad is None: if p.grad is None:
......
...@@ -10,6 +10,7 @@ DENY_TEST = [ ...@@ -10,6 +10,7 @@ DENY_TEST = [
] ]
MULTIGPU_TEST = [ MULTIGPU_TEST = [
"pipeline_parallel_test", "pipeline_parallel_test",
"dynamic_batchsize_test",
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment