Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yaoyuping
nnDetection
Commits
8119fb04
Commit
8119fb04
authored
Apr 21, 2026
by
chenxi226
Browse files
能8卡多轮训练到结束
parent
bb715355
Pipeline
#3517
failed
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
5 deletions
+34
-5
nndet/training/swa.py
nndet/training/swa.py
+10
-5
scripts/train.py
scripts/train.py
+24
-0
No files found.
nndet/training/swa.py
View file @
8119fb04
...
...
@@ -98,12 +98,17 @@ class BaseSWA(StochasticWeightAveraging):
_scheduler
=
{
"scheduler"
:
_scheduler
}
self
.
_swa_scheduler
.
update
(
_scheduler
)
if
trainer
.
lr_schedulers
:
lr_scheduler
=
trainer
.
lr_schedulers
[
0
][
"scheduler"
]
rank_zero_warn
(
f
"Swapping lr_scheduler
{
lr_scheduler
}
for
{
self
.
_swa_scheduler
}
"
)
trainer
.
lr_schedulers
[
0
]
=
self
.
_swa_scheduler
if
trainer
.
lr_scheduler_configs
:
lr_scheduler
=
trainer
.
lr_scheduler_configs
[
0
].
scheduler
rank_zero_warn
(
f
"Swapping lr_scheduler
{
lr_scheduler
}
for
{
self
.
_swa_scheduler
[
'scheduler'
]
}
"
)
# 更新现有调度器配置
trainer
.
lr_scheduler_configs
[
0
].
scheduler
=
self
.
_swa_scheduler
[
"scheduler"
]
trainer
.
lr_scheduler_configs
[
0
].
interval
=
self
.
_swa_scheduler
[
"interval"
]
else
:
trainer
.
lr_schedulers
.
append
(
self
.
_swa_scheduler
)
# 新版 PL 使用 add_scheduler 方法添加调度器
from
pytorch_lightning.core.optimizer
import
LRSchedulerConfig
swa_config
=
LRSchedulerConfig
(
**
self
.
_swa_scheduler
)
trainer
.
lr_scheduler_configs
.
append
(
swa_config
)
self
.
n_averaged
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
pl_module
.
device
)
...
...
scripts/train.py
View file @
8119fb04
...
...
@@ -251,6 +251,28 @@ def _train(
checkpoint_cb
.
CHECKPOINT_NAME_LAST
=
'model_last'
callbacks
.
append
(
checkpoint_cb
)
callbacks
.
append
(
LearningRateMonitor
(
logging_interval
=
"epoch"
))
# ========== 终极修复 SWA 报错 ==========
import
pytorch_lightning.callbacks.stochastic_weight_avg
as
swa_module
from
pytorch_lightning.callbacks.stochastic_weight_avg
import
StochasticWeightAveraging
def
fixed_state_dict
(
self
)
->
dict
:
# 安全获取真实 scheduler(兼容 dict 格式)
if
isinstance
(
self
.
_swa_scheduler
,
dict
):
scheduler
=
self
.
_swa_scheduler
.
get
(
"scheduler"
,
None
)
else
:
scheduler
=
self
.
_swa_scheduler
sch_state
=
scheduler
.
state_dict
()
if
scheduler
is
not
None
else
None
return
{
"n_averaged"
:
0
if
self
.
n_averaged
is
None
else
self
.
n_averaged
.
item
(),
"latest_update_epoch"
:
self
.
_latest_update_epoch
,
"scheduler_state"
:
sch_state
,
"average_model_state"
:
None
if
self
.
_average_model
is
None
else
self
.
_average_model
.
state_dict
(),
}
StochasticWeightAveraging
.
state_dict
=
fixed_state_dict
OmegaConf
.
save
(
cfg
,
str
(
Path
(
os
.
getcwd
())
/
"config.yaml"
))
OmegaConf
.
save
(
cfg
,
str
(
Path
(
os
.
getcwd
())
/
"config_resolved.yaml"
),
resolve
=
True
)
...
...
@@ -287,6 +309,8 @@ def _train(
plugins
=
plugins
,
# terminate_on_nan=True, # TODO: make modular
# move_metrics_to_cpu=False,
# stochastic_weight_avg=False,
strategy
=
"ddp_find_unused_parameters_true"
,
# <--- 加上这一行
**
trainer_kwargs
)
trainer
.
fit
(
module
,
datamodule
=
datamodule
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment