Commit 955dd5de authored by mibaumgartner's avatar mibaumgartner
Browse files

restrict lt version, update to latest lt version

parent ddb4d304
......@@ -104,11 +104,11 @@ class LightningBaseModule(pl.LightningModule):
"""
return torch.zeros(*self.example_input_array_shape)
def summarize(self, mode: Optional[str]) -> Optional[ModelSummary]:
def summarize(self, *args, **kwargs) -> Optional[ModelSummary]:
"""
Save model summary as txt
"""
summary = super().summarize(mode=mode)
summary = super().summarize(*args, **kwargs)
save_txt(summary, "./network")
return summary
......
......@@ -283,7 +283,7 @@ def _train(
weights_summary='full',
plugins=plugins,
terminate_on_nan=True, # TODO: make modular
move_metrics_to_cpu=True,
move_metrics_to_cpu=False,
**trainer_kwargs
)
trainer.fit(module, datamodule=datamodule)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment