"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "cee3aa0dd40eaab8e84ab947a5c896efc150428b"
Unverified Commit a960fe8c authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

allow for `None` batch (#1280)

* have get_kth_microbatch deal with None batch

* broadcast based on tensor parallel rank

* dtype

* remove unnecessary .cuda()

Processes of tensor parallel rank != 0 doesn't need to prepare one or more `torch.utils.data.DataLoader` instances, which means the argument of `batch` of `get_kth_microbatch` function can be `None` but the current function implementation doesn't allow for it.
parent 2a4ab177
...@@ -18,7 +18,7 @@ _logger = get_transformer_logger(__name__) ...@@ -18,7 +18,7 @@ _logger = get_transformer_logger(__name__)
Batch = Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]] Batch = Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]
LossFunc = Callable[[torch.Tensor], torch.Tensor] LossFunc = Callable[[torch.Tensor], torch.Tensor]
FwdStepFunc = Callable[[Batch, torch.nn.Module], Tuple[torch.Tensor, LossFunc]] FwdStepFunc = Callable[[Optional[Batch], torch.nn.Module], Tuple[torch.Tensor, LossFunc]]
def build_model( def build_model(
...@@ -147,7 +147,7 @@ def _get_params_for_weight_decay_optimization( ...@@ -147,7 +147,7 @@ def _get_params_for_weight_decay_optimization(
def forward_step( def forward_step(
forward_step_func: FwdStepFunc, forward_step_func: FwdStepFunc,
batch: Batch, batch: Optional[Batch],
model: torch.nn.Module, model: torch.nn.Module,
input_tensor: Optional[Union[torch.Tensor, List[torch.Tensor]]], input_tensor: Optional[Union[torch.Tensor, List[torch.Tensor]]],
losses_reduced: List[torch.Tensor], losses_reduced: List[torch.Tensor],
......
...@@ -23,7 +23,7 @@ _logger = get_transformer_logger(__name__) ...@@ -23,7 +23,7 @@ _logger = get_transformer_logger(__name__)
# TODO(mkozuki): Reduce cyclomatic complexity # TODO(mkozuki): Reduce cyclomatic complexity
def _forward_backward_pipelining_with_interleaving( def _forward_backward_pipelining_with_interleaving(
forward_step_func: FwdStepFunc, forward_step_func: FwdStepFunc,
batch: List[Batch], batch: List[Optional[Batch]],
model: List[torch.nn.Module], model: List[torch.nn.Module],
*, *,
forward_only: bool, forward_only: bool,
......
...@@ -153,7 +153,7 @@ def send_backward_recv_forward( ...@@ -153,7 +153,7 @@ def send_backward_recv_forward(
def forward_backward_pipelining_without_interleaving( def forward_backward_pipelining_without_interleaving(
forward_step_func: FwdStepFunc, forward_step_func: FwdStepFunc,
batch: Batch, batch: Optional[Batch],
model: Union[torch.nn.Module, List[torch.nn.Module]], model: Union[torch.nn.Module, List[torch.nn.Module]],
*, *,
forward_only: bool, forward_only: bool,
...@@ -230,7 +230,7 @@ def forward_backward_pipelining_without_interleaving( ...@@ -230,7 +230,7 @@ def forward_backward_pipelining_without_interleaving(
_logger.debug(f"warmup iter: {i} / {num_warmup_microbatches}") _logger.debug(f"warmup iter: {i} / {num_warmup_microbatches}")
_logger.debug("receive fwd") _logger.debug("receive fwd")
input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype) input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype)
cur_microbatch = get_kth_microbatch(batch, i) cur_microbatch: Optional[torch.Tensor] = get_kth_microbatch(batch, i)
output_tensor = forward_step( output_tensor = forward_step(
forward_step_func, forward_step_func,
cur_microbatch, cur_microbatch,
...@@ -262,7 +262,7 @@ def forward_backward_pipelining_without_interleaving( ...@@ -262,7 +262,7 @@ def forward_backward_pipelining_without_interleaving(
_logger.debug(f"steady iter: {i} / {num_microbatches_remaining}") _logger.debug(f"steady iter: {i} / {num_microbatches_remaining}")
last_iteration: bool = i == (num_microbatches_remaining - 1) last_iteration: bool = i == (num_microbatches_remaining - 1)
cur_microbatch: torch.Tensor = get_kth_microbatch(batch, i + num_warmup_microbatches) cur_microbatch: Optional[torch.Tensor] = get_kth_microbatch(batch, i + num_warmup_microbatches)
output_tensor: Union[torch.Tensor, Sequence[torch.Tensor]] = forward_step( output_tensor: Union[torch.Tensor, Sequence[torch.Tensor]] = forward_step(
forward_step_func, forward_step_func,
cur_microbatch, cur_microbatch,
......
...@@ -119,12 +119,14 @@ def _split_batch_into_microbatch( ...@@ -119,12 +119,14 @@ def _split_batch_into_microbatch(
# TODO(mkozuki): Support non-tensor local minibatches? # TODO(mkozuki): Support non-tensor local minibatches?
def get_kth_microbatch(batch: List[torch.Tensor], k: int) -> List[torch.Tensor]: def get_kth_microbatch(batch: Optional[List[torch.Tensor]], k: int) -> List[torch.Tensor]:
"""Create a list of microbatches from a list of local minibatches. """Create a list of microbatches from a list of local minibatches.
This function creates a list of `k`th microbatches from a list of local minibatches. This function creates a list of `k`th microbatches from a list of local minibatches.
`a local minibatch` consists of `global_batch_size / data_parallel_size` samples. `a local minibatch` consists of `global_batch_size / data_parallel_size` samples.
""" """
if batch is None:
return batch
micro_batch_size = get_micro_batch_size() micro_batch_size = get_micro_batch_size()
start = k * micro_batch_size start = k * micro_batch_size
end = start + micro_batch_size end = start + micro_batch_size
......
...@@ -10,7 +10,7 @@ from apex.transformer.pipeline_parallel.schedules import get_forward_backward_fu ...@@ -10,7 +10,7 @@ from apex.transformer.pipeline_parallel.schedules import get_forward_backward_fu
from apex.transformer.pipeline_parallel.schedules.common import build_model from apex.transformer.pipeline_parallel.schedules.common import build_model
from apex.transformer.pipeline_parallel.schedules.common import _get_params_for_weight_decay_optimization from apex.transformer.pipeline_parallel.schedules.common import _get_params_for_weight_decay_optimization
from apex.transformer.testing.standalone_bert import bert_model_provider from apex.transformer.testing.standalone_bert import bert_model_provider
from apex.transformer.testing import global_vars from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
from apex.transformer.testing.commons import initialize_distributed from apex.transformer.testing.commons import initialize_distributed
...@@ -49,7 +49,7 @@ def generate_fancy_data_labels(sequence_len, batch_size): ...@@ -49,7 +49,7 @@ def generate_fancy_data_labels(sequence_len, batch_size):
global inds global inds
global masks global masks
global MANUAL_SEED global MANUAL_SEED
temps = list() temps = []
for i in range(batch_size): for i in range(batch_size):
if inds is None or data_idx >= len(inds): if inds is None or data_idx >= len(inds):
# hack as use of RNG will fall out of sync due to pipelines being different # hack as use of RNG will fall out of sync due to pipelines being different
...@@ -67,23 +67,27 @@ def generate_fancy_data_labels(sequence_len, batch_size): ...@@ -67,23 +67,27 @@ def generate_fancy_data_labels(sequence_len, batch_size):
data_idx_ = data_idx data_idx_ = data_idx
offset = inds[data_idx_] #* SEQUENCE_LEN offset = inds[data_idx_] #* SEQUENCE_LEN
data_idx += 1 data_idx += 1
curr = fancy_data[offset:offset+sequence_len].clone().detach() curr = fancy_data[offset:offset+sequence_len].clone().detach()
temps.append(curr) temps.append(curr)
temp = torch.stack(temps, dim=0).cuda() temp = torch.stack(temps, dim=0).cuda()
mask = masks[data_idx//batch_size] mask = masks[data_idx//batch_size]
mask_not = torch.logical_not(mask) mask_not = torch.logical_not(mask).long()
data = mask * temp + mask_not*124 data = mask * temp + mask_not*124
label = temp label = temp
return (data, label, mask_not) if parallel_state.get_tensor_model_parallel_rank() == 0:
data_dict = {"text": data, "label": label, "mask_not": mask_not}
else:
data_dict = None
keys = ["text", "label", "mask_not"]
dtype = torch.int64
broadcasted_data = tensor_parallel.broadcast_data(keys, data_dict, torch.long)
return (broadcasted_data["text"].long(), broadcasted_data["label"].long(), broadcasted_data["mask_not"])
easy_data = None easy_data = None
def fwd_step_func(batch, model): def fwd_step_func(batch, model):
data, label, loss_mask = batch data, label, loss_mask = batch
data = data.cuda()
label = label.cuda()
loss_mask = loss_mask.cuda()
y = model(data, torch.ones_like(data), lm_labels=label) y = model(data, torch.ones_like(data), lm_labels=label)
def loss_func(output_tensor): def loss_func(output_tensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment