Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
9b5e7ce2
Commit
9b5e7ce2
authored
Jun 08, 2023
by
Maruyama_Aya
Browse files
modify shell for check
parent
730a092b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
7 additions
and
0 deletions
+7
-0
examples/images/dreambooth/colossalai.sh
examples/images/dreambooth/colossalai.sh
+1
-0
examples/images/dreambooth/test_ci.sh
examples/images/dreambooth/test_ci.sh
+1
-0
examples/images/dreambooth/train_dreambooth_colossalai.py
examples/images/dreambooth/train_dreambooth_colossalai.py
+5
-0
No files found.
examples/images/dreambooth/colossalai.sh
View file @
9b5e7ce2
...
@@ -14,4 +14,5 @@ torchrun --nproc_per_node 4 --standalone train_dreambooth_colossalai.py \
...
@@ -14,4 +14,5 @@ torchrun --nproc_per_node 4 --standalone train_dreambooth_colossalai.py \
--lr_scheduler
=
"constant"
\
--lr_scheduler
=
"constant"
\
--lr_warmup_steps
=
0
\
--lr_warmup_steps
=
0
\
--num_class_images
=
200
\
--num_class_images
=
200
\
--test_run
=
True
\
--placement
=
"auto"
\
--placement
=
"auto"
\
examples/images/dreambooth/test_ci.sh
View file @
9b5e7ce2
...
@@ -19,6 +19,7 @@ for plugin in "gemini"; do
...
@@ -19,6 +19,7 @@ for plugin in "gemini"; do
--learning_rate
=
5e-6
\
--learning_rate
=
5e-6
\
--lr_scheduler
=
"constant"
\
--lr_scheduler
=
"constant"
\
--lr_warmup_steps
=
0
\
--lr_warmup_steps
=
0
\
--test_run
=
True
\
--num_class_images
=
200
\
--num_class_images
=
200
\
--placement
=
"auto"
# "cuda"
--placement
=
"auto"
# "cuda"
done
done
examples/images/dreambooth/train_dreambooth_colossalai.py
View file @
9b5e7ce2
...
@@ -198,6 +198,7 @@ def parse_args(input_args=None):
...
@@ -198,6 +198,7 @@ def parse_args(input_args=None):
parser
.
add_argument
(
"--max_grad_norm"
,
default
=
1.0
,
type
=
float
,
help
=
"Max gradient norm."
)
parser
.
add_argument
(
"--max_grad_norm"
,
default
=
1.0
,
type
=
float
,
help
=
"Max gradient norm."
)
parser
.
add_argument
(
"--push_to_hub"
,
action
=
"store_true"
,
help
=
"Whether or not to push the model to the Hub."
)
parser
.
add_argument
(
"--push_to_hub"
,
action
=
"store_true"
,
help
=
"Whether or not to push the model to the Hub."
)
parser
.
add_argument
(
"--hub_token"
,
type
=
str
,
default
=
None
,
help
=
"The token to use to push to the Model Hub."
)
parser
.
add_argument
(
"--hub_token"
,
type
=
str
,
default
=
None
,
help
=
"The token to use to push to the Model Hub."
)
parser
.
add_argument
(
"--test_run"
,
default
=
False
,
help
=
"Whether to use a smaller dataset for test run."
)
parser
.
add_argument
(
parser
.
add_argument
(
"--hub_model_id"
,
"--hub_model_id"
,
type
=
str
,
type
=
str
,
...
@@ -267,6 +268,7 @@ class DreamBoothDataset(Dataset):
...
@@ -267,6 +268,7 @@ class DreamBoothDataset(Dataset):
class_prompt
=
None
,
class_prompt
=
None
,
size
=
512
,
size
=
512
,
center_crop
=
False
,
center_crop
=
False
,
test
=
False
,
):
):
self
.
size
=
size
self
.
size
=
size
self
.
center_crop
=
center_crop
self
.
center_crop
=
center_crop
...
@@ -277,6 +279,8 @@ class DreamBoothDataset(Dataset):
...
@@ -277,6 +279,8 @@ class DreamBoothDataset(Dataset):
raise
ValueError
(
"Instance images root doesn't exists."
)
raise
ValueError
(
"Instance images root doesn't exists."
)
self
.
instance_images_path
=
list
(
Path
(
instance_data_root
).
iterdir
())
self
.
instance_images_path
=
list
(
Path
(
instance_data_root
).
iterdir
())
if
test
:
self
.
instance_images_path
=
self
.
instance_images_path
[:
10
]
self
.
num_instance_images
=
len
(
self
.
instance_images_path
)
self
.
num_instance_images
=
len
(
self
.
instance_images_path
)
self
.
instance_prompt
=
instance_prompt
self
.
instance_prompt
=
instance_prompt
self
.
_length
=
self
.
num_instance_images
self
.
_length
=
self
.
num_instance_images
...
@@ -509,6 +513,7 @@ def main(args):
...
@@ -509,6 +513,7 @@ def main(args):
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
size
=
args
.
resolution
,
size
=
args
.
resolution
,
center_crop
=
args
.
center_crop
,
center_crop
=
args
.
center_crop
,
test
=
args
.
test_run
)
)
def
collate_fn
(
examples
):
def
collate_fn
(
examples
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment