Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
DragNoise_pytorch
Commits
66e662c1
Commit
66e662c1
authored
Dec 17, 2024
by
bailuo
Browse files
init & optimize
parents
Pipeline
#2116
failed with stages
in 0 seconds
Changes
62
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
624 additions
and
0 deletions
+624
-0
utils/lora_utils.py
utils/lora_utils.py
+313
-0
utils/ui_utils.py
utils/ui_utils.py
+311
-0
No files found.
utils/lora_utils.py
0 → 100755
View file @
66e662c1
# *************************************************************************
# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
# ytedance Inc..
# *************************************************************************
from
PIL
import
Image
import
os
import
numpy
as
np
from
einops
import
rearrange
import
torch
import
torch.nn.functional
as
F
from
torchvision
import
transforms
from
accelerate
import
Accelerator
from
accelerate.utils
import
set_seed
from
PIL
import
Image
from
transformers
import
AutoTokenizer
,
PretrainedConfig
import
diffusers
from
diffusers
import
(
AutoencoderKL
,
DDPMScheduler
,
DiffusionPipeline
,
DPMSolverMultistepScheduler
,
StableDiffusionPipeline
,
UNet2DConditionModel
,
)
from
diffusers.loaders
import
AttnProcsLayers
,
LoraLoaderMixin
from
diffusers.models.attention_processor
import
(
AttnAddedKVProcessor
,
AttnAddedKVProcessor2_0
,
LoRAAttnAddedKVProcessor
,
LoRAAttnProcessor
,
LoRAAttnProcessor2_0
,
SlicedAttnAddedKVProcessor
,
)
from
diffusers.optimization
import
get_scheduler
from
diffusers.utils
import
check_min_version
from
diffusers.utils.import_utils
import
is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version
(
"0.17.0"
)
def
import_model_class_from_model_name_or_path
(
pretrained_model_name_or_path
:
str
,
revision
:
str
):
text_encoder_config
=
PretrainedConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
subfolder
=
"text_encoder"
,
revision
=
revision
,
)
model_class
=
text_encoder_config
.
architectures
[
0
]
if
model_class
==
"CLIPTextModel"
:
from
transformers
import
CLIPTextModel
return
CLIPTextModel
elif
model_class
==
"RobertaSeriesModelWithTransformation"
:
from
diffusers.pipelines.alt_diffusion.modeling_roberta_series
import
RobertaSeriesModelWithTransformation
return
RobertaSeriesModelWithTransformation
elif
model_class
==
"T5EncoderModel"
:
from
transformers
import
T5EncoderModel
return
T5EncoderModel
else
:
raise
ValueError
(
f
"
{
model_class
}
is not supported."
)
def
tokenize_prompt
(
tokenizer
,
prompt
,
tokenizer_max_length
=
None
):
if
tokenizer_max_length
is
not
None
:
max_length
=
tokenizer_max_length
else
:
max_length
=
tokenizer
.
model_max_length
text_inputs
=
tokenizer
(
prompt
,
truncation
=
True
,
padding
=
"max_length"
,
max_length
=
max_length
,
return_tensors
=
"pt"
,
)
return
text_inputs
def
encode_prompt
(
text_encoder
,
input_ids
,
attention_mask
,
text_encoder_use_attention_mask
=
False
):
text_input_ids
=
input_ids
.
to
(
text_encoder
.
device
)
if
text_encoder_use_attention_mask
:
attention_mask
=
attention_mask
.
to
(
text_encoder
.
device
)
else
:
attention_mask
=
None
prompt_embeds
=
text_encoder
(
text_input_ids
,
attention_mask
=
attention_mask
,
)
prompt_embeds
=
prompt_embeds
[
0
]
return
prompt_embeds
# model_path: path of the model
# image: input image, have not been pre-processed
# save_lora_path: the path to save the lora
# prompt: the user input prompt
# lora_step: number of lora training step
# lora_lr: learning rate of lora training
# lora_rank: the rank of lora
# save_interval: the frequency of saving lora checkpoints
def
train_lora
(
image
,
prompt
,
model_path
,
vae_path
,
save_lora_path
,
lora_step
,
lora_lr
,
lora_batch_size
,
lora_rank
,
progress
,
# lora_batch_size=1,
save_interval
=-
1
,
):
# initialize accelerator
accelerator
=
Accelerator
(
gradient_accumulation_steps
=
1
,
mixed_precision
=
'fp16'
)
set_seed
(
0
)
# Load the tokenizer
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
subfolder
=
"tokenizer"
,
revision
=
None
,
use_fast
=
False
,
)
# initialize the model
noise_scheduler
=
DDPMScheduler
.
from_pretrained
(
model_path
,
subfolder
=
"scheduler"
)
text_encoder_cls
=
import_model_class_from_model_name_or_path
(
model_path
,
revision
=
None
)
text_encoder
=
text_encoder_cls
.
from_pretrained
(
model_path
,
subfolder
=
"text_encoder"
,
revision
=
None
)
if
vae_path
==
"default"
:
vae
=
AutoencoderKL
.
from_pretrained
(
model_path
,
subfolder
=
"vae"
,
revision
=
None
)
else
:
vae
=
AutoencoderKL
.
from_pretrained
(
vae_path
)
unet
=
UNet2DConditionModel
.
from_pretrained
(
model_path
,
subfolder
=
"unet"
,
revision
=
None
)
# set device and dtype
device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
vae
.
requires_grad_
(
False
)
text_encoder
.
requires_grad_
(
False
)
unet
.
requires_grad_
(
False
)
unet
.
to
(
device
,
dtype
=
torch
.
float16
)
vae
.
to
(
device
,
dtype
=
torch
.
float16
)
text_encoder
.
to
(
device
,
dtype
=
torch
.
float16
)
lora_rank_list
=
[
4
,
4
,
4
,
4
,
8
,
8
,
8
,
8
,
16
,
16
,
16
,
16
,
16
,
16
,
16
,
16
,
16
,
16
,
8
,
8
,
8
,
8
,
8
,
8
,
4
,
4
,
4
,
4
,
4
,
4
,
32
,
32
]
# down:4+4+4, up:6+6+6, mid:1+1
lora_rank_inx
=
0
# initialize UNet LoRA
unet_lora_attn_procs
=
{}
for
name
,
attn_processor
in
unet
.
attn_processors
.
items
():
cross_attention_dim
=
None
if
name
.
endswith
(
"attn1.processor"
)
else
unet
.
config
.
cross_attention_dim
if
name
.
startswith
(
"mid_block"
):
hidden_size
=
unet
.
config
.
block_out_channels
[
-
1
]
elif
name
.
startswith
(
"up_blocks"
):
block_id
=
int
(
name
[
len
(
"up_blocks."
)])
hidden_size
=
list
(
reversed
(
unet
.
config
.
block_out_channels
))[
block_id
]
elif
name
.
startswith
(
"down_blocks"
):
block_id
=
int
(
name
[
len
(
"down_blocks."
)])
hidden_size
=
unet
.
config
.
block_out_channels
[
block_id
]
else
:
raise
NotImplementedError
(
"name must start with up_blocks, mid_blocks, or down_blocks"
)
if
isinstance
(
attn_processor
,
(
AttnAddedKVProcessor
,
SlicedAttnAddedKVProcessor
,
AttnAddedKVProcessor2_0
)):
lora_attn_processor_class
=
LoRAAttnAddedKVProcessor
else
:
lora_attn_processor_class
=
(
LoRAAttnProcessor2_0
if
hasattr
(
F
,
"scaled_dot_product_attention"
)
else
LoRAAttnProcessor
)
lora_rank
=
lora_rank_list
[
lora_rank_inx
]
*
2
unet_lora_attn_procs
[
name
]
=
lora_attn_processor_class
(
hidden_size
=
hidden_size
,
cross_attention_dim
=
cross_attention_dim
,
rank
=
lora_rank
)
lora_rank_inx
=
lora_rank_inx
+
1
unet
.
set_attn_processor
(
unet_lora_attn_procs
)
unet_lora_layers
=
AttnProcsLayers
(
unet
.
attn_processors
)
# Optimizer creation
params_to_optimize
=
(
unet_lora_layers
.
parameters
())
optimizer
=
torch
.
optim
.
AdamW
(
params_to_optimize
,
lr
=
lora_lr
,
betas
=
(
0.9
,
0.999
),
weight_decay
=
1e-2
,
eps
=
1e-08
,
)
lr_scheduler
=
get_scheduler
(
"constant"
,
optimizer
=
optimizer
,
num_warmup_steps
=
0
,
num_training_steps
=
lora_step
,
num_cycles
=
1
,
power
=
1.0
,
)
# prepare accelerator
unet_lora_layers
=
accelerator
.
prepare_model
(
unet_lora_layers
)
optimizer
=
accelerator
.
prepare_optimizer
(
optimizer
)
lr_scheduler
=
accelerator
.
prepare_scheduler
(
lr_scheduler
)
# initialize text embeddings
with
torch
.
no_grad
():
text_inputs
=
tokenize_prompt
(
tokenizer
,
prompt
,
tokenizer_max_length
=
None
)
text_embedding
=
encode_prompt
(
text_encoder
,
text_inputs
.
input_ids
,
text_inputs
.
attention_mask
,
text_encoder_use_attention_mask
=
False
)
text_embedding
=
text_embedding
.
repeat
(
lora_batch_size
,
1
,
1
)
# initialize latent distribution
image_transforms
=
transforms
.
Compose
(
[
transforms
.
Resize
(
512
,
interpolation
=
transforms
.
InterpolationMode
.
BILINEAR
),
transforms
.
RandomCrop
(
512
),
transforms
.
ToTensor
(),
transforms
.
Normalize
([
0.5
],
[
0.5
]),
]
)
for
step
in
progress
.
tqdm
(
range
(
lora_step
),
desc
=
"training LoRA"
):
unet
.
train
()
image_batch
=
[]
for
_
in
range
(
lora_batch_size
):
image_transformed
=
image_transforms
(
Image
.
fromarray
(
image
)).
to
(
device
,
dtype
=
torch
.
float16
)
image_transformed
=
image_transformed
.
unsqueeze
(
dim
=
0
)
image_batch
.
append
(
image_transformed
)
# repeat the image_transformed to enable multi-batch training
image_batch
=
torch
.
cat
(
image_batch
,
dim
=
0
)
latents_dist
=
vae
.
encode
(
image_batch
).
latent_dist
model_input
=
latents_dist
.
sample
()
*
vae
.
config
.
scaling_factor
# Sample noise that we'll add to the latents
noise
=
torch
.
randn_like
(
model_input
)
bsz
,
channels
,
height
,
width
=
model_input
.
shape
# Sample a random timestep for each image
timesteps
=
torch
.
randint
(
0
,
noise_scheduler
.
config
.
num_train_timesteps
,
(
bsz
,),
device
=
model_input
.
device
)
timesteps
=
timesteps
.
long
()
# Add noise to the model input according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_model_input
=
noise_scheduler
.
add_noise
(
model_input
,
noise
,
timesteps
)
# Predict the noise residual
model_pred
=
unet
(
noisy_model_input
,
timesteps
,
text_embedding
).
sample
# Get the target for loss depending on the prediction type
if
noise_scheduler
.
config
.
prediction_type
==
"epsilon"
:
target
=
noise
elif
noise_scheduler
.
config
.
prediction_type
==
"v_prediction"
:
target
=
noise_scheduler
.
get_velocity
(
model_input
,
noise
,
timesteps
)
else
:
raise
ValueError
(
f
"Unknown prediction type
{
noise_scheduler
.
config
.
prediction_type
}
"
)
loss
=
F
.
mse_loss
(
model_pred
.
float
(),
target
.
float
(),
reduction
=
"mean"
)
accelerator
.
backward
(
loss
)
optimizer
.
step
()
lr_scheduler
.
step
()
optimizer
.
zero_grad
()
if
save_interval
>
0
and
(
step
+
1
)
%
save_interval
==
0
:
save_lora_path_intermediate
=
os
.
path
.
join
(
save_lora_path
,
str
(
step
+
1
))
if
not
os
.
path
.
isdir
(
save_lora_path_intermediate
):
os
.
mkdir
(
save_lora_path_intermediate
)
# unet = unet.to(torch.float32)
# unwrap_model is used to remove all special modules added when doing distributed training
# so here, there is no need to call unwrap_model
# unet_lora_layers = accelerator.unwrap_model(unet_lora_layers)
LoraLoaderMixin
.
save_lora_weights
(
save_directory
=
save_lora_path_intermediate
,
unet_lora_layers
=
unet_lora_layers
,
text_encoder_lora_layers
=
None
,
)
# unet = unet.to(torch.float16)
# save the trained lora
# unet = unet.to(torch.float32)
# unwrap_model is used to remove all special modules added when doing distributed training
# so here, there is no need to call unwrap_model
# unet_lora_layers = accelerator.unwrap_model(unet_lora_layers)
LoraLoaderMixin
.
save_lora_weights
(
save_directory
=
save_lora_path
,
unet_lora_layers
=
unet_lora_layers
,
text_encoder_lora_layers
=
None
,
)
return
utils/ui_utils.py
0 → 100755
View file @
66e662c1
import
copy
import
os
import
cv2
import
numpy
as
np
import
gradio
as
gr
from
copy
import
deepcopy
from
einops
import
rearrange
from
types
import
SimpleNamespace
import
datetime
import
PIL
from
PIL
import
Image
from
PIL.ImageOps
import
exif_transpose
import
torch
import
torch.nn.functional
as
F
from
diffusers
import
DDIMScheduler
,
AutoencoderKL
,
DPMSolverMultistepScheduler
from
drag_pipeline
import
DragPipeline
from
torchvision.utils
import
save_image
from
pytorch_lightning
import
seed_everything
from
.drag_utils
import
drag_diffusion_update
from
.lora_utils
import
train_lora
from
.attn_utils
import
register_attention_editor_diffusers
,
MutualSelfAttentionControl
from
.freeu_utils
import
register_free_upblock2d
,
register_free_crossattn_upblock2d
# -------------- general UI functionality --------------
def
clear_all
(
length
=
480
):
return
gr
.
Image
.
update
(
value
=
None
,
height
=
length
,
width
=
length
),
\
gr
.
Image
.
update
(
value
=
None
,
height
=
length
,
width
=
length
),
\
gr
.
Image
.
update
(
value
=
None
,
height
=
length
,
width
=
length
),
\
[],
None
,
None
def
clear_all_gen
(
length
=
480
):
return
gr
.
Image
.
update
(
value
=
None
,
height
=
length
,
width
=
length
),
\
gr
.
Image
.
update
(
value
=
None
,
height
=
length
,
width
=
length
),
\
gr
.
Image
.
update
(
value
=
None
,
height
=
length
,
width
=
length
),
\
[],
None
,
None
,
None
def
mask_image
(
image
,
mask
,
color
=
[
255
,
0
,
0
],
alpha
=
0.5
):
""" Overlay mask on image for visualization purpose.
Args:
image (H, W, 3) or (H, W): input image
mask (H, W): mask to be overlaid
color: the color of overlaid mask
alpha: the transparency of the mask
"""
out
=
deepcopy
(
image
)
img
=
deepcopy
(
image
)
img
[
mask
==
1
]
=
color
out
=
cv2
.
addWeighted
(
img
,
alpha
,
out
,
1
-
alpha
,
0
,
out
)
return
out
def
store_img
(
img
,
length
=
512
):
image
,
mask
=
img
[
"image"
],
np
.
float32
(
img
[
"mask"
][:,
:,
0
])
/
255.
height
,
width
,
_
=
image
.
shape
image
=
Image
.
fromarray
(
image
)
image
=
exif_transpose
(
image
)
image
=
image
.
resize
((
length
,
int
(
length
*
height
/
width
)),
PIL
.
Image
.
BILINEAR
)
mask
=
cv2
.
resize
(
mask
,
(
length
,
int
(
length
*
height
/
width
)),
interpolation
=
cv2
.
INTER_NEAREST
)
image
=
np
.
array
(
image
)
if
mask
.
sum
()
>
0
:
mask
=
np
.
uint8
(
mask
>
0
)
masked_img
=
mask_image
(
image
,
1
-
mask
,
color
=
[
0
,
0
,
0
],
alpha
=
0.3
)
else
:
masked_img
=
image
.
copy
()
# when new image is uploaded, `selected_points` should be empty
return
image
,
[],
masked_img
,
mask
# once user upload an image, the original image is stored in `original_image`
# the same image is displayed in `input_image` for point clicking purpose
def
store_img_gen
(
img
):
image
,
mask
=
img
[
"image"
],
np
.
float32
(
img
[
"mask"
][:,
:,
0
])
/
255.
image
=
Image
.
fromarray
(
image
)
image
=
exif_transpose
(
image
)
image
=
np
.
array
(
image
)
if
mask
.
sum
()
>
0
:
mask
=
np
.
uint8
(
mask
>
0
)
masked_img
=
mask_image
(
image
,
1
-
mask
,
color
=
[
0
,
0
,
0
],
alpha
=
0.3
)
else
:
masked_img
=
image
.
copy
()
# when new image is uploaded, `selected_points` should be empty
return
image
,
[],
masked_img
,
mask
# user click the image to get points, and show the points on the image
def
get_points
(
img
,
sel_pix
,
evt
:
gr
.
SelectData
):
# collect the selected point
sel_pix
.
append
(
evt
.
index
)
# draw points
points
=
[]
for
idx
,
point
in
enumerate
(
sel_pix
):
if
idx
%
2
==
0
:
# draw a red circle at the handle point
cv2
.
circle
(
img
,
tuple
(
point
),
10
,
(
255
,
0
,
0
),
-
1
)
else
:
# draw a blue circle at the handle point
cv2
.
circle
(
img
,
tuple
(
point
),
10
,
(
0
,
0
,
255
),
-
1
)
points
.
append
(
tuple
(
point
))
# draw an arrow from handle point to target point
if
len
(
points
)
==
2
:
cv2
.
arrowedLine
(
img
,
points
[
0
],
points
[
1
],
(
255
,
255
,
255
),
4
,
tipLength
=
0.5
)
points
=
[]
return
img
if
isinstance
(
img
,
np
.
ndarray
)
else
np
.
array
(
img
)
# clear all handle/target points
def
undo_points
(
original_image
,
mask
):
if
mask
.
sum
()
>
0
:
mask
=
np
.
uint8
(
mask
>
0
)
masked_img
=
mask_image
(
original_image
,
1
-
mask
,
color
=
[
0
,
0
,
0
],
alpha
=
0.3
)
else
:
masked_img
=
original_image
.
copy
()
return
masked_img
,
[]
# ------------------------------------------------------
# ----------- dragging user-input image utils -----------
def
train_lora_interface
(
original_image
,
prompt
,
model_path
,
vae_path
,
lora_path
,
lora_step
,
lora_lr
,
lora_batch_size
,
lora_rank
,
progress
=
gr
.
Progress
()):
train_lora
(
original_image
,
prompt
,
model_path
,
vae_path
,
lora_path
,
lora_step
,
lora_lr
,
lora_batch_size
,
lora_rank
,
progress
)
return
"Training LoRA Done!"
def
preprocess_image
(
image
,
device
):
image
=
torch
.
from_numpy
(
image
).
float
()
/
127.5
-
1
# [-1, 1]
image
=
rearrange
(
image
,
"h w c -> 1 c h w"
)
image
=
image
.
to
(
device
)
return
image
def
run_drag
(
source_image
,
image_with_clicks
,
mask
,
prompt
,
points
,
inversion_strength
,
end_step
,
lam
,
latent_lr
,
n_pix_step
,
model_path
,
vae_path
,
lora_path
,
start_step
,
start_layer
,
save_dir
=
"./results"
):
# initialize model
device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
scheduler
=
DDIMScheduler
(
beta_start
=
0.00085
,
beta_end
=
0.012
,
beta_schedule
=
"scaled_linear"
,
clip_sample
=
False
,
set_alpha_to_one
=
False
,
steps_offset
=
1
)
model
=
DragPipeline
.
from_pretrained
(
model_path
,
scheduler
=
scheduler
).
to
(
device
)
# call this function to override unet forward function,
# so that intermediate features are returned after forward
model
.
modify_unet_forward
()
# print(model)
# set vae
if
vae_path
!=
"default"
:
model
.
vae
=
AutoencoderKL
.
from_pretrained
(
vae_path
).
to
(
model
.
vae
.
device
,
model
.
vae
.
dtype
)
# initialize parameters
seed
=
42
# random seed used by a lot of people for unknown reason
seed_everything
(
seed
)
args
=
SimpleNamespace
()
args
.
prompt
=
prompt
args
.
points
=
points
args
.
n_inference_step
=
50
args
.
n_actual_inference_step
=
round
(
inversion_strength
*
args
.
n_inference_step
)
args
.
guidance_scale
=
1.0
args
.
unet_feature_idx
=
[
3
]
args
.
r_m
=
1
args
.
r_p
=
3
args
.
lam
=
lam
args
.
end_step
=
end_step
args
.
lr
=
latent_lr
args
.
n_pix_step
=
n_pix_step
full_h
,
full_w
=
source_image
.
shape
[:
2
]
args
.
sup_res_h
=
int
(
0.5
*
full_h
)
args
.
sup_res_w
=
int
(
0.5
*
full_w
)
print
(
args
)
source_image
=
preprocess_image
(
source_image
,
device
)
image_with_clicks
=
preprocess_image
(
image_with_clicks
,
device
)
# set lora
if
lora_path
==
""
:
print
(
"applying default parameters"
)
model
.
unet
.
set_default_attn_processor
()
else
:
print
(
"applying lora: "
+
lora_path
)
model
.
unet
.
load_attn_procs
(
lora_path
)
# invert the source image
# the latent code resolution is too small, only 64*64
invert_code
=
model
.
invert
(
source_image
,
prompt
,
guidance_scale
=
args
.
guidance_scale
,
num_inference_steps
=
args
.
n_inference_step
,
num_actual_inference_steps
=
args
.
n_actual_inference_step
)
mask
=
torch
.
from_numpy
(
mask
).
float
()
/
255.
mask
[
mask
>
0.0
]
=
1.0
mask
=
rearrange
(
mask
,
"h w -> 1 1 h w"
).
cuda
()
mask
=
F
.
interpolate
(
mask
,
(
args
.
sup_res_h
,
args
.
sup_res_w
),
mode
=
"nearest"
)
handle_points
=
[]
target_points
=
[]
# here, the point is in x,y coordinate
for
idx
,
point
in
enumerate
(
points
):
cur_point
=
torch
.
tensor
([
point
[
1
]
/
full_h
*
args
.
sup_res_h
,
point
[
0
]
/
full_w
*
args
.
sup_res_w
])
cur_point
=
torch
.
round
(
cur_point
)
if
idx
%
2
==
0
:
handle_points
.
append
(
cur_point
)
else
:
target_points
.
append
(
cur_point
)
print
(
'handle points:'
,
handle_points
)
print
(
'target points:'
,
target_points
)
init_code
=
invert_code
init_code_orig
=
deepcopy
(
init_code
)
model
.
scheduler
.
set_timesteps
(
args
.
n_inference_step
)
t
=
model
.
scheduler
.
timesteps
[
args
.
n_inference_step
-
args
.
n_actual_inference_step
]
# feature shape: [1280,16,16], [1280,32,32], [640,64,64], [320,64,64]
# update according to the given supervision
updated_init_code
,
h_feature
,
h_features
=
drag_diffusion_update
(
model
,
init_code
,
t
,
handle_points
,
target_points
,
mask
,
args
)
n_move
=
len
(
h_features
)
gen_img_list
=
[]
gen_image
=
model
(
prompt
=
args
.
prompt
,
h_feature
=
h_feature
,
end_step
=
args
.
end_step
,
batch_size
=
2
,
latents
=
torch
.
cat
([
init_code_orig
,
updated_init_code
],
dim
=
0
),
# latents=torch.cat([updated_init_code, updated_init_code], dim=0),
guidance_scale
=
args
.
guidance_scale
,
num_inference_steps
=
args
.
n_inference_step
,
num_actual_inference_steps
=
args
.
n_actual_inference_step
)[
1
].
unsqueeze
(
dim
=
0
)
# resize gen_image into the size of source_image
# we do this because shape of gen_image will be rounded to multipliers of 8
gen_image
=
F
.
interpolate
(
gen_image
,
(
full_h
,
full_w
),
mode
=
'bilinear'
)
copy_gen
=
copy
.
deepcopy
(
gen_image
)
gen_img_list
.
append
(
copy_gen
)
# save the original image, user editing instructions, synthesized image
save_result
=
torch
.
cat
([
source_image
*
0.5
+
0.5
,
torch
.
ones
((
1
,
3
,
full_h
,
25
)).
cuda
(),
image_with_clicks
*
0.5
+
0.5
,
torch
.
ones
((
1
,
3
,
full_h
,
25
)).
cuda
(),
gen_image
[
0
:
1
]
],
dim
=-
1
)
if
not
os
.
path
.
isdir
(
save_dir
):
os
.
mkdir
(
save_dir
)
save_prefix
=
datetime
.
datetime
.
now
().
strftime
(
"%Y-%m-%d-%H%M-%S"
)
save_image
(
gen_image
,
os
.
path
.
join
(
save_dir
,
save_prefix
+
'.png'
))
#
out_image
=
gen_image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()[
0
]
out_image
=
(
out_image
*
255
).
astype
(
np
.
uint8
)
return
out_image
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment