Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
c7a39d38
Commit
c7a39d38
authored
Jun 27, 2022
by
Patrick von Platen
Browse files
refactor all sinus embeddings
parent
02a76c2c
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
117 additions
and
204 deletions
+117
-204
src/diffusers/models/embeddings.py
src/diffusers/models/embeddings.py
+25
-94
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+0
-21
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+9
-24
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+3
-20
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+3
-22
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+0
-17
tests/test_layers_utils.py
tests/test_layers_utils.py
+77
-6
No files found.
src/diffusers/models/embeddings.py
View file @
c7a39d38
...
...
@@ -11,15 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
math
import
numpy
as
np
import
numpy
as
np
import
torch
from
torch
import
nn
import
torch.nn.functional
as
F
def
get_timestep_embedding
(
timesteps
,
embedding_dim
,
flip_sin_to_cos
=
False
,
downscale_freq_shift
=
1
,
max_period
=
10000
):
def
get_timestep_embedding
(
timesteps
,
embedding_dim
,
flip_sin_to_cos
=
False
,
downscale_freq_shift
=
1
,
scale
=
1
,
max_period
=
10000
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
...
...
@@ -31,16 +32,20 @@ def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, down
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
assert
len
(
timesteps
.
shape
)
==
1
assert
len
(
timesteps
.
shape
)
==
1
,
"Timesteps should be a 1d-array"
half_dim
=
embedding_dim
//
2
emb
=
torch
.
exp
(
-
math
.
log
(
max_period
)
*
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float32
)
/
(
embedding_dim
//
2
-
downscale_freq_shift
))
emb
=
emb
.
to
(
device
=
timesteps
.
device
)
emb_coeff
=
-
math
.
log
(
max_period
)
/
(
half_dim
-
downscale_freq_shift
)
emb
=
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float32
,
device
=
timesteps
.
device
)
emb
=
torch
.
exp
(
emb
*
emb_coeff
)
emb
=
timesteps
[:,
None
].
float
()
*
emb
[
None
,
:]
# scale embeddings
emb
=
scale
*
emb
# concat sine and cosine embeddings
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
1
)
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
-
1
)
# flip sine and cosine embeddings
if
flip_sin_to_cos
:
...
...
@@ -52,81 +57,20 @@ def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, down
return
emb
#def get_timestep_embedding(timesteps, embedding_dim):
# """
# This matches the implementation in Denoising Diffusion Probabilistic Models:
# From Fairseq.
# Build sinusoidal embeddings.
# This matches the implementation in tensor2tensor, but differs slightly
# from the description in Section 3.5 of "Attention Is All You Need".
# """
# assert len(timesteps.shape) == 1
#
# half_dim = embedding_dim // 2
# emb = math.log(10000) / (half_dim - 1)
# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
# emb = emb.to(device=timesteps.device)
# emb = timesteps.float()[:, None] * emb[None, :]
# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
# if embedding_dim % 2 == 1: # zero pad
# emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
#def timestep_embedding(timesteps, dim, max_period=10000):
# """
# Create sinusoidal timestep embeddings.
#
# :param timesteps: a 1-D Tensor of N indices, one per batch element.
# These may be fractional.
# :param dim: the dimension of the output.
# :param max_period: controls the minimum frequency of the embeddings.
# :return: an [N x dim] Tensor of positional embeddings.
# """
# half = dim // 2
# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
# device=timesteps.device
# )
# args = timesteps[:, None].float() * freqs[None, :]
# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
# if dim % 2:
# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
# return embedding
#def a_get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
# assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
# half_dim = embedding_dim // 2
# magic number 10000 is from transformers
# emb = math.log(max_positions) / (half_dim - 1)
# emb = math.log(2.) / (half_dim - 1)
# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
# emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = timesteps.float()[:, None] * emb[None, :]
# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
# if embedding_dim % 2 == 1: # zero pad
# emb = F.pad(emb, (0, 1), mode="constant")
# assert emb.shape == (timesteps.shape[0], embedding_dim)
# return emb
# unet_grad_tts.py
class
SinusoidalPosEmb
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
(
SinusoidalPosEmb
,
self
).
__init__
()
self
.
dim
=
dim
# unet_sde_score_estimation.py
class
GaussianFourierProjection
(
nn
.
Module
):
"""Gaussian Fourier embeddings for noise levels."""
def
forward
(
self
,
x
,
scale
=
1000
):
device
=
x
.
device
half_dim
=
self
.
dim
//
2
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
device
=
device
).
float
()
*
-
emb
)
emb
=
scale
*
x
.
unsqueeze
(
1
)
*
emb
.
unsqueeze
(
0
)
emb
=
torch
.
cat
((
emb
.
sin
(),
emb
.
cos
()),
dim
=-
1
)
return
emb
def
__init__
(
self
,
embedding_size
=
256
,
scale
=
1.0
):
super
().
__init__
()
self
.
W
=
nn
.
Parameter
(
torch
.
randn
(
embedding_size
)
*
scale
,
requires_grad
=
False
)
def
forward
(
self
,
x
):
x_proj
=
x
[:,
None
]
*
self
.
W
[
None
,
:]
*
2
*
np
.
pi
return
torch
.
cat
([
torch
.
sin
(
x_proj
),
torch
.
cos
(
x_proj
)],
dim
=-
1
)
# unet_rl.py
# unet_rl.py
- TODO(need test)
class
SinusoidalPosEmb
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
...
...
@@ -140,16 +84,3 @@ class SinusoidalPosEmb(nn.Module):
emb
=
x
[:,
None
]
*
emb
[
None
,
:]
emb
=
torch
.
cat
((
emb
.
sin
(),
emb
.
cos
()),
dim
=-
1
)
return
emb
# unet_sde_score_estimation.py
class
GaussianFourierProjection
(
nn
.
Module
):
"""Gaussian Fourier embeddings for noise levels."""
def
__init__
(
self
,
embedding_size
=
256
,
scale
=
1.0
):
super
().
__init__
()
self
.
W
=
nn
.
Parameter
(
torch
.
randn
(
embedding_size
)
*
scale
,
requires_grad
=
False
)
def
forward
(
self
,
x
):
x_proj
=
x
[:,
None
]
*
self
.
W
[
None
,
:]
*
2
*
np
.
pi
return
torch
.
cat
([
torch
.
sin
(
x_proj
),
torch
.
cos
(
x_proj
)],
dim
=-
1
)
src/diffusers/models/unet.py
View file @
c7a39d38
...
...
@@ -33,27 +33,6 @@ from ..modeling_utils import ModelMixin
from
.embeddings
import
get_timestep_embedding
#def get_timestep_embedding(timesteps, embedding_dim):
# """
# This matches the implementation in Denoising Diffusion Probabilistic Models:
# From Fairseq.
# Build sinusoidal embeddings.
# This matches the implementation in tensor2tensor, but differs slightly
# from the description in Section 3.5 of "Attention Is All You Need".
# """
# assert len(timesteps.shape) == 1
#
# half_dim = embedding_dim // 2
# emb = math.log(10000) / (half_dim - 1)
# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
# emb = emb.to(device=timesteps.device)
# emb = timesteps.float()[:, None] * emb[None, :]
# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
# if embedding_dim % 2 == 1: # zero pad
# emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
# return emb
def
nonlinearity
(
x
):
# swish
return
x
*
torch
.
sigmoid
(
x
)
...
...
src/diffusers/models/unet_glide.py
View file @
c7a39d38
...
...
@@ -87,27 +87,6 @@ def normalization(channels, swish=0.0):
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
)
# def timestep_embedding(timesteps, dim, max_period=10000):
# """
# Create sinusoidal timestep embeddings.
#
# :param timesteps: a 1-D Tensor of N indices, one per batch element.
# These may be fractional.
# :param dim: the dimension of the output.
# :param max_period: controls the minimum frequency of the embeddings.
# :return: an [N x dim] Tensor of positional embeddings.
# """
# half = dim // 2
# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
# device=timesteps.device
# )
# args = timesteps[:, None].float() * freqs[None]
# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
# if dim % 2:
# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
# return embedding
def
zero_module
(
module
):
"""
Zero out the parameters of a module and return it.
...
...
@@ -628,7 +607,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
"""
hs
=
[]
emb
=
self
.
time_embed
(
get_timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
))
emb
=
self
.
time_embed
(
get_timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
)
)
h
=
x
.
type
(
self
.
dtype
)
for
module
in
self
.
input_blocks
:
...
...
@@ -715,7 +696,9 @@ class GlideTextToImageUNetModel(GlideUNetModel):
def
forward
(
self
,
x
,
timesteps
,
transformer_out
=
None
):
hs
=
[]
emb
=
self
.
time_embed
(
get_timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
))
emb
=
self
.
time_embed
(
get_timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
)
)
# project the last token
transformer_proj
=
self
.
transformer_proj
(
transformer_out
[:,
-
1
])
...
...
@@ -807,7 +790,9 @@ class GlideSuperResUNetModel(GlideUNetModel):
x
=
torch
.
cat
([
x
,
upsampled
],
dim
=
1
)
hs
=
[]
emb
=
self
.
time_embed
(
get_timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
))
emb
=
self
.
time_embed
(
get_timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
)
)
h
=
x
for
module
in
self
.
input_blocks
:
...
...
src/diffusers/models/unet_grad_tts.py
View file @
c7a39d38
import
math
import
torch
...
...
@@ -11,6 +9,7 @@ except:
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
class
Mish
(
torch
.
nn
.
Module
):
...
...
@@ -107,21 +106,6 @@ class Residual(torch.nn.Module):
return
output
class
SinusoidalPosEmb
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
(
SinusoidalPosEmb
,
self
).
__init__
()
self
.
dim
=
dim
def
forward
(
self
,
x
,
scale
=
1000
):
device
=
x
.
device
half_dim
=
self
.
dim
//
2
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
device
=
device
).
float
()
*
-
emb
)
emb
=
scale
*
x
.
unsqueeze
(
1
)
*
emb
.
unsqueeze
(
0
)
emb
=
torch
.
cat
((
emb
.
sin
(),
emb
.
cos
()),
dim
=-
1
)
return
emb
class
UNetGradTTSModel
(
ModelMixin
,
ConfigMixin
):
def
__init__
(
self
,
dim
,
dim_mults
=
(
1
,
2
,
4
),
groups
=
8
,
n_spks
=
None
,
spk_emb_dim
=
64
,
n_feats
=
80
,
pe_scale
=
1000
):
super
(
UNetGradTTSModel
,
self
).
__init__
()
...
...
@@ -149,7 +133,6 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
torch
.
nn
.
Linear
(
spk_emb_dim
,
spk_emb_dim
*
4
),
Mish
(),
torch
.
nn
.
Linear
(
spk_emb_dim
*
4
,
n_feats
)
)
self
.
time_pos_emb
=
SinusoidalPosEmb
(
dim
)
self
.
mlp
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
dim
,
dim
*
4
),
Mish
(),
torch
.
nn
.
Linear
(
dim
*
4
,
dim
))
dims
=
[
2
+
(
1
if
n_spks
>
1
else
0
),
*
map
(
lambda
m
:
dim
*
m
,
dim_mults
)]
...
...
@@ -198,8 +181,8 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
if
not
isinstance
(
spk
,
type
(
None
)):
s
=
self
.
spk_mlp
(
spk
)
t
=
get_timestep_embedding
(
timesteps
,
self
.
dim
,
scale
=
self
.
pe_scale
)
t
=
self
.
time_pos_emb
(
timesteps
,
scale
=
self
.
pe_scale
)
t
=
self
.
mlp
(
t
)
if
self
.
n_spks
<
2
:
...
...
src/diffusers/models/unet_ldm.py
View file @
c7a39d38
...
...
@@ -317,27 +317,6 @@ def normalization(channels, swish=0.0):
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
)
#def timestep_embedding(timesteps, dim, max_period=10000):
# """
# Create sinusoidal timestep embeddings.
#
# :param timesteps: a 1-D Tensor of N indices, one per batch element.
# These may be fractional.
# :param dim: the dimension of the output.
# :param max_period: controls the minimum frequency of the embeddings.
# :return: an [N x dim] Tensor of positional embeddings.
# """
# half = dim // 2
# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
# device=timesteps.device
# )
# args = timesteps[:, None].float() * freqs[None]
# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
# if dim % 2:
# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
# return embedding
## go
class
AttentionPool2d
(
nn
.
Module
):
"""
...
...
@@ -1232,7 +1211,9 @@ class EncoderUNetModel(nn.Module):
:param timesteps: a 1-D batch of timesteps.
:return: an [N x K] Tensor of outputs.
"""
emb
=
self
.
time_embed
(
get_timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
))
emb
=
self
.
time_embed
(
get_timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
)
)
results
=
[]
h
=
x
.
type
(
self
.
dtype
)
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
c7a39d38
...
...
@@ -382,23 +382,6 @@ def get_act(nonlinearity):
raise
NotImplementedError
(
"activation function does not exist!"
)
#def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
# assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
# half_dim = embedding_dim // 2
# magic number 10000 is from transformers
# emb = math.log(max_positions) / (half_dim - 1)
# emb = math.log(2.) / (half_dim - 1)
# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
# emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = timesteps.float()[:, None] * emb[None, :]
# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
# if embedding_dim % 2 == 1: # zero pad
# emb = F.pad(emb, (0, 1), mode="constant")
# assert emb.shape == (timesteps.shape[0], embedding_dim)
# return emb
def
default_init
(
scale
=
1.0
):
"""The same initialization used in DDPM."""
scale
=
1e-10
if
scale
==
0
else
scale
...
...
tests/test_layers_utils.py
View file @
c7a39d38
...
...
@@ -21,8 +21,7 @@ import unittest
import
numpy
as
np
import
torch
#from diffusers.models.embeddings import get_timestep_embedding, timestep_embedding, a_get_timestep_embedding
from
diffusers.models.embeddings
import
get_timestep_embedding
,
timestep_embedding
from
diffusers.models.embeddings
import
get_timestep_embedding
from
diffusers.testing_utils
import
floats_tensor
,
slow
,
torch_device
...
...
@@ -30,15 +29,87 @@ torch.backends.cuda.matmul.allow_tf32 = False
class
EmbeddingsTests
(
unittest
.
TestCase
):
def
test_timestep_embeddings
(
self
):
embedding_dim
=
256
timesteps
=
torch
.
arange
(
16
)
t1
=
get_timestep_embedding
(
timesteps
,
embedding_dim
)
# first vector should always be composed only of 0's and 1's
assert
(
t1
[
0
,
:
embedding_dim
//
2
]
-
0
).
abs
().
sum
()
<
1e-5
assert
(
t1
[
0
,
embedding_dim
//
2
:]
-
1
).
abs
().
sum
()
<
1e-5
# last element of each vector should be one
assert
(
t1
[:,
-
1
]
-
1
).
abs
().
sum
()
<
1e-5
# For large embeddings (e.g. 128) the frequency of every vector is higher
# than the previous one which means that the gradients of later vectors are
# ALWAYS higher than the previous ones
grad_mean
=
np
.
abs
(
np
.
gradient
(
t1
,
axis
=-
1
)).
mean
(
axis
=
1
)
prev_grad
=
0.0
for
grad
in
grad_mean
:
assert
grad
>
prev_grad
prev_grad
=
grad
def
test_timestep_defaults
(
self
):
embedding_dim
=
16
timesteps
=
torch
.
arange
(
10
)
t1
=
get_timestep_embedding
(
timesteps
,
embedding_dim
)
t2
=
timestep_embedding
(
timesteps
,
embedding_dim
)
t3
=
get_timestep_embedding
(
timesteps
,
embedding_dim
,
flip_sin_to_cos
=
True
,
downscale_freq_factor
=
8
)
t2
=
get_timestep_embedding
(
timesteps
,
embedding_dim
,
flip_sin_to_cos
=
False
,
downscale_freq_shift
=
1
,
max_period
=
10_000
)
assert
torch
.
allclose
(
t1
.
cpu
(),
t2
.
cpu
(),
1e-3
)
def
test_timestep_flip_sin_cos
(
self
):
embedding_dim
=
16
timesteps
=
torch
.
arange
(
10
)
t1
=
get_timestep_embedding
(
timesteps
,
embedding_dim
,
flip_sin_to_cos
=
True
)
t1
=
torch
.
cat
([
t1
[:,
embedding_dim
//
2
:],
t1
[:,
:
embedding_dim
//
2
]],
dim
=-
1
)
t2
=
get_timestep_embedding
(
timesteps
,
embedding_dim
,
flip_sin_to_cos
=
False
)
assert
torch
.
allclose
(
t1
.
cpu
(),
t2
.
cpu
(),
1e-3
)
def
test_timestep_downscale_freq_shift
(
self
):
embedding_dim
=
16
timesteps
=
torch
.
arange
(
10
)
t1
=
get_timestep_embedding
(
timesteps
,
embedding_dim
,
downscale_freq_shift
=
0
)
t2
=
get_timestep_embedding
(
timesteps
,
embedding_dim
,
downscale_freq_shift
=
1
)
# get cosine half (vectors that are wrapped into cosine)
cosine_half
=
(
t1
-
t2
)[:,
embedding_dim
//
2
:]
# cosine needs to be negative
assert
(
np
.
abs
((
cosine_half
<=
0
).
numpy
())
-
1
).
sum
()
<
1e-5
import
ipdb
;
ipdb
.
set_trace
()
def
test_sinoid_embeddings_hardcoded
(
self
):
embedding_dim
=
64
timesteps
=
torch
.
arange
(
128
)
# standard unet, score_vde
t1
=
get_timestep_embedding
(
timesteps
,
embedding_dim
,
downscale_freq_shift
=
1
,
flip_sin_to_cos
=
False
)
# glide, ldm
t2
=
get_timestep_embedding
(
timesteps
,
embedding_dim
,
downscale_freq_shift
=
0
,
flip_sin_to_cos
=
True
)
# grad-tts
t3
=
get_timestep_embedding
(
timesteps
,
embedding_dim
,
scale
=
1000
)
assert
torch
.
allclose
(
t1
[
23
:
26
,
47
:
50
].
flatten
().
cpu
(),
torch
.
tensor
([
0.9646
,
0.9804
,
0.9892
,
0.9615
,
0.9787
,
0.9882
,
0.9582
,
0.9769
,
0.9872
]),
1e-3
,
)
assert
torch
.
allclose
(
t2
[
23
:
26
,
47
:
50
].
flatten
().
cpu
(),
torch
.
tensor
([
0.3019
,
0.2280
,
0.1716
,
0.3146
,
0.2377
,
0.1790
,
0.3272
,
0.2474
,
0.1864
]),
1e-3
,
)
assert
torch
.
allclose
(
t3
[
23
:
26
,
47
:
50
].
flatten
().
cpu
(),
torch
.
tensor
([
-
0.9801
,
-
0.9464
,
-
0.9349
,
-
0.3952
,
0.8887
,
-
0.9709
,
0.5299
,
-
0.2853
,
-
0.9927
]),
1e-3
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment