Commit f17a3933 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

Merge branch 'main' into lmcafee/embed-standalone

parents 804ed2e6 fd5469aa
...@@ -142,10 +142,16 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -142,10 +142,16 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
if recv_prev: if recv_prev:
tensor_recv_prev = mpu.gather_split_1d_tensor( tensor_recv_prev = mpu.gather_split_1d_tensor(
tensor_recv_prev).view(tensor_shape).requires_grad_() tensor_recv_prev).view(tensor_shape).requires_grad_()
tensor_recv_prev = mpu.make_viewless_tensor(tensor_recv_prev,
requires_grad = True,
keep_graph = False)
if recv_next: if recv_next:
tensor_recv_next = mpu.gather_split_1d_tensor( tensor_recv_next = mpu.gather_split_1d_tensor(
tensor_recv_next).view(tensor_shape).requires_grad_() tensor_recv_next).view(tensor_shape).requires_grad_()
tensor_recv_next = mpu.make_viewless_tensor(tensor_recv_next,
requires_grad = True,
keep_graph = False)
return tensor_recv_prev, tensor_recv_next return tensor_recv_prev, tensor_recv_next
......
...@@ -28,6 +28,7 @@ from megatron.model import DistributedDataParallel as LocalDDP ...@@ -28,6 +28,7 @@ from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module from megatron.model import Float16Module
from megatron.model import ModelType from megatron.model import ModelType
def get_forward_backward_func(): def get_forward_backward_func():
args = get_args() args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1: if mpu.get_pipeline_model_parallel_world_size() > 1:
...@@ -66,24 +67,29 @@ def get_forward_backward_func(): ...@@ -66,24 +67,29 @@ def get_forward_backward_func():
forward_backward_func = forward_backward_no_pipelining forward_backward_func = forward_backward_no_pipelining
return forward_backward_func return forward_backward_func
def free_output_tensor(output_tensors, deallocate_pipeline_outputs): def deallocate_output_tensor(out):
'''Pseudo-free (i.e., set to scalar) the output tensor's '.data' field. '''Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
This method should be called right after the output tensor has been This method should be called right after the output tensor has been
sent to the next pipeline stage. At this point, the output tensor is sent to the next pipeline stage. At this point, the output tensor is
only useful for its '.grad_fn' field, and not its '.data'. only useful for its '.grad_fn' field, and not its '.data'.
''' '''
if not deallocate_pipeline_outputs or output_tensors is None: if out is None:
return return
if isinstance(output_tensors, torch.Tensor): assert isinstance(out, torch.Tensor), \
output_tensors = [output_tensors] "expected Tensor, found %s." % type(out).__name__
for output_tensor in output_tensors: assert out._base is None, \
output_tensor.data = torch.cuda.FloatTensor([0]) "counter-productive to free a view of another tensor."
out.data = torch.empty(
(1,),
device = out.device,
dtype = out.dtype,
)
def custom_backward(output, grad_output): def custom_backward(output, grad_output):
'''Directly call C++ autograd engine. '''Directly call C++ autograd engine.
To make the 'free_output_tensor' (above) optimization work, the C++ To make the 'deallocate_output_tensor' (above) optimization work, the C++
autograd engine must be called directly, bypassing Pytorch's autograd engine must be called directly, bypassing Pytorch's
torch.autograd.backward. Pytorch's 'backward' checks that the output and torch.autograd.backward. Pytorch's 'backward' checks that the output and
grad have the same shape, while C++'s 'backward' does not. grad have the same shape, while C++'s 'backward' does not.
...@@ -114,6 +120,7 @@ def custom_backward(output, grad_output): ...@@ -114,6 +120,7 @@ def custom_backward(output, grad_output):
allow_unreachable=True, allow_unreachable=True,
accumulate_grad=True, accumulate_grad=True,
) )
def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced): def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced):
"""Forward step for passed-in model. """Forward step for passed-in model.
...@@ -188,11 +195,7 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad): ...@@ -188,11 +195,7 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
# Backward pass. # Backward pass.
if output_tensor_grad[0] is None: if output_tensor_grad[0] is None:
output_tensor = optimizer.scale_loss(output_tensor[0]) output_tensor = optimizer.scale_loss(output_tensor[0])
if args.deallocate_pipeline_outputs: custom_backward(output_tensor[0], output_tensor_grad[0])
custom_backward(output_tensor[0], output_tensor_grad[0])
else:
torch.autograd.backward(output_tensor[0],
grad_tensors=output_tensor_grad[0])
# Collect the grad of the input_tensor. # Collect the grad of the input_tensor.
input_tensor_grad = [None] input_tensor_grad = [None]
...@@ -400,8 +403,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -400,8 +403,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
output_tensor, recv_prev=recv_prev, output_tensor, recv_prev=recv_prev,
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
timers=timers) timers=timers)
free_output_tensor(output_tensor, args.deallocate_pipeline_outputs)
input_tensors[next_forward_model_chunk_id].append(input_tensor) input_tensors[next_forward_model_chunk_id].append(input_tensor)
deallocate_output_tensor(output_tensor)
# Run 1F1B in steady state. # Run 1F1B in steady state.
for k in range(num_microbatches_remaining): for k in range(num_microbatches_remaining):
...@@ -465,7 +468,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -465,7 +468,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
output_tensor, input_tensor_grad, output_tensor, input_tensor_grad,
recv_prev=recv_prev, recv_next=recv_next, recv_prev=recv_prev, recv_next=recv_next,
tensor_shape=tensor_shape, timers=timers) tensor_shape=tensor_shape, timers=timers)
free_output_tensor(output_tensor, args.deallocate_pipeline_outputs) deallocate_output_tensor(output_tensor)
# Put input_tensor and output_tensor_grad in data structures in the # Put input_tensor and output_tensor_grad in data structures in the
# right location. # right location.
...@@ -641,7 +644,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -641,7 +644,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
if not forward_only: if not forward_only:
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
output_tensors.append(output_tensor) output_tensors.append(output_tensor)
free_output_tensor(output_tensor, args.deallocate_pipeline_outputs) deallocate_output_tensor(output_tensor[0])
# Before running 1F1B, need to receive first forward tensor. # Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to # If all microbatches are run in warmup / cooldown phase, then no need to
...@@ -670,7 +673,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -670,7 +673,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
# Add input_tensor and output_tensor to end of list. # Add input_tensor and output_tensor to end of list.
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
output_tensors.append(output_tensor) output_tensors.append(output_tensor)
free_output_tensor(output_tensor, args.deallocate_pipeline_outputs) deallocate_output_tensor(output_tensor[0])
# Pop input_tensor and output_tensor from the start of the list for # Pop input_tensor and output_tensor from the start of the list for
# the backward pass. # the backward pass.
......
...@@ -308,7 +308,10 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap ...@@ -308,7 +308,10 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
args.accumulate_allreduce_grads_in_fp32, args.accumulate_allreduce_grads_in_fp32,
args.use_contiguous_buffers_in_local_ddp) args.use_contiguous_buffers_in_local_ddp)
for model_module in model] for model_module in model]
# broad cast params from data parallel src rank to other data parallel ranks
if args.data_parallel_random_init:
for model_module in model:
model_module.broadcast_params()
else: else:
raise NotImplementedError('Unknown DDP implementation specified: ' raise NotImplementedError('Unknown DDP implementation specified: '
'{}. Exiting.'.format(args.DDP_impl)) '{}. Exiting.'.format(args.DDP_impl))
......
...@@ -21,7 +21,7 @@ from functools import partial ...@@ -21,7 +21,7 @@ from functools import partial
from megatron import get_args, get_timers, mpu, print_rank_0 from megatron import get_args, get_timers, mpu, print_rank_0
from megatron.data.vit_dataset import build_train_valid_datasets from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model import ModelType from megatron.model import ModelType
from megatron.model.vit_model import VitModel from megatron.model.vision.classification import VitClassificationModel
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
...@@ -31,9 +31,9 @@ def model_provider(pre_process=True, post_process=True): ...@@ -31,9 +31,9 @@ def model_provider(pre_process=True, post_process=True):
print_rank_0("building VIT model ...") print_rank_0("building VIT model ...")
args = get_args() args = get_args()
model = VitModel(num_classes=args.num_classes, model = VitClassificationModel(num_classes=args.num_classes,
pre_process=pre_process, pre_process=pre_process,
post_process=post_process) post_process=post_process)
return model return model
def get_batch(data_iterator): def get_batch(data_iterator):
...@@ -82,7 +82,10 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -82,7 +82,10 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
print_rank_0( print_rank_0(
"> building train, validation, and test datasets " "for VIT ..." "> building train, validation, and test datasets " "for VIT ..."
) )
train_ds, valid_ds = build_train_valid_datasets(data_path=args.data_path) train_ds, valid_ds = build_train_valid_datasets(
data_path=args.data_path,
image_size=(args.img_h, args.img_w)
)
print_rank_0("> finished creating VIT datasets ...") print_rank_0("> finished creating VIT datasets ...")
return train_ds, valid_ds, None return train_ds, valid_ds, None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment