Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
9e110299
Unverified
Commit
9e110299
authored
Dec 06, 2022
by
Suraj Patil
Committed by
GitHub
Dec 06, 2022
Browse files
[dreambooth] make collate_fn global (#1547)
make collate_fn global
parent
c2283310
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
31 additions
and
29 deletions
+31
-29
examples/dreambooth/train_dreambooth.py
examples/dreambooth/train_dreambooth.py
+31
-29
No files found.
examples/dreambooth/train_dreambooth.py
View file @
9e110299
...
...
@@ -304,9 +304,10 @@ class DreamBoothDataset(Dataset):
example
[
"instance_images"
]
=
self
.
image_transforms
(
instance_image
)
example
[
"instance_prompt_ids"
]
=
self
.
tokenizer
(
self
.
instance_prompt
,
padding
=
"do_not_pad"
,
truncation
=
True
,
padding
=
"max_length"
,
max_length
=
self
.
tokenizer
.
model_max_length
,
return_tensors
=
"pt"
,
).
input_ids
if
self
.
class_data_root
:
...
...
@@ -316,14 +317,37 @@ class DreamBoothDataset(Dataset):
example
[
"class_images"
]
=
self
.
image_transforms
(
class_image
)
example
[
"class_prompt_ids"
]
=
self
.
tokenizer
(
self
.
class_prompt
,
padding
=
"do_not_pad"
,
truncation
=
True
,
padding
=
"max_length"
,
max_length
=
self
.
tokenizer
.
model_max_length
,
return_tensors
=
"pt"
,
).
input_ids
return
example
def
collate_fn
(
examples
,
with_prior_preservation
=
False
):
input_ids
=
[
example
[
"instance_prompt_ids"
]
for
example
in
examples
]
pixel_values
=
[
example
[
"instance_images"
]
for
example
in
examples
]
# Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes.
if
with_prior_preservation
:
input_ids
+=
[
example
[
"class_prompt_ids"
]
for
example
in
examples
]
pixel_values
+=
[
example
[
"class_images"
]
for
example
in
examples
]
pixel_values
=
torch
.
stack
(
pixel_values
)
pixel_values
=
pixel_values
.
to
(
memory_format
=
torch
.
contiguous_format
).
float
()
input_ids
=
torch
.
cat
(
input_ids
,
dim
=
0
)
batch
=
{
"input_ids"
:
input_ids
,
"pixel_values"
:
pixel_values
,
}
return
batch
class
PromptDataset
(
Dataset
):
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
...
...
@@ -514,34 +538,12 @@ def main(args):
center_crop
=
args
.
center_crop
,
)
def
collate_fn
(
examples
):
input_ids
=
[
example
[
"instance_prompt_ids"
]
for
example
in
examples
]
pixel_values
=
[
example
[
"instance_images"
]
for
example
in
examples
]
# Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes.
if
args
.
with_prior_preservation
:
input_ids
+=
[
example
[
"class_prompt_ids"
]
for
example
in
examples
]
pixel_values
+=
[
example
[
"class_images"
]
for
example
in
examples
]
pixel_values
=
torch
.
stack
(
pixel_values
)
pixel_values
=
pixel_values
.
to
(
memory_format
=
torch
.
contiguous_format
).
float
()
input_ids
=
tokenizer
.
pad
(
{
"input_ids"
:
input_ids
},
padding
=
"max_length"
,
max_length
=
tokenizer
.
model_max_length
,
return_tensors
=
"pt"
,
).
input_ids
batch
=
{
"input_ids"
:
input_ids
,
"pixel_values"
:
pixel_values
,
}
return
batch
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
train_dataset
,
batch_size
=
args
.
train_batch_size
,
shuffle
=
True
,
collate_fn
=
collate_fn
,
num_workers
=
1
train_dataset
,
batch_size
=
args
.
train_batch_size
,
shuffle
=
True
,
collate_fn
=
lambda
examples
:
collate_fn
(
examples
,
args
.
with_prior_preservation
),
num_workers
=
1
,
)
# Scheduler and math around the number of training steps.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment