Commit 0a159036 authored by hepj's avatar hepj
Browse files

修改VIT模型loss就地操作

parent c0f05c10
...@@ -53,7 +53,9 @@ def train_one_epoch(model: torch.nn.Module, ...@@ -53,7 +53,9 @@ def train_one_epoch(model: torch.nn.Module,
print("Loss is {}, stopping training".format(loss_value)) print("Loss is {}, stopping training".format(loss_value))
sys.exit(1) sys.exit(1)
loss /= accum_iter loss_new= loss/accum_iter
loss=loss_new
loss_scaler(loss, optimizer, parameters=model.parameters(), loss_scaler(loss, optimizer, parameters=model.parameters(),
update_grad=(data_iter_step + 1) % accum_iter == 0) update_grad=(data_iter_step + 1) % accum_iter == 0)
if (data_iter_step + 1) % accum_iter == 0: if (data_iter_step + 1) % accum_iter == 0:
...@@ -79,4 +81,4 @@ def train_one_epoch(model: torch.nn.Module, ...@@ -79,4 +81,4 @@ def train_one_epoch(model: torch.nn.Module,
# gather the stats from all processes # gather the stats from all processes
metric_logger.synchronize_between_processes() metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger) print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()} return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment