Commit 3db2063b authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'enable_autocast' into 'main'

Add ability to enable autocast through schedules

See merge request ADLR/megatron-lm!569
parents 2699f93e 98bb40d9
...@@ -90,6 +90,9 @@ def get_forward_backward_func(): ...@@ -90,6 +90,9 @@ def get_forward_backward_func():
collect_non_loss_data: TODO collect_non_loss_data: TODO
enable_autocast (optional, default=False): If True, runs the
forward_step_func call inside torch.autocast context
""" """
pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size() pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
if pipeline_model_parallel_size > 1: if pipeline_model_parallel_size > 1:
...@@ -166,7 +169,8 @@ def forward_step(forward_step_func, ...@@ -166,7 +169,8 @@ def forward_step(forward_step_func,
input_tensor, input_tensor,
forward_data_store, forward_data_store,
timers, timers,
collect_non_loss_data=False): collect_non_loss_data=False,
enable_autocast=False):
"""Forward step for passed-in model. """Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise If first stage, input tensor is obtained from data_iterator, otherwise
...@@ -184,7 +188,7 @@ def forward_step(forward_step_func, ...@@ -184,7 +188,7 @@ def forward_step(forward_step_func,
set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor") set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor")
set_input_tensor(input_tensor) set_input_tensor(input_tensor)
context_manager = torch.autocast("cuda") if torch.is_autocast_enabled() else nullcontext() context_manager = torch.autocast("cuda") if enable_autocast else nullcontext()
with context_manager: with context_manager:
output_tensor, loss_func = forward_step_func(data_iterator, model) output_tensor, loss_func = forward_step_func(data_iterator, model)
...@@ -296,7 +300,8 @@ def forward_backward_no_pipelining(*, ...@@ -296,7 +300,8 @@ def forward_backward_no_pipelining(*,
sequence_parallel: bool = False, # unused sequence_parallel: bool = False, # unused
forward_only: bool = False, forward_only: bool = False,
timers: Callable = None, timers: Callable = None,
collect_non_loss_data: bool = False): collect_non_loss_data: bool = False,
enable_autocast: bool = False):
"""Run forward and backward passes with no pipeline parallelism """Run forward and backward passes with no pipeline parallelism
(no inter-stage communication). (no inter-stage communication).
...@@ -320,7 +325,7 @@ def forward_backward_no_pipelining(*, ...@@ -320,7 +325,7 @@ def forward_backward_no_pipelining(*,
for i in range(num_microbatches - 1): for i in range(num_microbatches - 1):
output_tensor = forward_step(forward_step_func, data_iterator, output_tensor = forward_step(forward_step_func, data_iterator,
model, num_microbatches, input_tensor, forward_data_store, model, num_microbatches, input_tensor, forward_data_store,
timers, collect_non_loss_data) timers, collect_non_loss_data, enable_autocast)
if not forward_only: if not forward_only:
backward_step(grad_scaler, input_tensor, output_tensor, backward_step(grad_scaler, input_tensor, output_tensor,
output_tensor_grad, model_type, timers) output_tensor_grad, model_type, timers)
...@@ -329,7 +334,7 @@ def forward_backward_no_pipelining(*, ...@@ -329,7 +334,7 @@ def forward_backward_no_pipelining(*,
# synchronize gradients). # synchronize gradients).
output_tensor = forward_step(forward_step_func, data_iterator, output_tensor = forward_step(forward_step_func, data_iterator,
model, num_microbatches, input_tensor, forward_data_store, model, num_microbatches, input_tensor, forward_data_store,
timers, collect_non_loss_data) timers, collect_non_loss_data, enable_autocast)
if not forward_only: if not forward_only:
backward_step(grad_scaler, input_tensor, output_tensor, backward_step(grad_scaler, input_tensor, output_tensor,
...@@ -350,7 +355,8 @@ def forward_backward_pipelining_with_interleaving(*, ...@@ -350,7 +355,8 @@ def forward_backward_pipelining_with_interleaving(*,
sequence_parallel: bool = False, sequence_parallel: bool = False,
forward_only: bool = False, forward_only: bool = False,
timers: Callable = None, timers: Callable = None,
collect_non_loss_data: bool = False): collect_non_loss_data: bool = False,
enable_autocast: bool = False):
"""Run interleaved 1F1B schedule (model split into model chunks), with """Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed. communication between pipeline stages as needed.
...@@ -440,7 +446,8 @@ def forward_backward_pipelining_with_interleaving(*, ...@@ -440,7 +446,8 @@ def forward_backward_pipelining_with_interleaving(*,
input_tensor, input_tensor,
forward_data_store, forward_data_store,
timers, timers,
collect_non_loss_data) collect_non_loss_data,
enable_autocast)
output_tensors[model_chunk_id].append(output_tensor) output_tensors[model_chunk_id].append(output_tensor)
# if forward-only, no need to save tensors for a backward pass # if forward-only, no need to save tensors for a backward pass
...@@ -731,7 +738,8 @@ def forward_backward_pipelining_without_interleaving(*, ...@@ -731,7 +738,8 @@ def forward_backward_pipelining_without_interleaving(*,
sequence_parallel: bool = False, sequence_parallel: bool = False,
forward_only: bool = False, forward_only: bool = False,
timers: Callable = None, timers: Callable = None,
collect_non_loss_data: bool = False): collect_non_loss_data: bool = False,
enable_autocast: bool = False):
"""Run non-interleaved 1F1B schedule, with communication between pipeline """Run non-interleaved 1F1B schedule, with communication between pipeline
stages. stages.
...@@ -775,7 +783,9 @@ def forward_backward_pipelining_without_interleaving(*, ...@@ -775,7 +783,9 @@ def forward_backward_pipelining_without_interleaving(*,
# Run warmup forward passes. # Run warmup forward passes.
for i in range(num_warmup_microbatches): for i in range(num_warmup_microbatches):
input_tensor = recv_forward(recv_tensor_shapes, dtype, timers=timers) input_tensor = recv_forward(recv_tensor_shapes, dtype, timers=timers)
output_tensor = forward_step(forward_step_func, data_iterator, model, num_microbatches,input_tensor, forward_data_store,timers, collect_non_loss_data) output_tensor = forward_step(forward_step_func, data_iterator, model, num_microbatches,
input_tensor, forward_data_store,
timers, collect_non_loss_data, enable_autocast)
send_forward(output_tensor, send_tensor_shapes, timers=timers) send_forward(output_tensor, send_tensor_shapes, timers=timers)
if not forward_only: if not forward_only:
...@@ -795,7 +805,7 @@ def forward_backward_pipelining_without_interleaving(*, ...@@ -795,7 +805,7 @@ def forward_backward_pipelining_without_interleaving(*,
output_tensor = forward_step(forward_step_func, data_iterator, model, num_microbatches, output_tensor = forward_step(forward_step_func, data_iterator, model, num_microbatches,
input_tensor, forward_data_store, input_tensor, forward_data_store,
timers, collect_non_loss_data) timers, collect_non_loss_data, enable_autocast)
if forward_only: if forward_only:
send_forward(output_tensor, send_tensor_shapes, timers=timers) send_forward(output_tensor, send_tensor_shapes, timers=timers)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment