Unverified Commit a960fe8c authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

allow for `None` batch (#1280)

* have get_kth_microbatch deal with None batch

* broadcast based on tensor parallel rank

* dtype

* remove unnecessary .cuda()

Processes of tensor parallel rank != 0 doesn't need to prepare one or more `torch.utils.data.DataLoader` instances, which means the argument of `batch` of `get_kth_microbatch` function can be `None` but the current function implementation doesn't allow for it.
parent 2a4ab177
......@@ -18,7 +18,7 @@ _logger = get_transformer_logger(__name__)
Batch = Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]
LossFunc = Callable[[torch.Tensor], torch.Tensor]
FwdStepFunc = Callable[[Batch, torch.nn.Module], Tuple[torch.Tensor, LossFunc]]
FwdStepFunc = Callable[[Optional[Batch], torch.nn.Module], Tuple[torch.Tensor, LossFunc]]
def build_model(
......@@ -147,7 +147,7 @@ def _get_params_for_weight_decay_optimization(
def forward_step(
forward_step_func: FwdStepFunc,
batch: Batch,
batch: Optional[Batch],
model: torch.nn.Module,
input_tensor: Optional[Union[torch.Tensor, List[torch.Tensor]]],
losses_reduced: List[torch.Tensor],
......
......@@ -23,7 +23,7 @@ _logger = get_transformer_logger(__name__)
# TODO(mkozuki): Reduce cyclomatic complexity
def _forward_backward_pipelining_with_interleaving(
forward_step_func: FwdStepFunc,
batch: List[Batch],
batch: List[Optional[Batch]],
model: List[torch.nn.Module],
*,
forward_only: bool,
......
......@@ -153,7 +153,7 @@ def send_backward_recv_forward(
def forward_backward_pipelining_without_interleaving(
forward_step_func: FwdStepFunc,
batch: Batch,
batch: Optional[Batch],
model: Union[torch.nn.Module, List[torch.nn.Module]],
*,
forward_only: bool,
......@@ -230,7 +230,7 @@ def forward_backward_pipelining_without_interleaving(
_logger.debug(f"warmup iter: {i} / {num_warmup_microbatches}")
_logger.debug("receive fwd")
input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype)
cur_microbatch = get_kth_microbatch(batch, i)
cur_microbatch: Optional[torch.Tensor] = get_kth_microbatch(batch, i)
output_tensor = forward_step(
forward_step_func,
cur_microbatch,
......@@ -262,7 +262,7 @@ def forward_backward_pipelining_without_interleaving(
_logger.debug(f"steady iter: {i} / {num_microbatches_remaining}")
last_iteration: bool = i == (num_microbatches_remaining - 1)
cur_microbatch: torch.Tensor = get_kth_microbatch(batch, i + num_warmup_microbatches)
cur_microbatch: Optional[torch.Tensor] = get_kth_microbatch(batch, i + num_warmup_microbatches)
output_tensor: Union[torch.Tensor, Sequence[torch.Tensor]] = forward_step(
forward_step_func,
cur_microbatch,
......
......@@ -119,12 +119,14 @@ def _split_batch_into_microbatch(
# TODO(mkozuki): Support non-tensor local minibatches?
def get_kth_microbatch(batch: List[torch.Tensor], k: int) -> List[torch.Tensor]:
def get_kth_microbatch(batch: Optional[List[torch.Tensor]], k: int) -> List[torch.Tensor]:
"""Create a list of microbatches from a list of local minibatches.
This function creates a list of `k`th microbatches from a list of local minibatches.
`a local minibatch` consists of `global_batch_size / data_parallel_size` samples.
"""
if batch is None:
return batch
micro_batch_size = get_micro_batch_size()
start = k * micro_batch_size
end = start + micro_batch_size
......
......@@ -10,7 +10,7 @@ from apex.transformer.pipeline_parallel.schedules import get_forward_backward_fu
from apex.transformer.pipeline_parallel.schedules.common import build_model
from apex.transformer.pipeline_parallel.schedules.common import _get_params_for_weight_decay_optimization
from apex.transformer.testing.standalone_bert import bert_model_provider
from apex.transformer.testing.standalone_bert import bert_model_provider
from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
from apex.transformer.testing.commons import initialize_distributed
......@@ -49,7 +49,7 @@ def generate_fancy_data_labels(sequence_len, batch_size):
global inds
global masks
global MANUAL_SEED
temps = list()
temps = []
for i in range(batch_size):
if inds is None or data_idx >= len(inds):
# hack as use of RNG will fall out of sync due to pipelines being different
......@@ -67,23 +67,27 @@ def generate_fancy_data_labels(sequence_len, batch_size):
data_idx_ = data_idx
offset = inds[data_idx_] #* SEQUENCE_LEN
data_idx += 1
curr = fancy_data[offset:offset+sequence_len].clone().detach()
temps.append(curr)
temp = torch.stack(temps, dim=0).cuda()
mask = masks[data_idx//batch_size]
mask_not = torch.logical_not(mask)
mask_not = torch.logical_not(mask).long()
data = mask * temp + mask_not*124
label = temp
return (data, label, mask_not)
if parallel_state.get_tensor_model_parallel_rank() == 0:
data_dict = {"text": data, "label": label, "mask_not": mask_not}
else:
data_dict = None
keys = ["text", "label", "mask_not"]
dtype = torch.int64
broadcasted_data = tensor_parallel.broadcast_data(keys, data_dict, torch.long)
return (broadcasted_data["text"].long(), broadcasted_data["label"].long(), broadcasted_data["mask_not"])
easy_data = None
def fwd_step_func(batch, model):
data, label, loss_mask = batch
data = data.cuda()
label = label.cuda()
loss_mask = loss_mask.cuda()
y = model(data, torch.ones_like(data), lm_labels=label)
def loss_func(output_tensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment