Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
a29ea36d
Unverified
Commit
a29ea36d
authored
Jul 13, 2023
by
junming huang
Committed by
GitHub
Jul 12, 2023
Browse files
Update train_unconditional.py (#3899)
increase the time of timeout when using big dataset or high resolution
parent
af48bf20
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
1 deletion
+4
-1
examples/unconditional_image_generation/train_unconditional.py
...les/unconditional_image_generation/train_unconditional.py
+4
-1
No files found.
examples/unconditional_image_generation/train_unconditional.py
View file @
a29ea36d
...
@@ -4,6 +4,7 @@ import logging
...
@@ -4,6 +4,7 @@ import logging
import
math
import
math
import
os
import
os
import
shutil
import
shutil
from
datetime
import
timedelta
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Optional
from
typing
import
Optional
...
@@ -11,7 +12,7 @@ import accelerate
...
@@ -11,7 +12,7 @@ import accelerate
import
datasets
import
datasets
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
accelerate
import
Accelerator
from
accelerate
import
Accelerator
,
InitProcessGroupKwargs
from
accelerate.logging
import
get_logger
from
accelerate.logging
import
get_logger
from
accelerate.utils
import
ProjectConfiguration
from
accelerate.utils
import
ProjectConfiguration
from
datasets
import
load_dataset
from
datasets
import
load_dataset
...
@@ -286,11 +287,13 @@ def main(args):
...
@@ -286,11 +287,13 @@ def main(args):
logging_dir
=
os
.
path
.
join
(
args
.
output_dir
,
args
.
logging_dir
)
logging_dir
=
os
.
path
.
join
(
args
.
output_dir
,
args
.
logging_dir
)
accelerator_project_config
=
ProjectConfiguration
(
project_dir
=
args
.
output_dir
,
logging_dir
=
logging_dir
)
accelerator_project_config
=
ProjectConfiguration
(
project_dir
=
args
.
output_dir
,
logging_dir
=
logging_dir
)
kwargs
=
InitProcessGroupKwargs
(
timeout
=
timedelta
(
seconds
=
7200
))
#a big number for high resolution or big dataset
accelerator
=
Accelerator
(
accelerator
=
Accelerator
(
gradient_accumulation_steps
=
args
.
gradient_accumulation_steps
,
gradient_accumulation_steps
=
args
.
gradient_accumulation_steps
,
mixed_precision
=
args
.
mixed_precision
,
mixed_precision
=
args
.
mixed_precision
,
log_with
=
args
.
logger
,
log_with
=
args
.
logger
,
project_config
=
accelerator_project_config
,
project_config
=
accelerator_project_config
,
kwargs_handlers
=
[
kwargs
],
)
)
if
args
.
logger
==
"tensorboard"
:
if
args
.
logger
==
"tensorboard"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment