"src/graph/sampling/vscode:/vscode.git/clone" did not exist on "bcd37684268a919f25aa5b9eb88f4e59aca1e7b4"
Commit aa9cae27 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Small notes in comments in response to Jared's comments

parent dd079406
...@@ -205,6 +205,10 @@ class DynamicLossScaler: ...@@ -205,6 +205,10 @@ class DynamicLossScaler:
return grad_in return grad_in
def backward(self, output_tensor, retain_graph=False, output_tensor_grad=None): def backward(self, output_tensor, retain_graph=False, output_tensor_grad=None):
# If output_tensor_grad is None, this is the last stage, and
# output_tensor is actually the loss and needs to be scaled.
# Otherwise, output_tensor does not need to be scaled again since
# output_tensor_grad is already scaled.
if output_tensor_grad is None: if output_tensor_grad is None:
scaled_output_tensor = output_tensor * self.loss_scale scaled_output_tensor = output_tensor * self.loss_scale
else: else:
......
...@@ -361,14 +361,17 @@ def train_step(forward_step_func, data_iterator, ...@@ -361,14 +361,17 @@ def train_step(forward_step_func, data_iterator,
# Compute number of microbatches in a minibatch. # Compute number of microbatches in a minibatch.
num_microbatches_in_minibatch = args.num_microbatches_in_minibatch num_microbatches_in_minibatch = args.num_microbatches_in_minibatch
# TODO: Switch to the following schedule when async communication is supported # TODO: Switch to the following schedule to facilitate more
# so that we can facilitate mroe memory-efficient training. # memory-efficient training.
# num_warmup_microbatches = \ # num_warmup_microbatches = \
# (torch.distributed.get_world_size(group=mpu.get_pipeline_model_parallel_group()) - # (torch.distributed.get_world_size(group=mpu.get_pipeline_model_parallel_group()) -
# torch.distributed.get_rank(group=mpu.get_pipeline_model_parallel_group()) - 1) # torch.distributed.get_rank(group=mpu.get_pipeline_model_parallel_group()) - 1)
# num_warmup_microbatches = min( # num_warmup_microbatches = min(
# num_warmup_microbatches, # num_warmup_microbatches,
# num_microbatches_in_minibatch) # num_microbatches_in_minibatch)
# For now, perform training without warmup. Perform forward
# passes for all microbatches, then backward passes for all
# microbatches.
num_warmup_microbatches = num_microbatches_in_minibatch num_warmup_microbatches = num_microbatches_in_minibatch
input_tensors = [] input_tensors = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment