Commit 513d7d8e authored by Abhinav Khattar's avatar Abhinav Khattar
Browse files

add enable autocast


Signed-off-by: default avatarAbhinav Khattar <aklife97@gmail.com>
parent 2699f93e
......@@ -90,6 +90,9 @@ def get_forward_backward_func():
collect_non_loss_data: TODO
enable_autocat (optional, default=False): If True, runs the
forward_step_func call inside torch.autocast context
"""
pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
if pipeline_model_parallel_size > 1:
......@@ -166,7 +169,8 @@ def forward_step(forward_step_func,
input_tensor,
forward_data_store,
timers,
collect_non_loss_data=False):
collect_non_loss_data=False,
enable_autocast=False):
"""Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise
......@@ -184,7 +188,7 @@ def forward_step(forward_step_func,
set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor")
set_input_tensor(input_tensor)
context_manager = torch.autocast("cuda") if torch.is_autocast_enabled() else nullcontext()
context_manager = torch.autocast("cuda") if enable_autocast else nullcontext()
with context_manager:
output_tensor, loss_func = forward_step_func(data_iterator, model)
......@@ -296,7 +300,8 @@ def forward_backward_no_pipelining(*,
sequence_parallel: bool = False, # unused
forward_only: bool = False,
timers: Callable = None,
collect_non_loss_data: bool = False):
collect_non_loss_data: bool = False,
enable_autocast: bool = False):
"""Run forward and backward passes with no pipeline parallelism
(no inter-stage communication).
......@@ -320,7 +325,7 @@ def forward_backward_no_pipelining(*,
for i in range(num_microbatches - 1):
output_tensor = forward_step(forward_step_func, data_iterator,
model, num_microbatches, input_tensor, forward_data_store,
timers, collect_non_loss_data)
timers, collect_non_loss_data, enable_autocast)
if not forward_only:
backward_step(grad_scaler, input_tensor, output_tensor,
output_tensor_grad, model_type, timers)
......@@ -329,7 +334,7 @@ def forward_backward_no_pipelining(*,
# synchronize gradients).
output_tensor = forward_step(forward_step_func, data_iterator,
model, num_microbatches, input_tensor, forward_data_store,
timers, collect_non_loss_data)
timers, collect_non_loss_data, enable_autocast)
if not forward_only:
backward_step(grad_scaler, input_tensor, output_tensor,
......@@ -350,7 +355,8 @@ def forward_backward_pipelining_with_interleaving(*,
sequence_parallel: bool = False,
forward_only: bool = False,
timers: Callable = None,
collect_non_loss_data: bool = False):
collect_non_loss_data: bool = False,
enable_autocast: bool = False):
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
......@@ -440,7 +446,8 @@ def forward_backward_pipelining_with_interleaving(*,
input_tensor,
forward_data_store,
timers,
collect_non_loss_data)
collect_non_loss_data,
enable_autocast)
output_tensors[model_chunk_id].append(output_tensor)
# if forward-only, no need to save tensors for a backward pass
......@@ -731,7 +738,8 @@ def forward_backward_pipelining_without_interleaving(*,
sequence_parallel: bool = False,
forward_only: bool = False,
timers: Callable = None,
collect_non_loss_data: bool = False):
collect_non_loss_data: bool = False,
enable_autocast: bool = False):
"""Run non-interleaved 1F1B schedule, with communication between pipeline
stages.
......@@ -775,7 +783,9 @@ def forward_backward_pipelining_without_interleaving(*,
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
input_tensor = recv_forward(recv_tensor_shapes, dtype, timers=timers)
output_tensor = forward_step(forward_step_func, data_iterator, model, num_microbatches,input_tensor, forward_data_store,timers, collect_non_loss_data)
output_tensor = forward_step(forward_step_func, data_iterator, model, num_microbatches,
input_tensor, forward_data_store,
timers, collect_non_loss_data, enable_autocast)
send_forward(output_tensor, send_tensor_shapes, timers=timers)
if not forward_only:
......@@ -795,7 +805,7 @@ def forward_backward_pipelining_without_interleaving(*,
output_tensor = forward_step(forward_step_func, data_iterator, model, num_microbatches,
input_tensor, forward_data_store,
timers, collect_non_loss_data)
timers, collect_non_loss_data, enable_autocast)
if forward_only:
send_forward(output_tensor, send_tensor_shapes, timers=timers)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment