"tests/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "135567f18edc0bf02d515d5c76cc736d1ebddad3"
Unverified Commit d7f33a73 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Fix typecheck for pytorch lightning v1.8 (#5207)

parent 7cb66c97
...@@ -388,7 +388,7 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -388,7 +388,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
return optim_conf return optim_conf
def setup(self, stage=None): def setup(self, stage: str = cast(str, None)): # add default value to be backward-compatible
# redirect the access to trainer/log to this module # redirect the access to trainer/log to this module
# but note that we might be missing other attributes, # but note that we might be missing other attributes,
# which could potentially be a problem # which could potentially be a problem
...@@ -400,7 +400,7 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -400,7 +400,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
return self.model.setup(stage) return self.model.setup(stage)
def teardown(self, stage=None): def teardown(self, stage: str = cast(str, None)):
return self.model.teardown(stage) return self.model.teardown(stage)
def configure_architecture_optimizers(self) -> list[optim.Optimizer] | optim.Optimizer | None: def configure_architecture_optimizers(self) -> list[optim.Optimizer] | optim.Optimizer | None:
...@@ -492,7 +492,7 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -492,7 +492,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
self.model.lr_scheduler_step(cast(Any, scheduler), cast(int, opt_idx), None) self.model.lr_scheduler_step(cast(Any, scheduler), cast(int, opt_idx), None)
except AttributeError: except AttributeError:
# lightning < 1.6 # lightning < 1.6
for lr_scheduler in self.trainer.lr_schedulers: for lr_scheduler in self.trainer.lr_schedulers: # type: ignore
if lr_scheduler['reduce_on_plateau']: if lr_scheduler['reduce_on_plateau']:
warnings.warn('Reduce-lr-on-plateau is not supported in NAS. It will be ignored.', UserWarning) warnings.warn('Reduce-lr-on-plateau is not supported in NAS. It will be ignored.', UserWarning)
if lr_scheduler['interval'] == interval and current_idx % lr_scheduler['frequency']: if lr_scheduler['interval'] == interval and current_idx % lr_scheduler['frequency']:
......
...@@ -275,6 +275,8 @@ class GumbelDartsLightningModule(DartsLightningModule): ...@@ -275,6 +275,8 @@ class GumbelDartsLightningModule(DartsLightningModule):
def on_train_epoch_start(self): def on_train_epoch_start(self):
if self.use_temp_anneal: if self.use_temp_anneal:
if self.trainer.max_epochs is None:
raise ValueError('Please set max_epochs for trainer when using temperature annealing.')
self.temp = (1 - self.trainer.current_epoch / self.trainer.max_epochs) * (self.init_temp - self.min_temp) + self.min_temp self.temp = (1 - self.trainer.current_epoch / self.trainer.max_epochs) * (self.init_temp - self.min_temp) + self.min_temp
self.temp = max(self.temp, self.min_temp) self.temp = max(self.temp, self.min_temp)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment