Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
77aadfee
Commit
77aadfee
authored
Jun 20, 2022
by
Patrick von Platen
Browse files
Merge branch 'main' of
https://github.com/huggingface/diffusers
parents
452339e2
80898b52
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
147 additions
and
300 deletions
+147
-300
1
1
+0
-289
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+2
-2
src/diffusers/pipelines/pipeline_grad_tts.py
src/diffusers/pipelines/pipeline_grad_tts.py
+1
-1
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+144
-8
No files found.
1
deleted
100644 → 0
View file @
452339e2
#
coding
=
utf
-
8
#
Copyright
2022
The
HuggingFace
Inc
.
team
.
#
Copyright
(
c
)
2022
,
NVIDIA
CORPORATION
.
All
rights
reserved
.
#
#
Licensed
under
the
Apache
License
,
Version
2.0
(
the
"License"
);
#
you
may
not
use
this
file
except
in
compliance
with
the
License
.
#
You
may
obtain
a
copy
of
the
License
at
#
#
http
://
www
.
apache
.
org
/
licenses
/
LICENSE
-
2.0
#
#
Unless
required
by
applicable
law
or
agreed
to
in
writing
,
software
#
distributed
under
the
License
is
distributed
on
an
"AS IS"
BASIS
,
#
WITHOUT
WARRANTIES
OR
CONDITIONS
OF
ANY
KIND
,
either
express
or
implied
.
#
See
the
License
for
the
specific
language
governing
permissions
and
#
limitations
under
the
License
.
""" ConfigMixinuration base class and utilities."""
import
inspect
import
json
import
os
import
re
from
collections
import
OrderedDict
from
typing
import
Any
,
Dict
,
Tuple
,
Union
from
huggingface_hub
import
hf_hub_download
from
requests
import
HTTPError
from
.
import
__version__
from
.
utils
import
(
DIFFUSERS_CACHE
,
HUGGINGFACE_CO_RESOLVE_ENDPOINT
,
EntryNotFoundError
,
RepositoryNotFoundError
,
RevisionNotFoundError
,
logging
,
)
logger
=
logging
.
get_logger
(
__name__
)
_re_configuration_file
=
re
.
compile
(
r
"config\.(.*)\.json"
)
class
ConfigMixin
:
r
"""
Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
methods for loading/downloading/saving configurations.
"""
config_name
=
None
def
register_to_config
(
self
,
**
kwargs
):
if
self
.
config_name
is
None
:
raise
NotImplementedError
(
f
"Make sure that {self.__class__} has defined a class name `config_name`"
)
kwargs
[
"_class_name"
]
=
self
.
__class__
.
__name__
kwargs
[
"_diffusers_version"
]
=
__version__
for
key
,
value
in
kwargs
.
items
():
try
:
setattr
(
self
,
key
,
value
)
except
AttributeError
as
err
:
logger
.
error
(
f
"Can't set {key} with value {value} for {self}"
)
raise
err
if
not
hasattr
(
self
,
"_internal_dict"
):
internal_dict
=
kwargs
else
:
previous_dict
=
dict
(
self
.
_internal_dict
)
internal_dict
=
{**
self
.
_internal_dict
,
**
kwargs
}
logger
.
debug
(
f
"Updating config from {previous_dict} to {internal_dict}"
)
self
.
_internal_dict
=
FrozenDict
(
internal_dict
)
def
save_config
(
self
,
save_directory
:
Union
[
str
,
os
.
PathLike
],
push_to_hub
:
bool
=
False
,
**
kwargs
):
"""
Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
[`~ConfigMixin.from_config`] class method.
Args:
save_directory (`str` or `os.PathLike`):
Directory where the configuration JSON file will be saved (will be created if it does not exist).
kwargs:
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
if
os
.
path
.
isfile
(
save_directory
):
raise
AssertionError
(
f
"Provided path ({save_directory}) should be a directory, not a file"
)
os
.
makedirs
(
save_directory
,
exist_ok
=
True
)
#
If
we
save
using
the
predefined
names
,
we
can
load
using
`
from_config
`
output_config_file
=
os
.
path
.
join
(
save_directory
,
self
.
config_name
)
self
.
to_json_file
(
output_config_file
)
logger
.
info
(
f
"ConfigMixinuration saved in {output_config_file}"
)
@
classmethod
def
from_config
(
cls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
return_unused_kwargs
=
False
,
**
kwargs
):
config_dict
=
cls
.
get_config_dict
(
pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
**
kwargs
)
init_dict
,
unused_kwargs
=
cls
.
extract_init_dict
(
config_dict
,
**
kwargs
)
model
=
cls
(**
init_dict
)
if
return_unused_kwargs
:
return
model
,
unused_kwargs
else
:
return
model
@
classmethod
def
get_config_dict
(
cls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
**
kwargs
)
->
Tuple
[
Dict
[
str
,
Any
],
Dict
[
str
,
Any
]]:
cache_dir
=
kwargs
.
pop
(
"cache_dir"
,
DIFFUSERS_CACHE
)
force_download
=
kwargs
.
pop
(
"force_download"
,
False
)
resume_download
=
kwargs
.
pop
(
"resume_download"
,
False
)
proxies
=
kwargs
.
pop
(
"proxies"
,
None
)
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
local_files_only
=
kwargs
.
pop
(
"local_files_only"
,
False
)
revision
=
kwargs
.
pop
(
"revision"
,
None
)
user_agent
=
{
"file_type"
:
"config"
}
pretrained_model_name_or_path
=
str
(
pretrained_model_name_or_path
)
if
cls
.
config_name
is
None
:
raise
ValueError
(
"`self.config_name` is not defined. Note that one should not load a config from "
"`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
)
if
os
.
path
.
isfile
(
pretrained_model_name_or_path
):
config_file
=
pretrained_model_name_or_path
elif
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
if
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_model_name_or_path
,
cls
.
config_name
)):
#
Load
from
a
PyTorch
checkpoint
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
cls
.
config_name
)
else
:
raise
EnvironmentError
(
f
"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
)
else
:
try
:
#
Load
from
URL
or
cache
if
already
cached
config_file
=
hf_hub_download
(
pretrained_model_name_or_path
,
filename
=
cls
.
config_name
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
,
resume_download
=
resume_download
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
user_agent
=
user_agent
,
)
except
RepositoryNotFoundError
:
raise
EnvironmentError
(
f
"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed"
" on 'https://huggingface.co/models'
\n
If this is a private repository, make sure to pass a token"
" having permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and"
" pass `use_auth_token=True`."
)
except
RevisionNotFoundError
:
raise
EnvironmentError
(
f
"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
" this model name. Check the model page at"
f
" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except
EntryNotFoundError
:
raise
EnvironmentError
(
f
"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
)
except
HTTPError
as
err
:
raise
EnvironmentError
(
"There was a specific connection error when trying to load"
f
" {pretrained_model_name_or_path}:
\n
{err}"
)
except
ValueError
:
raise
EnvironmentError
(
f
"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f
" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f
" directory containing a {cls.config_name} file.
\n
Checkout your internet connection or see how to"
" run the library in offline mode at"
" 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
)
except
EnvironmentError
:
raise
EnvironmentError
(
f
"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f
"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f
"containing a {cls.config_name} file"
)
try
:
#
Load
config
dict
config_dict
=
cls
.
_dict_from_json_file
(
config_file
)
except
(
json
.
JSONDecodeError
,
UnicodeDecodeError
):
raise
EnvironmentError
(
f
"It looks like the config file at '{config_file}' is not a valid JSON file."
)
return
config_dict
@
classmethod
def
extract_init_dict
(
cls
,
config_dict
,
**
kwargs
):
expected_keys
=
set
(
dict
(
inspect
.
signature
(
cls
.
__init__
).
parameters
).
keys
())
expected_keys
.
remove
(
"self"
)
init_dict
=
{}
for
key
in
expected_keys
:
if
key
in
kwargs
:
#
overwrite
key
init_dict
[
key
]
=
kwargs
.
pop
(
key
)
elif
key
in
config_dict
:
#
use
value
from
config
dict
init_dict
[
key
]
=
config_dict
.
pop
(
key
)
unused_kwargs
=
config_dict
.
update
(
kwargs
)
passed_keys
=
set
(
init_dict
.
keys
())
if
len
(
expected_keys
-
passed_keys
)
>
0
:
logger
.
warning
(
f
"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
)
return
init_dict
,
unused_kwargs
@
classmethod
def
_dict_from_json_file
(
cls
,
json_file
:
Union
[
str
,
os
.
PathLike
]):
with
open
(
json_file
,
"r"
,
encoding
=
"utf-8"
)
as
reader
:
text
=
reader
.
read
()
return
json
.
loads
(
text
)
def
__repr__
(
self
):
return
f
"{self.__class__.__name__} {self.to_json_string()}"
@
property
def
config
(
self
)
->
Dict
[
str
,
Any
]:
return
self
.
_internal_dict
def
to_json_string
(
self
)
->
str
:
"""
Serializes this instance to a JSON string.
Returns:
`str`: String containing all the attributes that make up this configuration instance in JSON format.
"""
import
ipdb
;
ipdb
.
set_trace
()
config_dict
=
self
.
_internal_dict
return
json
.
dumps
(
config_dict
,
indent
=
2
,
sort_keys
=
True
)
+
"
\n
"
def
to_json_file
(
self
,
json_file_path
:
Union
[
str
,
os
.
PathLike
]):
"""
Save this instance to a JSON file.
Args:
json_file_path (`str` or `os.PathLike`):
Path to the JSON file in which this configuration instance's parameters will be saved.
"""
with
open
(
json_file_path
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
writer
.
write
(
self
.
to_json_string
())
class
FrozenDict
(
OrderedDict
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(*
args
,
**
kwargs
)
for
key
,
value
in
self
.
items
():
setattr
(
self
,
key
,
value
)
self
.
__frozen
=
True
def
__delitem__
(
self
,
*
args
,
**
kwargs
):
raise
Exception
(
f
"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance."
)
def
setdefault
(
self
,
*
args
,
**
kwargs
):
raise
Exception
(
f
"You cannot use ``setdefault`` on a {self.__class__.__name__} instance."
)
def
pop
(
self
,
*
args
,
**
kwargs
):
raise
Exception
(
f
"You cannot use ``pop`` on a {self.__class__.__name__} instance."
)
def
update
(
self
,
*
args
,
**
kwargs
):
raise
Exception
(
f
"You cannot use ``update`` on a {self.__class__.__name__} instance."
)
def
__setattr__
(
self
,
name
,
value
):
if
hasattr
(
self
,
"__frozen"
)
and
self
.
__frozen
:
raise
Exception
(
f
"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance."
)
super
().
__setattr__
(
name
,
value
)
def
__setitem__
(
self
,
name
,
value
):
if
hasattr
(
self
,
"__frozen"
)
and
self
.
__frozen
:
raise
Exception
(
f
"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance."
)
super
().
__setitem__
(
name
,
value
)
src/diffusers/models/unet_grad_tts.py
View file @
77aadfee
...
...
@@ -190,7 +190,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
self
.
final_block
=
Block
(
dim
,
dim
)
self
.
final_conv
=
torch
.
nn
.
Conv2d
(
dim
,
1
,
1
)
def
forward
(
self
,
x
,
mask
,
mu
,
t
,
spk
=
None
):
def
forward
(
self
,
x
,
timesteps
,
mu
,
mask
,
spk
=
None
):
if
self
.
n_spks
>
1
:
# Get speaker embedding
spk
=
self
.
spk_emb
(
spk
)
...
...
@@ -198,7 +198,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
if
not
isinstance
(
spk
,
type
(
None
)):
s
=
self
.
spk_mlp
(
spk
)
t
=
self
.
time_pos_emb
(
t
,
scale
=
self
.
pe_scale
)
t
=
self
.
time_pos_emb
(
t
imesteps
,
scale
=
self
.
pe_scale
)
t
=
self
.
mlp
(
t
)
if
self
.
n_spks
<
2
:
...
...
src/diffusers/pipelines/pipeline_grad_tts.py
View file @
77aadfee
...
...
@@ -472,7 +472,7 @@ class GradTTS(DiffusionPipeline):
t
=
(
1.0
-
(
t
+
0.5
)
*
h
)
*
torch
.
ones
(
z
.
shape
[
0
],
dtype
=
z
.
dtype
,
device
=
z
.
device
)
time
=
t
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
residual
=
self
.
unet
(
xt
,
y_mask
,
mu_y
,
t
,
speaker_id
)
residual
=
self
.
unet
(
xt
,
t
,
mu_y
,
y_mask
,
speaker_id
)
xt
=
self
.
noise_scheduler
.
step
(
xt
,
residual
,
mu_y
,
h
,
time
)
xt
=
xt
*
y_mask
...
...
tests/test_modeling_utils.py
View file @
77aadfee
...
...
@@ -34,6 +34,8 @@ from diffusers import (
LatentDiffusion
,
PNDMScheduler
,
UNetModel
,
UNetLDMModel
,
UNetGradTTSModel
,
)
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.pipeline_utils
import
DiffusionPipeline
...
...
@@ -246,7 +248,6 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
# fmt: off
expected_output_slice
=
torch
.
tensor
([
0.2891
,
-
0.1899
,
0.2595
,
-
0.6214
,
0.0968
,
-
0.2622
,
0.4688
,
0.1311
,
0.0053
])
# fmt: on
print
(
output_slice
)
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
...
...
@@ -320,17 +321,14 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
assert
image
is
not
None
,
"Make sure output is not None"
# TODO (patil-suraj): Check why GLIDESuperResUNetModel always outputs zero
@
unittest
.
skip
(
"GLIDESuperResUNetModel always outputs zero"
)
def
test_output_pretrained
(
self
):
model
=
GLIDESuperResUNetModel
.
from_pretrained
(
"fusing/glide-super-res-dummy"
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
3
,
32
,
32
)
noise
=
torch
.
randn
(
1
,
3
,
64
,
64
)
low_res
=
torch
.
randn
(
1
,
3
,
4
,
4
)
time_step
=
torch
.
tensor
([
42
]
*
noise
.
shape
[
0
])
...
...
@@ -340,9 +338,148 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
output
,
_
=
torch
.
split
(
output
,
3
,
dim
=
1
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
0.2891
,
-
0.1899
,
0.2
59
5
,
-
0.6214
,
0.0968
,
-
0.2622
,
0.4688
,
0.1311
,
0.0053
])
expected_output_slice
=
torch
.
tensor
([
-
22.8782
,
-
23.2652
,
-
15.3966
,
-
22.8034
,
-
23.31
59
,
-
15.5640
,
-
15.3970
,
-
15.4614
,
-
10.4370
])
# fmt: on
print
(
output_slice
)
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
UNetLDMModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
UNetLDMModel
@
property
def
dummy_input
(
self
):
batch_size
=
4
num_channels
=
4
sizes
=
(
32
,
32
)
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]).
to
(
torch_device
)
return
{
"x"
:
noise
,
"timesteps"
:
time_step
}
@
property
def
get_input_shape
(
self
):
return
(
4
,
32
,
32
)
@
property
def
get_output_shape
(
self
):
return
(
4
,
32
,
32
)
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
"image_size"
:
32
,
"in_channels"
:
4
,
"out_channels"
:
4
,
"model_channels"
:
32
,
"num_res_blocks"
:
2
,
"attention_resolutions"
:
(
16
,),
"channel_mult"
:
(
1
,
2
),
"num_heads"
:
2
,
"conv_resample"
:
True
,
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
UNetLDMModel
.
from_pretrained
(
"fusing/unet-ldm-dummy"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained
(
self
):
model
=
UNetLDMModel
.
from_pretrained
(
"fusing/unet-ldm-dummy"
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
image_size
,
model
.
config
.
image_size
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
13.3258
,
-
20.1100
,
-
15.9873
,
-
17.6617
,
-
23.0596
,
-
17.9419
,
-
13.3675
,
-
16.1889
,
-
12.3800
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
UNetGradTTSModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
UNetGradTTSModel
@
property
def
dummy_input
(
self
):
batch_size
=
4
num_features
=
32
seq_len
=
16
noise
=
floats_tensor
((
batch_size
,
num_features
,
seq_len
)).
to
(
torch_device
)
condition
=
floats_tensor
((
batch_size
,
num_features
,
seq_len
)).
to
(
torch_device
)
mask
=
floats_tensor
((
batch_size
,
1
,
seq_len
)).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]
*
batch_size
).
to
(
torch_device
)
return
{
"x"
:
noise
,
"timesteps"
:
time_step
,
"mu"
:
condition
,
"mask"
:
mask
}
@
property
def
get_input_shape
(
self
):
return
(
4
,
32
,
16
)
@
property
def
get_output_shape
(
self
):
return
(
4
,
32
,
16
)
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
"dim"
:
64
,
"groups"
:
4
,
"dim_mults"
:
(
1
,
2
),
"n_feats"
:
32
,
"pe_scale"
:
1000
,
"n_spks"
:
1
,
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
UNetGradTTSModel
.
from_pretrained
(
"fusing/unet-grad-tts-dummy"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained
(
self
):
model
=
UNetGradTTSModel
.
from_pretrained
(
"fusing/unet-grad-tts-dummy"
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
num_features
=
model
.
config
.
n_feats
seq_len
=
16
noise
=
torch
.
randn
((
1
,
num_features
,
seq_len
))
condition
=
torch
.
randn
((
1
,
num_features
,
seq_len
))
mask
=
torch
.
randn
((
1
,
1
,
seq_len
))
time_step
=
torch
.
tensor
([
10
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
,
condition
,
mask
)
output_slice
=
output
[
0
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
0.0690
,
-
0.0531
,
0.0633
,
-
0.0660
,
-
0.0541
,
0.0650
,
-
0.0656
,
-
0.0555
,
0.0617
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
...
...
@@ -450,7 +587,6 @@ class PipelineTesterMixin(unittest.TestCase):
image
=
ldm
([
prompt
],
generator
=
generator
,
num_inference_steps
=
20
)
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
print
(
image_slice
.
shape
)
assert
image
.
shape
==
(
1
,
3
,
256
,
256
)
expected_slice
=
torch
.
tensor
([
0.7295
,
0.7358
,
0.7256
,
0.7435
,
0.7095
,
0.6884
,
0.7325
,
0.6921
,
0.6458
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment