Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
stablediffusion_v2.1_pytorch
Commits
4007efdd
Commit
4007efdd
authored
May 12, 2024
by
lijian6
Browse files
Initial commit
parents
Pipeline
#994
canceled with stages
Changes
138
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3782 additions
and
0 deletions
+3782
-0
ldm/modules/karlo/kakao/models/sr_256_1k.py
ldm/modules/karlo/kakao/models/sr_256_1k.py
+10
-0
ldm/modules/karlo/kakao/models/sr_64_256.py
ldm/modules/karlo/kakao/models/sr_64_256.py
+88
-0
ldm/modules/karlo/kakao/modules/__init__.py
ldm/modules/karlo/kakao/modules/__init__.py
+49
-0
ldm/modules/karlo/kakao/modules/diffusion/gaussian_diffusion.py
...dules/karlo/kakao/modules/diffusion/gaussian_diffusion.py
+828
-0
ldm/modules/karlo/kakao/modules/diffusion/respace.py
ldm/modules/karlo/kakao/modules/diffusion/respace.py
+112
-0
ldm/modules/karlo/kakao/modules/nn.py
ldm/modules/karlo/kakao/modules/nn.py
+114
-0
ldm/modules/karlo/kakao/modules/resample.py
ldm/modules/karlo/kakao/modules/resample.py
+68
-0
ldm/modules/karlo/kakao/modules/unet.py
ldm/modules/karlo/kakao/modules/unet.py
+792
-0
ldm/modules/karlo/kakao/modules/xf.py
ldm/modules/karlo/kakao/modules/xf.py
+231
-0
ldm/modules/karlo/kakao/sampler.py
ldm/modules/karlo/kakao/sampler.py
+272
-0
ldm/modules/karlo/kakao/template.py
ldm/modules/karlo/kakao/template.py
+142
-0
ldm/modules/midas/__init__.py
ldm/modules/midas/__init__.py
+0
-0
ldm/modules/midas/api.py
ldm/modules/midas/api.py
+170
-0
ldm/modules/midas/midas/__init__.py
ldm/modules/midas/midas/__init__.py
+0
-0
ldm/modules/midas/midas/base_model.py
ldm/modules/midas/midas/base_model.py
+16
-0
ldm/modules/midas/midas/blocks.py
ldm/modules/midas/midas/blocks.py
+342
-0
ldm/modules/midas/midas/dpt_depth.py
ldm/modules/midas/midas/dpt_depth.py
+109
-0
ldm/modules/midas/midas/midas_net.py
ldm/modules/midas/midas/midas_net.py
+76
-0
ldm/modules/midas/midas/midas_net_custom.py
ldm/modules/midas/midas/midas_net_custom.py
+129
-0
ldm/modules/midas/midas/transforms.py
ldm/modules/midas/midas/transforms.py
+234
-0
No files found.
ldm/modules/karlo/kakao/models/sr_256_1k.py
0 → 100644
View file @
4007efdd
# ------------------------------------------------------------------------------------
# Karlo-v1.0.alpha
# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
# ------------------------------------------------------------------------------------
from
ldm.modules.karlo.kakao.models.sr_64_256
import
SupRes64to256Progressive
class
SupRes256to1kProgressive
(
SupRes64to256Progressive
):
pass
# no difference currently
ldm/modules/karlo/kakao/models/sr_64_256.py
0 → 100644
View file @
4007efdd
# ------------------------------------------------------------------------------------
# Karlo-v1.0.alpha
# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
# ------------------------------------------------------------------------------------
import
copy
import
torch
from
ldm.modules.karlo.kakao.modules.unet
import
SuperResUNetModel
from
ldm.modules.karlo.kakao.modules
import
create_gaussian_diffusion
class
ImprovedSupRes64to256ProgressiveModel
(
torch
.
nn
.
Module
):
"""
ImprovedSR model fine-tunes the pretrained DDPM-based SR model by using adversarial and perceptual losses.
In specific, the low-resolution sample is iteratively recovered by 6 steps with the frozen pretrained SR model.
In the following additional one step, a seperate fine-tuned model recovers high-frequency details.
This approach greatly improves the fidelity of images of 256x256px, even with small number of reverse steps.
"""
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
_config
=
config
self
.
_diffusion_kwargs
=
dict
(
steps
=
config
.
diffusion
.
steps
,
learn_sigma
=
config
.
diffusion
.
learn_sigma
,
sigma_small
=
config
.
diffusion
.
sigma_small
,
noise_schedule
=
config
.
diffusion
.
noise_schedule
,
use_kl
=
config
.
diffusion
.
use_kl
,
predict_xstart
=
config
.
diffusion
.
predict_xstart
,
rescale_learned_sigmas
=
config
.
diffusion
.
rescale_learned_sigmas
,
)
self
.
model_first_steps
=
SuperResUNetModel
(
in_channels
=
3
,
# auto-changed to 6 inside the model
model_channels
=
config
.
model
.
hparams
.
channels
,
out_channels
=
3
,
num_res_blocks
=
config
.
model
.
hparams
.
depth
,
attention_resolutions
=
(),
# no attention
dropout
=
config
.
model
.
hparams
.
dropout
,
channel_mult
=
config
.
model
.
hparams
.
channels_multiple
,
resblock_updown
=
True
,
use_middle_attention
=
False
,
)
self
.
model_last_step
=
SuperResUNetModel
(
in_channels
=
3
,
# auto-changed to 6 inside the model
model_channels
=
config
.
model
.
hparams
.
channels
,
out_channels
=
3
,
num_res_blocks
=
config
.
model
.
hparams
.
depth
,
attention_resolutions
=
(),
# no attention
dropout
=
config
.
model
.
hparams
.
dropout
,
channel_mult
=
config
.
model
.
hparams
.
channels_multiple
,
resblock_updown
=
True
,
use_middle_attention
=
False
,
)
@
classmethod
def
load_from_checkpoint
(
cls
,
config
,
ckpt_path
,
strict
:
bool
=
True
):
ckpt
=
torch
.
load
(
ckpt_path
,
map_location
=
"cpu"
)[
"state_dict"
]
model
=
cls
(
config
)
model
.
load_state_dict
(
ckpt
,
strict
=
strict
)
return
model
def
get_sample_fn
(
self
,
timestep_respacing
):
diffusion_kwargs
=
copy
.
deepcopy
(
self
.
_diffusion_kwargs
)
diffusion_kwargs
.
update
(
timestep_respacing
=
timestep_respacing
)
diffusion
=
create_gaussian_diffusion
(
**
diffusion_kwargs
)
return
diffusion
.
p_sample_loop_progressive_for_improved_sr
def
forward
(
self
,
low_res
,
timestep_respacing
=
"7"
,
**
kwargs
):
assert
(
timestep_respacing
==
"7"
),
"different respacing method may work, but no guaranteed"
sample_fn
=
self
.
get_sample_fn
(
timestep_respacing
)
sample_outputs
=
sample_fn
(
self
.
model_first_steps
,
self
.
model_last_step
,
shape
=
low_res
.
shape
,
clip_denoised
=
True
,
model_kwargs
=
dict
(
low_res
=
low_res
),
**
kwargs
,
)
for
x
in
sample_outputs
:
sample
=
x
[
"sample"
]
yield
sample
ldm/modules/karlo/kakao/modules/__init__.py
0 → 100644
View file @
4007efdd
# ------------------------------------------------------------------------------------
# Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion)
# ------------------------------------------------------------------------------------
from
.diffusion
import
gaussian_diffusion
as
gd
from
.diffusion.respace
import
(
SpacedDiffusion
,
space_timesteps
,
)
def
create_gaussian_diffusion
(
steps
,
learn_sigma
,
sigma_small
,
noise_schedule
,
use_kl
,
predict_xstart
,
rescale_learned_sigmas
,
timestep_respacing
,
):
betas
=
gd
.
get_named_beta_schedule
(
noise_schedule
,
steps
)
if
use_kl
:
loss_type
=
gd
.
LossType
.
RESCALED_KL
elif
rescale_learned_sigmas
:
loss_type
=
gd
.
LossType
.
RESCALED_MSE
else
:
loss_type
=
gd
.
LossType
.
MSE
if
not
timestep_respacing
:
timestep_respacing
=
[
steps
]
return
SpacedDiffusion
(
use_timesteps
=
space_timesteps
(
steps
,
timestep_respacing
),
betas
=
betas
,
model_mean_type
=
(
gd
.
ModelMeanType
.
EPSILON
if
not
predict_xstart
else
gd
.
ModelMeanType
.
START_X
),
model_var_type
=
(
(
gd
.
ModelVarType
.
FIXED_LARGE
if
not
sigma_small
else
gd
.
ModelVarType
.
FIXED_SMALL
)
if
not
learn_sigma
else
gd
.
ModelVarType
.
LEARNED_RANGE
),
loss_type
=
loss_type
,
)
ldm/modules/karlo/kakao/modules/diffusion/gaussian_diffusion.py
0 → 100644
View file @
4007efdd
# ------------------------------------------------------------------------------------
# Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion)
# ------------------------------------------------------------------------------------
import
enum
import
math
import
numpy
as
np
import
torch
as
th
def
_warmup_beta
(
beta_start
,
beta_end
,
num_diffusion_timesteps
,
warmup_frac
):
betas
=
beta_end
*
np
.
ones
(
num_diffusion_timesteps
,
dtype
=
np
.
float64
)
warmup_time
=
int
(
num_diffusion_timesteps
*
warmup_frac
)
betas
[:
warmup_time
]
=
np
.
linspace
(
beta_start
,
beta_end
,
warmup_time
,
dtype
=
np
.
float64
)
return
betas
def
get_beta_schedule
(
beta_schedule
,
*
,
beta_start
,
beta_end
,
num_diffusion_timesteps
):
"""
This is the deprecated API for creating beta schedules.
See get_named_beta_schedule() for the new library of schedules.
"""
if
beta_schedule
==
"quad"
:
betas
=
(
np
.
linspace
(
beta_start
**
0.5
,
beta_end
**
0.5
,
num_diffusion_timesteps
,
dtype
=
np
.
float64
,
)
**
2
)
elif
beta_schedule
==
"linear"
:
betas
=
np
.
linspace
(
beta_start
,
beta_end
,
num_diffusion_timesteps
,
dtype
=
np
.
float64
)
elif
beta_schedule
==
"warmup10"
:
betas
=
_warmup_beta
(
beta_start
,
beta_end
,
num_diffusion_timesteps
,
0.1
)
elif
beta_schedule
==
"warmup50"
:
betas
=
_warmup_beta
(
beta_start
,
beta_end
,
num_diffusion_timesteps
,
0.5
)
elif
beta_schedule
==
"const"
:
betas
=
beta_end
*
np
.
ones
(
num_diffusion_timesteps
,
dtype
=
np
.
float64
)
elif
beta_schedule
==
"jsd"
:
# 1/T, 1/(T-1), 1/(T-2), ..., 1
betas
=
1.0
/
np
.
linspace
(
num_diffusion_timesteps
,
1
,
num_diffusion_timesteps
,
dtype
=
np
.
float64
)
else
:
raise
NotImplementedError
(
beta_schedule
)
assert
betas
.
shape
==
(
num_diffusion_timesteps
,)
return
betas
def
get_named_beta_schedule
(
schedule_name
,
num_diffusion_timesteps
):
"""
Get a pre-defined beta schedule for the given name.
The beta schedule library consists of beta schedules which remain similar
in the limit of num_diffusion_timesteps.
Beta schedules may be added, but should not be removed or changed once
they are committed to maintain backwards compatibility.
"""
if
schedule_name
==
"linear"
:
# Linear schedule from Ho et al, extended to work for any number of
# diffusion steps.
scale
=
1000
/
num_diffusion_timesteps
return
get_beta_schedule
(
"linear"
,
beta_start
=
scale
*
0.0001
,
beta_end
=
scale
*
0.02
,
num_diffusion_timesteps
=
num_diffusion_timesteps
,
)
elif
schedule_name
==
"squaredcos_cap_v2"
:
return
betas_for_alpha_bar
(
num_diffusion_timesteps
,
lambda
t
:
math
.
cos
((
t
+
0.008
)
/
1.008
*
math
.
pi
/
2
)
**
2
,
)
else
:
raise
NotImplementedError
(
f
"unknown beta schedule:
{
schedule_name
}
"
)
def
betas_for_alpha_bar
(
num_diffusion_timesteps
,
alpha_bar
,
max_beta
=
0.999
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas
=
[]
for
i
in
range
(
num_diffusion_timesteps
):
t1
=
i
/
num_diffusion_timesteps
t2
=
(
i
+
1
)
/
num_diffusion_timesteps
betas
.
append
(
min
(
1
-
alpha_bar
(
t2
)
/
alpha_bar
(
t1
),
max_beta
))
return
np
.
array
(
betas
)
class
ModelMeanType
(
enum
.
Enum
):
"""
Which type of output the model predicts.
"""
PREVIOUS_X
=
enum
.
auto
()
# the model predicts x_{t-1}
START_X
=
enum
.
auto
()
# the model predicts x_0
EPSILON
=
enum
.
auto
()
# the model predicts epsilon
class
ModelVarType
(
enum
.
Enum
):
"""
What is used as the model's output variance.
The LEARNED_RANGE option has been added to allow the model to predict
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
"""
LEARNED
=
enum
.
auto
()
FIXED_SMALL
=
enum
.
auto
()
FIXED_LARGE
=
enum
.
auto
()
LEARNED_RANGE
=
enum
.
auto
()
class
LossType
(
enum
.
Enum
):
MSE
=
enum
.
auto
()
# use raw MSE loss (and KL when learning variances)
RESCALED_MSE
=
(
enum
.
auto
()
)
# use raw MSE loss (with RESCALED_KL when learning variances)
KL
=
enum
.
auto
()
# use the variational lower-bound
RESCALED_KL
=
enum
.
auto
()
# like KL, but rescale to estimate the full VLB
def
is_vb
(
self
):
return
self
==
LossType
.
KL
or
self
==
LossType
.
RESCALED_KL
class
GaussianDiffusion
(
th
.
nn
.
Module
):
"""
Utilities for training and sampling diffusion models.
Original ported from this codebase:
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
:param betas: a 1-D numpy array of betas for each diffusion timestep,
starting at T and going to 1.
"""
def
__init__
(
self
,
*
,
betas
,
model_mean_type
,
model_var_type
,
loss_type
,
):
super
(
GaussianDiffusion
,
self
).
__init__
()
self
.
model_mean_type
=
model_mean_type
self
.
model_var_type
=
model_var_type
self
.
loss_type
=
loss_type
# Use float64 for accuracy.
betas
=
np
.
array
(
betas
,
dtype
=
np
.
float64
)
assert
len
(
betas
.
shape
)
==
1
,
"betas must be 1-D"
assert
(
betas
>
0
).
all
()
and
(
betas
<=
1
).
all
()
self
.
num_timesteps
=
int
(
betas
.
shape
[
0
])
alphas
=
1.0
-
betas
alphas_cumprod
=
np
.
cumprod
(
alphas
,
axis
=
0
)
alphas_cumprod_prev
=
np
.
append
(
1.0
,
alphas_cumprod
[:
-
1
])
alphas_cumprod_next
=
np
.
append
(
alphas_cumprod
[
1
:],
0.0
)
assert
alphas_cumprod_prev
.
shape
==
(
self
.
num_timesteps
,)
# calculations for diffusion q(x_t | x_{t-1}) and others
sqrt_alphas_cumprod
=
np
.
sqrt
(
alphas_cumprod
)
sqrt_one_minus_alphas_cumprod
=
np
.
sqrt
(
1.0
-
alphas_cumprod
)
log_one_minus_alphas_cumprod
=
np
.
log
(
1.0
-
alphas_cumprod
)
sqrt_recip_alphas_cumprod
=
np
.
sqrt
(
1.0
/
alphas_cumprod
)
sqrt_recipm1_alphas_cumprod
=
np
.
sqrt
(
1.0
/
alphas_cumprod
-
1
)
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance
=
(
betas
*
(
1.0
-
alphas_cumprod_prev
)
/
(
1.0
-
alphas_cumprod
)
)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
posterior_log_variance_clipped
=
np
.
log
(
np
.
append
(
posterior_variance
[
1
],
posterior_variance
[
1
:])
)
posterior_mean_coef1
=
(
betas
*
np
.
sqrt
(
alphas_cumprod_prev
)
/
(
1.0
-
alphas_cumprod
)
)
posterior_mean_coef2
=
(
(
1.0
-
alphas_cumprod_prev
)
*
np
.
sqrt
(
alphas
)
/
(
1.0
-
alphas_cumprod
)
)
self
.
register_buffer
(
"betas"
,
th
.
from_numpy
(
betas
),
persistent
=
False
)
self
.
register_buffer
(
"alphas_cumprod"
,
th
.
from_numpy
(
alphas_cumprod
),
persistent
=
False
)
self
.
register_buffer
(
"alphas_cumprod_prev"
,
th
.
from_numpy
(
alphas_cumprod_prev
),
persistent
=
False
)
self
.
register_buffer
(
"alphas_cumprod_next"
,
th
.
from_numpy
(
alphas_cumprod_next
),
persistent
=
False
)
self
.
register_buffer
(
"sqrt_alphas_cumprod"
,
th
.
from_numpy
(
sqrt_alphas_cumprod
),
persistent
=
False
)
self
.
register_buffer
(
"sqrt_one_minus_alphas_cumprod"
,
th
.
from_numpy
(
sqrt_one_minus_alphas_cumprod
),
persistent
=
False
,
)
self
.
register_buffer
(
"log_one_minus_alphas_cumprod"
,
th
.
from_numpy
(
log_one_minus_alphas_cumprod
),
persistent
=
False
,
)
self
.
register_buffer
(
"sqrt_recip_alphas_cumprod"
,
th
.
from_numpy
(
sqrt_recip_alphas_cumprod
),
persistent
=
False
,
)
self
.
register_buffer
(
"sqrt_recipm1_alphas_cumprod"
,
th
.
from_numpy
(
sqrt_recipm1_alphas_cumprod
),
persistent
=
False
,
)
self
.
register_buffer
(
"posterior_variance"
,
th
.
from_numpy
(
posterior_variance
),
persistent
=
False
)
self
.
register_buffer
(
"posterior_log_variance_clipped"
,
th
.
from_numpy
(
posterior_log_variance_clipped
),
persistent
=
False
,
)
self
.
register_buffer
(
"posterior_mean_coef1"
,
th
.
from_numpy
(
posterior_mean_coef1
),
persistent
=
False
,
)
self
.
register_buffer
(
"posterior_mean_coef2"
,
th
.
from_numpy
(
posterior_mean_coef2
),
persistent
=
False
,
)
def
q_mean_variance
(
self
,
x_start
,
t
):
"""
Get the distribution q(x_t | x_0).
:param x_start: the [N x C x ...] tensor of noiseless inputs.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
"""
mean
=
(
_extract_into_tensor
(
self
.
sqrt_alphas_cumprod
,
t
,
x_start
.
shape
)
*
x_start
)
variance
=
_extract_into_tensor
(
1.0
-
self
.
alphas_cumprod
,
t
,
x_start
.
shape
)
log_variance
=
_extract_into_tensor
(
self
.
log_one_minus_alphas_cumprod
,
t
,
x_start
.
shape
)
return
mean
,
variance
,
log_variance
def
q_sample
(
self
,
x_start
,
t
,
noise
=
None
):
"""
Diffuse the data for a given number of diffusion steps.
In other words, sample from q(x_t | x_0).
:param x_start: the initial data batch.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:param noise: if specified, the split-out normal noise.
:return: A noisy version of x_start.
"""
if
noise
is
None
:
noise
=
th
.
randn_like
(
x_start
)
assert
noise
.
shape
==
x_start
.
shape
return
(
_extract_into_tensor
(
self
.
sqrt_alphas_cumprod
,
t
,
x_start
.
shape
)
*
x_start
+
_extract_into_tensor
(
self
.
sqrt_one_minus_alphas_cumprod
,
t
,
x_start
.
shape
)
*
noise
)
def
q_posterior_mean_variance
(
self
,
x_start
,
x_t
,
t
):
"""
Compute the mean and variance of the diffusion posterior:
q(x_{t-1} | x_t, x_0)
"""
assert
x_start
.
shape
==
x_t
.
shape
posterior_mean
=
(
_extract_into_tensor
(
self
.
posterior_mean_coef1
,
t
,
x_t
.
shape
)
*
x_start
+
_extract_into_tensor
(
self
.
posterior_mean_coef2
,
t
,
x_t
.
shape
)
*
x_t
)
posterior_variance
=
_extract_into_tensor
(
self
.
posterior_variance
,
t
,
x_t
.
shape
)
posterior_log_variance_clipped
=
_extract_into_tensor
(
self
.
posterior_log_variance_clipped
,
t
,
x_t
.
shape
)
assert
(
posterior_mean
.
shape
[
0
]
==
posterior_variance
.
shape
[
0
]
==
posterior_log_variance_clipped
.
shape
[
0
]
==
x_start
.
shape
[
0
]
)
return
posterior_mean
,
posterior_variance
,
posterior_log_variance_clipped
def
p_mean_variance
(
self
,
model
,
x
,
t
,
clip_denoised
=
True
,
denoised_fn
=
None
,
model_kwargs
=
None
,
**
ignore_kwargs
,
):
"""
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
the initial x, x_0.
:param model: the model, which takes a signal and a batch of timesteps
as input.
:param x: the [N x C x ...] tensor at time t.
:param t: a 1-D Tensor of timesteps.
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample. Applies before
clip_denoised.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict with the following keys:
- 'mean': the model mean output.
- 'variance': the model variance output.
- 'log_variance': the log of 'variance'.
- 'pred_xstart': the prediction for x_0.
"""
if
model_kwargs
is
None
:
model_kwargs
=
{}
B
,
C
=
x
.
shape
[:
2
]
assert
t
.
shape
==
(
B
,)
model_output
=
model
(
x
,
t
,
**
model_kwargs
)
if
isinstance
(
model_output
,
tuple
):
model_output
,
extra
=
model_output
else
:
extra
=
None
if
self
.
model_var_type
in
[
ModelVarType
.
LEARNED
,
ModelVarType
.
LEARNED_RANGE
]:
assert
model_output
.
shape
==
(
B
,
C
*
2
,
*
x
.
shape
[
2
:])
model_output
,
model_var_values
=
th
.
split
(
model_output
,
C
,
dim
=
1
)
if
self
.
model_var_type
==
ModelVarType
.
LEARNED
:
model_log_variance
=
model_var_values
model_variance
=
th
.
exp
(
model_log_variance
)
else
:
min_log
=
_extract_into_tensor
(
self
.
posterior_log_variance_clipped
,
t
,
x
.
shape
)
max_log
=
_extract_into_tensor
(
th
.
log
(
self
.
betas
),
t
,
x
.
shape
)
# The model_var_values is [-1, 1] for [min_var, max_var].
frac
=
(
model_var_values
+
1
)
/
2
model_log_variance
=
frac
*
max_log
+
(
1
-
frac
)
*
min_log
model_variance
=
th
.
exp
(
model_log_variance
)
else
:
model_variance
,
model_log_variance
=
{
# for fixedlarge, we set the initial (log-)variance like so
# to get a better decoder log likelihood.
ModelVarType
.
FIXED_LARGE
:
(
th
.
cat
([
self
.
posterior_variance
[
1
][
None
],
self
.
betas
[
1
:]]),
th
.
log
(
th
.
cat
([
self
.
posterior_variance
[
1
][
None
],
self
.
betas
[
1
:]])),
),
ModelVarType
.
FIXED_SMALL
:
(
self
.
posterior_variance
,
self
.
posterior_log_variance_clipped
,
),
}[
self
.
model_var_type
]
model_variance
=
_extract_into_tensor
(
model_variance
,
t
,
x
.
shape
)
model_log_variance
=
_extract_into_tensor
(
model_log_variance
,
t
,
x
.
shape
)
def
process_xstart
(
x
):
if
denoised_fn
is
not
None
:
x
=
denoised_fn
(
x
)
if
clip_denoised
:
return
x
.
clamp
(
-
1
,
1
)
return
x
if
self
.
model_mean_type
==
ModelMeanType
.
PREVIOUS_X
:
pred_xstart
=
process_xstart
(
self
.
_predict_xstart_from_xprev
(
x_t
=
x
,
t
=
t
,
xprev
=
model_output
)
)
model_mean
=
model_output
elif
self
.
model_mean_type
in
[
ModelMeanType
.
START_X
,
ModelMeanType
.
EPSILON
]:
if
self
.
model_mean_type
==
ModelMeanType
.
START_X
:
pred_xstart
=
process_xstart
(
model_output
)
else
:
pred_xstart
=
process_xstart
(
self
.
_predict_xstart_from_eps
(
x_t
=
x
,
t
=
t
,
eps
=
model_output
)
)
model_mean
,
_
,
_
=
self
.
q_posterior_mean_variance
(
x_start
=
pred_xstart
,
x_t
=
x
,
t
=
t
)
else
:
raise
NotImplementedError
(
self
.
model_mean_type
)
assert
(
model_mean
.
shape
==
model_log_variance
.
shape
==
pred_xstart
.
shape
==
x
.
shape
)
return
{
"mean"
:
model_mean
,
"variance"
:
model_variance
,
"log_variance"
:
model_log_variance
,
"pred_xstart"
:
pred_xstart
,
}
def
_predict_xstart_from_eps
(
self
,
x_t
,
t
,
eps
):
assert
x_t
.
shape
==
eps
.
shape
return
(
_extract_into_tensor
(
self
.
sqrt_recip_alphas_cumprod
,
t
,
x_t
.
shape
)
*
x_t
-
_extract_into_tensor
(
self
.
sqrt_recipm1_alphas_cumprod
,
t
,
x_t
.
shape
)
*
eps
)
def
_predict_eps_from_xstart
(
self
,
x_t
,
t
,
pred_xstart
):
return
(
_extract_into_tensor
(
self
.
sqrt_recip_alphas_cumprod
,
t
,
x_t
.
shape
)
*
x_t
-
pred_xstart
)
/
_extract_into_tensor
(
self
.
sqrt_recipm1_alphas_cumprod
,
t
,
x_t
.
shape
)
def
condition_mean
(
self
,
cond_fn
,
p_mean_var
,
x
,
t
,
model_kwargs
=
None
):
"""
Compute the mean for the previous step, given a function cond_fn that
computes the gradient of a conditional log probability with respect to
x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
condition on y.
This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
"""
gradient
=
cond_fn
(
x
,
t
,
**
model_kwargs
)
new_mean
=
(
p_mean_var
[
"mean"
].
float
()
+
p_mean_var
[
"variance"
]
*
gradient
.
float
()
)
return
new_mean
def
condition_score
(
self
,
cond_fn
,
p_mean_var
,
x
,
t
,
model_kwargs
=
None
):
"""
Compute what the p_mean_variance output would have been, should the
model's score function be conditioned by cond_fn.
See condition_mean() for details on cond_fn.
Unlike condition_mean(), this instead uses the conditioning strategy
from Song et al (2020).
"""
alpha_bar
=
_extract_into_tensor
(
self
.
alphas_cumprod
,
t
,
x
.
shape
)
eps
=
self
.
_predict_eps_from_xstart
(
x
,
t
,
p_mean_var
[
"pred_xstart"
])
eps
=
eps
-
(
1
-
alpha_bar
).
sqrt
()
*
cond_fn
(
x
,
t
,
**
model_kwargs
)
out
=
p_mean_var
.
copy
()
out
[
"pred_xstart"
]
=
self
.
_predict_xstart_from_eps
(
x
,
t
,
eps
)
out
[
"mean"
],
_
,
_
=
self
.
q_posterior_mean_variance
(
x_start
=
out
[
"pred_xstart"
],
x_t
=
x
,
t
=
t
)
return
out
def
p_sample
(
self
,
model
,
x
,
t
,
clip_denoised
=
True
,
denoised_fn
=
None
,
cond_fn
=
None
,
model_kwargs
=
None
,
):
"""
Sample x_{t-1} from the model at the given timestep.
:param model: the model to sample from.
:param x: the current tensor at x_{t-1}.
:param t: the value of t, starting at 0 for the first diffusion step.
:param clip_denoised: if True, clip the x_start prediction to [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample.
:param cond_fn: if not None, this is a gradient function that acts
similarly to the model.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict containing the following keys:
- 'sample': a random sample from the model.
- 'pred_xstart': a prediction of x_0.
"""
out
=
self
.
p_mean_variance
(
model
,
x
,
t
,
clip_denoised
=
clip_denoised
,
denoised_fn
=
denoised_fn
,
model_kwargs
=
model_kwargs
,
)
noise
=
th
.
randn_like
(
x
)
nonzero_mask
=
(
(
t
!=
0
).
float
().
view
(
-
1
,
*
([
1
]
*
(
len
(
x
.
shape
)
-
1
)))
)
# no noise when t == 0
if
cond_fn
is
not
None
:
out
[
"mean"
]
=
self
.
condition_mean
(
cond_fn
,
out
,
x
,
t
,
model_kwargs
=
model_kwargs
)
sample
=
out
[
"mean"
]
+
nonzero_mask
*
th
.
exp
(
0.5
*
out
[
"log_variance"
])
*
noise
return
{
"sample"
:
sample
,
"pred_xstart"
:
out
[
"pred_xstart"
]}
def
p_sample_loop
(
self
,
model
,
shape
,
noise
=
None
,
clip_denoised
=
True
,
denoised_fn
=
None
,
cond_fn
=
None
,
model_kwargs
=
None
,
device
=
None
,
progress
=
False
,
):
"""
Generate samples from the model.
:param model: the model module.
:param shape: the shape of the samples, (N, C, H, W).
:param noise: if specified, the noise from the encoder to sample.
Should be of the same shape as `shape`.
:param clip_denoised: if True, clip x_start predictions to [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample.
:param cond_fn: if not None, this is a gradient function that acts
similarly to the model.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:param device: if specified, the device to create the samples on.
If not specified, use a model parameter's device.
:param progress: if True, show a tqdm progress bar.
:return: a non-differentiable batch of samples.
"""
final
=
None
for
sample
in
self
.
p_sample_loop_progressive
(
model
,
shape
,
noise
=
noise
,
clip_denoised
=
clip_denoised
,
denoised_fn
=
denoised_fn
,
cond_fn
=
cond_fn
,
model_kwargs
=
model_kwargs
,
device
=
device
,
progress
=
progress
,
):
final
=
sample
return
final
[
"sample"
]
def
p_sample_loop_progressive
(
self
,
model
,
shape
,
noise
=
None
,
clip_denoised
=
True
,
denoised_fn
=
None
,
cond_fn
=
None
,
model_kwargs
=
None
,
device
=
None
,
progress
=
False
,
):
"""
Generate samples from the model and yield intermediate samples from
each timestep of diffusion.
Arguments are the same as p_sample_loop().
Returns a generator over dicts, where each dict is the return value of
p_sample().
"""
if
device
is
None
:
device
=
next
(
model
.
parameters
()).
device
assert
isinstance
(
shape
,
(
tuple
,
list
))
if
noise
is
not
None
:
img
=
noise
else
:
img
=
th
.
randn
(
*
shape
,
device
=
device
)
indices
=
list
(
range
(
self
.
num_timesteps
))[::
-
1
]
if
progress
:
# Lazy import so that we don't depend on tqdm.
from
tqdm.auto
import
tqdm
indices
=
tqdm
(
indices
)
for
idx
,
i
in
enumerate
(
indices
):
t
=
th
.
tensor
([
i
]
*
shape
[
0
],
device
=
device
)
with
th
.
no_grad
():
out
=
self
.
p_sample
(
model
,
img
,
t
,
clip_denoised
=
clip_denoised
,
denoised_fn
=
denoised_fn
,
cond_fn
=
cond_fn
,
model_kwargs
=
model_kwargs
,
)
yield
out
img
=
out
[
"sample"
]
def
p_sample_loop_progressive_for_improved_sr
(
self
,
model
,
model_aux
,
shape
,
noise
=
None
,
clip_denoised
=
True
,
denoised_fn
=
None
,
cond_fn
=
None
,
model_kwargs
=
None
,
device
=
None
,
progress
=
False
,
):
"""
Modified version of p_sample_loop_progressive for sampling from the improved sr model
"""
if
device
is
None
:
device
=
next
(
model
.
parameters
()).
device
assert
isinstance
(
shape
,
(
tuple
,
list
))
if
noise
is
not
None
:
img
=
noise
else
:
img
=
th
.
randn
(
*
shape
,
device
=
device
)
indices
=
list
(
range
(
self
.
num_timesteps
))[::
-
1
]
if
progress
:
# Lazy import so that we don't depend on tqdm.
from
tqdm.auto
import
tqdm
indices
=
tqdm
(
indices
)
for
idx
,
i
in
enumerate
(
indices
):
t
=
th
.
tensor
([
i
]
*
shape
[
0
],
device
=
device
)
with
th
.
no_grad
():
out
=
self
.
p_sample
(
model_aux
if
len
(
indices
)
-
1
==
idx
else
model
,
img
,
t
,
clip_denoised
=
clip_denoised
,
denoised_fn
=
denoised_fn
,
cond_fn
=
cond_fn
,
model_kwargs
=
model_kwargs
,
)
yield
out
img
=
out
[
"sample"
]
def
ddim_sample
(
self
,
model
,
x
,
t
,
clip_denoised
=
True
,
denoised_fn
=
None
,
cond_fn
=
None
,
model_kwargs
=
None
,
eta
=
0.0
,
):
"""
Sample x_{t-1} from the model using DDIM.
Same usage as p_sample().
"""
out
=
self
.
p_mean_variance
(
model
,
x
,
t
,
clip_denoised
=
clip_denoised
,
denoised_fn
=
denoised_fn
,
model_kwargs
=
model_kwargs
,
)
if
cond_fn
is
not
None
:
out
=
self
.
condition_score
(
cond_fn
,
out
,
x
,
t
,
model_kwargs
=
model_kwargs
)
# Usually our model outputs epsilon, but we re-derive it
# in case we used x_start or x_prev prediction.
eps
=
self
.
_predict_eps_from_xstart
(
x
,
t
,
out
[
"pred_xstart"
])
alpha_bar
=
_extract_into_tensor
(
self
.
alphas_cumprod
,
t
,
x
.
shape
)
alpha_bar_prev
=
_extract_into_tensor
(
self
.
alphas_cumprod_prev
,
t
,
x
.
shape
)
sigma
=
(
eta
*
th
.
sqrt
((
1
-
alpha_bar_prev
)
/
(
1
-
alpha_bar
))
*
th
.
sqrt
(
1
-
alpha_bar
/
alpha_bar_prev
)
)
# Equation 12.
noise
=
th
.
randn_like
(
x
)
mean_pred
=
(
out
[
"pred_xstart"
]
*
th
.
sqrt
(
alpha_bar_prev
)
+
th
.
sqrt
(
1
-
alpha_bar_prev
-
sigma
**
2
)
*
eps
)
nonzero_mask
=
(
(
t
!=
0
).
float
().
view
(
-
1
,
*
([
1
]
*
(
len
(
x
.
shape
)
-
1
)))
)
# no noise when t == 0
sample
=
mean_pred
+
nonzero_mask
*
sigma
*
noise
return
{
"sample"
:
sample
,
"pred_xstart"
:
out
[
"pred_xstart"
]}
def
ddim_reverse_sample
(
self
,
model
,
x
,
t
,
clip_denoised
=
True
,
denoised_fn
=
None
,
cond_fn
=
None
,
model_kwargs
=
None
,
eta
=
0.0
,
):
"""
Sample x_{t+1} from the model using DDIM reverse ODE.
"""
assert
eta
==
0.0
,
"Reverse ODE only for deterministic path"
out
=
self
.
p_mean_variance
(
model
,
x
,
t
,
clip_denoised
=
clip_denoised
,
denoised_fn
=
denoised_fn
,
model_kwargs
=
model_kwargs
,
)
if
cond_fn
is
not
None
:
out
=
self
.
condition_score
(
cond_fn
,
out
,
x
,
t
,
model_kwargs
=
model_kwargs
)
# Usually our model outputs epsilon, but we re-derive it
# in case we used x_start or x_prev prediction.
eps
=
(
_extract_into_tensor
(
self
.
sqrt_recip_alphas_cumprod
,
t
,
x
.
shape
)
*
x
-
out
[
"pred_xstart"
]
)
/
_extract_into_tensor
(
self
.
sqrt_recipm1_alphas_cumprod
,
t
,
x
.
shape
)
alpha_bar_next
=
_extract_into_tensor
(
self
.
alphas_cumprod_next
,
t
,
x
.
shape
)
# Equation 12. reversed
mean_pred
=
(
out
[
"pred_xstart"
]
*
th
.
sqrt
(
alpha_bar_next
)
+
th
.
sqrt
(
1
-
alpha_bar_next
)
*
eps
)
return
{
"sample"
:
mean_pred
,
"pred_xstart"
:
out
[
"pred_xstart"
]}
def
ddim_sample_loop
(
self
,
model
,
shape
,
noise
=
None
,
clip_denoised
=
True
,
denoised_fn
=
None
,
cond_fn
=
None
,
model_kwargs
=
None
,
device
=
None
,
progress
=
False
,
eta
=
0.0
,
):
"""
Generate samples from the model using DDIM.
Same usage as p_sample_loop().
"""
final
=
None
for
sample
in
self
.
ddim_sample_loop_progressive
(
model
,
shape
,
noise
=
noise
,
clip_denoised
=
clip_denoised
,
denoised_fn
=
denoised_fn
,
cond_fn
=
cond_fn
,
model_kwargs
=
model_kwargs
,
device
=
device
,
progress
=
progress
,
eta
=
eta
,
):
final
=
sample
return
final
[
"sample"
]
def
ddim_sample_loop_progressive
(
self
,
model
,
shape
,
noise
=
None
,
clip_denoised
=
True
,
denoised_fn
=
None
,
cond_fn
=
None
,
model_kwargs
=
None
,
device
=
None
,
progress
=
False
,
eta
=
0.0
,
):
"""
Use DDIM to sample from the model and yield intermediate samples from
each timestep of DDIM.
Same usage as p_sample_loop_progressive().
"""
if
device
is
None
:
device
=
next
(
model
.
parameters
()).
device
assert
isinstance
(
shape
,
(
tuple
,
list
))
if
noise
is
not
None
:
img
=
noise
else
:
img
=
th
.
randn
(
*
shape
,
device
=
device
)
indices
=
list
(
range
(
self
.
num_timesteps
))[::
-
1
]
if
progress
:
# Lazy import so that we don't depend on tqdm.
from
tqdm.auto
import
tqdm
indices
=
tqdm
(
indices
)
for
i
in
indices
:
t
=
th
.
tensor
([
i
]
*
shape
[
0
],
device
=
device
)
with
th
.
no_grad
():
out
=
self
.
ddim_sample
(
model
,
img
,
t
,
clip_denoised
=
clip_denoised
,
denoised_fn
=
denoised_fn
,
cond_fn
=
cond_fn
,
model_kwargs
=
model_kwargs
,
eta
=
eta
,
)
yield
out
img
=
out
[
"sample"
]
def
_extract_into_tensor
(
arr
,
timesteps
,
broadcast_shape
):
"""
Extract values from a 1-D numpy array for a batch of indices.
:param arr: the 1-D numpy array.
:param timesteps: a tensor of indices into the array to extract.
:param broadcast_shape: a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
res
=
arr
.
to
(
device
=
timesteps
.
device
)[
timesteps
].
float
()
while
len
(
res
.
shape
)
<
len
(
broadcast_shape
):
res
=
res
[...,
None
]
return
res
+
th
.
zeros
(
broadcast_shape
,
device
=
timesteps
.
device
)
ldm/modules/karlo/kakao/modules/diffusion/respace.py
0 → 100644
View file @
4007efdd
# ------------------------------------------------------------------------------------
# Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion)
# ------------------------------------------------------------------------------------
import
torch
as
th
from
.gaussian_diffusion
import
GaussianDiffusion
def
space_timesteps
(
num_timesteps
,
section_counts
):
"""
Create a list of timesteps to use from an original diffusion process,
given the number of timesteps we want to take from equally-sized portions
of the original process.
For example, if there's 300 timesteps and the section counts are [10,15,20]
then the first 100 timesteps are strided to be 10 timesteps, the second 100
are strided to be 15 timesteps, and the final 100 are strided to be 20.
:param num_timesteps: the number of diffusion steps in the original
process to divide up.
:param section_counts: either a list of numbers, or a string containing
comma-separated numbers, indicating the step count
per section. As a special case, use "ddimN" where N
is a number of steps to use the striding from the
DDIM paper.
:return: a set of diffusion steps from the original process to use.
"""
if
isinstance
(
section_counts
,
str
):
if
section_counts
.
startswith
(
"ddim"
):
desired_count
=
int
(
section_counts
[
len
(
"ddim"
)
:])
for
i
in
range
(
1
,
num_timesteps
):
if
len
(
range
(
0
,
num_timesteps
,
i
))
==
desired_count
:
return
set
(
range
(
0
,
num_timesteps
,
i
))
raise
ValueError
(
f
"cannot create exactly
{
num_timesteps
}
steps with an integer stride"
)
elif
section_counts
==
"fast27"
:
steps
=
space_timesteps
(
num_timesteps
,
"10,10,3,2,2"
)
# Help reduce DDIM artifacts from noisiest timesteps.
steps
.
remove
(
num_timesteps
-
1
)
steps
.
add
(
num_timesteps
-
3
)
return
steps
section_counts
=
[
int
(
x
)
for
x
in
section_counts
.
split
(
","
)]
size_per
=
num_timesteps
//
len
(
section_counts
)
extra
=
num_timesteps
%
len
(
section_counts
)
start_idx
=
0
all_steps
=
[]
for
i
,
section_count
in
enumerate
(
section_counts
):
size
=
size_per
+
(
1
if
i
<
extra
else
0
)
if
size
<
section_count
:
raise
ValueError
(
f
"cannot divide section of
{
size
}
steps into
{
section_count
}
"
)
if
section_count
<=
1
:
frac_stride
=
1
else
:
frac_stride
=
(
size
-
1
)
/
(
section_count
-
1
)
cur_idx
=
0.0
taken_steps
=
[]
for
_
in
range
(
section_count
):
taken_steps
.
append
(
start_idx
+
round
(
cur_idx
))
cur_idx
+=
frac_stride
all_steps
+=
taken_steps
start_idx
+=
size
return
set
(
all_steps
)
class
SpacedDiffusion
(
GaussianDiffusion
):
"""
A diffusion process which can skip steps in a base diffusion process.
:param use_timesteps: a collection (sequence or set) of timesteps from the
original diffusion process to retain.
:param kwargs: the kwargs to create the base diffusion process.
"""
def
__init__
(
self
,
use_timesteps
,
**
kwargs
):
self
.
use_timesteps
=
set
(
use_timesteps
)
self
.
original_num_steps
=
len
(
kwargs
[
"betas"
])
base_diffusion
=
GaussianDiffusion
(
**
kwargs
)
# pylint: disable=missing-kwoa
last_alpha_cumprod
=
1.0
new_betas
=
[]
timestep_map
=
[]
for
i
,
alpha_cumprod
in
enumerate
(
base_diffusion
.
alphas_cumprod
):
if
i
in
self
.
use_timesteps
:
new_betas
.
append
(
1
-
alpha_cumprod
/
last_alpha_cumprod
)
last_alpha_cumprod
=
alpha_cumprod
timestep_map
.
append
(
i
)
kwargs
[
"betas"
]
=
th
.
tensor
(
new_betas
).
numpy
()
super
().
__init__
(
**
kwargs
)
self
.
register_buffer
(
"timestep_map"
,
th
.
tensor
(
timestep_map
),
persistent
=
False
)
def
p_mean_variance
(
self
,
model
,
*
args
,
**
kwargs
):
return
super
().
p_mean_variance
(
self
.
_wrap_model
(
model
),
*
args
,
**
kwargs
)
def
condition_mean
(
self
,
cond_fn
,
*
args
,
**
kwargs
):
return
super
().
condition_mean
(
self
.
_wrap_model
(
cond_fn
),
*
args
,
**
kwargs
)
def
condition_score
(
self
,
cond_fn
,
*
args
,
**
kwargs
):
return
super
().
condition_score
(
self
.
_wrap_model
(
cond_fn
),
*
args
,
**
kwargs
)
def
_wrap_model
(
self
,
model
):
def
wrapped
(
x
,
ts
,
**
kwargs
):
ts_cpu
=
ts
.
detach
().
to
(
"cpu"
)
return
model
(
x
,
self
.
timestep_map
[
ts_cpu
].
to
(
device
=
ts
.
device
,
dtype
=
ts
.
dtype
),
**
kwargs
)
return
wrapped
ldm/modules/karlo/kakao/modules/nn.py
0 → 100644
View file @
4007efdd
# ------------------------------------------------------------------------------------
# Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion)
# ------------------------------------------------------------------------------------
import
math
import
torch
as
th
import
torch.nn
as
nn
import
torch.nn.functional
as
F
class
GroupNorm32
(
nn
.
GroupNorm
):
def
__init__
(
self
,
num_groups
,
num_channels
,
swish
,
eps
=
1e-5
):
super
().
__init__
(
num_groups
=
num_groups
,
num_channels
=
num_channels
,
eps
=
eps
)
self
.
swish
=
swish
def
forward
(
self
,
x
):
y
=
super
().
forward
(
x
.
float
()).
to
(
x
.
dtype
)
if
self
.
swish
==
1.0
:
y
=
F
.
silu
(
y
)
elif
self
.
swish
:
y
=
y
*
F
.
sigmoid
(
y
*
float
(
self
.
swish
))
return
y
def
conv_nd
(
dims
,
*
args
,
**
kwargs
):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if
dims
==
1
:
return
nn
.
Conv1d
(
*
args
,
**
kwargs
)
elif
dims
==
2
:
return
nn
.
Conv2d
(
*
args
,
**
kwargs
)
elif
dims
==
3
:
return
nn
.
Conv3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
def
linear
(
*
args
,
**
kwargs
):
"""
Create a linear module.
"""
return
nn
.
Linear
(
*
args
,
**
kwargs
)
def
avg_pool_nd
(
dims
,
*
args
,
**
kwargs
):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if
dims
==
1
:
return
nn
.
AvgPool1d
(
*
args
,
**
kwargs
)
elif
dims
==
2
:
return
nn
.
AvgPool2d
(
*
args
,
**
kwargs
)
elif
dims
==
3
:
return
nn
.
AvgPool3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
def
zero_module
(
module
):
"""
Zero out the parameters of a module and return it.
"""
for
p
in
module
.
parameters
():
p
.
detach
().
zero_
()
return
module
def
scale_module
(
module
,
scale
):
"""
Scale the parameters of a module and return it.
"""
for
p
in
module
.
parameters
():
p
.
detach
().
mul_
(
scale
)
return
module
def
normalization
(
channels
,
swish
=
0.0
):
"""
Make a standard normalization layer, with an optional swish activation.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
)
def
timestep_embedding
(
timesteps
,
dim
,
max_period
=
10000
):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half
=
dim
//
2
freqs
=
th
.
exp
(
-
math
.
log
(
max_period
)
*
th
.
arange
(
start
=
0
,
end
=
half
,
dtype
=
th
.
float32
,
device
=
timesteps
.
device
)
/
half
)
args
=
timesteps
[:,
None
].
float
()
*
freqs
[
None
]
embedding
=
th
.
cat
([
th
.
cos
(
args
),
th
.
sin
(
args
)],
dim
=-
1
)
if
dim
%
2
:
embedding
=
th
.
cat
([
embedding
,
th
.
zeros_like
(
embedding
[:,
:
1
])],
dim
=-
1
)
return
embedding
def
mean_flat
(
tensor
):
"""
Take the mean over all non-batch dimensions.
"""
return
tensor
.
mean
(
dim
=
list
(
range
(
1
,
len
(
tensor
.
shape
))))
ldm/modules/karlo/kakao/modules/resample.py
0 → 100644
View file @
4007efdd
# ------------------------------------------------------------------------------------
# Modified from Guided-Diffusion (https://github.com/openai/guided-diffusion)
# ------------------------------------------------------------------------------------
from
abc
import
abstractmethod
import
torch
as
th
def
create_named_schedule_sampler
(
name
,
diffusion
):
"""
Create a ScheduleSampler from a library of pre-defined samplers.
:param name: the name of the sampler.
:param diffusion: the diffusion object to sample for.
"""
if
name
==
"uniform"
:
return
UniformSampler
(
diffusion
)
else
:
raise
NotImplementedError
(
f
"unknown schedule sampler:
{
name
}
"
)
class
ScheduleSampler
(
th
.
nn
.
Module
):
"""
A distribution over timesteps in the diffusion process, intended to reduce
variance of the objective.
By default, samplers perform unbiased importance sampling, in which the
objective's mean is unchanged.
However, subclasses may override sample() to change how the resampled
terms are reweighted, allowing for actual changes in the objective.
"""
@
abstractmethod
def
weights
(
self
):
"""
Get a numpy array of weights, one per diffusion step.
The weights needn't be normalized, but must be positive.
"""
def
sample
(
self
,
batch_size
,
device
):
"""
Importance-sample timesteps for a batch.
:param batch_size: the number of timesteps.
:param device: the torch device to save to.
:return: a tuple (timesteps, weights):
- timesteps: a tensor of timestep indices.
- weights: a tensor of weights to scale the resulting losses.
"""
w
=
self
.
weights
()
p
=
w
/
th
.
sum
(
w
)
indices
=
p
.
multinomial
(
batch_size
,
replacement
=
True
)
weights
=
1
/
(
len
(
p
)
*
p
[
indices
])
return
indices
,
weights
class
UniformSampler
(
ScheduleSampler
):
def
__init__
(
self
,
diffusion
):
super
(
UniformSampler
,
self
).
__init__
()
self
.
diffusion
=
diffusion
self
.
register_buffer
(
"_weights"
,
th
.
ones
([
diffusion
.
num_timesteps
]),
persistent
=
False
)
def
weights
(
self
):
return
self
.
_weights
ldm/modules/karlo/kakao/modules/unet.py
0 → 100644
View file @
4007efdd
# ------------------------------------------------------------------------------------
# Modified from Guided-Diffusion (https://github.com/openai/guided-diffusion)
# ------------------------------------------------------------------------------------
import
math
from
abc
import
abstractmethod
import
torch
as
th
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
.nn
import
(
avg_pool_nd
,
conv_nd
,
linear
,
normalization
,
timestep_embedding
,
zero_module
,
)
from
.xf
import
LayerNorm
class
TimestepBlock
(
nn
.
Module
):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@
abstractmethod
def
forward
(
self
,
x
,
emb
):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
def
forward
(
self
,
x
,
emb
,
encoder_out
=
None
,
mask
=
None
):
for
layer
in
self
:
if
isinstance
(
layer
,
TimestepBlock
):
x
=
layer
(
x
,
emb
)
elif
isinstance
(
layer
,
AttentionBlock
):
x
=
layer
(
x
,
encoder_out
,
mask
=
mask
)
else
:
x
=
layer
(
x
)
return
x
class
Upsample
(
nn
.
Module
):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
if
use_conv
:
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
if
self
.
dims
==
3
:
x
=
F
.
interpolate
(
x
,
(
x
.
shape
[
2
],
x
.
shape
[
3
]
*
2
,
x
.
shape
[
4
]
*
2
),
mode
=
"nearest"
)
else
:
x
=
F
.
interpolate
(
x
,
scale_factor
=
2
,
mode
=
"nearest"
)
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
return
x
class
Downsample
(
nn
.
Module
):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
stride
=
2
if
dims
!=
3
else
(
1
,
2
,
2
)
if
use_conv
:
self
.
op
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
stride
=
stride
,
padding
=
1
)
else
:
assert
self
.
channels
==
self
.
out_channels
self
.
op
=
avg_pool_nd
(
dims
,
kernel_size
=
stride
,
stride
=
stride
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
return
self
.
op
(
x
)
class
ResBlock
(
TimestepBlock
):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param use_checkpoint: if True, use gradient checkpointing on this module.
:param up: if True, use this block for upsampling.
:param down: if True, use this block for downsampling.
"""
def
__init__
(
self
,
channels
,
emb_channels
,
dropout
,
out_channels
=
None
,
use_conv
=
False
,
use_scale_shift_norm
=
False
,
dims
=
2
,
use_checkpoint
=
False
,
up
=
False
,
down
=
False
,
):
super
().
__init__
()
self
.
channels
=
channels
self
.
emb_channels
=
emb_channels
self
.
dropout
=
dropout
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
use_checkpoint
=
use_checkpoint
self
.
use_scale_shift_norm
=
use_scale_shift_norm
self
.
in_layers
=
nn
.
Sequential
(
normalization
(
channels
,
swish
=
1.0
),
nn
.
Identity
(),
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
),
)
self
.
updown
=
up
or
down
if
up
:
self
.
h_upd
=
Upsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Upsample
(
channels
,
False
,
dims
)
elif
down
:
self
.
h_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Downsample
(
channels
,
False
,
dims
)
else
:
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
self
.
emb_layers
=
nn
.
Sequential
(
nn
.
SiLU
(),
linear
(
emb_channels
,
2
*
self
.
out_channels
if
use_scale_shift_norm
else
self
.
out_channels
,
),
)
self
.
out_layers
=
nn
.
Sequential
(
normalization
(
self
.
out_channels
,
swish
=
0.0
if
use_scale_shift_norm
else
1.0
),
nn
.
SiLU
()
if
use_scale_shift_norm
else
nn
.
Identity
(),
nn
.
Dropout
(
p
=
dropout
),
zero_module
(
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
)
),
)
if
self
.
out_channels
==
channels
:
self
.
skip_connection
=
nn
.
Identity
()
elif
use_conv
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
else
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
def
forward
(
self
,
x
,
emb
):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
if
self
.
updown
:
in_rest
,
in_conv
=
self
.
in_layers
[:
-
1
],
self
.
in_layers
[
-
1
]
h
=
in_rest
(
x
)
h
=
self
.
h_upd
(
h
)
x
=
self
.
x_upd
(
x
)
h
=
in_conv
(
h
)
else
:
h
=
self
.
in_layers
(
x
)
emb_out
=
self
.
emb_layers
(
emb
)
while
len
(
emb_out
.
shape
)
<
len
(
h
.
shape
):
emb_out
=
emb_out
[...,
None
]
if
self
.
use_scale_shift_norm
:
out_norm
,
out_rest
=
self
.
out_layers
[
0
],
self
.
out_layers
[
1
:]
scale
,
shift
=
th
.
chunk
(
emb_out
,
2
,
dim
=
1
)
h
=
out_norm
(
h
)
*
(
1
+
scale
)
+
shift
h
=
out_rest
(
h
)
else
:
h
=
h
+
emb_out
h
=
self
.
out_layers
(
h
)
return
self
.
skip_connection
(
x
)
+
h
class
ResBlockNoTimeEmbedding
(
nn
.
Module
):
"""
A residual block without time embedding
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param use_checkpoint: if True, use gradient checkpointing on this module.
:param up: if True, use this block for upsampling.
:param down: if True, use this block for downsampling.
"""
def
__init__
(
self
,
channels
,
emb_channels
,
dropout
,
out_channels
=
None
,
use_conv
=
False
,
dims
=
2
,
use_checkpoint
=
False
,
up
=
False
,
down
=
False
,
**
kwargs
,
):
super
().
__init__
()
self
.
channels
=
channels
self
.
emb_channels
=
emb_channels
self
.
dropout
=
dropout
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
use_checkpoint
=
use_checkpoint
self
.
in_layers
=
nn
.
Sequential
(
normalization
(
channels
,
swish
=
1.0
),
nn
.
Identity
(),
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
),
)
self
.
updown
=
up
or
down
if
up
:
self
.
h_upd
=
Upsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Upsample
(
channels
,
False
,
dims
)
elif
down
:
self
.
h_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Downsample
(
channels
,
False
,
dims
)
else
:
self
.
h_upd
=
self
.
x_upd
=
nn
.
Identity
()
self
.
out_layers
=
nn
.
Sequential
(
normalization
(
self
.
out_channels
,
swish
=
1.0
),
nn
.
Dropout
(
p
=
dropout
),
zero_module
(
conv_nd
(
dims
,
self
.
out_channels
,
self
.
out_channels
,
3
,
padding
=
1
)
),
)
if
self
.
out_channels
==
channels
:
self
.
skip_connection
=
nn
.
Identity
()
elif
use_conv
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
else
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
def
forward
(
self
,
x
,
emb
=
None
):
"""
Apply the block to a Tensor, NOT conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
assert
emb
is
None
if
self
.
updown
:
in_rest
,
in_conv
=
self
.
in_layers
[:
-
1
],
self
.
in_layers
[
-
1
]
h
=
in_rest
(
x
)
h
=
self
.
h_upd
(
h
)
x
=
self
.
x_upd
(
x
)
h
=
in_conv
(
h
)
else
:
h
=
self
.
in_layers
(
x
)
h
=
self
.
out_layers
(
h
)
return
self
.
skip_connection
(
x
)
+
h
class
AttentionBlock
(
nn
.
Module
):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def
__init__
(
self
,
channels
,
num_heads
=
1
,
num_head_channels
=-
1
,
use_checkpoint
=
False
,
encoder_channels
=
None
,
):
super
().
__init__
()
self
.
channels
=
channels
if
num_head_channels
==
-
1
:
self
.
num_heads
=
num_heads
else
:
assert
(
channels
%
num_head_channels
==
0
),
f
"q,k,v channels
{
channels
}
is not divisible by num_head_channels
{
num_head_channels
}
"
self
.
num_heads
=
channels
//
num_head_channels
self
.
use_checkpoint
=
use_checkpoint
self
.
norm
=
normalization
(
channels
,
swish
=
0.0
)
self
.
qkv
=
conv_nd
(
1
,
channels
,
channels
*
3
,
1
)
self
.
attention
=
QKVAttention
(
self
.
num_heads
)
if
encoder_channels
is
not
None
:
self
.
encoder_kv
=
conv_nd
(
1
,
encoder_channels
,
channels
*
2
,
1
)
self
.
proj_out
=
zero_module
(
conv_nd
(
1
,
channels
,
channels
,
1
))
def
forward
(
self
,
x
,
encoder_out
=
None
,
mask
=
None
):
b
,
c
,
*
spatial
=
x
.
shape
qkv
=
self
.
qkv
(
self
.
norm
(
x
).
view
(
b
,
c
,
-
1
))
if
encoder_out
is
not
None
:
encoder_out
=
self
.
encoder_kv
(
encoder_out
)
h
=
self
.
attention
(
qkv
,
encoder_out
,
mask
=
mask
)
else
:
h
=
self
.
attention
(
qkv
)
h
=
self
.
proj_out
(
h
)
return
x
+
h
.
reshape
(
b
,
c
,
*
spatial
)
class
QKVAttention
(
nn
.
Module
):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""
def
__init__
(
self
,
n_heads
):
super
().
__init__
()
self
.
n_heads
=
n_heads
def
forward
(
self
,
qkv
,
encoder_kv
=
None
,
mask
=
None
):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs
,
width
,
length
=
qkv
.
shape
assert
width
%
(
3
*
self
.
n_heads
)
==
0
ch
=
width
//
(
3
*
self
.
n_heads
)
q
,
k
,
v
=
qkv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
3
,
length
).
split
(
ch
,
dim
=
1
)
if
encoder_kv
is
not
None
:
assert
encoder_kv
.
shape
[
1
]
==
self
.
n_heads
*
ch
*
2
ek
,
ev
=
encoder_kv
.
reshape
(
bs
*
self
.
n_heads
,
ch
*
2
,
-
1
).
split
(
ch
,
dim
=
1
)
k
=
th
.
cat
([
ek
,
k
],
dim
=-
1
)
v
=
th
.
cat
([
ev
,
v
],
dim
=-
1
)
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
ch
))
weight
=
th
.
einsum
(
"bct,bcs->bts"
,
q
*
scale
,
k
*
scale
)
if
mask
is
not
None
:
mask
=
F
.
pad
(
mask
,
(
0
,
length
),
value
=
0.0
)
mask
=
(
mask
.
unsqueeze
(
1
)
.
expand
(
-
1
,
self
.
n_heads
,
-
1
)
.
reshape
(
bs
*
self
.
n_heads
,
1
,
-
1
)
)
weight
=
weight
+
mask
weight
=
th
.
softmax
(
weight
,
dim
=-
1
)
a
=
th
.
einsum
(
"bts,bcs->bct"
,
weight
,
v
)
return
a
.
reshape
(
bs
,
-
1
,
length
)
class
UNetModel
(
nn
.
Module
):
"""
The full UNet model with attention and timestep embedding.
:param in_channels: channels in the input Tensor.
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, if this contains 4, then at 4x downsampling, attention
will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param clip_dim: dimension of clip feature.
:param num_classes: if specified (as an int), then this model will be
class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
:param num_heads: the number of attention heads in each attention layer.
:param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
:param resblock_updown: use residual blocks for up/downsampling.
:param encoder_channels: use to make the dimension of query and kv same in AttentionBlock.
:param use_time_embedding: use time embedding for condition.
"""
def
__init__
(
self
,
in_channels
,
model_channels
,
out_channels
,
num_res_blocks
,
attention_resolutions
,
dropout
=
0
,
channel_mult
=
(
1
,
2
,
4
,
8
),
conv_resample
=
True
,
dims
=
2
,
clip_dim
=
None
,
use_checkpoint
=
False
,
num_heads
=
1
,
num_head_channels
=-
1
,
num_heads_upsample
=-
1
,
use_scale_shift_norm
=
False
,
use_middle_attention
=
True
,
resblock_updown
=
False
,
encoder_channels
=
None
,
use_time_embedding
=
True
,
):
super
().
__init__
()
if
num_heads_upsample
==
-
1
:
num_heads_upsample
=
num_heads
self
.
in_channels
=
in_channels
self
.
model_channels
=
model_channels
self
.
out_channels
=
out_channels
self
.
num_res_blocks
=
num_res_blocks
self
.
attention_resolutions
=
attention_resolutions
self
.
dropout
=
dropout
self
.
channel_mult
=
channel_mult
self
.
conv_resample
=
conv_resample
self
.
clip_dim
=
clip_dim
self
.
use_checkpoint
=
use_checkpoint
self
.
num_heads
=
num_heads
self
.
num_head_channels
=
num_head_channels
self
.
num_heads_upsample
=
num_heads_upsample
self
.
use_middle_attention
=
use_middle_attention
self
.
use_time_embedding
=
use_time_embedding
if
self
.
use_time_embedding
:
time_embed_dim
=
model_channels
*
4
self
.
time_embed
=
nn
.
Sequential
(
linear
(
model_channels
,
time_embed_dim
),
nn
.
SiLU
(),
linear
(
time_embed_dim
,
time_embed_dim
),
)
if
self
.
clip_dim
is
not
None
:
self
.
clip_emb
=
nn
.
Linear
(
clip_dim
,
time_embed_dim
)
else
:
time_embed_dim
=
None
CustomResidualBlock
=
(
ResBlock
if
self
.
use_time_embedding
else
ResBlockNoTimeEmbedding
)
ch
=
input_ch
=
int
(
channel_mult
[
0
]
*
model_channels
)
self
.
input_blocks
=
nn
.
ModuleList
(
[
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
ch
,
3
,
padding
=
1
))]
)
self
.
_feature_size
=
ch
input_block_chans
=
[
ch
]
ds
=
1
for
level
,
mult
in
enumerate
(
channel_mult
):
for
_
in
range
(
num_res_blocks
):
layers
=
[
CustomResidualBlock
(
ch
,
time_embed_dim
,
dropout
,
out_channels
=
int
(
mult
*
model_channels
),
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
)
]
ch
=
int
(
mult
*
model_channels
)
if
ds
in
attention_resolutions
:
layers
.
append
(
AttentionBlock
(
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_head_channels
=
num_head_channels
,
encoder_channels
=
encoder_channels
,
)
)
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
_feature_size
+=
ch
input_block_chans
.
append
(
ch
)
if
level
!=
len
(
channel_mult
)
-
1
:
out_ch
=
ch
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
CustomResidualBlock
(
ch
,
time_embed_dim
,
dropout
,
out_channels
=
out_ch
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
down
=
True
,
)
if
resblock_updown
else
Downsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
)
)
ch
=
out_ch
input_block_chans
.
append
(
ch
)
ds
*=
2
self
.
_feature_size
+=
ch
self
.
middle_block
=
TimestepEmbedSequential
(
CustomResidualBlock
(
ch
,
time_embed_dim
,
dropout
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
),
*
(
AttentionBlock
(
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_head_channels
=
num_head_channels
,
encoder_channels
=
encoder_channels
,
),
)
if
self
.
use_middle_attention
else
tuple
(),
# add AttentionBlock or not
CustomResidualBlock
(
ch
,
time_embed_dim
,
dropout
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
),
)
self
.
_feature_size
+=
ch
self
.
output_blocks
=
nn
.
ModuleList
([])
for
level
,
mult
in
list
(
enumerate
(
channel_mult
))[::
-
1
]:
for
i
in
range
(
num_res_blocks
+
1
):
ich
=
input_block_chans
.
pop
()
layers
=
[
CustomResidualBlock
(
ch
+
ich
,
time_embed_dim
,
dropout
,
out_channels
=
int
(
model_channels
*
mult
),
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
)
]
ch
=
int
(
model_channels
*
mult
)
if
ds
in
attention_resolutions
:
layers
.
append
(
AttentionBlock
(
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads_upsample
,
num_head_channels
=
num_head_channels
,
encoder_channels
=
encoder_channels
,
)
)
if
level
and
i
==
num_res_blocks
:
out_ch
=
ch
layers
.
append
(
CustomResidualBlock
(
ch
,
time_embed_dim
,
dropout
,
out_channels
=
out_ch
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
up
=
True
,
)
if
resblock_updown
else
Upsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
)
ds
//=
2
self
.
output_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
_feature_size
+=
ch
self
.
out
=
nn
.
Sequential
(
normalization
(
ch
,
swish
=
1.0
),
nn
.
Identity
(),
zero_module
(
conv_nd
(
dims
,
input_ch
,
out_channels
,
3
,
padding
=
1
)),
)
def
forward
(
self
,
x
,
timesteps
,
y
=
None
):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
assert
(
y
is
not
None
)
==
(
self
.
clip_dim
is
not
None
),
"must specify y if and only if the model is clip-rep-conditional"
hs
=
[]
if
self
.
use_time_embedding
:
emb
=
self
.
time_embed
(
timestep_embedding
(
timesteps
,
self
.
model_channels
))
if
self
.
clip_dim
is
not
None
:
emb
=
emb
+
self
.
clip_emb
(
y
)
else
:
emb
=
None
h
=
x
for
module
in
self
.
input_blocks
:
h
=
module
(
h
,
emb
)
hs
.
append
(
h
)
h
=
self
.
middle_block
(
h
,
emb
)
for
module
in
self
.
output_blocks
:
h
=
th
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
)
h
=
module
(
h
,
emb
)
return
self
.
out
(
h
)
class
SuperResUNetModel
(
UNetModel
):
"""
A UNetModel that performs super-resolution.
Expects an extra kwarg `low_res` to condition on a low-resolution image.
Assumes that the shape of low-resolution and the input should be the same.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
if
"in_channels"
in
kwargs
:
kwargs
=
dict
(
kwargs
)
kwargs
[
"in_channels"
]
=
kwargs
[
"in_channels"
]
*
2
else
:
# Curse you, Python. Or really, just curse positional arguments :|.
args
=
list
(
args
)
args
[
1
]
=
args
[
1
]
*
2
super
().
__init__
(
*
args
,
**
kwargs
)
def
forward
(
self
,
x
,
timesteps
,
low_res
=
None
,
**
kwargs
):
_
,
_
,
new_height
,
new_width
=
x
.
shape
assert
new_height
==
low_res
.
shape
[
2
]
and
new_width
==
low_res
.
shape
[
3
]
x
=
th
.
cat
([
x
,
low_res
],
dim
=
1
)
return
super
().
forward
(
x
,
timesteps
,
**
kwargs
)
class
PLMImUNet
(
UNetModel
):
"""
A UNetModel that conditions on text with a pretrained text encoder in CLIP.
:param text_ctx: number of text tokens to expect.
:param xf_width: width of the transformer.
:param clip_emb_mult: #extra tokens by projecting clip text feature.
:param clip_emb_type: type of condition (here, we fix clip image feature).
:param clip_emb_drop: dropout rato of clip image feature for cfg.
"""
def
__init__
(
self
,
text_ctx
,
xf_width
,
*
args
,
clip_emb_mult
=
None
,
clip_emb_type
=
"image"
,
clip_emb_drop
=
0.0
,
**
kwargs
,
):
self
.
text_ctx
=
text_ctx
self
.
xf_width
=
xf_width
self
.
clip_emb_mult
=
clip_emb_mult
self
.
clip_emb_type
=
clip_emb_type
self
.
clip_emb_drop
=
clip_emb_drop
if
not
xf_width
:
super
().
__init__
(
*
args
,
**
kwargs
,
encoder_channels
=
None
)
else
:
super
().
__init__
(
*
args
,
**
kwargs
,
encoder_channels
=
xf_width
)
# Project text encoded feat seq from pre-trained text encoder in CLIP
self
.
text_seq_proj
=
nn
.
Sequential
(
nn
.
Linear
(
self
.
clip_dim
,
xf_width
),
LayerNorm
(
xf_width
),
)
# Project CLIP text feat
self
.
text_feat_proj
=
nn
.
Linear
(
self
.
clip_dim
,
self
.
model_channels
*
4
)
assert
clip_emb_mult
is
not
None
assert
clip_emb_type
==
"image"
assert
self
.
clip_dim
is
not
None
,
"CLIP representation dim should be specified"
self
.
clip_tok_proj
=
nn
.
Linear
(
self
.
clip_dim
,
self
.
xf_width
*
self
.
clip_emb_mult
)
if
self
.
clip_emb_drop
>
0
:
self
.
cf_param
=
nn
.
Parameter
(
th
.
empty
(
self
.
clip_dim
,
dtype
=
th
.
float32
))
def
proc_clip_emb_drop
(
self
,
feat
):
if
self
.
clip_emb_drop
>
0
:
bsz
,
feat_dim
=
feat
.
shape
assert
(
feat_dim
==
self
.
clip_dim
),
f
"CLIP input dim:
{
feat_dim
}
, model CLIP dim:
{
self
.
clip_dim
}
"
drop_idx
=
th
.
rand
((
bsz
,),
device
=
feat
.
device
)
<
self
.
clip_emb_drop
feat
=
th
.
where
(
drop_idx
[...,
None
],
self
.
cf_param
[
None
].
type_as
(
feat
),
feat
)
return
feat
def
forward
(
self
,
x
,
timesteps
,
txt_feat
=
None
,
txt_feat_seq
=
None
,
mask
=
None
,
y
=
None
):
bsz
=
x
.
shape
[
0
]
hs
=
[]
emb
=
self
.
time_embed
(
timestep_embedding
(
timesteps
,
self
.
model_channels
))
emb
=
emb
+
self
.
clip_emb
(
y
)
xf_out
=
self
.
text_seq_proj
(
txt_feat_seq
)
xf_out
=
xf_out
.
permute
(
0
,
2
,
1
)
emb
=
emb
+
self
.
text_feat_proj
(
txt_feat
)
xf_out
=
th
.
cat
(
[
self
.
clip_tok_proj
(
y
).
reshape
(
bsz
,
-
1
,
self
.
clip_emb_mult
),
xf_out
,
],
dim
=
2
,
)
mask
=
F
.
pad
(
mask
,
(
self
.
clip_emb_mult
,
0
),
value
=
True
)
mask
=
th
.
where
(
mask
,
0.0
,
float
(
"-inf"
))
h
=
x
for
module
in
self
.
input_blocks
:
h
=
module
(
h
,
emb
,
xf_out
,
mask
=
mask
)
hs
.
append
(
h
)
h
=
self
.
middle_block
(
h
,
emb
,
xf_out
,
mask
=
mask
)
for
module
in
self
.
output_blocks
:
h
=
th
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
)
h
=
module
(
h
,
emb
,
xf_out
,
mask
=
mask
)
h
=
self
.
out
(
h
)
return
h
ldm/modules/karlo/kakao/modules/xf.py
0 → 100644
View file @
4007efdd
# ------------------------------------------------------------------------------------
# Adapted from the repos below:
# (a) Guided-Diffusion (https://github.com/openai/guided-diffusion)
# (b) CLIP ViT (https://github.com/openai/CLIP/)
# ------------------------------------------------------------------------------------
import
math
import
torch
as
th
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
.nn
import
timestep_embedding
def
convert_module_to_f16
(
param
):
"""
Convert primitive modules to float16.
"""
if
isinstance
(
param
,
(
nn
.
Linear
,
nn
.
Conv2d
,
nn
.
ConvTranspose2d
)):
param
.
weight
.
data
=
param
.
weight
.
data
.
half
()
if
param
.
bias
is
not
None
:
param
.
bias
.
data
=
param
.
bias
.
data
.
half
()
class
LayerNorm
(
nn
.
LayerNorm
):
"""
Implementation that supports fp16 inputs but fp32 gains/biases.
"""
def
forward
(
self
,
x
:
th
.
Tensor
):
return
super
().
forward
(
x
.
float
()).
to
(
x
.
dtype
)
class
MultiheadAttention
(
nn
.
Module
):
def
__init__
(
self
,
n_ctx
,
width
,
heads
):
super
().
__init__
()
self
.
n_ctx
=
n_ctx
self
.
width
=
width
self
.
heads
=
heads
self
.
c_qkv
=
nn
.
Linear
(
width
,
width
*
3
)
self
.
c_proj
=
nn
.
Linear
(
width
,
width
)
self
.
attention
=
QKVMultiheadAttention
(
heads
,
n_ctx
)
def
forward
(
self
,
x
,
mask
=
None
):
x
=
self
.
c_qkv
(
x
)
x
=
self
.
attention
(
x
,
mask
=
mask
)
x
=
self
.
c_proj
(
x
)
return
x
class
MLP
(
nn
.
Module
):
def
__init__
(
self
,
width
):
super
().
__init__
()
self
.
width
=
width
self
.
c_fc
=
nn
.
Linear
(
width
,
width
*
4
)
self
.
c_proj
=
nn
.
Linear
(
width
*
4
,
width
)
self
.
gelu
=
nn
.
GELU
()
def
forward
(
self
,
x
):
return
self
.
c_proj
(
self
.
gelu
(
self
.
c_fc
(
x
)))
class
QKVMultiheadAttention
(
nn
.
Module
):
def
__init__
(
self
,
n_heads
:
int
,
n_ctx
:
int
):
super
().
__init__
()
self
.
n_heads
=
n_heads
self
.
n_ctx
=
n_ctx
def
forward
(
self
,
qkv
,
mask
=
None
):
bs
,
n_ctx
,
width
=
qkv
.
shape
attn_ch
=
width
//
self
.
n_heads
//
3
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
attn_ch
))
qkv
=
qkv
.
view
(
bs
,
n_ctx
,
self
.
n_heads
,
-
1
)
q
,
k
,
v
=
th
.
split
(
qkv
,
attn_ch
,
dim
=-
1
)
weight
=
th
.
einsum
(
"bthc,bshc->bhts"
,
q
*
scale
,
k
*
scale
)
wdtype
=
weight
.
dtype
if
mask
is
not
None
:
weight
=
weight
+
mask
[:,
None
,
...]
weight
=
th
.
softmax
(
weight
,
dim
=-
1
).
type
(
wdtype
)
return
th
.
einsum
(
"bhts,bshc->bthc"
,
weight
,
v
).
reshape
(
bs
,
n_ctx
,
-
1
)
class
ResidualAttentionBlock
(
nn
.
Module
):
def
__init__
(
self
,
n_ctx
:
int
,
width
:
int
,
heads
:
int
,
):
super
().
__init__
()
self
.
attn
=
MultiheadAttention
(
n_ctx
,
width
,
heads
,
)
self
.
ln_1
=
LayerNorm
(
width
)
self
.
mlp
=
MLP
(
width
)
self
.
ln_2
=
LayerNorm
(
width
)
def
forward
(
self
,
x
,
mask
=
None
):
x
=
x
+
self
.
attn
(
self
.
ln_1
(
x
),
mask
=
mask
)
x
=
x
+
self
.
mlp
(
self
.
ln_2
(
x
))
return
x
class
Transformer
(
nn
.
Module
):
def
__init__
(
self
,
n_ctx
:
int
,
width
:
int
,
layers
:
int
,
heads
:
int
,
):
super
().
__init__
()
self
.
n_ctx
=
n_ctx
self
.
width
=
width
self
.
layers
=
layers
self
.
resblocks
=
nn
.
ModuleList
(
[
ResidualAttentionBlock
(
n_ctx
,
width
,
heads
,
)
for
_
in
range
(
layers
)
]
)
def
forward
(
self
,
x
,
mask
=
None
):
for
block
in
self
.
resblocks
:
x
=
block
(
x
,
mask
=
mask
)
return
x
class
PriorTransformer
(
nn
.
Module
):
"""
A Causal Transformer that conditions on CLIP text embedding, text.
:param text_ctx: number of text tokens to expect.
:param xf_width: width of the transformer.
:param xf_layers: depth of the transformer.
:param xf_heads: heads in the transformer.
:param xf_final_ln: use a LayerNorm after the output layer.
:param clip_dim: dimension of clip feature.
"""
def
__init__
(
self
,
text_ctx
,
xf_width
,
xf_layers
,
xf_heads
,
xf_final_ln
,
clip_dim
,
):
super
().
__init__
()
self
.
text_ctx
=
text_ctx
self
.
xf_width
=
xf_width
self
.
xf_layers
=
xf_layers
self
.
xf_heads
=
xf_heads
self
.
clip_dim
=
clip_dim
self
.
ext_len
=
4
self
.
time_embed
=
nn
.
Sequential
(
nn
.
Linear
(
xf_width
,
xf_width
),
nn
.
SiLU
(),
nn
.
Linear
(
xf_width
,
xf_width
),
)
self
.
text_enc_proj
=
nn
.
Linear
(
clip_dim
,
xf_width
)
self
.
text_emb_proj
=
nn
.
Linear
(
clip_dim
,
xf_width
)
self
.
clip_img_proj
=
nn
.
Linear
(
clip_dim
,
xf_width
)
self
.
out_proj
=
nn
.
Linear
(
xf_width
,
clip_dim
)
self
.
transformer
=
Transformer
(
text_ctx
+
self
.
ext_len
,
xf_width
,
xf_layers
,
xf_heads
,
)
if
xf_final_ln
:
self
.
final_ln
=
LayerNorm
(
xf_width
)
else
:
self
.
final_ln
=
None
self
.
positional_embedding
=
nn
.
Parameter
(
th
.
empty
(
1
,
text_ctx
+
self
.
ext_len
,
xf_width
)
)
self
.
prd_emb
=
nn
.
Parameter
(
th
.
randn
((
1
,
1
,
xf_width
)))
nn
.
init
.
normal_
(
self
.
prd_emb
,
std
=
0.01
)
nn
.
init
.
normal_
(
self
.
positional_embedding
,
std
=
0.01
)
def
forward
(
self
,
x
,
timesteps
,
text_emb
=
None
,
text_enc
=
None
,
mask
=
None
,
causal_mask
=
None
,
):
bsz
=
x
.
shape
[
0
]
mask
=
F
.
pad
(
mask
,
(
0
,
self
.
ext_len
),
value
=
True
)
t_emb
=
self
.
time_embed
(
timestep_embedding
(
timesteps
,
self
.
xf_width
))
text_enc
=
self
.
text_enc_proj
(
text_enc
)
text_emb
=
self
.
text_emb_proj
(
text_emb
)
x
=
self
.
clip_img_proj
(
x
)
input_seq
=
[
text_enc
,
text_emb
[:,
None
,
:],
t_emb
[:,
None
,
:],
x
[:,
None
,
:],
self
.
prd_emb
.
to
(
x
.
dtype
).
expand
(
bsz
,
-
1
,
-
1
),
]
input
=
th
.
cat
(
input_seq
,
dim
=
1
)
input
=
input
+
self
.
positional_embedding
.
to
(
input
.
dtype
)
mask
=
th
.
where
(
mask
,
0.0
,
float
(
"-inf"
))
mask
=
(
mask
[:,
None
,
:]
+
causal_mask
).
to
(
input
.
dtype
)
out
=
self
.
transformer
(
input
,
mask
=
mask
)
if
self
.
final_ln
is
not
None
:
out
=
self
.
final_ln
(
out
)
out
=
self
.
out_proj
(
out
[:,
-
1
])
return
out
ldm/modules/karlo/kakao/sampler.py
0 → 100644
View file @
4007efdd
# ------------------------------------------------------------------------------------
# Karlo-v1.0.alpha
# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
# source: https://github.com/kakaobrain/karlo/blob/3c68a50a16d76b48a15c181d1c5a5e0879a90f85/karlo/sampler/t2i.py#L15
# ------------------------------------------------------------------------------------
from
typing
import
Iterator
import
torch
import
torchvision.transforms.functional
as
TVF
from
torchvision.transforms
import
InterpolationMode
from
.template
import
BaseSampler
,
CKPT_PATH
class
T2ISampler
(
BaseSampler
):
"""
A sampler for text-to-image generation.
:param root_dir: directory for model checkpoints.
:param sampling_type: ["default", "fast"]
"""
def
__init__
(
self
,
root_dir
:
str
,
sampling_type
:
str
=
"default"
,
):
super
().
__init__
(
root_dir
,
sampling_type
)
@
classmethod
def
from_pretrained
(
cls
,
root_dir
:
str
,
clip_model_path
:
str
,
clip_stat_path
:
str
,
sampling_type
:
str
=
"default"
,
):
model
=
cls
(
root_dir
=
root_dir
,
sampling_type
=
sampling_type
,
)
model
.
load_clip
(
clip_model_path
)
model
.
load_prior
(
f
"
{
CKPT_PATH
[
'prior'
]
}
"
,
clip_stat_path
=
clip_stat_path
,
prior_config
=
"configs/karlo/prior_1B_vit_l.yaml"
)
model
.
load_decoder
(
f
"
{
CKPT_PATH
[
'decoder'
]
}
"
,
decoder_config
=
"configs/karlo/decoder_900M_vit_l.yaml"
)
model
.
load_sr_64_256
(
CKPT_PATH
[
"sr_256"
],
sr_config
=
"configs/karlo/improved_sr_64_256_1.4B.yaml"
)
return
model
def
preprocess
(
self
,
prompt
:
str
,
bsz
:
int
,
):
"""Setup prompts & cfg scales"""
prompts_batch
=
[
prompt
for
_
in
range
(
bsz
)]
prior_cf_scales_batch
=
[
self
.
_prior_cf_scale
]
*
len
(
prompts_batch
)
prior_cf_scales_batch
=
torch
.
tensor
(
prior_cf_scales_batch
,
device
=
"cuda"
)
decoder_cf_scales_batch
=
[
self
.
_decoder_cf_scale
]
*
len
(
prompts_batch
)
decoder_cf_scales_batch
=
torch
.
tensor
(
decoder_cf_scales_batch
,
device
=
"cuda"
)
""" Get CLIP text feature """
clip_model
=
self
.
_clip
tokenizer
=
self
.
_tokenizer
max_txt_length
=
self
.
_prior
.
model
.
text_ctx
tok
,
mask
=
tokenizer
.
padded_tokens_and_mask
(
prompts_batch
,
max_txt_length
)
cf_token
,
cf_mask
=
tokenizer
.
padded_tokens_and_mask
([
""
],
max_txt_length
)
if
not
(
cf_token
.
shape
==
tok
.
shape
):
cf_token
=
cf_token
.
expand
(
tok
.
shape
[
0
],
-
1
)
cf_mask
=
cf_mask
.
expand
(
tok
.
shape
[
0
],
-
1
)
tok
=
torch
.
cat
([
tok
,
cf_token
],
dim
=
0
)
mask
=
torch
.
cat
([
mask
,
cf_mask
],
dim
=
0
)
tok
,
mask
=
tok
.
to
(
device
=
"cuda"
),
mask
.
to
(
device
=
"cuda"
)
txt_feat
,
txt_feat_seq
=
clip_model
.
encode_text
(
tok
)
return
(
prompts_batch
,
prior_cf_scales_batch
,
decoder_cf_scales_batch
,
txt_feat
,
txt_feat_seq
,
tok
,
mask
,
)
def
__call__
(
self
,
prompt
:
str
,
bsz
:
int
,
progressive_mode
=
None
,
)
->
Iterator
[
torch
.
Tensor
]:
assert
progressive_mode
in
(
"loop"
,
"stage"
,
"final"
)
with
torch
.
no_grad
(),
torch
.
cuda
.
amp
.
autocast
():
(
prompts_batch
,
prior_cf_scales_batch
,
decoder_cf_scales_batch
,
txt_feat
,
txt_feat_seq
,
tok
,
mask
,
)
=
self
.
preprocess
(
prompt
,
bsz
,
)
""" Transform CLIP text feature into image feature """
img_feat
=
self
.
_prior
(
txt_feat
,
txt_feat_seq
,
mask
,
prior_cf_scales_batch
,
timestep_respacing
=
self
.
_prior_sm
,
)
""" Generate 64x64px images """
images_64_outputs
=
self
.
_decoder
(
txt_feat
,
txt_feat_seq
,
tok
,
mask
,
img_feat
,
cf_guidance_scales
=
decoder_cf_scales_batch
,
timestep_respacing
=
self
.
_decoder_sm
,
)
images_64
=
None
for
k
,
out
in
enumerate
(
images_64_outputs
):
images_64
=
out
if
progressive_mode
==
"loop"
:
yield
torch
.
clamp
(
out
*
0.5
+
0.5
,
0.0
,
1.0
)
if
progressive_mode
==
"stage"
:
yield
torch
.
clamp
(
out
*
0.5
+
0.5
,
0.0
,
1.0
)
images_64
=
torch
.
clamp
(
images_64
,
-
1
,
1
)
""" Upsample 64x64 to 256x256 """
images_256
=
TVF
.
resize
(
images_64
,
[
256
,
256
],
interpolation
=
InterpolationMode
.
BICUBIC
,
antialias
=
True
,
)
images_256_outputs
=
self
.
_sr_64_256
(
images_256
,
timestep_respacing
=
self
.
_sr_sm
)
for
k
,
out
in
enumerate
(
images_256_outputs
):
images_256
=
out
if
progressive_mode
==
"loop"
:
yield
torch
.
clamp
(
out
*
0.5
+
0.5
,
0.0
,
1.0
)
if
progressive_mode
==
"stage"
:
yield
torch
.
clamp
(
out
*
0.5
+
0.5
,
0.0
,
1.0
)
yield
torch
.
clamp
(
images_256
*
0.5
+
0.5
,
0.0
,
1.0
)
class
PriorSampler
(
BaseSampler
):
"""
A sampler for text-to-image generation, but only the prior.
:param root_dir: directory for model checkpoints.
:param sampling_type: ["default", "fast"]
"""
def
__init__
(
self
,
root_dir
:
str
,
sampling_type
:
str
=
"default"
,
):
super
().
__init__
(
root_dir
,
sampling_type
)
@
classmethod
def
from_pretrained
(
cls
,
root_dir
:
str
,
clip_model_path
:
str
,
clip_stat_path
:
str
,
sampling_type
:
str
=
"default"
,
):
model
=
cls
(
root_dir
=
root_dir
,
sampling_type
=
sampling_type
,
)
model
.
load_clip
(
clip_model_path
)
model
.
load_prior
(
f
"
{
CKPT_PATH
[
'prior'
]
}
"
,
clip_stat_path
=
clip_stat_path
,
prior_config
=
"configs/karlo/prior_1B_vit_l.yaml"
)
return
model
def
preprocess
(
self
,
prompt
:
str
,
bsz
:
int
,
):
"""Setup prompts & cfg scales"""
prompts_batch
=
[
prompt
for
_
in
range
(
bsz
)]
prior_cf_scales_batch
=
[
self
.
_prior_cf_scale
]
*
len
(
prompts_batch
)
prior_cf_scales_batch
=
torch
.
tensor
(
prior_cf_scales_batch
,
device
=
"cuda"
)
decoder_cf_scales_batch
=
[
self
.
_decoder_cf_scale
]
*
len
(
prompts_batch
)
decoder_cf_scales_batch
=
torch
.
tensor
(
decoder_cf_scales_batch
,
device
=
"cuda"
)
""" Get CLIP text feature """
clip_model
=
self
.
_clip
tokenizer
=
self
.
_tokenizer
max_txt_length
=
self
.
_prior
.
model
.
text_ctx
tok
,
mask
=
tokenizer
.
padded_tokens_and_mask
(
prompts_batch
,
max_txt_length
)
cf_token
,
cf_mask
=
tokenizer
.
padded_tokens_and_mask
([
""
],
max_txt_length
)
if
not
(
cf_token
.
shape
==
tok
.
shape
):
cf_token
=
cf_token
.
expand
(
tok
.
shape
[
0
],
-
1
)
cf_mask
=
cf_mask
.
expand
(
tok
.
shape
[
0
],
-
1
)
tok
=
torch
.
cat
([
tok
,
cf_token
],
dim
=
0
)
mask
=
torch
.
cat
([
mask
,
cf_mask
],
dim
=
0
)
tok
,
mask
=
tok
.
to
(
device
=
"cuda"
),
mask
.
to
(
device
=
"cuda"
)
txt_feat
,
txt_feat_seq
=
clip_model
.
encode_text
(
tok
)
return
(
prompts_batch
,
prior_cf_scales_batch
,
decoder_cf_scales_batch
,
txt_feat
,
txt_feat_seq
,
tok
,
mask
,
)
def
__call__
(
self
,
prompt
:
str
,
bsz
:
int
,
progressive_mode
=
None
,
)
->
Iterator
[
torch
.
Tensor
]:
assert
progressive_mode
in
(
"loop"
,
"stage"
,
"final"
)
with
torch
.
no_grad
(),
torch
.
cuda
.
amp
.
autocast
():
(
prompts_batch
,
prior_cf_scales_batch
,
decoder_cf_scales_batch
,
txt_feat
,
txt_feat_seq
,
tok
,
mask
,
)
=
self
.
preprocess
(
prompt
,
bsz
,
)
""" Transform CLIP text feature into image feature """
img_feat
=
self
.
_prior
(
txt_feat
,
txt_feat_seq
,
mask
,
prior_cf_scales_batch
,
timestep_respacing
=
self
.
_prior_sm
,
)
yield
img_feat
ldm/modules/karlo/kakao/template.py
0 → 100644
View file @
4007efdd
# ------------------------------------------------------------------------------------
# Karlo-v1.0.alpha
# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
# ------------------------------------------------------------------------------------
import
os
import
logging
import
torch
from
omegaconf
import
OmegaConf
from
ldm.modules.karlo.kakao.models.clip
import
CustomizedCLIP
,
CustomizedTokenizer
from
ldm.modules.karlo.kakao.models.prior_model
import
PriorDiffusionModel
from
ldm.modules.karlo.kakao.models.decoder_model
import
Text2ImProgressiveModel
from
ldm.modules.karlo.kakao.models.sr_64_256
import
ImprovedSupRes64to256ProgressiveModel
SAMPLING_CONF
=
{
"default"
:
{
"prior_sm"
:
"25"
,
"prior_n_samples"
:
1
,
"prior_cf_scale"
:
4.0
,
"decoder_sm"
:
"50"
,
"decoder_cf_scale"
:
8.0
,
"sr_sm"
:
"7"
,
},
"fast"
:
{
"prior_sm"
:
"25"
,
"prior_n_samples"
:
1
,
"prior_cf_scale"
:
4.0
,
"decoder_sm"
:
"25"
,
"decoder_cf_scale"
:
8.0
,
"sr_sm"
:
"7"
,
},
}
CKPT_PATH
=
{
"prior"
:
"prior-ckpt-step=01000000-of-01000000.ckpt"
,
"decoder"
:
"decoder-ckpt-step=01000000-of-01000000.ckpt"
,
"sr_256"
:
"improved-sr-ckpt-step=1.2M.ckpt"
,
}
class
BaseSampler
:
_PRIOR_CLASS
=
PriorDiffusionModel
_DECODER_CLASS
=
Text2ImProgressiveModel
_SR256_CLASS
=
ImprovedSupRes64to256ProgressiveModel
def
__init__
(
self
,
root_dir
:
str
,
sampling_type
:
str
=
"fast"
,
):
self
.
_root_dir
=
root_dir
sampling_type
=
SAMPLING_CONF
[
sampling_type
]
self
.
_prior_sm
=
sampling_type
[
"prior_sm"
]
self
.
_prior_n_samples
=
sampling_type
[
"prior_n_samples"
]
self
.
_prior_cf_scale
=
sampling_type
[
"prior_cf_scale"
]
assert
self
.
_prior_n_samples
==
1
self
.
_decoder_sm
=
sampling_type
[
"decoder_sm"
]
self
.
_decoder_cf_scale
=
sampling_type
[
"decoder_cf_scale"
]
self
.
_sr_sm
=
sampling_type
[
"sr_sm"
]
def
__repr__
(
self
):
line
=
""
line
+=
f
"Prior, sampling method:
{
self
.
_prior_sm
}
, cf_scale:
{
self
.
_prior_cf_scale
}
\n
"
line
+=
f
"Decoder, sampling method:
{
self
.
_decoder_sm
}
, cf_scale:
{
self
.
_decoder_cf_scale
}
\n
"
line
+=
f
"SR(64->256), sampling method:
{
self
.
_sr_sm
}
"
return
line
def
load_clip
(
self
,
clip_path
:
str
):
clip
=
CustomizedCLIP
.
load_from_checkpoint
(
os
.
path
.
join
(
self
.
_root_dir
,
clip_path
)
)
clip
=
torch
.
jit
.
script
(
clip
)
clip
.
cuda
()
clip
.
eval
()
self
.
_clip
=
clip
self
.
_tokenizer
=
CustomizedTokenizer
()
def
load_prior
(
self
,
ckpt_path
:
str
,
clip_stat_path
:
str
,
prior_config
:
str
=
"configs/prior_1B_vit_l.yaml"
):
logging
.
info
(
f
"Loading prior:
{
ckpt_path
}
"
)
config
=
OmegaConf
.
load
(
prior_config
)
clip_mean
,
clip_std
=
torch
.
load
(
os
.
path
.
join
(
self
.
_root_dir
,
clip_stat_path
),
map_location
=
"cpu"
)
prior
=
self
.
_PRIOR_CLASS
.
load_from_checkpoint
(
config
,
self
.
_tokenizer
,
clip_mean
,
clip_std
,
os
.
path
.
join
(
self
.
_root_dir
,
ckpt_path
),
strict
=
True
,
)
prior
.
cuda
()
prior
.
eval
()
logging
.
info
(
"done."
)
self
.
_prior
=
prior
def
load_decoder
(
self
,
ckpt_path
:
str
,
decoder_config
:
str
=
"configs/decoder_900M_vit_l.yaml"
):
logging
.
info
(
f
"Loading decoder:
{
ckpt_path
}
"
)
config
=
OmegaConf
.
load
(
decoder_config
)
decoder
=
self
.
_DECODER_CLASS
.
load_from_checkpoint
(
config
,
self
.
_tokenizer
,
os
.
path
.
join
(
self
.
_root_dir
,
ckpt_path
),
strict
=
True
,
)
decoder
.
cuda
()
decoder
.
eval
()
logging
.
info
(
"done."
)
self
.
_decoder
=
decoder
def
load_sr_64_256
(
self
,
ckpt_path
:
str
,
sr_config
:
str
=
"configs/improved_sr_64_256_1.4B.yaml"
):
logging
.
info
(
f
"Loading SR(64->256):
{
ckpt_path
}
"
)
config
=
OmegaConf
.
load
(
sr_config
)
sr
=
self
.
_SR256_CLASS
.
load_from_checkpoint
(
config
,
os
.
path
.
join
(
self
.
_root_dir
,
ckpt_path
),
strict
=
True
)
sr
.
cuda
()
sr
.
eval
()
logging
.
info
(
"done."
)
self
.
_sr_64_256
=
sr
\ No newline at end of file
ldm/modules/midas/__init__.py
0 → 100644
View file @
4007efdd
ldm/modules/midas/api.py
0 → 100644
View file @
4007efdd
# based on https://github.com/isl-org/MiDaS
import
cv2
import
torch
import
torch.nn
as
nn
from
torchvision.transforms
import
Compose
from
ldm.modules.midas.midas.dpt_depth
import
DPTDepthModel
from
ldm.modules.midas.midas.midas_net
import
MidasNet
from
ldm.modules.midas.midas.midas_net_custom
import
MidasNet_small
from
ldm.modules.midas.midas.transforms
import
Resize
,
NormalizeImage
,
PrepareForNet
ISL_PATHS
=
{
"dpt_large"
:
"midas_models/dpt_large-midas-2f21e586.pt"
,
"dpt_hybrid"
:
"midas_models/dpt_hybrid-midas-501f0c75.pt"
,
"midas_v21"
:
""
,
"midas_v21_small"
:
""
,
}
def
disabled_train
(
self
,
mode
=
True
):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return
self
def
load_midas_transform
(
model_type
):
# https://github.com/isl-org/MiDaS/blob/master/run.py
# load transform only
if
model_type
==
"dpt_large"
:
# DPT-Large
net_w
,
net_h
=
384
,
384
resize_mode
=
"minimal"
normalization
=
NormalizeImage
(
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
])
elif
model_type
==
"dpt_hybrid"
:
# DPT-Hybrid
net_w
,
net_h
=
384
,
384
resize_mode
=
"minimal"
normalization
=
NormalizeImage
(
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
])
elif
model_type
==
"midas_v21"
:
net_w
,
net_h
=
384
,
384
resize_mode
=
"upper_bound"
normalization
=
NormalizeImage
(
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
])
elif
model_type
==
"midas_v21_small"
:
net_w
,
net_h
=
256
,
256
resize_mode
=
"upper_bound"
normalization
=
NormalizeImage
(
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
])
else
:
assert
False
,
f
"model_type '
{
model_type
}
' not implemented, use: --model_type large"
transform
=
Compose
(
[
Resize
(
net_w
,
net_h
,
resize_target
=
None
,
keep_aspect_ratio
=
True
,
ensure_multiple_of
=
32
,
resize_method
=
resize_mode
,
image_interpolation_method
=
cv2
.
INTER_CUBIC
,
),
normalization
,
PrepareForNet
(),
]
)
return
transform
def
load_model
(
model_type
):
# https://github.com/isl-org/MiDaS/blob/master/run.py
# load network
model_path
=
ISL_PATHS
[
model_type
]
if
model_type
==
"dpt_large"
:
# DPT-Large
model
=
DPTDepthModel
(
path
=
model_path
,
backbone
=
"vitl16_384"
,
non_negative
=
True
,
)
net_w
,
net_h
=
384
,
384
resize_mode
=
"minimal"
normalization
=
NormalizeImage
(
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
])
elif
model_type
==
"dpt_hybrid"
:
# DPT-Hybrid
model
=
DPTDepthModel
(
path
=
model_path
,
backbone
=
"vitb_rn50_384"
,
non_negative
=
True
,
)
net_w
,
net_h
=
384
,
384
resize_mode
=
"minimal"
normalization
=
NormalizeImage
(
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
])
elif
model_type
==
"midas_v21"
:
model
=
MidasNet
(
model_path
,
non_negative
=
True
)
net_w
,
net_h
=
384
,
384
resize_mode
=
"upper_bound"
normalization
=
NormalizeImage
(
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
]
)
elif
model_type
==
"midas_v21_small"
:
model
=
MidasNet_small
(
model_path
,
features
=
64
,
backbone
=
"efficientnet_lite3"
,
exportable
=
True
,
non_negative
=
True
,
blocks
=
{
'expand'
:
True
})
net_w
,
net_h
=
256
,
256
resize_mode
=
"upper_bound"
normalization
=
NormalizeImage
(
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
]
)
else
:
print
(
f
"model_type '
{
model_type
}
' not implemented, use: --model_type large"
)
assert
False
transform
=
Compose
(
[
Resize
(
net_w
,
net_h
,
resize_target
=
None
,
keep_aspect_ratio
=
True
,
ensure_multiple_of
=
32
,
resize_method
=
resize_mode
,
image_interpolation_method
=
cv2
.
INTER_CUBIC
,
),
normalization
,
PrepareForNet
(),
]
)
return
model
.
eval
(),
transform
class
MiDaSInference
(
nn
.
Module
):
MODEL_TYPES_TORCH_HUB
=
[
"DPT_Large"
,
"DPT_Hybrid"
,
"MiDaS_small"
]
MODEL_TYPES_ISL
=
[
"dpt_large"
,
"dpt_hybrid"
,
"midas_v21"
,
"midas_v21_small"
,
]
def
__init__
(
self
,
model_type
):
super
().
__init__
()
assert
(
model_type
in
self
.
MODEL_TYPES_ISL
)
model
,
_
=
load_model
(
model_type
)
self
.
model
=
model
self
.
model
.
train
=
disabled_train
def
forward
(
self
,
x
):
# x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array
# NOTE: we expect that the correct transform has been called during dataloading.
with
torch
.
no_grad
():
prediction
=
self
.
model
(
x
)
prediction
=
torch
.
nn
.
functional
.
interpolate
(
prediction
.
unsqueeze
(
1
),
size
=
x
.
shape
[
2
:],
mode
=
"bicubic"
,
align_corners
=
False
,
)
assert
prediction
.
shape
==
(
x
.
shape
[
0
],
1
,
x
.
shape
[
2
],
x
.
shape
[
3
])
return
prediction
ldm/modules/midas/midas/__init__.py
0 → 100644
View file @
4007efdd
ldm/modules/midas/midas/base_model.py
0 → 100644
View file @
4007efdd
import
torch
class
BaseModel
(
torch
.
nn
.
Module
):
def
load
(
self
,
path
):
"""Load model from file.
Args:
path (str): file path
"""
parameters
=
torch
.
load
(
path
,
map_location
=
torch
.
device
(
'cpu'
))
if
"optimizer"
in
parameters
:
parameters
=
parameters
[
"model"
]
self
.
load_state_dict
(
parameters
)
ldm/modules/midas/midas/blocks.py
0 → 100644
View file @
4007efdd
import
torch
import
torch.nn
as
nn
from
.vit
import
(
_make_pretrained_vitb_rn50_384
,
_make_pretrained_vitl16_384
,
_make_pretrained_vitb16_384
,
forward_vit
,
)
def
_make_encoder
(
backbone
,
features
,
use_pretrained
,
groups
=
1
,
expand
=
False
,
exportable
=
True
,
hooks
=
None
,
use_vit_only
=
False
,
use_readout
=
"ignore"
,):
if
backbone
==
"vitl16_384"
:
pretrained
=
_make_pretrained_vitl16_384
(
use_pretrained
,
hooks
=
hooks
,
use_readout
=
use_readout
)
scratch
=
_make_scratch
(
[
256
,
512
,
1024
,
1024
],
features
,
groups
=
groups
,
expand
=
expand
)
# ViT-L/16 - 85.0% Top1 (backbone)
elif
backbone
==
"vitb_rn50_384"
:
pretrained
=
_make_pretrained_vitb_rn50_384
(
use_pretrained
,
hooks
=
hooks
,
use_vit_only
=
use_vit_only
,
use_readout
=
use_readout
,
)
scratch
=
_make_scratch
(
[
256
,
512
,
768
,
768
],
features
,
groups
=
groups
,
expand
=
expand
)
# ViT-H/16 - 85.0% Top1 (backbone)
elif
backbone
==
"vitb16_384"
:
pretrained
=
_make_pretrained_vitb16_384
(
use_pretrained
,
hooks
=
hooks
,
use_readout
=
use_readout
)
scratch
=
_make_scratch
(
[
96
,
192
,
384
,
768
],
features
,
groups
=
groups
,
expand
=
expand
)
# ViT-B/16 - 84.6% Top1 (backbone)
elif
backbone
==
"resnext101_wsl"
:
pretrained
=
_make_pretrained_resnext101_wsl
(
use_pretrained
)
scratch
=
_make_scratch
([
256
,
512
,
1024
,
2048
],
features
,
groups
=
groups
,
expand
=
expand
)
# efficientnet_lite3
elif
backbone
==
"efficientnet_lite3"
:
pretrained
=
_make_pretrained_efficientnet_lite3
(
use_pretrained
,
exportable
=
exportable
)
scratch
=
_make_scratch
([
32
,
48
,
136
,
384
],
features
,
groups
=
groups
,
expand
=
expand
)
# efficientnet_lite3
else
:
print
(
f
"Backbone '
{
backbone
}
' not implemented"
)
assert
False
return
pretrained
,
scratch
def
_make_scratch
(
in_shape
,
out_shape
,
groups
=
1
,
expand
=
False
):
scratch
=
nn
.
Module
()
out_shape1
=
out_shape
out_shape2
=
out_shape
out_shape3
=
out_shape
out_shape4
=
out_shape
if
expand
==
True
:
out_shape1
=
out_shape
out_shape2
=
out_shape
*
2
out_shape3
=
out_shape
*
4
out_shape4
=
out_shape
*
8
scratch
.
layer1_rn
=
nn
.
Conv2d
(
in_shape
[
0
],
out_shape1
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
,
groups
=
groups
)
scratch
.
layer2_rn
=
nn
.
Conv2d
(
in_shape
[
1
],
out_shape2
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
,
groups
=
groups
)
scratch
.
layer3_rn
=
nn
.
Conv2d
(
in_shape
[
2
],
out_shape3
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
,
groups
=
groups
)
scratch
.
layer4_rn
=
nn
.
Conv2d
(
in_shape
[
3
],
out_shape4
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
,
groups
=
groups
)
return
scratch
def
_make_pretrained_efficientnet_lite3
(
use_pretrained
,
exportable
=
False
):
efficientnet
=
torch
.
hub
.
load
(
"rwightman/gen-efficientnet-pytorch"
,
"tf_efficientnet_lite3"
,
pretrained
=
use_pretrained
,
exportable
=
exportable
)
return
_make_efficientnet_backbone
(
efficientnet
)
def
_make_efficientnet_backbone
(
effnet
):
pretrained
=
nn
.
Module
()
pretrained
.
layer1
=
nn
.
Sequential
(
effnet
.
conv_stem
,
effnet
.
bn1
,
effnet
.
act1
,
*
effnet
.
blocks
[
0
:
2
]
)
pretrained
.
layer2
=
nn
.
Sequential
(
*
effnet
.
blocks
[
2
:
3
])
pretrained
.
layer3
=
nn
.
Sequential
(
*
effnet
.
blocks
[
3
:
5
])
pretrained
.
layer4
=
nn
.
Sequential
(
*
effnet
.
blocks
[
5
:
9
])
return
pretrained
def
_make_resnet_backbone
(
resnet
):
pretrained
=
nn
.
Module
()
pretrained
.
layer1
=
nn
.
Sequential
(
resnet
.
conv1
,
resnet
.
bn1
,
resnet
.
relu
,
resnet
.
maxpool
,
resnet
.
layer1
)
pretrained
.
layer2
=
resnet
.
layer2
pretrained
.
layer3
=
resnet
.
layer3
pretrained
.
layer4
=
resnet
.
layer4
return
pretrained
def
_make_pretrained_resnext101_wsl
(
use_pretrained
):
resnet
=
torch
.
hub
.
load
(
"facebookresearch/WSL-Images"
,
"resnext101_32x8d_wsl"
)
return
_make_resnet_backbone
(
resnet
)
class
Interpolate
(
nn
.
Module
):
"""Interpolation module.
"""
def
__init__
(
self
,
scale_factor
,
mode
,
align_corners
=
False
):
"""Init.
Args:
scale_factor (float): scaling
mode (str): interpolation mode
"""
super
(
Interpolate
,
self
).
__init__
()
self
.
interp
=
nn
.
functional
.
interpolate
self
.
scale_factor
=
scale_factor
self
.
mode
=
mode
self
.
align_corners
=
align_corners
def
forward
(
self
,
x
):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: interpolated data
"""
x
=
self
.
interp
(
x
,
scale_factor
=
self
.
scale_factor
,
mode
=
self
.
mode
,
align_corners
=
self
.
align_corners
)
return
x
class
ResidualConvUnit
(
nn
.
Module
):
"""Residual convolution module.
"""
def
__init__
(
self
,
features
):
"""Init.
Args:
features (int): number of features
"""
super
().
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
features
,
features
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
True
)
self
.
conv2
=
nn
.
Conv2d
(
features
,
features
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
True
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
def
forward
(
self
,
x
):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out
=
self
.
relu
(
x
)
out
=
self
.
conv1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
return
out
+
x
class
FeatureFusionBlock
(
nn
.
Module
):
"""Feature fusion block.
"""
def
__init__
(
self
,
features
):
"""Init.
Args:
features (int): number of features
"""
super
(
FeatureFusionBlock
,
self
).
__init__
()
self
.
resConfUnit1
=
ResidualConvUnit
(
features
)
self
.
resConfUnit2
=
ResidualConvUnit
(
features
)
def
forward
(
self
,
*
xs
):
"""Forward pass.
Returns:
tensor: output
"""
output
=
xs
[
0
]
if
len
(
xs
)
==
2
:
output
+=
self
.
resConfUnit1
(
xs
[
1
])
output
=
self
.
resConfUnit2
(
output
)
output
=
nn
.
functional
.
interpolate
(
output
,
scale_factor
=
2
,
mode
=
"bilinear"
,
align_corners
=
True
)
return
output
class
ResidualConvUnit_custom
(
nn
.
Module
):
"""Residual convolution module.
"""
def
__init__
(
self
,
features
,
activation
,
bn
):
"""Init.
Args:
features (int): number of features
"""
super
().
__init__
()
self
.
bn
=
bn
self
.
groups
=
1
self
.
conv1
=
nn
.
Conv2d
(
features
,
features
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
True
,
groups
=
self
.
groups
)
self
.
conv2
=
nn
.
Conv2d
(
features
,
features
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
True
,
groups
=
self
.
groups
)
if
self
.
bn
==
True
:
self
.
bn1
=
nn
.
BatchNorm2d
(
features
)
self
.
bn2
=
nn
.
BatchNorm2d
(
features
)
self
.
activation
=
activation
self
.
skip_add
=
nn
.
quantized
.
FloatFunctional
()
def
forward
(
self
,
x
):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out
=
self
.
activation
(
x
)
out
=
self
.
conv1
(
out
)
if
self
.
bn
==
True
:
out
=
self
.
bn1
(
out
)
out
=
self
.
activation
(
out
)
out
=
self
.
conv2
(
out
)
if
self
.
bn
==
True
:
out
=
self
.
bn2
(
out
)
if
self
.
groups
>
1
:
out
=
self
.
conv_merge
(
out
)
return
self
.
skip_add
.
add
(
out
,
x
)
# return out + x
class
FeatureFusionBlock_custom
(
nn
.
Module
):
"""Feature fusion block.
"""
def
__init__
(
self
,
features
,
activation
,
deconv
=
False
,
bn
=
False
,
expand
=
False
,
align_corners
=
True
):
"""Init.
Args:
features (int): number of features
"""
super
(
FeatureFusionBlock_custom
,
self
).
__init__
()
self
.
deconv
=
deconv
self
.
align_corners
=
align_corners
self
.
groups
=
1
self
.
expand
=
expand
out_features
=
features
if
self
.
expand
==
True
:
out_features
=
features
//
2
self
.
out_conv
=
nn
.
Conv2d
(
features
,
out_features
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
True
,
groups
=
1
)
self
.
resConfUnit1
=
ResidualConvUnit_custom
(
features
,
activation
,
bn
)
self
.
resConfUnit2
=
ResidualConvUnit_custom
(
features
,
activation
,
bn
)
self
.
skip_add
=
nn
.
quantized
.
FloatFunctional
()
def
forward
(
self
,
*
xs
):
"""Forward pass.
Returns:
tensor: output
"""
output
=
xs
[
0
]
if
len
(
xs
)
==
2
:
res
=
self
.
resConfUnit1
(
xs
[
1
])
output
=
self
.
skip_add
.
add
(
output
,
res
)
# output += res
output
=
self
.
resConfUnit2
(
output
)
output
=
nn
.
functional
.
interpolate
(
output
,
scale_factor
=
2
,
mode
=
"bilinear"
,
align_corners
=
self
.
align_corners
)
output
=
self
.
out_conv
(
output
)
return
output
ldm/modules/midas/midas/dpt_depth.py
0 → 100644
View file @
4007efdd
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
.base_model
import
BaseModel
from
.blocks
import
(
FeatureFusionBlock
,
FeatureFusionBlock_custom
,
Interpolate
,
_make_encoder
,
forward_vit
,
)
def
_make_fusion_block
(
features
,
use_bn
):
return
FeatureFusionBlock_custom
(
features
,
nn
.
ReLU
(
False
),
deconv
=
False
,
bn
=
use_bn
,
expand
=
False
,
align_corners
=
True
,
)
class
DPT
(
BaseModel
):
def
__init__
(
self
,
head
,
features
=
256
,
backbone
=
"vitb_rn50_384"
,
readout
=
"project"
,
channels_last
=
False
,
use_bn
=
False
,
):
super
(
DPT
,
self
).
__init__
()
self
.
channels_last
=
channels_last
hooks
=
{
"vitb_rn50_384"
:
[
0
,
1
,
8
,
11
],
"vitb16_384"
:
[
2
,
5
,
8
,
11
],
"vitl16_384"
:
[
5
,
11
,
17
,
23
],
}
# Instantiate backbone and reassemble blocks
self
.
pretrained
,
self
.
scratch
=
_make_encoder
(
backbone
,
features
,
False
,
# Set to true of you want to train from scratch, uses ImageNet weights
groups
=
1
,
expand
=
False
,
exportable
=
False
,
hooks
=
hooks
[
backbone
],
use_readout
=
readout
,
)
self
.
scratch
.
refinenet1
=
_make_fusion_block
(
features
,
use_bn
)
self
.
scratch
.
refinenet2
=
_make_fusion_block
(
features
,
use_bn
)
self
.
scratch
.
refinenet3
=
_make_fusion_block
(
features
,
use_bn
)
self
.
scratch
.
refinenet4
=
_make_fusion_block
(
features
,
use_bn
)
self
.
scratch
.
output_conv
=
head
def
forward
(
self
,
x
):
if
self
.
channels_last
==
True
:
x
.
contiguous
(
memory_format
=
torch
.
channels_last
)
layer_1
,
layer_2
,
layer_3
,
layer_4
=
forward_vit
(
self
.
pretrained
,
x
)
layer_1_rn
=
self
.
scratch
.
layer1_rn
(
layer_1
)
layer_2_rn
=
self
.
scratch
.
layer2_rn
(
layer_2
)
layer_3_rn
=
self
.
scratch
.
layer3_rn
(
layer_3
)
layer_4_rn
=
self
.
scratch
.
layer4_rn
(
layer_4
)
path_4
=
self
.
scratch
.
refinenet4
(
layer_4_rn
)
path_3
=
self
.
scratch
.
refinenet3
(
path_4
,
layer_3_rn
)
path_2
=
self
.
scratch
.
refinenet2
(
path_3
,
layer_2_rn
)
path_1
=
self
.
scratch
.
refinenet1
(
path_2
,
layer_1_rn
)
out
=
self
.
scratch
.
output_conv
(
path_1
)
return
out
class
DPTDepthModel
(
DPT
):
def
__init__
(
self
,
path
=
None
,
non_negative
=
True
,
**
kwargs
):
features
=
kwargs
[
"features"
]
if
"features"
in
kwargs
else
256
head
=
nn
.
Sequential
(
nn
.
Conv2d
(
features
,
features
//
2
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
),
Interpolate
(
scale_factor
=
2
,
mode
=
"bilinear"
,
align_corners
=
True
),
nn
.
Conv2d
(
features
//
2
,
32
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
),
nn
.
ReLU
(
True
),
nn
.
Conv2d
(
32
,
1
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
),
nn
.
ReLU
(
True
)
if
non_negative
else
nn
.
Identity
(),
nn
.
Identity
(),
)
super
().
__init__
(
head
,
**
kwargs
)
if
path
is
not
None
:
self
.
load
(
path
)
def
forward
(
self
,
x
):
return
super
().
forward
(
x
).
squeeze
(
dim
=
1
)
ldm/modules/midas/midas/midas_net.py
0 → 100644
View file @
4007efdd
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
This file contains code that is adapted from
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
"""
import
torch
import
torch.nn
as
nn
from
.base_model
import
BaseModel
from
.blocks
import
FeatureFusionBlock
,
Interpolate
,
_make_encoder
class
MidasNet
(
BaseModel
):
"""Network for monocular depth estimation.
"""
def
__init__
(
self
,
path
=
None
,
features
=
256
,
non_negative
=
True
):
"""Init.
Args:
path (str, optional): Path to saved model. Defaults to None.
features (int, optional): Number of features. Defaults to 256.
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
"""
print
(
"Loading weights: "
,
path
)
super
(
MidasNet
,
self
).
__init__
()
use_pretrained
=
False
if
path
is
None
else
True
self
.
pretrained
,
self
.
scratch
=
_make_encoder
(
backbone
=
"resnext101_wsl"
,
features
=
features
,
use_pretrained
=
use_pretrained
)
self
.
scratch
.
refinenet4
=
FeatureFusionBlock
(
features
)
self
.
scratch
.
refinenet3
=
FeatureFusionBlock
(
features
)
self
.
scratch
.
refinenet2
=
FeatureFusionBlock
(
features
)
self
.
scratch
.
refinenet1
=
FeatureFusionBlock
(
features
)
self
.
scratch
.
output_conv
=
nn
.
Sequential
(
nn
.
Conv2d
(
features
,
128
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
),
Interpolate
(
scale_factor
=
2
,
mode
=
"bilinear"
),
nn
.
Conv2d
(
128
,
32
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
),
nn
.
ReLU
(
True
),
nn
.
Conv2d
(
32
,
1
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
),
nn
.
ReLU
(
True
)
if
non_negative
else
nn
.
Identity
(),
)
if
path
:
self
.
load
(
path
)
def
forward
(
self
,
x
):
"""Forward pass.
Args:
x (tensor): input data (image)
Returns:
tensor: depth
"""
layer_1
=
self
.
pretrained
.
layer1
(
x
)
layer_2
=
self
.
pretrained
.
layer2
(
layer_1
)
layer_3
=
self
.
pretrained
.
layer3
(
layer_2
)
layer_4
=
self
.
pretrained
.
layer4
(
layer_3
)
layer_1_rn
=
self
.
scratch
.
layer1_rn
(
layer_1
)
layer_2_rn
=
self
.
scratch
.
layer2_rn
(
layer_2
)
layer_3_rn
=
self
.
scratch
.
layer3_rn
(
layer_3
)
layer_4_rn
=
self
.
scratch
.
layer4_rn
(
layer_4
)
path_4
=
self
.
scratch
.
refinenet4
(
layer_4_rn
)
path_3
=
self
.
scratch
.
refinenet3
(
path_4
,
layer_3_rn
)
path_2
=
self
.
scratch
.
refinenet2
(
path_3
,
layer_2_rn
)
path_1
=
self
.
scratch
.
refinenet1
(
path_2
,
layer_1_rn
)
out
=
self
.
scratch
.
output_conv
(
path_1
)
return
torch
.
squeeze
(
out
,
dim
=
1
)
ldm/modules/midas/midas/midas_net_custom.py
0 → 100644
View file @
4007efdd
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
This file contains code that is adapted from
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
"""
import
torch
import
torch.nn
as
nn
from
.base_model
import
BaseModel
from
.blocks
import
FeatureFusionBlock
,
FeatureFusionBlock_custom
,
Interpolate
,
_make_encoder
class
MidasNet_small
(
BaseModel
):
"""Network for monocular depth estimation.
"""
def
__init__
(
self
,
path
=
None
,
features
=
64
,
backbone
=
"efficientnet_lite3"
,
non_negative
=
True
,
exportable
=
True
,
channels_last
=
False
,
align_corners
=
True
,
blocks
=
{
'expand'
:
True
}):
"""Init.
Args:
path (str, optional): Path to saved model. Defaults to None.
features (int, optional): Number of features. Defaults to 256.
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
"""
print
(
"Loading weights: "
,
path
)
super
(
MidasNet_small
,
self
).
__init__
()
use_pretrained
=
False
if
path
else
True
self
.
channels_last
=
channels_last
self
.
blocks
=
blocks
self
.
backbone
=
backbone
self
.
groups
=
1
features1
=
features
features2
=
features
features3
=
features
features4
=
features
self
.
expand
=
False
if
"expand"
in
self
.
blocks
and
self
.
blocks
[
'expand'
]
==
True
:
self
.
expand
=
True
features1
=
features
features2
=
features
*
2
features3
=
features
*
4
features4
=
features
*
8
self
.
pretrained
,
self
.
scratch
=
_make_encoder
(
self
.
backbone
,
features
,
use_pretrained
,
groups
=
self
.
groups
,
expand
=
self
.
expand
,
exportable
=
exportable
)
self
.
scratch
.
activation
=
nn
.
ReLU
(
False
)
self
.
scratch
.
refinenet4
=
FeatureFusionBlock_custom
(
features4
,
self
.
scratch
.
activation
,
deconv
=
False
,
bn
=
False
,
expand
=
self
.
expand
,
align_corners
=
align_corners
)
self
.
scratch
.
refinenet3
=
FeatureFusionBlock_custom
(
features3
,
self
.
scratch
.
activation
,
deconv
=
False
,
bn
=
False
,
expand
=
self
.
expand
,
align_corners
=
align_corners
)
self
.
scratch
.
refinenet2
=
FeatureFusionBlock_custom
(
features2
,
self
.
scratch
.
activation
,
deconv
=
False
,
bn
=
False
,
expand
=
self
.
expand
,
align_corners
=
align_corners
)
self
.
scratch
.
refinenet1
=
FeatureFusionBlock_custom
(
features1
,
self
.
scratch
.
activation
,
deconv
=
False
,
bn
=
False
,
align_corners
=
align_corners
)
self
.
scratch
.
output_conv
=
nn
.
Sequential
(
nn
.
Conv2d
(
features
,
features
//
2
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
groups
=
self
.
groups
),
Interpolate
(
scale_factor
=
2
,
mode
=
"bilinear"
),
nn
.
Conv2d
(
features
//
2
,
32
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
),
self
.
scratch
.
activation
,
nn
.
Conv2d
(
32
,
1
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
),
nn
.
ReLU
(
True
)
if
non_negative
else
nn
.
Identity
(),
nn
.
Identity
(),
)
if
path
:
self
.
load
(
path
)
def
forward
(
self
,
x
):
"""Forward pass.
Args:
x (tensor): input data (image)
Returns:
tensor: depth
"""
if
self
.
channels_last
==
True
:
print
(
"self.channels_last = "
,
self
.
channels_last
)
x
.
contiguous
(
memory_format
=
torch
.
channels_last
)
layer_1
=
self
.
pretrained
.
layer1
(
x
)
layer_2
=
self
.
pretrained
.
layer2
(
layer_1
)
layer_3
=
self
.
pretrained
.
layer3
(
layer_2
)
layer_4
=
self
.
pretrained
.
layer4
(
layer_3
)
layer_1_rn
=
self
.
scratch
.
layer1_rn
(
layer_1
)
layer_2_rn
=
self
.
scratch
.
layer2_rn
(
layer_2
)
layer_3_rn
=
self
.
scratch
.
layer3_rn
(
layer_3
)
layer_4_rn
=
self
.
scratch
.
layer4_rn
(
layer_4
)
path_4
=
self
.
scratch
.
refinenet4
(
layer_4_rn
)
path_3
=
self
.
scratch
.
refinenet3
(
path_4
,
layer_3_rn
)
path_2
=
self
.
scratch
.
refinenet2
(
path_3
,
layer_2_rn
)
path_1
=
self
.
scratch
.
refinenet1
(
path_2
,
layer_1_rn
)
out
=
self
.
scratch
.
output_conv
(
path_1
)
return
torch
.
squeeze
(
out
,
dim
=
1
)
def
fuse_model
(
m
):
prev_previous_type
=
nn
.
Identity
()
prev_previous_name
=
''
previous_type
=
nn
.
Identity
()
previous_name
=
''
for
name
,
module
in
m
.
named_modules
():
if
prev_previous_type
==
nn
.
Conv2d
and
previous_type
==
nn
.
BatchNorm2d
and
type
(
module
)
==
nn
.
ReLU
:
# print("FUSED ", prev_previous_name, previous_name, name)
torch
.
quantization
.
fuse_modules
(
m
,
[
prev_previous_name
,
previous_name
,
name
],
inplace
=
True
)
elif
prev_previous_type
==
nn
.
Conv2d
and
previous_type
==
nn
.
BatchNorm2d
:
# print("FUSED ", prev_previous_name, previous_name)
torch
.
quantization
.
fuse_modules
(
m
,
[
prev_previous_name
,
previous_name
],
inplace
=
True
)
# elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
# print("FUSED ", previous_name, name)
# torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
prev_previous_type
=
previous_type
prev_previous_name
=
previous_name
previous_type
=
type
(
module
)
previous_name
=
name
\ No newline at end of file
ldm/modules/midas/midas/transforms.py
0 → 100644
View file @
4007efdd
import
numpy
as
np
import
cv2
import
math
def
apply_min_size
(
sample
,
size
,
image_interpolation_method
=
cv2
.
INTER_AREA
):
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
Args:
sample (dict): sample
size (tuple): image size
Returns:
tuple: new size
"""
shape
=
list
(
sample
[
"disparity"
].
shape
)
if
shape
[
0
]
>=
size
[
0
]
and
shape
[
1
]
>=
size
[
1
]:
return
sample
scale
=
[
0
,
0
]
scale
[
0
]
=
size
[
0
]
/
shape
[
0
]
scale
[
1
]
=
size
[
1
]
/
shape
[
1
]
scale
=
max
(
scale
)
shape
[
0
]
=
math
.
ceil
(
scale
*
shape
[
0
])
shape
[
1
]
=
math
.
ceil
(
scale
*
shape
[
1
])
# resize
sample
[
"image"
]
=
cv2
.
resize
(
sample
[
"image"
],
tuple
(
shape
[::
-
1
]),
interpolation
=
image_interpolation_method
)
sample
[
"disparity"
]
=
cv2
.
resize
(
sample
[
"disparity"
],
tuple
(
shape
[::
-
1
]),
interpolation
=
cv2
.
INTER_NEAREST
)
sample
[
"mask"
]
=
cv2
.
resize
(
sample
[
"mask"
].
astype
(
np
.
float32
),
tuple
(
shape
[::
-
1
]),
interpolation
=
cv2
.
INTER_NEAREST
,
)
sample
[
"mask"
]
=
sample
[
"mask"
].
astype
(
bool
)
return
tuple
(
shape
)
class
Resize
(
object
):
"""Resize sample to given size (width, height).
"""
def
__init__
(
self
,
width
,
height
,
resize_target
=
True
,
keep_aspect_ratio
=
False
,
ensure_multiple_of
=
1
,
resize_method
=
"lower_bound"
,
image_interpolation_method
=
cv2
.
INTER_AREA
,
):
"""Init.
Args:
width (int): desired output width
height (int): desired output height
resize_target (bool, optional):
True: Resize the full sample (image, mask, target).
False: Resize image only.
Defaults to True.
keep_aspect_ratio (bool, optional):
True: Keep the aspect ratio of the input sample.
Output sample might not have the given width and height, and
resize behaviour depends on the parameter 'resize_method'.
Defaults to False.
ensure_multiple_of (int, optional):
Output width and height is constrained to be multiple of this parameter.
Defaults to 1.
resize_method (str, optional):
"lower_bound": Output will be at least as large as the given size.
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
Defaults to "lower_bound".
"""
self
.
__width
=
width
self
.
__height
=
height
self
.
__resize_target
=
resize_target
self
.
__keep_aspect_ratio
=
keep_aspect_ratio
self
.
__multiple_of
=
ensure_multiple_of
self
.
__resize_method
=
resize_method
self
.
__image_interpolation_method
=
image_interpolation_method
def
constrain_to_multiple_of
(
self
,
x
,
min_val
=
0
,
max_val
=
None
):
y
=
(
np
.
round
(
x
/
self
.
__multiple_of
)
*
self
.
__multiple_of
).
astype
(
int
)
if
max_val
is
not
None
and
y
>
max_val
:
y
=
(
np
.
floor
(
x
/
self
.
__multiple_of
)
*
self
.
__multiple_of
).
astype
(
int
)
if
y
<
min_val
:
y
=
(
np
.
ceil
(
x
/
self
.
__multiple_of
)
*
self
.
__multiple_of
).
astype
(
int
)
return
y
def
get_size
(
self
,
width
,
height
):
# determine new height and width
scale_height
=
self
.
__height
/
height
scale_width
=
self
.
__width
/
width
if
self
.
__keep_aspect_ratio
:
if
self
.
__resize_method
==
"lower_bound"
:
# scale such that output size is lower bound
if
scale_width
>
scale_height
:
# fit width
scale_height
=
scale_width
else
:
# fit height
scale_width
=
scale_height
elif
self
.
__resize_method
==
"upper_bound"
:
# scale such that output size is upper bound
if
scale_width
<
scale_height
:
# fit width
scale_height
=
scale_width
else
:
# fit height
scale_width
=
scale_height
elif
self
.
__resize_method
==
"minimal"
:
# scale as least as possbile
if
abs
(
1
-
scale_width
)
<
abs
(
1
-
scale_height
):
# fit width
scale_height
=
scale_width
else
:
# fit height
scale_width
=
scale_height
else
:
raise
ValueError
(
f
"resize_method
{
self
.
__resize_method
}
not implemented"
)
if
self
.
__resize_method
==
"lower_bound"
:
new_height
=
self
.
constrain_to_multiple_of
(
scale_height
*
height
,
min_val
=
self
.
__height
)
new_width
=
self
.
constrain_to_multiple_of
(
scale_width
*
width
,
min_val
=
self
.
__width
)
elif
self
.
__resize_method
==
"upper_bound"
:
new_height
=
self
.
constrain_to_multiple_of
(
scale_height
*
height
,
max_val
=
self
.
__height
)
new_width
=
self
.
constrain_to_multiple_of
(
scale_width
*
width
,
max_val
=
self
.
__width
)
elif
self
.
__resize_method
==
"minimal"
:
new_height
=
self
.
constrain_to_multiple_of
(
scale_height
*
height
)
new_width
=
self
.
constrain_to_multiple_of
(
scale_width
*
width
)
else
:
raise
ValueError
(
f
"resize_method
{
self
.
__resize_method
}
not implemented"
)
return
(
new_width
,
new_height
)
def
__call__
(
self
,
sample
):
width
,
height
=
self
.
get_size
(
sample
[
"image"
].
shape
[
1
],
sample
[
"image"
].
shape
[
0
]
)
# resize sample
sample
[
"image"
]
=
cv2
.
resize
(
sample
[
"image"
],
(
width
,
height
),
interpolation
=
self
.
__image_interpolation_method
,
)
if
self
.
__resize_target
:
if
"disparity"
in
sample
:
sample
[
"disparity"
]
=
cv2
.
resize
(
sample
[
"disparity"
],
(
width
,
height
),
interpolation
=
cv2
.
INTER_NEAREST
,
)
if
"depth"
in
sample
:
sample
[
"depth"
]
=
cv2
.
resize
(
sample
[
"depth"
],
(
width
,
height
),
interpolation
=
cv2
.
INTER_NEAREST
)
sample
[
"mask"
]
=
cv2
.
resize
(
sample
[
"mask"
].
astype
(
np
.
float32
),
(
width
,
height
),
interpolation
=
cv2
.
INTER_NEAREST
,
)
sample
[
"mask"
]
=
sample
[
"mask"
].
astype
(
bool
)
return
sample
class
NormalizeImage
(
object
):
"""Normlize image by given mean and std.
"""
def
__init__
(
self
,
mean
,
std
):
self
.
__mean
=
mean
self
.
__std
=
std
def
__call__
(
self
,
sample
):
sample
[
"image"
]
=
(
sample
[
"image"
]
-
self
.
__mean
)
/
self
.
__std
return
sample
class
PrepareForNet
(
object
):
"""Prepare sample for usage as network input.
"""
def
__init__
(
self
):
pass
def
__call__
(
self
,
sample
):
image
=
np
.
transpose
(
sample
[
"image"
],
(
2
,
0
,
1
))
sample
[
"image"
]
=
np
.
ascontiguousarray
(
image
).
astype
(
np
.
float32
)
if
"mask"
in
sample
:
sample
[
"mask"
]
=
sample
[
"mask"
].
astype
(
np
.
float32
)
sample
[
"mask"
]
=
np
.
ascontiguousarray
(
sample
[
"mask"
])
if
"disparity"
in
sample
:
disparity
=
sample
[
"disparity"
].
astype
(
np
.
float32
)
sample
[
"disparity"
]
=
np
.
ascontiguousarray
(
disparity
)
if
"depth"
in
sample
:
depth
=
sample
[
"depth"
].
astype
(
np
.
float32
)
sample
[
"depth"
]
=
np
.
ascontiguousarray
(
depth
)
return
sample
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment