Commit a9d3dc9f authored by Marta's avatar Marta
Browse files

dont wait until 2nd epoch to measure perf

parent f1a9d2b4
......@@ -43,12 +43,10 @@ class PerformanceLoggingCallback(Callback):
self.timestamps.append(time.time())
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
if trainer.current_epoch == 1:
self.do_step()
self.do_step()
def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
if trainer.current_epoch == 1:
self.do_step()
self.do_step()
def process_performance_stats(self, deltas):
def _round3(val):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment