Commit 8119fb04 authored by chenxi226's avatar chenxi226
Browse files

能8卡多轮训练到结束

parent bb715355
Pipeline #3517 failed
......@@ -98,12 +98,17 @@ class BaseSWA(StochasticWeightAveraging):
_scheduler = {"scheduler": _scheduler}
self._swa_scheduler.update(_scheduler)
if trainer.lr_schedulers:
lr_scheduler = trainer.lr_schedulers[0]["scheduler"]
rank_zero_warn(f"Swapping lr_scheduler {lr_scheduler} for {self._swa_scheduler}")
trainer.lr_schedulers[0] = self._swa_scheduler
if trainer.lr_scheduler_configs:
lr_scheduler = trainer.lr_scheduler_configs[0].scheduler
rank_zero_warn(f"Swapping lr_scheduler {lr_scheduler} for {self._swa_scheduler['scheduler']}")
# 更新现有调度器配置
trainer.lr_scheduler_configs[0].scheduler = self._swa_scheduler["scheduler"]
trainer.lr_scheduler_configs[0].interval = self._swa_scheduler["interval"]
else:
trainer.lr_schedulers.append(self._swa_scheduler)
# 新版 PL 使用 add_scheduler 方法添加调度器
from pytorch_lightning.core.optimizer import LRSchedulerConfig
swa_config = LRSchedulerConfig(**self._swa_scheduler)
trainer.lr_scheduler_configs.append(swa_config)
self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device)
......
......@@ -252,6 +252,28 @@ def _train(
callbacks.append(checkpoint_cb)
callbacks.append(LearningRateMonitor(logging_interval="epoch"))
# ========== 终极修复 SWA 报错 ==========
import pytorch_lightning.callbacks.stochastic_weight_avg as swa_module
from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging
def fixed_state_dict(self) -> dict:
# 安全获取真实 scheduler(兼容 dict 格式)
if isinstance(self._swa_scheduler, dict):
scheduler = self._swa_scheduler.get("scheduler", None)
else:
scheduler = self._swa_scheduler
sch_state = scheduler.state_dict() if scheduler is not None else None
return {
"n_averaged": 0 if self.n_averaged is None else self.n_averaged.item(),
"latest_update_epoch": self._latest_update_epoch,
"scheduler_state": sch_state,
"average_model_state": None if self._average_model is None else self._average_model.state_dict(),
}
StochasticWeightAveraging.state_dict = fixed_state_dict
OmegaConf.save(cfg, str(Path(os.getcwd()) / "config.yaml"))
OmegaConf.save(cfg, str(Path(os.getcwd()) / "config_resolved.yaml"), resolve=True)
save_pickle(plan, train_dir / "plan.pkl") # backup plan
......@@ -287,6 +309,8 @@ def _train(
plugins=plugins,
# terminate_on_nan=True, # TODO: make modular
# move_metrics_to_cpu=False,
# stochastic_weight_avg=False,
strategy="ddp_find_unused_parameters_true", # <--- 加上这一行
**trainer_kwargs
)
trainer.fit(module, datamodule=datamodule)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment