Unverified Commit 544a6b14 authored by wlf-darkmatter's avatar wlf-darkmatter Committed by GitHub
Browse files

Imporve ability to view usage data during training (#1940)

* Imporve ability to view usage data during training

The wrong labeling of the dataset may cause problems such as gradient explosion during the training process. The wrong labeling can be found by setting up a hook function.
**Problem**: However, when a for loop is used to traverse the iterator of the pytorch `DataLoader` class, the hook cannot obtain the information of the currently read image, and thus cannot determine the source of the error.
**Solution**: Load the data_batch information into the `runner` during the train process, and then pass it to the hook function to solve the problem.

* Update epoch_based_runner.py

* Update iter_based_runner.py

* strict the scope of runner.data_batch

* strict the scope of runner.data_batch

* strict the scope of runner.data_batch
parent 21bada32
......@@ -45,10 +45,12 @@ class EpochBasedRunner(BaseRunner):
self.call_hook('before_train_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self.data_batch = data_batch
self._inner_iter = i
self.call_hook('before_train_iter')
self.run_iter(data_batch, train_mode=True, **kwargs)
self.call_hook('after_train_iter')
del self.data_batch
self._iter += 1
self.call_hook('after_train_epoch')
......@@ -62,11 +64,12 @@ class EpochBasedRunner(BaseRunner):
self.call_hook('before_val_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self.data_batch = data_batch
self._inner_iter = i
self.call_hook('before_val_iter')
self.run_iter(data_batch, train_mode=False)
self.call_hook('after_val_iter')
del self.data_batch
self.call_hook('after_val_epoch')
def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
......
......@@ -57,6 +57,7 @@ class IterBasedRunner(BaseRunner):
self.data_loader = data_loader
self._epoch = data_loader.epoch
data_batch = next(data_loader)
self.data_batch = data_batch
self.call_hook('before_train_iter')
outputs = self.model.train_step(data_batch, self.optimizer, **kwargs)
if not isinstance(outputs, dict):
......@@ -65,6 +66,7 @@ class IterBasedRunner(BaseRunner):
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
self.outputs = outputs
self.call_hook('after_train_iter')
del self.data_batch
self._inner_iter += 1
self._iter += 1
......@@ -74,6 +76,7 @@ class IterBasedRunner(BaseRunner):
self.mode = 'val'
self.data_loader = data_loader
data_batch = next(data_loader)
self.data_batch = data_batch
self.call_hook('before_val_iter')
outputs = self.model.val_step(data_batch, **kwargs)
if not isinstance(outputs, dict):
......@@ -82,6 +85,7 @@ class IterBasedRunner(BaseRunner):
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
self.outputs = outputs
self.call_hook('after_val_iter')
del self.data_batch
self._inner_iter += 1
def run(self, data_loaders, workflow, max_iters=None, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment