Unverified Commit 533f2ef2 authored by ZHANG Zhi's avatar ZHANG Zhi Committed by GitHub
Browse files

Fix PyTorchImageClassificationTrainer's training (#3339)

* Fix PyTorchImageClassificationTrainer's training

It seems that the current process only calculates the loss and gradient, and does not use the optimizer for optimization. Therefore, the model is not actually trained, and its accuracy on the Web UI remains unchanged.

* Add intermediate reports for every epoch

For now, intermediate reports and final reports are consistent, and they are displayed once after all epochs have finished. This may not meet our expectations. We hope that intermediate reports can provide us with the validation results after each epoch.
parent 8175f280
...@@ -145,12 +145,15 @@ class PyTorchImageClassificationTrainer(BaseTrainer): ...@@ -145,12 +145,15 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
def _train(self): def _train(self):
for i, batch in enumerate(self._train_dataloader): for i, batch in enumerate(self._train_dataloader):
self._optimizer.zero_grad()
loss = self.training_step(batch, i) loss = self.training_step(batch, i)
loss.backward() loss.backward()
self._optimizer.step()
def fit(self) -> None: def fit(self) -> None:
for _ in range(self._trainer_kwargs['max_epochs']): for _ in range(self._trainer_kwargs['max_epochs']):
self._train() self._train()
self._validate()
# assuming val_acc here # assuming val_acc here
nni.report_final_result(self._validate()['val_acc']) nni.report_final_result(self._validate()['val_acc'])
...@@ -204,6 +207,7 @@ class PyTorchMultiModelTrainer(BaseTrainer): ...@@ -204,6 +207,7 @@ class PyTorchMultiModelTrainer(BaseTrainer):
max_epochs = max([x['trainer_kwargs']['max_epochs'] for x in self.kwargs['model_kwargs']]) max_epochs = max([x['trainer_kwargs']['max_epochs'] for x in self.kwargs['model_kwargs']])
for _ in range(max_epochs): for _ in range(max_epochs):
self._train() self._train()
self._validate()
nni.report_final_result(self._validate()) nni.report_final_result(self._validate())
def _train(self): def _train(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment