Commit 98bb40d9 authored by Abhinav Khattar's avatar Abhinav Khattar
Browse files

address review comments


Signed-off-by: default avatarAbhinav Khattar <aklife97@gmail.com>
parent 630811b2
......@@ -90,7 +90,7 @@ def get_forward_backward_func():
collect_non_loss_data: TODO
enable_autocat (optional, default=False): If True, runs the
enable_autocast (optional, default=False): If True, runs the
forward_step_func call inside torch.autocast context
"""
......@@ -188,7 +188,7 @@ def forward_step(forward_step_func,
set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor")
set_input_tensor(input_tensor)
context_manager = torch.autocast("cuda") if (enable_autocast or torch.is_autocast_enabled()) else nullcontext()
context_manager = torch.autocast("cuda") if enable_autocast else nullcontext()
with context_manager:
output_tensor, loss_func = forward_step_func(data_iterator, model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment