Commit 513d7d8e authored by Abhinav Khattar's avatar Abhinav Khattar
Browse files

add enable autocast


Signed-off-by: default avatarAbhinav Khattar <aklife97@gmail.com>
parent 2699f93e
...@@ -90,6 +90,9 @@ def get_forward_backward_func(): ...@@ -90,6 +90,9 @@ def get_forward_backward_func():
collect_non_loss_data: TODO collect_non_loss_data: TODO
enable_autocat (optional, default=False): If True, runs the
forward_step_func call inside torch.autocast context
""" """
pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size() pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
if pipeline_model_parallel_size > 1: if pipeline_model_parallel_size > 1:
...@@ -166,7 +169,8 @@ def forward_step(forward_step_func, ...@@ -166,7 +169,8 @@ def forward_step(forward_step_func,
input_tensor, input_tensor,
forward_data_store, forward_data_store,
timers, timers,
collect_non_loss_data=False): collect_non_loss_data=False,
enable_autocast=False):
"""Forward step for passed-in model. """Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise If first stage, input tensor is obtained from data_iterator, otherwise
...@@ -184,7 +188,7 @@ def forward_step(forward_step_func, ...@@ -184,7 +188,7 @@ def forward_step(forward_step_func,
set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor") set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor")
set_input_tensor(input_tensor) set_input_tensor(input_tensor)
context_manager = torch.autocast("cuda") if torch.is_autocast_enabled() else nullcontext() context_manager = torch.autocast("cuda") if enable_autocast else nullcontext()
with context_manager: with context_manager:
output_tensor, loss_func = forward_step_func(data_iterator, model) output_tensor, loss_func = forward_step_func(data_iterator, model)
...@@ -296,7 +300,8 @@ def forward_backward_no_pipelining(*, ...@@ -296,7 +300,8 @@ def forward_backward_no_pipelining(*,
sequence_parallel: bool = False, # unused sequence_parallel: bool = False, # unused
forward_only: bool = False, forward_only: bool = False,
timers: Callable = None, timers: Callable = None,
collect_non_loss_data: bool = False): collect_non_loss_data: bool = False,
enable_autocast: bool = False):
"""Run forward and backward passes with no pipeline parallelism """Run forward and backward passes with no pipeline parallelism
(no inter-stage communication). (no inter-stage communication).
...@@ -320,7 +325,7 @@ def forward_backward_no_pipelining(*, ...@@ -320,7 +325,7 @@ def forward_backward_no_pipelining(*,
for i in range(num_microbatches - 1): for i in range(num_microbatches - 1):
output_tensor = forward_step(forward_step_func, data_iterator, output_tensor = forward_step(forward_step_func, data_iterator,
model, num_microbatches, input_tensor, forward_data_store, model, num_microbatches, input_tensor, forward_data_store,
timers, collect_non_loss_data) timers, collect_non_loss_data, enable_autocast)
if not forward_only: if not forward_only:
backward_step(grad_scaler, input_tensor, output_tensor, backward_step(grad_scaler, input_tensor, output_tensor,
output_tensor_grad, model_type, timers) output_tensor_grad, model_type, timers)
...@@ -329,7 +334,7 @@ def forward_backward_no_pipelining(*, ...@@ -329,7 +334,7 @@ def forward_backward_no_pipelining(*,
# synchronize gradients). # synchronize gradients).
output_tensor = forward_step(forward_step_func, data_iterator, output_tensor = forward_step(forward_step_func, data_iterator,
model, num_microbatches, input_tensor, forward_data_store, model, num_microbatches, input_tensor, forward_data_store,
timers, collect_non_loss_data) timers, collect_non_loss_data, enable_autocast)
if not forward_only: if not forward_only:
backward_step(grad_scaler, input_tensor, output_tensor, backward_step(grad_scaler, input_tensor, output_tensor,
...@@ -350,7 +355,8 @@ def forward_backward_pipelining_with_interleaving(*, ...@@ -350,7 +355,8 @@ def forward_backward_pipelining_with_interleaving(*,
sequence_parallel: bool = False, sequence_parallel: bool = False,
forward_only: bool = False, forward_only: bool = False,
timers: Callable = None, timers: Callable = None,
collect_non_loss_data: bool = False): collect_non_loss_data: bool = False,
enable_autocast: bool = False):
"""Run interleaved 1F1B schedule (model split into model chunks), with """Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed. communication between pipeline stages as needed.
...@@ -440,7 +446,8 @@ def forward_backward_pipelining_with_interleaving(*, ...@@ -440,7 +446,8 @@ def forward_backward_pipelining_with_interleaving(*,
input_tensor, input_tensor,
forward_data_store, forward_data_store,
timers, timers,
collect_non_loss_data) collect_non_loss_data,
enable_autocast)
output_tensors[model_chunk_id].append(output_tensor) output_tensors[model_chunk_id].append(output_tensor)
# if forward-only, no need to save tensors for a backward pass # if forward-only, no need to save tensors for a backward pass
...@@ -731,7 +738,8 @@ def forward_backward_pipelining_without_interleaving(*, ...@@ -731,7 +738,8 @@ def forward_backward_pipelining_without_interleaving(*,
sequence_parallel: bool = False, sequence_parallel: bool = False,
forward_only: bool = False, forward_only: bool = False,
timers: Callable = None, timers: Callable = None,
collect_non_loss_data: bool = False): collect_non_loss_data: bool = False,
enable_autocast: bool = False):
"""Run non-interleaved 1F1B schedule, with communication between pipeline """Run non-interleaved 1F1B schedule, with communication between pipeline
stages. stages.
...@@ -775,7 +783,9 @@ def forward_backward_pipelining_without_interleaving(*, ...@@ -775,7 +783,9 @@ def forward_backward_pipelining_without_interleaving(*,
# Run warmup forward passes. # Run warmup forward passes.
for i in range(num_warmup_microbatches): for i in range(num_warmup_microbatches):
input_tensor = recv_forward(recv_tensor_shapes, dtype, timers=timers) input_tensor = recv_forward(recv_tensor_shapes, dtype, timers=timers)
output_tensor = forward_step(forward_step_func, data_iterator, model, num_microbatches,input_tensor, forward_data_store,timers, collect_non_loss_data) output_tensor = forward_step(forward_step_func, data_iterator, model, num_microbatches,
input_tensor, forward_data_store,
timers, collect_non_loss_data, enable_autocast)
send_forward(output_tensor, send_tensor_shapes, timers=timers) send_forward(output_tensor, send_tensor_shapes, timers=timers)
if not forward_only: if not forward_only:
...@@ -795,7 +805,7 @@ def forward_backward_pipelining_without_interleaving(*, ...@@ -795,7 +805,7 @@ def forward_backward_pipelining_without_interleaving(*,
output_tensor = forward_step(forward_step_func, data_iterator, model, num_microbatches, output_tensor = forward_step(forward_step_func, data_iterator, model, num_microbatches,
input_tensor, forward_data_store, input_tensor, forward_data_store,
timers, collect_non_loss_data) timers, collect_non_loss_data, enable_autocast)
if forward_only: if forward_only:
send_forward(output_tensor, send_tensor_shapes, timers=timers) send_forward(output_tensor, send_tensor_shapes, timers=timers)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment