Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
9e110299
Unverified
Commit
9e110299
authored
Dec 06, 2022
by
Suraj Patil
Committed by
GitHub
Dec 06, 2022
Browse files
[dreambooth] make collate_fn global (#1547)
make collate_fn global
parent
c2283310
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
31 additions
and
29 deletions
+31
-29
examples/dreambooth/train_dreambooth.py
examples/dreambooth/train_dreambooth.py
+31
-29
No files found.
examples/dreambooth/train_dreambooth.py
View file @
9e110299
...
@@ -304,9 +304,10 @@ class DreamBoothDataset(Dataset):
...
@@ -304,9 +304,10 @@ class DreamBoothDataset(Dataset):
example
[
"instance_images"
]
=
self
.
image_transforms
(
instance_image
)
example
[
"instance_images"
]
=
self
.
image_transforms
(
instance_image
)
example
[
"instance_prompt_ids"
]
=
self
.
tokenizer
(
example
[
"instance_prompt_ids"
]
=
self
.
tokenizer
(
self
.
instance_prompt
,
self
.
instance_prompt
,
padding
=
"do_not_pad"
,
truncation
=
True
,
truncation
=
True
,
padding
=
"max_length"
,
max_length
=
self
.
tokenizer
.
model_max_length
,
max_length
=
self
.
tokenizer
.
model_max_length
,
return_tensors
=
"pt"
,
).
input_ids
).
input_ids
if
self
.
class_data_root
:
if
self
.
class_data_root
:
...
@@ -316,14 +317,37 @@ class DreamBoothDataset(Dataset):
...
@@ -316,14 +317,37 @@ class DreamBoothDataset(Dataset):
example
[
"class_images"
]
=
self
.
image_transforms
(
class_image
)
example
[
"class_images"
]
=
self
.
image_transforms
(
class_image
)
example
[
"class_prompt_ids"
]
=
self
.
tokenizer
(
example
[
"class_prompt_ids"
]
=
self
.
tokenizer
(
self
.
class_prompt
,
self
.
class_prompt
,
padding
=
"do_not_pad"
,
truncation
=
True
,
truncation
=
True
,
padding
=
"max_length"
,
max_length
=
self
.
tokenizer
.
model_max_length
,
max_length
=
self
.
tokenizer
.
model_max_length
,
return_tensors
=
"pt"
,
).
input_ids
).
input_ids
return
example
return
example
def
collate_fn
(
examples
,
with_prior_preservation
=
False
):
input_ids
=
[
example
[
"instance_prompt_ids"
]
for
example
in
examples
]
pixel_values
=
[
example
[
"instance_images"
]
for
example
in
examples
]
# Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes.
if
with_prior_preservation
:
input_ids
+=
[
example
[
"class_prompt_ids"
]
for
example
in
examples
]
pixel_values
+=
[
example
[
"class_images"
]
for
example
in
examples
]
pixel_values
=
torch
.
stack
(
pixel_values
)
pixel_values
=
pixel_values
.
to
(
memory_format
=
torch
.
contiguous_format
).
float
()
input_ids
=
torch
.
cat
(
input_ids
,
dim
=
0
)
batch
=
{
"input_ids"
:
input_ids
,
"pixel_values"
:
pixel_values
,
}
return
batch
class
PromptDataset
(
Dataset
):
class
PromptDataset
(
Dataset
):
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
...
@@ -514,34 +538,12 @@ def main(args):
...
@@ -514,34 +538,12 @@ def main(args):
center_crop
=
args
.
center_crop
,
center_crop
=
args
.
center_crop
,
)
)
def
collate_fn
(
examples
):
input_ids
=
[
example
[
"instance_prompt_ids"
]
for
example
in
examples
]
pixel_values
=
[
example
[
"instance_images"
]
for
example
in
examples
]
# Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes.
if
args
.
with_prior_preservation
:
input_ids
+=
[
example
[
"class_prompt_ids"
]
for
example
in
examples
]
pixel_values
+=
[
example
[
"class_images"
]
for
example
in
examples
]
pixel_values
=
torch
.
stack
(
pixel_values
)
pixel_values
=
pixel_values
.
to
(
memory_format
=
torch
.
contiguous_format
).
float
()
input_ids
=
tokenizer
.
pad
(
{
"input_ids"
:
input_ids
},
padding
=
"max_length"
,
max_length
=
tokenizer
.
model_max_length
,
return_tensors
=
"pt"
,
).
input_ids
batch
=
{
"input_ids"
:
input_ids
,
"pixel_values"
:
pixel_values
,
}
return
batch
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
train_dataset
,
batch_size
=
args
.
train_batch_size
,
shuffle
=
True
,
collate_fn
=
collate_fn
,
num_workers
=
1
train_dataset
,
batch_size
=
args
.
train_batch_size
,
shuffle
=
True
,
collate_fn
=
lambda
examples
:
collate_fn
(
examples
,
args
.
with_prior_preservation
),
num_workers
=
1
,
)
)
# Scheduler and math around the number of training steps.
# Scheduler and math around the number of training steps.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment