Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
292c81ed
Unverified
Commit
292c81ed
authored
Feb 08, 2023
by
Fazzie-Maqianli
Committed by
GitHub
Feb 08, 2023
Browse files
fix/transformer-verison (#2581)
parent
d3480396
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
19 additions
and
24 deletions
+19
-24
examples/images/diffusion/README.md
examples/images/diffusion/README.md
+1
-1
examples/images/diffusion/environment.yaml
examples/images/diffusion/environment.yaml
+1
-1
examples/images/diffusion/requirements.txt
examples/images/diffusion/requirements.txt
+1
-1
examples/images/dreambooth/train_dreambooth_colossalai.py
examples/images/dreambooth/train_dreambooth_colossalai.py
+16
-21
No files found.
examples/images/diffusion/README.md
View file @
292c81ed
...
@@ -52,7 +52,7 @@ You can also update an existing [latent diffusion](https://github.com/CompVis/la
...
@@ -52,7 +52,7 @@ You can also update an existing [latent diffusion](https://github.com/CompVis/la
```
```
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
pip install transformers
==4.19.2
diffusers invisible-watermark
pip install transformers diffusers invisible-watermark
```
```
#### Step 2: install lightning
#### Step 2: install lightning
...
...
examples/images/diffusion/environment.yaml
View file @
292c81ed
...
@@ -18,7 +18,7 @@ dependencies:
...
@@ -18,7 +18,7 @@ dependencies:
-
test-tube>=0.7.5
-
test-tube>=0.7.5
-
streamlit==1.12.1
-
streamlit==1.12.1
-
einops==0.3.0
-
einops==0.3.0
-
transformers
==4.19.2
-
transformers
-
webdataset==0.2.5
-
webdataset==0.2.5
-
kornia==0.6
-
kornia==0.6
-
open_clip_torch==2.0.2
-
open_clip_torch==2.0.2
...
...
examples/images/diffusion/requirements.txt
View file @
292c81ed
...
@@ -9,7 +9,7 @@ omegaconf==2.1.1
...
@@ -9,7 +9,7 @@ omegaconf==2.1.1
test-tube>=0.7.5
test-tube>=0.7.5
streamlit>=0.73.1
streamlit>=0.73.1
einops==0.3.0
einops==0.3.0
transformers
==4.19.2
transformers
webdataset==0.2.5
webdataset==0.2.5
open-clip-torch==2.7.0
open-clip-torch==2.7.0
gradio==3.11
gradio==3.11
...
...
examples/images/dreambooth/train_dreambooth_colossalai.py
View file @
292c81ed
...
@@ -10,7 +10,7 @@ import torch.nn.functional as F
...
@@ -10,7 +10,7 @@ import torch.nn.functional as F
import
torch.utils.checkpoint
import
torch.utils.checkpoint
from
diffusers
import
AutoencoderKL
,
DDPMScheduler
,
DiffusionPipeline
,
UNet2DConditionModel
from
diffusers
import
AutoencoderKL
,
DDPMScheduler
,
DiffusionPipeline
,
UNet2DConditionModel
from
diffusers.optimization
import
get_scheduler
from
diffusers.optimization
import
get_scheduler
from
huggingface_hub
import
HfFolder
,
Repository
,
whoami
from
huggingface_hub
import
HfFolder
,
Repository
,
create_repo
,
whoami
from
PIL
import
Image
from
PIL
import
Image
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
from
torchvision
import
transforms
from
torchvision
import
transforms
...
@@ -133,9 +133,13 @@ def parse_args(input_args=None):
...
@@ -133,9 +133,13 @@ def parse_args(input_args=None):
default
=
"cpu"
,
default
=
"cpu"
,
help
=
"Placement Policy for Gemini. Valid when using colossalai as dist plan."
,
help
=
"Placement Policy for Gemini. Valid when using colossalai as dist plan."
,
)
)
parser
.
add_argument
(
"--center_crop"
,
parser
.
add_argument
(
action
=
"store_true"
,
"--center_crop"
,
help
=
"Whether to center crop images before resizing to resolution"
)
default
=
False
,
action
=
"store_true"
,
help
=
(
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
" cropped. The images will be resized to the resolution first before cropping."
),
)
parser
.
add_argument
(
"--train_batch_size"
,
parser
.
add_argument
(
"--train_batch_size"
,
type
=
int
,
type
=
int
,
default
=
4
,
default
=
4
,
...
@@ -149,13 +153,6 @@ def parse_args(input_args=None):
...
@@ -149,13 +153,6 @@ def parse_args(input_args=None):
help
=
"Total number of training steps to perform. If provided, overrides num_train_epochs."
,
help
=
"Total number of training steps to perform. If provided, overrides num_train_epochs."
,
)
)
parser
.
add_argument
(
"--save_steps"
,
type
=
int
,
default
=
500
,
help
=
"Save checkpoint every X updates steps."
)
parser
.
add_argument
(
"--save_steps"
,
type
=
int
,
default
=
500
,
help
=
"Save checkpoint every X updates steps."
)
parser
.
add_argument
(
"--gradient_accumulation_steps"
,
type
=
int
,
default
=
1
,
help
=
"Number of updates steps to accumulate before performing a backward/update pass. If using Gemini, it must be 1"
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--gradient_checkpointing"
,
"--gradient_checkpointing"
,
action
=
"store_true"
,
action
=
"store_true"
,
...
@@ -356,7 +353,6 @@ def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"):
...
@@ -356,7 +353,6 @@ def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"):
def
main
(
args
):
def
main
(
args
):
if
args
.
seed
is
None
:
if
args
.
seed
is
None
:
colossalai
.
launch_from_torch
(
config
=
{})
colossalai
.
launch_from_torch
(
config
=
{})
else
:
else
:
...
@@ -410,7 +406,8 @@ def main(args):
...
@@ -410,7 +406,8 @@ def main(args):
repo_name
=
get_full_repo_name
(
Path
(
args
.
output_dir
).
name
,
token
=
args
.
hub_token
)
repo_name
=
get_full_repo_name
(
Path
(
args
.
output_dir
).
name
,
token
=
args
.
hub_token
)
else
:
else
:
repo_name
=
args
.
hub_model_id
repo_name
=
args
.
hub_model_id
repo
=
Repository
(
args
.
output_dir
,
clone_from
=
repo_name
)
create_repo
(
repo_name
,
exist_ok
=
True
,
token
=
args
.
hub_token
)
repo
=
Repository
(
args
.
output_dir
,
clone_from
=
repo_name
,
token
=
args
.
hub_token
)
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
),
"w+"
)
as
gitignore
:
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
),
"w+"
)
as
gitignore
:
if
"step_*"
not
in
gitignore
:
if
"step_*"
not
in
gitignore
:
...
@@ -469,9 +466,8 @@ def main(args):
...
@@ -469,9 +466,8 @@ def main(args):
if
args
.
gradient_checkpointing
:
if
args
.
gradient_checkpointing
:
unet
.
enable_gradient_checkpointing
()
unet
.
enable_gradient_checkpointing
()
assert
args
.
gradient_accumulation_steps
==
1
,
"if using ColossalAI gradient_accumulation_steps must be set to 1."
if
args
.
scale_lr
:
if
args
.
scale_lr
:
args
.
learning_rate
=
args
.
learning_rate
*
args
.
gradient_accumulation_steps
*
args
.
train_batch_size
*
world_size
args
.
learning_rate
=
args
.
learning_rate
*
args
.
train_batch_size
*
world_size
unet
=
gemini_zero_dpp
(
unet
,
args
.
placement
)
unet
=
gemini_zero_dpp
(
unet
,
args
.
placement
)
...
@@ -529,7 +525,7 @@ def main(args):
...
@@ -529,7 +525,7 @@ def main(args):
# Scheduler and math around the number of training steps.
# Scheduler and math around the number of training steps.
overrode_max_train_steps
=
False
overrode_max_train_steps
=
False
num_update_steps_per_epoch
=
math
.
ceil
(
len
(
train_dataloader
)
/
args
.
gradient_accumulation_steps
)
num_update_steps_per_epoch
=
math
.
ceil
(
len
(
train_dataloader
))
if
args
.
max_train_steps
is
None
:
if
args
.
max_train_steps
is
None
:
args
.
max_train_steps
=
args
.
num_train_epochs
*
num_update_steps_per_epoch
args
.
max_train_steps
=
args
.
num_train_epochs
*
num_update_steps_per_epoch
overrode_max_train_steps
=
True
overrode_max_train_steps
=
True
...
@@ -537,8 +533,8 @@ def main(args):
...
@@ -537,8 +533,8 @@ def main(args):
lr_scheduler
=
get_scheduler
(
lr_scheduler
=
get_scheduler
(
args
.
lr_scheduler
,
args
.
lr_scheduler
,
optimizer
=
optimizer
,
optimizer
=
optimizer
,
num_warmup_steps
=
args
.
lr_warmup_steps
*
args
.
gradient_accumulation_steps
,
num_warmup_steps
=
args
.
lr_warmup_steps
,
num_training_steps
=
args
.
max_train_steps
*
args
.
gradient_accumulation_steps
,
num_training_steps
=
args
.
max_train_steps
,
)
)
weight_dtype
=
torch
.
float32
weight_dtype
=
torch
.
float32
if
args
.
mixed_precision
==
"fp16"
:
if
args
.
mixed_precision
==
"fp16"
:
...
@@ -553,14 +549,14 @@ def main(args):
...
@@ -553,14 +549,14 @@ def main(args):
text_encoder
.
to
(
get_current_device
(),
dtype
=
weight_dtype
)
text_encoder
.
to
(
get_current_device
(),
dtype
=
weight_dtype
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch
=
math
.
ceil
(
len
(
train_dataloader
)
/
args
.
gradient_accumulation_steps
)
num_update_steps_per_epoch
=
math
.
ceil
(
len
(
train_dataloader
))
if
overrode_max_train_steps
:
if
overrode_max_train_steps
:
args
.
max_train_steps
=
args
.
num_train_epochs
*
num_update_steps_per_epoch
args
.
max_train_steps
=
args
.
num_train_epochs
*
num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
# Afterwards we recalculate our number of training epochs
args
.
num_train_epochs
=
math
.
ceil
(
args
.
max_train_steps
/
num_update_steps_per_epoch
)
args
.
num_train_epochs
=
math
.
ceil
(
args
.
max_train_steps
/
num_update_steps_per_epoch
)
# Train!
# Train!
total_batch_size
=
args
.
train_batch_size
*
world_size
*
args
.
gradient_accumulation_steps
total_batch_size
=
args
.
train_batch_size
*
world_size
logger
.
info
(
"***** Running training *****"
,
ranks
=
[
0
])
logger
.
info
(
"***** Running training *****"
,
ranks
=
[
0
])
logger
.
info
(
f
" Num examples =
{
len
(
train_dataset
)
}
"
,
ranks
=
[
0
])
logger
.
info
(
f
" Num examples =
{
len
(
train_dataset
)
}
"
,
ranks
=
[
0
])
...
@@ -568,7 +564,6 @@ def main(args):
...
@@ -568,7 +564,6 @@ def main(args):
logger
.
info
(
f
" Num Epochs =
{
args
.
num_train_epochs
}
"
,
ranks
=
[
0
])
logger
.
info
(
f
" Num Epochs =
{
args
.
num_train_epochs
}
"
,
ranks
=
[
0
])
logger
.
info
(
f
" Instantaneous batch size per device =
{
args
.
train_batch_size
}
"
,
ranks
=
[
0
])
logger
.
info
(
f
" Instantaneous batch size per device =
{
args
.
train_batch_size
}
"
,
ranks
=
[
0
])
logger
.
info
(
f
" Total train batch size (w. parallel, distributed & accumulation) =
{
total_batch_size
}
"
,
ranks
=
[
0
])
logger
.
info
(
f
" Total train batch size (w. parallel, distributed & accumulation) =
{
total_batch_size
}
"
,
ranks
=
[
0
])
logger
.
info
(
f
" Gradient Accumulation steps =
{
args
.
gradient_accumulation_steps
}
"
,
ranks
=
[
0
])
logger
.
info
(
f
" Total optimization steps =
{
args
.
max_train_steps
}
"
,
ranks
=
[
0
])
logger
.
info
(
f
" Total optimization steps =
{
args
.
max_train_steps
}
"
,
ranks
=
[
0
])
# Only show the progress bar once on each machine.
# Only show the progress bar once on each machine.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment