Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
33f3023e
Commit
33f3023e
authored
Jan 06, 2023
by
1SAA
Browse files
[hotfix] fix implement error in diffusers
parent
48d33b1b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
41 additions
and
21 deletions
+41
-21
colossalai/tensor/param_op_hook.py
colossalai/tensor/param_op_hook.py
+18
-0
examples/images/diffusion/ldm/modules/diffusionmodules/util.py
...les/images/diffusion/ldm/modules/diffusionmodules/util.py
+23
-21
No files found.
colossalai/tensor/param_op_hook.py
View file @
33f3023e
...
@@ -141,7 +141,25 @@ def _is_grad_tensor(obj) -> bool:
...
@@ -141,7 +141,25 @@ def _is_grad_tensor(obj) -> bool:
return
False
return
False
def
_has_grad_tensor
(
obj
)
->
bool
:
if
isinstance
(
obj
,
tuple
)
or
isinstance
(
obj
,
list
):
for
x
in
obj
:
if
_has_grad_tensor
(
x
):
return
True
return
False
elif
isinstance
(
obj
,
dict
):
for
x
in
obj
.
values
():
if
_has_grad_tensor
(
x
):
return
True
return
False
else
:
return
_is_grad_tensor
(
obj
)
def
_get_grad_args
(
*
args
):
def
_get_grad_args
(
*
args
):
# if there is no grad tensors, do nothing
if
not
_has_grad_tensor
(
args
):
return
args
,
None
# returns the identical args if there is a grad tensor
# returns the identical args if there is a grad tensor
for
obj
in
args
:
for
obj
in
args
:
if
_is_grad_tensor
(
obj
):
if
_is_grad_tensor
(
obj
):
...
...
examples/images/diffusion/ldm/modules/diffusionmodules/util.py
View file @
33f3023e
...
@@ -7,27 +7,22 @@
...
@@ -7,27 +7,22 @@
#
#
# thanks!
# thanks!
import
os
import
math
import
math
import
os
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
numpy
as
np
from
einops
import
repeat
from
einops
import
repeat
from
ldm.util
import
instantiate_from_config
from
ldm.util
import
instantiate_from_config
def
make_beta_schedule
(
schedule
,
n_timestep
,
linear_start
=
1e-4
,
linear_end
=
2e-2
,
cosine_s
=
8e-3
):
def
make_beta_schedule
(
schedule
,
n_timestep
,
linear_start
=
1e-4
,
linear_end
=
2e-2
,
cosine_s
=
8e-3
):
if
schedule
==
"linear"
:
if
schedule
==
"linear"
:
betas
=
(
betas
=
(
torch
.
linspace
(
linear_start
**
0.5
,
linear_end
**
0.5
,
n_timestep
,
dtype
=
torch
.
float64
)
**
2
)
torch
.
linspace
(
linear_start
**
0.5
,
linear_end
**
0.5
,
n_timestep
,
dtype
=
torch
.
float64
)
**
2
)
elif
schedule
==
"cosine"
:
elif
schedule
==
"cosine"
:
timesteps
=
(
timesteps
=
(
torch
.
arange
(
n_timestep
+
1
,
dtype
=
torch
.
float64
)
/
n_timestep
+
cosine_s
)
torch
.
arange
(
n_timestep
+
1
,
dtype
=
torch
.
float64
)
/
n_timestep
+
cosine_s
)
alphas
=
timesteps
/
(
1
+
cosine_s
)
*
np
.
pi
/
2
alphas
=
timesteps
/
(
1
+
cosine_s
)
*
np
.
pi
/
2
alphas
=
torch
.
cos
(
alphas
).
pow
(
2
)
alphas
=
torch
.
cos
(
alphas
).
pow
(
2
)
alphas
=
alphas
/
alphas
[
0
]
alphas
=
alphas
/
alphas
[
0
]
...
@@ -37,7 +32,7 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2,
...
@@ -37,7 +32,7 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2,
elif
schedule
==
"sqrt_linear"
:
elif
schedule
==
"sqrt_linear"
:
betas
=
torch
.
linspace
(
linear_start
,
linear_end
,
n_timestep
,
dtype
=
torch
.
float64
)
betas
=
torch
.
linspace
(
linear_start
,
linear_end
,
n_timestep
,
dtype
=
torch
.
float64
)
elif
schedule
==
"sqrt"
:
elif
schedule
==
"sqrt"
:
betas
=
torch
.
linspace
(
linear_start
,
linear_end
,
n_timestep
,
dtype
=
torch
.
float64
)
**
0.5
betas
=
torch
.
linspace
(
linear_start
,
linear_end
,
n_timestep
,
dtype
=
torch
.
float64
)
**
0.5
else
:
else
:
raise
ValueError
(
f
"schedule '
{
schedule
}
' unknown."
)
raise
ValueError
(
f
"schedule '
{
schedule
}
' unknown."
)
return
betas
.
numpy
()
return
betas
.
numpy
()
...
@@ -48,7 +43,7 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep
...
@@ -48,7 +43,7 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep
c
=
num_ddpm_timesteps
//
num_ddim_timesteps
c
=
num_ddpm_timesteps
//
num_ddim_timesteps
ddim_timesteps
=
np
.
asarray
(
list
(
range
(
0
,
num_ddpm_timesteps
,
c
)))
ddim_timesteps
=
np
.
asarray
(
list
(
range
(
0
,
num_ddpm_timesteps
,
c
)))
elif
ddim_discr_method
==
'quad'
:
elif
ddim_discr_method
==
'quad'
:
ddim_timesteps
=
((
np
.
linspace
(
0
,
np
.
sqrt
(
num_ddpm_timesteps
*
.
8
),
num_ddim_timesteps
))
**
2
).
astype
(
int
)
ddim_timesteps
=
((
np
.
linspace
(
0
,
np
.
sqrt
(
num_ddpm_timesteps
*
.
8
),
num_ddim_timesteps
))
**
2
).
astype
(
int
)
else
:
else
:
raise
NotImplementedError
(
f
'There is no ddim discretization method called "
{
ddim_discr_method
}
"'
)
raise
NotImplementedError
(
f
'There is no ddim discretization method called "
{
ddim_discr_method
}
"'
)
...
@@ -110,21 +105,26 @@ def checkpoint(func, inputs, params, flag):
...
@@ -110,21 +105,26 @@ def checkpoint(func, inputs, params, flag):
:param flag: if False, disable gradient checkpointing.
:param flag: if False, disable gradient checkpointing.
"""
"""
if
flag
:
if
flag
:
args
=
tuple
(
inputs
)
+
tuple
(
params
)
from
torch.utils.checkpoint
import
checkpoint
as
torch_checkpoint
return
CheckpointFunction
.
apply
(
func
,
len
(
inputs
),
*
args
)
return
torch_checkpoint
(
func
,
*
inputs
)
# args = tuple(inputs) + tuple(params)
# return CheckpointFunction.apply(func, len(inputs), *args)
else
:
else
:
return
func
(
*
inputs
)
return
func
(
*
inputs
)
class
CheckpointFunction
(
torch
.
autograd
.
Function
):
class
CheckpointFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
run_function
,
length
,
*
args
):
def
forward
(
ctx
,
run_function
,
length
,
*
args
):
ctx
.
run_function
=
run_function
ctx
.
run_function
=
run_function
ctx
.
input_tensors
=
list
(
args
[:
length
])
ctx
.
input_tensors
=
list
(
args
[:
length
])
ctx
.
input_params
=
list
(
args
[
length
:])
ctx
.
input_params
=
list
(
args
[
length
:])
ctx
.
gpu_autocast_kwargs
=
{
"enabled"
:
torch
.
is_autocast_enabled
(),
ctx
.
gpu_autocast_kwargs
=
{
"enabled"
:
torch
.
is_autocast_enabled
(),
"dtype"
:
torch
.
get_autocast_gpu_dtype
(),
"dtype"
:
torch
.
get_autocast_gpu_dtype
(),
"cache_enabled"
:
torch
.
is_autocast_cache_enabled
()}
"cache_enabled"
:
torch
.
is_autocast_cache_enabled
()
}
with
torch
.
no_grad
():
with
torch
.
no_grad
():
output_tensors
=
ctx
.
run_function
(
*
ctx
.
input_tensors
)
output_tensors
=
ctx
.
run_function
(
*
ctx
.
input_tensors
)
return
output_tensors
return
output_tensors
...
@@ -162,9 +162,8 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
...
@@ -162,9 +162,8 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
"""
"""
if
not
repeat_only
:
if
not
repeat_only
:
half
=
dim
//
2
half
=
dim
//
2
freqs
=
torch
.
exp
(
freqs
=
torch
.
exp
(
-
math
.
log
(
max_period
)
*
torch
.
arange
(
start
=
0
,
end
=
half
,
dtype
=
torch
.
float32
)
/
-
math
.
log
(
max_period
)
*
torch
.
arange
(
start
=
0
,
end
=
half
,
dtype
=
torch
.
float32
)
/
half
half
).
to
(
device
=
timesteps
.
device
)
).
to
(
device
=
timesteps
.
device
)
args
=
timesteps
[:,
None
].
float
()
*
freqs
[
None
]
args
=
timesteps
[:,
None
].
float
()
*
freqs
[
None
]
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
)
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
)
if
dim
%
2
:
if
dim
%
2
:
...
@@ -211,14 +210,17 @@ def normalization(channels):
...
@@ -211,14 +210,17 @@ def normalization(channels):
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class
SiLU
(
nn
.
Module
):
class
SiLU
(
nn
.
Module
):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
x
*
torch
.
sigmoid
(
x
)
return
x
*
torch
.
sigmoid
(
x
)
class
GroupNorm32
(
nn
.
GroupNorm
):
class
GroupNorm32
(
nn
.
GroupNorm
):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
super
().
forward
(
x
.
float
()).
type
(
x
.
dtype
)
return
super
().
forward
(
x
.
float
()).
type
(
x
.
dtype
)
def
conv_nd
(
dims
,
*
args
,
**
kwargs
):
def
conv_nd
(
dims
,
*
args
,
**
kwargs
):
"""
"""
Create a 1D, 2D, or 3D convolution module.
Create a 1D, 2D, or 3D convolution module.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment