Commit 630811b2 authored by Abhinav Khattar's avatar Abhinav Khattar
Browse files

keep previous autocast check


Signed-off-by: default avatarAbhinav Khattar <aklife97@gmail.com>
parent 513d7d8e
...@@ -188,7 +188,7 @@ def forward_step(forward_step_func, ...@@ -188,7 +188,7 @@ def forward_step(forward_step_func,
set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor") set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor")
set_input_tensor(input_tensor) set_input_tensor(input_tensor)
context_manager = torch.autocast("cuda") if enable_autocast else nullcontext() context_manager = torch.autocast("cuda") if (enable_autocast or torch.is_autocast_enabled()) else nullcontext()
with context_manager: with context_manager:
output_tensor, loss_func = forward_step_func(data_iterator, model) output_tensor, loss_func = forward_step_func(data_iterator, model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment