Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
33f3023e
Commit
33f3023e
authored
Jan 06, 2023
by
1SAA
Browse files
[hotfix] fix implement error in diffusers
parent
48d33b1b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
41 additions
and
21 deletions
+41
-21
colossalai/tensor/param_op_hook.py
colossalai/tensor/param_op_hook.py
+18
-0
examples/images/diffusion/ldm/modules/diffusionmodules/util.py
...les/images/diffusion/ldm/modules/diffusionmodules/util.py
+23
-21
No files found.
colossalai/tensor/param_op_hook.py
View file @
33f3023e
...
...
@@ -141,7 +141,25 @@ def _is_grad_tensor(obj) -> bool:
return
False
def
_has_grad_tensor
(
obj
)
->
bool
:
if
isinstance
(
obj
,
tuple
)
or
isinstance
(
obj
,
list
):
for
x
in
obj
:
if
_has_grad_tensor
(
x
):
return
True
return
False
elif
isinstance
(
obj
,
dict
):
for
x
in
obj
.
values
():
if
_has_grad_tensor
(
x
):
return
True
return
False
else
:
return
_is_grad_tensor
(
obj
)
def
_get_grad_args
(
*
args
):
# if there is no grad tensors, do nothing
if
not
_has_grad_tensor
(
args
):
return
args
,
None
# returns the identical args if there is a grad tensor
for
obj
in
args
:
if
_is_grad_tensor
(
obj
):
...
...
examples/images/diffusion/ldm/modules/diffusionmodules/util.py
View file @
33f3023e
...
...
@@ -7,27 +7,22 @@
#
# thanks!
import
os
import
math
import
os
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
numpy
as
np
from
einops
import
repeat
from
ldm.util
import
instantiate_from_config
def
make_beta_schedule
(
schedule
,
n_timestep
,
linear_start
=
1e-4
,
linear_end
=
2e-2
,
cosine_s
=
8e-3
):
if
schedule
==
"linear"
:
betas
=
(
torch
.
linspace
(
linear_start
**
0.5
,
linear_end
**
0.5
,
n_timestep
,
dtype
=
torch
.
float64
)
**
2
)
betas
=
(
torch
.
linspace
(
linear_start
**
0.5
,
linear_end
**
0.5
,
n_timestep
,
dtype
=
torch
.
float64
)
**
2
)
elif
schedule
==
"cosine"
:
timesteps
=
(
torch
.
arange
(
n_timestep
+
1
,
dtype
=
torch
.
float64
)
/
n_timestep
+
cosine_s
)
timesteps
=
(
torch
.
arange
(
n_timestep
+
1
,
dtype
=
torch
.
float64
)
/
n_timestep
+
cosine_s
)
alphas
=
timesteps
/
(
1
+
cosine_s
)
*
np
.
pi
/
2
alphas
=
torch
.
cos
(
alphas
).
pow
(
2
)
alphas
=
alphas
/
alphas
[
0
]
...
...
@@ -37,7 +32,7 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2,
elif
schedule
==
"sqrt_linear"
:
betas
=
torch
.
linspace
(
linear_start
,
linear_end
,
n_timestep
,
dtype
=
torch
.
float64
)
elif
schedule
==
"sqrt"
:
betas
=
torch
.
linspace
(
linear_start
,
linear_end
,
n_timestep
,
dtype
=
torch
.
float64
)
**
0.5
betas
=
torch
.
linspace
(
linear_start
,
linear_end
,
n_timestep
,
dtype
=
torch
.
float64
)
**
0.5
else
:
raise
ValueError
(
f
"schedule '
{
schedule
}
' unknown."
)
return
betas
.
numpy
()
...
...
@@ -48,7 +43,7 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep
c
=
num_ddpm_timesteps
//
num_ddim_timesteps
ddim_timesteps
=
np
.
asarray
(
list
(
range
(
0
,
num_ddpm_timesteps
,
c
)))
elif
ddim_discr_method
==
'quad'
:
ddim_timesteps
=
((
np
.
linspace
(
0
,
np
.
sqrt
(
num_ddpm_timesteps
*
.
8
),
num_ddim_timesteps
))
**
2
).
astype
(
int
)
ddim_timesteps
=
((
np
.
linspace
(
0
,
np
.
sqrt
(
num_ddpm_timesteps
*
.
8
),
num_ddim_timesteps
))
**
2
).
astype
(
int
)
else
:
raise
NotImplementedError
(
f
'There is no ddim discretization method called "
{
ddim_discr_method
}
"'
)
...
...
@@ -110,21 +105,26 @@ def checkpoint(func, inputs, params, flag):
:param flag: if False, disable gradient checkpointing.
"""
if
flag
:
args
=
tuple
(
inputs
)
+
tuple
(
params
)
return
CheckpointFunction
.
apply
(
func
,
len
(
inputs
),
*
args
)
from
torch.utils.checkpoint
import
checkpoint
as
torch_checkpoint
return
torch_checkpoint
(
func
,
*
inputs
)
# args = tuple(inputs) + tuple(params)
# return CheckpointFunction.apply(func, len(inputs), *args)
else
:
return
func
(
*
inputs
)
class
CheckpointFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
run_function
,
length
,
*
args
):
ctx
.
run_function
=
run_function
ctx
.
input_tensors
=
list
(
args
[:
length
])
ctx
.
input_params
=
list
(
args
[
length
:])
ctx
.
gpu_autocast_kwargs
=
{
"enabled"
:
torch
.
is_autocast_enabled
(),
ctx
.
gpu_autocast_kwargs
=
{
"enabled"
:
torch
.
is_autocast_enabled
(),
"dtype"
:
torch
.
get_autocast_gpu_dtype
(),
"cache_enabled"
:
torch
.
is_autocast_cache_enabled
()}
"cache_enabled"
:
torch
.
is_autocast_cache_enabled
()
}
with
torch
.
no_grad
():
output_tensors
=
ctx
.
run_function
(
*
ctx
.
input_tensors
)
return
output_tensors
...
...
@@ -162,9 +162,8 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
"""
if
not
repeat_only
:
half
=
dim
//
2
freqs
=
torch
.
exp
(
-
math
.
log
(
max_period
)
*
torch
.
arange
(
start
=
0
,
end
=
half
,
dtype
=
torch
.
float32
)
/
half
).
to
(
device
=
timesteps
.
device
)
freqs
=
torch
.
exp
(
-
math
.
log
(
max_period
)
*
torch
.
arange
(
start
=
0
,
end
=
half
,
dtype
=
torch
.
float32
)
/
half
).
to
(
device
=
timesteps
.
device
)
args
=
timesteps
[:,
None
].
float
()
*
freqs
[
None
]
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
)
if
dim
%
2
:
...
...
@@ -211,14 +210,17 @@ def normalization(channels):
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class
SiLU
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
*
torch
.
sigmoid
(
x
)
class
GroupNorm32
(
nn
.
GroupNorm
):
def
forward
(
self
,
x
):
return
super
().
forward
(
x
.
float
()).
type
(
x
.
dtype
)
def
conv_nd
(
dims
,
*
args
,
**
kwargs
):
"""
Create a 1D, 2D, or 3D convolution module.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment