Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
08c85229
Commit
08c85229
authored
Jun 20, 2022
by
Patrick von Platen
Browse files
add license disclaimers to schedulers
parent
2b8bc91c
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
123 additions
and
84 deletions
+123
-84
src/diffusers/models/unet_rl.py
src/diffusers/models/unet_rl.py
+69
-48
src/diffusers/pipelines/grad_tts_utils.py
src/diffusers/pipelines/grad_tts_utils.py
+1
-0
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+6
-1
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+5
-1
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+7
-2
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+35
-32
No files found.
src/diffusers/models/unet_rl.py
View file @
08c85229
# model adapted from diffuser https://github.com/jannerm/diffuser/blob/main/diffuser/models/temporal.py
import
math
import
torch
import
torch.nn
as
nn
import
einops
from
einops.layers.torch
import
Rearrange
import
math
class
SinusoidalPosEmb
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
...
...
@@ -20,6 +23,7 @@ class SinusoidalPosEmb(nn.Module):
emb
=
torch
.
cat
((
emb
.
sin
(),
emb
.
cos
()),
dim
=-
1
)
return
emb
class
Downsample1d
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
...
...
@@ -28,6 +32,7 @@ class Downsample1d(nn.Module):
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
class
Upsample1d
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
...
...
@@ -36,57 +41,61 @@ class Upsample1d(nn.Module):
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
class
Conv1dBlock
(
nn
.
Module
):
'''
Conv1d --> GroupNorm --> Mish
'''
"""
Conv1d --> GroupNorm --> Mish
"""
def
__init__
(
self
,
inp_channels
,
out_channels
,
kernel_size
,
n_groups
=
8
):
super
().
__init__
()
self
.
block
=
nn
.
Sequential
(
nn
.
Conv1d
(
inp_channels
,
out_channels
,
kernel_size
,
padding
=
kernel_size
//
2
),
Rearrange
(
'
batch channels horizon -> batch channels 1 horizon
'
),
Rearrange
(
"
batch channels horizon -> batch channels 1 horizon
"
),
nn
.
GroupNorm
(
n_groups
,
out_channels
),
Rearrange
(
'
batch channels 1 horizon -> batch channels horizon
'
),
Rearrange
(
"
batch channels 1 horizon -> batch channels horizon
"
),
nn
.
Mish
(),
)
def
forward
(
self
,
x
):
return
self
.
block
(
x
)
class
ResidualTemporalBlock
(
nn
.
Module
):
class
ResidualTemporalBlock
(
nn
.
Module
):
def
__init__
(
self
,
inp_channels
,
out_channels
,
embed_dim
,
horizon
,
kernel_size
=
5
):
super
().
__init__
()
self
.
blocks
=
nn
.
ModuleList
([
Conv1dBlock
(
inp_channels
,
out_channels
,
kernel_size
),
Conv1dBlock
(
out_channels
,
out_channels
,
kernel_size
),
])
self
.
blocks
=
nn
.
ModuleList
(
[
Conv1dBlock
(
inp_channels
,
out_channels
,
kernel_size
),
Conv1dBlock
(
out_channels
,
out_channels
,
kernel_size
),
]
)
self
.
time_mlp
=
nn
.
Sequential
(
nn
.
Mish
(),
nn
.
Linear
(
embed_dim
,
out_channels
),
Rearrange
(
'
batch t -> batch t 1
'
),
Rearrange
(
"
batch t -> batch t 1
"
),
)
self
.
residual_conv
=
nn
.
Conv1d
(
inp_channels
,
out_channels
,
1
)
\
if
inp_channels
!=
out_channels
else
nn
.
Identity
()
self
.
residual_conv
=
(
nn
.
Conv1d
(
inp_channels
,
out_channels
,
1
)
if
inp_channels
!=
out_channels
else
nn
.
Identity
()
)
def
forward
(
self
,
x
,
t
):
'''
x : [ batch_size x inp_channels x horizon ]
t : [ batch_size x embed_dim ]
returns:
out : [ batch_size x out_channels x horizon ]
'''
"""
x : [ batch_size x inp_channels x horizon ]
t : [ batch_size x embed_dim ]
returns:
out : [ batch_size x out_channels x horizon ]
"""
out
=
self
.
blocks
[
0
](
x
)
+
self
.
time_mlp
(
t
)
out
=
self
.
blocks
[
1
](
out
)
return
out
+
self
.
residual_conv
(
x
)
class
TemporalUnet
(
nn
.
Module
):
class
TemporalUnet
(
nn
.
Module
):
def
__init__
(
self
,
horizon
,
...
...
@@ -99,7 +108,7 @@ class TemporalUnet(nn.Module):
dims
=
[
transition_dim
,
*
map
(
lambda
m
:
dim
*
m
,
dim_mults
)]
in_out
=
list
(
zip
(
dims
[:
-
1
],
dims
[
1
:]))
print
(
f
'
[ models/temporal ] Channel dimensions:
{
in_out
}
'
)
print
(
f
"
[ models/temporal ] Channel dimensions:
{
in_out
}
"
)
time_dim
=
dim
self
.
time_mlp
=
nn
.
Sequential
(
...
...
@@ -117,11 +126,15 @@ class TemporalUnet(nn.Module):
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
in_out
):
is_last
=
ind
>=
(
num_resolutions
-
1
)
self
.
downs
.
append
(
nn
.
ModuleList
([
ResidualTemporalBlock
(
dim_in
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
ResidualTemporalBlock
(
dim_out
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
Downsample1d
(
dim_out
)
if
not
is_last
else
nn
.
Identity
()
]))
self
.
downs
.
append
(
nn
.
ModuleList
(
[
ResidualTemporalBlock
(
dim_in
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
ResidualTemporalBlock
(
dim_out
,
dim_out
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
Downsample1d
(
dim_out
)
if
not
is_last
else
nn
.
Identity
(),
]
)
)
if
not
is_last
:
horizon
=
horizon
//
2
...
...
@@ -133,11 +146,15 @@ class TemporalUnet(nn.Module):
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
reversed
(
in_out
[
1
:])):
is_last
=
ind
>=
(
num_resolutions
-
1
)
self
.
ups
.
append
(
nn
.
ModuleList
([
ResidualTemporalBlock
(
dim_out
*
2
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
ResidualTemporalBlock
(
dim_in
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
Upsample1d
(
dim_in
)
if
not
is_last
else
nn
.
Identity
()
]))
self
.
ups
.
append
(
nn
.
ModuleList
(
[
ResidualTemporalBlock
(
dim_out
*
2
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
ResidualTemporalBlock
(
dim_in
,
dim_in
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
Upsample1d
(
dim_in
)
if
not
is_last
else
nn
.
Identity
(),
]
)
)
if
not
is_last
:
horizon
=
horizon
*
2
...
...
@@ -148,11 +165,11 @@ class TemporalUnet(nn.Module):
)
def
forward
(
self
,
x
,
cond
,
time
):
'''
x : [ batch x horizon x transition ]
'''
"""
x : [ batch x horizon x transition ]
"""
x
=
einops
.
rearrange
(
x
,
'
b h t -> b t h
'
)
x
=
einops
.
rearrange
(
x
,
"
b h t -> b t h
"
)
t
=
self
.
time_mlp
(
time
)
h
=
[]
...
...
@@ -174,11 +191,11 @@ class TemporalUnet(nn.Module):
x
=
self
.
final_conv
(
x
)
x
=
einops
.
rearrange
(
x
,
'
b t h -> b h t
'
)
x
=
einops
.
rearrange
(
x
,
"
b t h -> b h t
"
)
return
x
class
TemporalValue
(
nn
.
Module
):
class
TemporalValue
(
nn
.
Module
):
def
__init__
(
self
,
horizon
,
...
...
@@ -207,11 +224,15 @@ class TemporalValue(nn.Module):
print
(
in_out
)
for
dim_in
,
dim_out
in
in_out
:
self
.
blocks
.
append
(
nn
.
ModuleList
([
ResidualTemporalBlock
(
dim_in
,
dim_out
,
kernel_size
=
5
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
ResidualTemporalBlock
(
dim_out
,
dim_out
,
kernel_size
=
5
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
Downsample1d
(
dim_out
)
]))
self
.
blocks
.
append
(
nn
.
ModuleList
(
[
ResidualTemporalBlock
(
dim_in
,
dim_out
,
kernel_size
=
5
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
ResidualTemporalBlock
(
dim_out
,
dim_out
,
kernel_size
=
5
,
embed_dim
=
time_dim
,
horizon
=
horizon
),
Downsample1d
(
dim_out
),
]
)
)
horizon
=
horizon
//
2
...
...
@@ -224,11 +245,11 @@ class TemporalValue(nn.Module):
)
def
forward
(
self
,
x
,
cond
,
time
,
*
args
):
'''
x : [ batch x horizon x transition ]
'''
"""
x : [ batch x horizon x transition ]
"""
x
=
einops
.
rearrange
(
x
,
'
b h t -> b t h
'
)
x
=
einops
.
rearrange
(
x
,
"
b h t -> b t h
"
)
t
=
self
.
time_mlp
(
time
)
...
...
@@ -239,4 +260,4 @@ class TemporalValue(nn.Module):
x
=
x
.
view
(
len
(
x
),
-
1
)
out
=
self
.
final_block
(
torch
.
cat
([
x
,
t
],
dim
=-
1
))
return
out
\ No newline at end of file
return
out
src/diffusers/pipelines/grad_tts_utils.py
View file @
08c85229
...
...
@@ -233,6 +233,7 @@ def english_cleaners(text):
text
=
collapse_whitespace
(
text
)
return
text
try
:
_inflect
=
inflect
.
engine
()
except
:
...
...
src/diffusers/schedulers/scheduling_ddim.py
View file @
08c85229
# Copyright 2022 The HuggingFace Team. All rights reserved.
# Copyright 2022
Stanford University Team and
The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -11,6 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
# and https://github.com/hojonathanho/diffusion
import
math
import
numpy
as
np
...
...
@@ -31,6 +35,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
def
alpha_bar
(
time_step
):
return
math
.
cos
((
time_step
+
0.008
)
/
1.008
*
math
.
pi
/
2
)
**
2
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
08c85229
# Copyright 2022 The HuggingFace Team. All rights reserved.
# Copyright 2022
UC Berkely Team and
The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import
math
import
numpy
as
np
...
...
@@ -31,6 +34,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
def
alpha_bar
(
time_step
):
return
math
.
cos
((
time_step
+
0.008
)
/
1.008
*
math
.
pi
/
2
)
**
2
...
...
src/diffusers/schedulers/scheduling_pndm.py
View file @
08c85229
# Copyright 2022 The HuggingFace Team. All rights reserved.
# Copyright 2022
Zhejiang University Team and
The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -11,9 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import
math
import
numpy
as
np
from
..configuration_utils
import
ConfigMixin
from
.scheduling_utils
import
SchedulerMixin
...
...
@@ -30,6 +34,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
def
alpha_bar
(
time_step
):
return
math
.
cos
((
time_step
+
0.008
)
/
1.008
*
math
.
pi
/
2
)
**
2
...
...
tests/test_modeling_utils.py
View file @
08c85229
...
...
@@ -17,11 +17,11 @@
import
inspect
import
tempfile
import
unittest
import
numpy
as
np
import
pytest
import
numpy
as
np
import
torch
import
pytest
from
diffusers
import
(
BDDM
,
DDIM
,
...
...
@@ -30,10 +30,10 @@ from diffusers import (
PNDM
,
DDIMScheduler
,
DDPMScheduler
,
GLIDESuperResUNetModel
,
LatentDiffusion
,
PNDMScheduler
,
UNetModel
,
GLIDESuperResUNetModel
)
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.pipeline_utils
import
DiffusionPipeline
...
...
@@ -105,7 +105,7 @@ class ModelTesterMixin:
max_diff
=
(
image
-
new_image
).
abs
().
sum
().
item
()
self
.
assertLessEqual
(
max_diff
,
1e-5
,
"Models give different forward passes"
)
def
test_determinism
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
...
...
@@ -121,7 +121,7 @@ class ModelTesterMixin:
out_2
=
out_2
[
~
np
.
isnan
(
out_2
)]
max_diff
=
np
.
amax
(
np
.
abs
(
out_1
-
out_2
))
self
.
assertLessEqual
(
max_diff
,
1e-5
)
def
test_output
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
...
...
@@ -130,11 +130,11 @@ class ModelTesterMixin:
with
torch
.
no_grad
():
output
=
model
(
**
inputs_dict
)
self
.
assertIsNotNone
(
output
)
expected_shape
=
inputs_dict
[
"x"
].
shape
self
.
assertEqual
(
output
.
shape
,
expected_shape
,
"Input and output shapes do not match"
)
def
test_forward_signature
(
self
):
init_dict
,
_
=
self
.
prepare_init_args_and_inputs_for_common
()
...
...
@@ -145,14 +145,14 @@ class ModelTesterMixin:
expected_arg_names
=
[
"x"
,
"timesteps"
]
self
.
assertListEqual
(
arg_names
[:
2
],
expected_arg_names
)
def
test_model_from_config
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
eval
()
# test if the model can be loaded from the config
# and has all the expected shape
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
...
...
@@ -160,17 +160,17 @@ class ModelTesterMixin:
new_model
=
self
.
model_class
.
from_config
(
tmpdirname
)
new_model
.
to
(
torch_device
)
new_model
.
eval
()
# check if all paramters shape are the same
for
param_name
in
model
.
state_dict
().
keys
():
param_1
=
model
.
state_dict
()[
param_name
]
param_2
=
new_model
.
state_dict
()[
param_name
]
self
.
assertEqual
(
param_1
.
shape
,
param_2
.
shape
)
with
torch
.
no_grad
():
output_1
=
model
(
**
inputs_dict
)
output_2
=
new_model
(
**
inputs_dict
)
self
.
assertEqual
(
output_1
.
shape
,
output_2
.
shape
)
def
test_training
(
self
):
...
...
@@ -180,7 +180,7 @@ class ModelTesterMixin:
model
.
to
(
torch_device
)
model
.
train
()
output
=
model
(
**
inputs_dict
)
noise
=
torch
.
randn
((
inputs_dict
[
"x"
].
shape
[
0
],
)
+
self
.
get_output_shape
).
to
(
torch_device
)
noise
=
torch
.
randn
((
inputs_dict
[
"x"
].
shape
[
0
],)
+
self
.
get_output_shape
).
to
(
torch_device
)
loss
=
torch
.
nn
.
functional
.
mse_loss
(
output
,
noise
)
loss
.
backward
()
...
...
@@ -198,11 +198,11 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
time_step
=
torch
.
tensor
([
10
]).
to
(
torch_device
)
return
{
"x"
:
noise
,
"timesteps"
:
time_step
}
@
property
def
get_input_shape
(
self
):
return
(
3
,
32
,
32
)
@
property
def
get_output_shape
(
self
):
return
(
3
,
32
,
32
)
...
...
@@ -217,7 +217,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
...
...
@@ -227,7 +227,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
image
=
model
(
**
self
.
dummy_input
)
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained
(
self
):
model
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
model
.
eval
()
...
...
@@ -235,13 +235,13 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
resolution
,
model
.
config
.
resolution
)
time_step
=
torch
.
tensor
([
10
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
0.2891
,
-
0.1899
,
0.2595
,
-
0.6214
,
0.0968
,
-
0.2622
,
0.4688
,
0.1311
,
0.0053
])
...
...
@@ -249,6 +249,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
print
(
output_slice
)
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
GLIDESuperResUNetTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
GLIDESuperResUNetModel
...
...
@@ -266,19 +267,19 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
],
device
=
torch_device
)
return
{
"x"
:
noise
,
"timesteps"
:
time_step
,
"low_res"
:
low_res
}
@
property
def
get_input_shape
(
self
):
return
(
3
,
32
,
32
)
@
property
def
get_output_shape
(
self
):
return
(
6
,
32
,
32
)
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
"attention_resolutions"
:
(
2
,),
"channel_mult"
:
(
1
,
2
),
"channel_mult"
:
(
1
,
2
),
"in_channels"
:
6
,
"out_channels"
:
6
,
"model_channels"
:
32
,
...
...
@@ -287,7 +288,7 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
"num_res_blocks"
:
2
,
"resblock_updown"
:
True
,
"resolution"
:
32
,
"use_scale_shift_norm"
:
True
"use_scale_shift_norm"
:
True
,
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
...
...
@@ -302,13 +303,15 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
output
=
model
(
**
inputs_dict
)
output
,
_
=
torch
.
split
(
output
,
3
,
dim
=
1
)
self
.
assertIsNotNone
(
output
)
expected_shape
=
inputs_dict
[
"x"
].
shape
self
.
assertEqual
(
output
.
shape
,
expected_shape
,
"Input and output shapes do not match"
)
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
GLIDESuperResUNetModel
.
from_pretrained
(
"fusing/glide-super-res-dummy"
,
output_loading_info
=
True
)
model
,
loading_info
=
GLIDESuperResUNetModel
.
from_pretrained
(
"fusing/glide-super-res-dummy"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
...
...
@@ -316,7 +319,7 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
image
=
model
(
**
self
.
dummy_input
)
assert
image
is
not
None
,
"Make sure output is not None"
# TODO (patil-suraj): Check why GLIDESuperResUNetModel always outputs zero
@
unittest
.
skip
(
"GLIDESuperResUNetModel always outputs zero"
)
def
test_output_pretrained
(
self
):
...
...
@@ -326,14 +329,14 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
3
,
32
,
32
)
low_res
=
torch
.
randn
(
1
,
3
,
4
,
4
)
time_step
=
torch
.
tensor
([
42
]
*
noise
.
shape
[
0
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
,
low_res
)
output
,
_
=
torch
.
split
(
output
,
3
,
dim
=
1
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment