Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
77aadfee
Commit
77aadfee
authored
Jun 20, 2022
by
Patrick von Platen
Browse files
Merge branch 'main' of
https://github.com/huggingface/diffusers
parents
452339e2
80898b52
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
147 additions
and
300 deletions
+147
-300
1
1
+0
-289
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+2
-2
src/diffusers/pipelines/pipeline_grad_tts.py
src/diffusers/pipelines/pipeline_grad_tts.py
+1
-1
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+144
-8
No files found.
1
deleted
100644 → 0
View file @
452339e2
#
coding
=
utf
-
8
#
Copyright
2022
The
HuggingFace
Inc
.
team
.
#
Copyright
(
c
)
2022
,
NVIDIA
CORPORATION
.
All
rights
reserved
.
#
#
Licensed
under
the
Apache
License
,
Version
2.0
(
the
"License"
);
#
you
may
not
use
this
file
except
in
compliance
with
the
License
.
#
You
may
obtain
a
copy
of
the
License
at
#
#
http
://
www
.
apache
.
org
/
licenses
/
LICENSE
-
2.0
#
#
Unless
required
by
applicable
law
or
agreed
to
in
writing
,
software
#
distributed
under
the
License
is
distributed
on
an
"AS IS"
BASIS
,
#
WITHOUT
WARRANTIES
OR
CONDITIONS
OF
ANY
KIND
,
either
express
or
implied
.
#
See
the
License
for
the
specific
language
governing
permissions
and
#
limitations
under
the
License
.
""" ConfigMixinuration base class and utilities."""
import
inspect
import
json
import
os
import
re
from
collections
import
OrderedDict
from
typing
import
Any
,
Dict
,
Tuple
,
Union
from
huggingface_hub
import
hf_hub_download
from
requests
import
HTTPError
from
.
import
__version__
from
.
utils
import
(
DIFFUSERS_CACHE
,
HUGGINGFACE_CO_RESOLVE_ENDPOINT
,
EntryNotFoundError
,
RepositoryNotFoundError
,
RevisionNotFoundError
,
logging
,
)
logger
=
logging
.
get_logger
(
__name__
)
_re_configuration_file
=
re
.
compile
(
r
"config\.(.*)\.json"
)
class
ConfigMixin
:
r
"""
Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
methods for loading/downloading/saving configurations.
"""
config_name
=
None
def
register_to_config
(
self
,
**
kwargs
):
if
self
.
config_name
is
None
:
raise
NotImplementedError
(
f
"Make sure that {self.__class__} has defined a class name `config_name`"
)
kwargs
[
"_class_name"
]
=
self
.
__class__
.
__name__
kwargs
[
"_diffusers_version"
]
=
__version__
for
key
,
value
in
kwargs
.
items
():
try
:
setattr
(
self
,
key
,
value
)
except
AttributeError
as
err
:
logger
.
error
(
f
"Can't set {key} with value {value} for {self}"
)
raise
err
if
not
hasattr
(
self
,
"_internal_dict"
):
internal_dict
=
kwargs
else
:
previous_dict
=
dict
(
self
.
_internal_dict
)
internal_dict
=
{**
self
.
_internal_dict
,
**
kwargs
}
logger
.
debug
(
f
"Updating config from {previous_dict} to {internal_dict}"
)
self
.
_internal_dict
=
FrozenDict
(
internal_dict
)
def
save_config
(
self
,
save_directory
:
Union
[
str
,
os
.
PathLike
],
push_to_hub
:
bool
=
False
,
**
kwargs
):
"""
Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
[`~ConfigMixin.from_config`] class method.
Args:
save_directory (`str` or `os.PathLike`):
Directory where the configuration JSON file will be saved (will be created if it does not exist).
kwargs:
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
if
os
.
path
.
isfile
(
save_directory
):
raise
AssertionError
(
f
"Provided path ({save_directory}) should be a directory, not a file"
)
os
.
makedirs
(
save_directory
,
exist_ok
=
True
)
#
If
we
save
using
the
predefined
names
,
we
can
load
using
`
from_config
`
output_config_file
=
os
.
path
.
join
(
save_directory
,
self
.
config_name
)
self
.
to_json_file
(
output_config_file
)
logger
.
info
(
f
"ConfigMixinuration saved in {output_config_file}"
)
@
classmethod
def
from_config
(
cls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
return_unused_kwargs
=
False
,
**
kwargs
):
config_dict
=
cls
.
get_config_dict
(
pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
**
kwargs
)
init_dict
,
unused_kwargs
=
cls
.
extract_init_dict
(
config_dict
,
**
kwargs
)
model
=
cls
(**
init_dict
)
if
return_unused_kwargs
:
return
model
,
unused_kwargs
else
:
return
model
@
classmethod
def
get_config_dict
(
cls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
**
kwargs
)
->
Tuple
[
Dict
[
str
,
Any
],
Dict
[
str
,
Any
]]:
cache_dir
=
kwargs
.
pop
(
"cache_dir"
,
DIFFUSERS_CACHE
)
force_download
=
kwargs
.
pop
(
"force_download"
,
False
)
resume_download
=
kwargs
.
pop
(
"resume_download"
,
False
)
proxies
=
kwargs
.
pop
(
"proxies"
,
None
)
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
local_files_only
=
kwargs
.
pop
(
"local_files_only"
,
False
)
revision
=
kwargs
.
pop
(
"revision"
,
None
)
user_agent
=
{
"file_type"
:
"config"
}
pretrained_model_name_or_path
=
str
(
pretrained_model_name_or_path
)
if
cls
.
config_name
is
None
:
raise
ValueError
(
"`self.config_name` is not defined. Note that one should not load a config from "
"`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
)
if
os
.
path
.
isfile
(
pretrained_model_name_or_path
):
config_file
=
pretrained_model_name_or_path
elif
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
if
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_model_name_or_path
,
cls
.
config_name
)):
#
Load
from
a
PyTorch
checkpoint
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
cls
.
config_name
)
else
:
raise
EnvironmentError
(
f
"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
)
else
:
try
:
#
Load
from
URL
or
cache
if
already
cached
config_file
=
hf_hub_download
(
pretrained_model_name_or_path
,
filename
=
cls
.
config_name
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
,
resume_download
=
resume_download
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
user_agent
=
user_agent
,
)
except
RepositoryNotFoundError
:
raise
EnvironmentError
(
f
"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed"
" on 'https://huggingface.co/models'
\n
If this is a private repository, make sure to pass a token"
" having permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and"
" pass `use_auth_token=True`."
)
except
RevisionNotFoundError
:
raise
EnvironmentError
(
f
"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
" this model name. Check the model page at"
f
" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except
EntryNotFoundError
:
raise
EnvironmentError
(
f
"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
)
except
HTTPError
as
err
:
raise
EnvironmentError
(
"There was a specific connection error when trying to load"
f
" {pretrained_model_name_or_path}:
\n
{err}"
)
except
ValueError
:
raise
EnvironmentError
(
f
"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f
" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f
" directory containing a {cls.config_name} file.
\n
Checkout your internet connection or see how to"
" run the library in offline mode at"
" 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
)
except
EnvironmentError
:
raise
EnvironmentError
(
f
"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f
"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f
"containing a {cls.config_name} file"
)
try
:
#
Load
config
dict
config_dict
=
cls
.
_dict_from_json_file
(
config_file
)
except
(
json
.
JSONDecodeError
,
UnicodeDecodeError
):
raise
EnvironmentError
(
f
"It looks like the config file at '{config_file}' is not a valid JSON file."
)
return
config_dict
@
classmethod
def
extract_init_dict
(
cls
,
config_dict
,
**
kwargs
):
expected_keys
=
set
(
dict
(
inspect
.
signature
(
cls
.
__init__
).
parameters
).
keys
())
expected_keys
.
remove
(
"self"
)
init_dict
=
{}
for
key
in
expected_keys
:
if
key
in
kwargs
:
#
overwrite
key
init_dict
[
key
]
=
kwargs
.
pop
(
key
)
elif
key
in
config_dict
:
#
use
value
from
config
dict
init_dict
[
key
]
=
config_dict
.
pop
(
key
)
unused_kwargs
=
config_dict
.
update
(
kwargs
)
passed_keys
=
set
(
init_dict
.
keys
())
if
len
(
expected_keys
-
passed_keys
)
>
0
:
logger
.
warning
(
f
"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
)
return
init_dict
,
unused_kwargs
@
classmethod
def
_dict_from_json_file
(
cls
,
json_file
:
Union
[
str
,
os
.
PathLike
]):
with
open
(
json_file
,
"r"
,
encoding
=
"utf-8"
)
as
reader
:
text
=
reader
.
read
()
return
json
.
loads
(
text
)
def
__repr__
(
self
):
return
f
"{self.__class__.__name__} {self.to_json_string()}"
@
property
def
config
(
self
)
->
Dict
[
str
,
Any
]:
return
self
.
_internal_dict
def
to_json_string
(
self
)
->
str
:
"""
Serializes this instance to a JSON string.
Returns:
`str`: String containing all the attributes that make up this configuration instance in JSON format.
"""
import
ipdb
;
ipdb
.
set_trace
()
config_dict
=
self
.
_internal_dict
return
json
.
dumps
(
config_dict
,
indent
=
2
,
sort_keys
=
True
)
+
"
\n
"
def
to_json_file
(
self
,
json_file_path
:
Union
[
str
,
os
.
PathLike
]):
"""
Save this instance to a JSON file.
Args:
json_file_path (`str` or `os.PathLike`):
Path to the JSON file in which this configuration instance's parameters will be saved.
"""
with
open
(
json_file_path
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
writer
.
write
(
self
.
to_json_string
())
class
FrozenDict
(
OrderedDict
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(*
args
,
**
kwargs
)
for
key
,
value
in
self
.
items
():
setattr
(
self
,
key
,
value
)
self
.
__frozen
=
True
def
__delitem__
(
self
,
*
args
,
**
kwargs
):
raise
Exception
(
f
"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance."
)
def
setdefault
(
self
,
*
args
,
**
kwargs
):
raise
Exception
(
f
"You cannot use ``setdefault`` on a {self.__class__.__name__} instance."
)
def
pop
(
self
,
*
args
,
**
kwargs
):
raise
Exception
(
f
"You cannot use ``pop`` on a {self.__class__.__name__} instance."
)
def
update
(
self
,
*
args
,
**
kwargs
):
raise
Exception
(
f
"You cannot use ``update`` on a {self.__class__.__name__} instance."
)
def
__setattr__
(
self
,
name
,
value
):
if
hasattr
(
self
,
"__frozen"
)
and
self
.
__frozen
:
raise
Exception
(
f
"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance."
)
super
().
__setattr__
(
name
,
value
)
def
__setitem__
(
self
,
name
,
value
):
if
hasattr
(
self
,
"__frozen"
)
and
self
.
__frozen
:
raise
Exception
(
f
"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance."
)
super
().
__setitem__
(
name
,
value
)
src/diffusers/models/unet_grad_tts.py
View file @
77aadfee
...
...
@@ -190,7 +190,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
self
.
final_block
=
Block
(
dim
,
dim
)
self
.
final_conv
=
torch
.
nn
.
Conv2d
(
dim
,
1
,
1
)
def
forward
(
self
,
x
,
mask
,
mu
,
t
,
spk
=
None
):
def
forward
(
self
,
x
,
timesteps
,
mu
,
mask
,
spk
=
None
):
if
self
.
n_spks
>
1
:
# Get speaker embedding
spk
=
self
.
spk_emb
(
spk
)
...
...
@@ -198,7 +198,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
if
not
isinstance
(
spk
,
type
(
None
)):
s
=
self
.
spk_mlp
(
spk
)
t
=
self
.
time_pos_emb
(
t
,
scale
=
self
.
pe_scale
)
t
=
self
.
time_pos_emb
(
t
imesteps
,
scale
=
self
.
pe_scale
)
t
=
self
.
mlp
(
t
)
if
self
.
n_spks
<
2
:
...
...
src/diffusers/pipelines/pipeline_grad_tts.py
View file @
77aadfee
...
...
@@ -472,7 +472,7 @@ class GradTTS(DiffusionPipeline):
t
=
(
1.0
-
(
t
+
0.5
)
*
h
)
*
torch
.
ones
(
z
.
shape
[
0
],
dtype
=
z
.
dtype
,
device
=
z
.
device
)
time
=
t
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
residual
=
self
.
unet
(
xt
,
y_mask
,
mu_y
,
t
,
speaker_id
)
residual
=
self
.
unet
(
xt
,
t
,
mu_y
,
y_mask
,
speaker_id
)
xt
=
self
.
noise_scheduler
.
step
(
xt
,
residual
,
mu_y
,
h
,
time
)
xt
=
xt
*
y_mask
...
...
tests/test_modeling_utils.py
View file @
77aadfee
...
...
@@ -34,6 +34,8 @@ from diffusers import (
LatentDiffusion
,
PNDMScheduler
,
UNetModel
,
UNetLDMModel
,
UNetGradTTSModel
,
)
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.pipeline_utils
import
DiffusionPipeline
...
...
@@ -246,7 +248,6 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
# fmt: off
expected_output_slice
=
torch
.
tensor
([
0.2891
,
-
0.1899
,
0.2595
,
-
0.6214
,
0.0968
,
-
0.2622
,
0.4688
,
0.1311
,
0.0053
])
# fmt: on
print
(
output_slice
)
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
...
...
@@ -320,17 +321,14 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
assert
image
is
not
None
,
"Make sure output is not None"
# TODO (patil-suraj): Check why GLIDESuperResUNetModel always outputs zero
@
unittest
.
skip
(
"GLIDESuperResUNetModel always outputs zero"
)
def
test_output_pretrained
(
self
):
model
=
GLIDESuperResUNetModel
.
from_pretrained
(
"fusing/glide-super-res-dummy"
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
3
,
32
,
32
)
noise
=
torch
.
randn
(
1
,
3
,
64
,
64
)
low_res
=
torch
.
randn
(
1
,
3
,
4
,
4
)
time_step
=
torch
.
tensor
([
42
]
*
noise
.
shape
[
0
])
...
...
@@ -340,9 +338,148 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
output
,
_
=
torch
.
split
(
output
,
3
,
dim
=
1
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
0.2891
,
-
0.1899
,
0.2595
,
-
0.6214
,
0.0968
,
-
0.2622
,
0.4688
,
0.1311
,
0.0053
])
expected_output_slice
=
torch
.
tensor
([
-
22.8782
,
-
23.2652
,
-
15.3966
,
-
22.8034
,
-
23.3159
,
-
15.5640
,
-
15.3970
,
-
15.4614
,
-
10.4370
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
UNetLDMModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
UNetLDMModel
@
property
def
dummy_input
(
self
):
batch_size
=
4
num_channels
=
4
sizes
=
(
32
,
32
)
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]).
to
(
torch_device
)
return
{
"x"
:
noise
,
"timesteps"
:
time_step
}
@
property
def
get_input_shape
(
self
):
return
(
4
,
32
,
32
)
@
property
def
get_output_shape
(
self
):
return
(
4
,
32
,
32
)
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
"image_size"
:
32
,
"in_channels"
:
4
,
"out_channels"
:
4
,
"model_channels"
:
32
,
"num_res_blocks"
:
2
,
"attention_resolutions"
:
(
16
,),
"channel_mult"
:
(
1
,
2
),
"num_heads"
:
2
,
"conv_resample"
:
True
,
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
UNetLDMModel
.
from_pretrained
(
"fusing/unet-ldm-dummy"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained
(
self
):
model
=
UNetLDMModel
.
from_pretrained
(
"fusing/unet-ldm-dummy"
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
image_size
,
model
.
config
.
image_size
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
13.3258
,
-
20.1100
,
-
15.9873
,
-
17.6617
,
-
23.0596
,
-
17.9419
,
-
13.3675
,
-
16.1889
,
-
12.3800
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
UNetGradTTSModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
UNetGradTTSModel
@
property
def
dummy_input
(
self
):
batch_size
=
4
num_features
=
32
seq_len
=
16
noise
=
floats_tensor
((
batch_size
,
num_features
,
seq_len
)).
to
(
torch_device
)
condition
=
floats_tensor
((
batch_size
,
num_features
,
seq_len
)).
to
(
torch_device
)
mask
=
floats_tensor
((
batch_size
,
1
,
seq_len
)).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]
*
batch_size
).
to
(
torch_device
)
return
{
"x"
:
noise
,
"timesteps"
:
time_step
,
"mu"
:
condition
,
"mask"
:
mask
}
@
property
def
get_input_shape
(
self
):
return
(
4
,
32
,
16
)
@
property
def
get_output_shape
(
self
):
return
(
4
,
32
,
16
)
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
"dim"
:
64
,
"groups"
:
4
,
"dim_mults"
:
(
1
,
2
),
"n_feats"
:
32
,
"pe_scale"
:
1000
,
"n_spks"
:
1
,
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
UNetGradTTSModel
.
from_pretrained
(
"fusing/unet-grad-tts-dummy"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained
(
self
):
model
=
UNetGradTTSModel
.
from_pretrained
(
"fusing/unet-grad-tts-dummy"
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
num_features
=
model
.
config
.
n_feats
seq_len
=
16
noise
=
torch
.
randn
((
1
,
num_features
,
seq_len
))
condition
=
torch
.
randn
((
1
,
num_features
,
seq_len
))
mask
=
torch
.
randn
((
1
,
1
,
seq_len
))
time_step
=
torch
.
tensor
([
10
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
,
condition
,
mask
)
output_slice
=
output
[
0
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
0.0690
,
-
0.0531
,
0.0633
,
-
0.0660
,
-
0.0541
,
0.0650
,
-
0.0656
,
-
0.0555
,
0.0617
])
# fmt: on
print
(
output_slice
)
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
...
...
@@ -450,7 +587,6 @@ class PipelineTesterMixin(unittest.TestCase):
image
=
ldm
([
prompt
],
generator
=
generator
,
num_inference_steps
=
20
)
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
print
(
image_slice
.
shape
)
assert
image
.
shape
==
(
1
,
3
,
256
,
256
)
expected_slice
=
torch
.
tensor
([
0.7295
,
0.7358
,
0.7256
,
0.7435
,
0.7095
,
0.6884
,
0.7325
,
0.6921
,
0.6458
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment