Commit 9d32a265 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

save intermediate

parent 4e3f4a9e
...@@ -55,46 +55,58 @@ class DDIM(DiffusionPipeline): ...@@ -55,46 +55,58 @@ class DDIM(DiffusionPipeline):
for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps): for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
# 1. predict noise residual # 1. predict noise residual
with torch.no_grad(): with torch.no_grad():
pred_noise_t = self.unet(image, inference_step_times[t]) residual = self.unet(image, inference_step_times[t])
# 2. get actual t and t-1 # 2. predict previous mean of image x_t-1
train_step = inference_step_times[t] pred_prev_image = self.noise_scheduler.predict_prev_image_step(residual, image, t, num_inference_steps, eta)
prev_train_step = inference_step_times[t - 1] if t > 0 else -1
# 3. compute alphas, betas # 3. optionally sample variance
alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step) variance = 0
alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(prev_train_step) if eta > 0:
beta_prod_t = 1 - alpha_prod_t noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
beta_prod_t_prev = 1 - alpha_prod_t_prev variance = self.noise_scheduler.get_variance(t).sqrt() * eta * noise
# 4. set current image to prev_image: x_t -> x_t-1
image = pred_prev_image + variance
# 2. get actual t and t-1
# train_step = inference_step_times[t]
# prev_train_step = inference_step_times[t - 1] if t > 0 else -1
#
# 3. compute alphas, betas
# alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step)
# alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(prev_train_step)
# beta_prod_t = 1 - alpha_prod_t
# beta_prod_t_prev = 1 - alpha_prod_t_prev
#
# 4. Compute predicted previous image from predicted noise # 4. Compute predicted previous image from predicted noise
# First: compute predicted original image from predicted noise also called # First: compute predicted original image from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_image = (image - beta_prod_t.sqrt() * pred_noise_t) / alpha_prod_t.sqrt() # pred_original_image = (image - beta_prod_t.sqrt() * pred_noise_t) / alpha_prod_t.sqrt()
#
# Second: Clip "predicted x_0" # Second: Clip "predicted x_0"
pred_original_image = torch.clamp(pred_original_image, -1, 1) # pred_original_image = torch.clamp(pred_original_image, -1, 1)
#
# Third: Compute variance: "sigma_t(η)" -> see formula (16) # Third: Compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
std_dev_t = (beta_prod_t_prev / beta_prod_t).sqrt() * (1 - alpha_prod_t / alpha_prod_t_prev).sqrt() # std_dev_t = (beta_prod_t_prev / beta_prod_t).sqrt() * (1 - alpha_prod_t / alpha_prod_t_prev).sqrt()
std_dev_t = eta * std_dev_t # std_dev_t = eta * std_dev_t
#
# Fourth: Compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf # Fourth: Compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_image_direction = (1 - alpha_prod_t_prev - std_dev_t**2).sqrt() * pred_noise_t # pred_image_direction = (1 - alpha_prod_t_prev - std_dev_t**2).sqrt() * pred_noise_t
#
# Fifth: Compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf # Fifth: Compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_prev_image = alpha_prod_t_prev.sqrt() * pred_original_image + pred_image_direction # pred_prev_image = alpha_prod_t_prev.sqrt() * pred_original_image + pred_image_direction
#
# 5. Sample x_t-1 image optionally if η > 0.0 by adding noise to pred_prev_image # 5. Sample x_t-1 image optionally if η > 0.0 by adding noise to pred_prev_image
# Note: eta = 1.0 essentially corresponds to DDPM # Note: eta = 1.0 essentially corresponds to DDPM
if eta > 0.0: # if eta > 0.0:
noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator) # noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
prev_image = pred_prev_image + std_dev_t * noise # prev_image = pred_prev_image + std_dev_t * noise
else: # else:
prev_image = pred_prev_image # prev_image = pred_prev_image
#
# 6. Set current image to prev_image: x_t -> x_t-1 # 6. Set current image to prev_image: x_t -> x_t-1
image = prev_image # image = prev_image
return image return image
...@@ -38,50 +38,23 @@ class DDPM(DiffusionPipeline): ...@@ -38,50 +38,23 @@ class DDPM(DiffusionPipeline):
generator=generator, generator=generator,
) )
for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)): num_prediction_steps = len(self.noise_scheduler)
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
# 1. predict noise residual # 1. predict noise residual
with torch.no_grad(): with torch.no_grad():
pred_noise_t = self.unet(image, t) residual = self.unet(image, t)
# 2. compute alphas, betas
alpha_prod_t = self.noise_scheduler.get_alpha_prod(t)
alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(t - 1)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
# 3. compute predicted image from residual
# First: compute predicted original image from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_image = (image - beta_prod_t.sqrt() * pred_noise_t) / alpha_prod_t.sqrt()
# Second: Clip "predicted x_0"
pred_original_image = torch.clamp(pred_original_image, -1, 1)
# Third: Compute coefficients for pred_original_image x_0 and current image x_t # 2. predict previous mean of image x_t-1
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_prev_image = self.noise_scheduler.predict_prev_image_step(residual, image, t)
pred_original_image_coeff = (alpha_prod_t_prev.sqrt() * self.noise_scheduler.get_beta(t)) / beta_prod_t
current_image_coeff = self.noise_scheduler.get_alpha(t).sqrt() * beta_prod_t_prev / beta_prod_t
# Fourth: Compute predicted previous image µ_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_prev_image = pred_original_image_coeff * pred_original_image + current_image_coeff * image
# 5. For t > 0, compute predicted variance βt (see formala (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) # 3. optionally sample variance
# and sample from it to get previous image variance = 0
# x_{t-1} ~ N(pred_prev_image, variance) == add variane to pred_image
if t > 0: if t > 0:
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.noise_scheduler.get_beta(t).sqrt()
# TODO(PVP):
# This variance seems to be good enough for inference - check if those `fix_small`, `fix_large`
# are really only needed for training or also for inference
# Also note LDM only uses "fixed_small";
# glide seems to use a weird mix of the two: https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246
noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator) noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
sampled_variance = variance * noise variance = self.noise_scheduler.get_variance(t) * noise
prev_image = pred_prev_image + sampled_variance
else:
prev_image = pred_prev_image
# 6. Set current image to prev_image: x_t -> x_t-1 # 4. set current image to prev_image: x_t -> x_t-1
image = prev_image image = pred_prev_image + variance
return image return image
...@@ -11,4 +11,5 @@ from .models.unet_ldm import UNetLDMModel ...@@ -11,4 +11,5 @@ from .models.unet_ldm import UNetLDMModel
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
from .schedulers.ddim import DDIMScheduler
from .schedulers.glide_ddim import GlideDDIMScheduler from .schedulers.glide_ddim import GlideDDIMScheduler
...@@ -18,4 +18,5 @@ ...@@ -18,4 +18,5 @@
from .classifier_free_guidance import ClassifierFreeGuidanceScheduler from .classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .gaussian_ddpm import GaussianDDPMScheduler from .gaussian_ddpm import GaussianDDPMScheduler
from .ddim import DDIMScheduler
from .glide_ddim import GlideDDIMScheduler from .glide_ddim import GlideDDIMScheduler
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
from torch import nn
from ..configuration_utils import ConfigMixin
from .schedulers_utils import betas_for_alpha_bar, linear_beta_schedule
SAMPLING_CONFIG_NAME = "scheduler_config.json"
class DDIMScheduler(nn.Module, ConfigMixin):
config_name = SAMPLING_CONFIG_NAME
def __init__(
self,
timesteps=1000,
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear",
clip_predicted_image=True,
):
super().__init__()
self.register(
timesteps=timesteps,
beta_start=beta_start,
beta_end=beta_end,
beta_schedule=beta_schedule,
)
self.num_timesteps = int(timesteps)
self.clip_image = clip_predicted_image
if beta_schedule == "linear":
betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
elif beta_schedule == "squaredcos_cap_v2":
# GLIDE cosine schedule
betas = betas_for_alpha_bar(
timesteps,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
self.register_buffer("betas", betas.to(torch.float32))
self.register_buffer("alphas", alphas.to(torch.float32))
self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32))
# alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
# TODO(PVP) - check how much of these is actually necessary!
# LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ...
# https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246
# variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
# if variance_type == "fixed_small":
# log_variance = torch.log(variance.clamp(min=1e-20))
# elif variance_type == "fixed_large":
# log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0))
#
#
# self.register_buffer("log_variance", log_variance.to(torch.float32))
def get_alpha(self, time_step):
return self.alphas[time_step]
def get_beta(self, time_step):
return self.betas[time_step]
def get_alpha_prod(self, time_step):
if time_step < 0:
return torch.tensor(1.0)
return self.alphas_cumprod[time_step]
def get_variance(self, t, num_inference_steps):
orig_t = (self.num_timesteps // num_inference_steps) * t
orig_prev_t = (self.num_timesteps // num_inference_steps) * (t - 1) if t > 0 else -1
alpha_prod_t = self.get_alpha_prod(orig_t)
alpha_prod_t_prev = self.get_alpha_prod(orig_prev_t)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
return variance
def predict_prev_image_step(self, residual, image, t, num_inference_steps, eta, output_pred_x_0=False):
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
# 1. get actual t and t-1
orig_t = (self.num_timesteps // num_inference_steps) * t
orig_prev_t = (self.num_timesteps // num_inference_steps) * (t - 1) if t > 0 else -1
# train_step = inference_step_times[t]
# prev_train_step = inference_step_times[t - 1] if t > 0 else -1
# 2. compute alphas, betas
alpha_prod_t = self.get_alpha_prod(orig_t)
alpha_prod_t_prev = self.get_alpha_prod(orig_prev_t)
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original image from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_image = (image - beta_prod_t.sqrt() * residual) / alpha_prod_t.sqrt()
# 4. Clip "predicted x_0"
if self.clip_image:
pred_original_image = torch.clamp(pred_original_image, -1, 1)
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
variance = self.get_variance(t, num_inference_steps)
std_dev_t = eta * variance.sqrt()
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_image_direction = (1 - alpha_prod_t_prev - std_dev_t**2).sqrt() * residual
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_prev_image = alpha_prod_t_prev.sqrt() * pred_original_image + pred_image_direction
return pred_prev_image
def sample_noise(self, shape, device, generator=None):
# always sample on CPU to be deterministic
return torch.randn(shape, generator=generator).to(device)
def __len__(self):
return self.num_timesteps
...@@ -34,6 +34,7 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin): ...@@ -34,6 +34,7 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
beta_end=0.02, beta_end=0.02,
beta_schedule="linear", beta_schedule="linear",
variance_type="fixed_small", variance_type="fixed_small",
clip_predicted_image=True,
): ):
super().__init__() super().__init__()
self.register( self.register(
...@@ -42,8 +43,10 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin): ...@@ -42,8 +43,10 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
beta_end=beta_end, beta_end=beta_end,
beta_schedule=beta_schedule, beta_schedule=beta_schedule,
variance_type=variance_type, variance_type=variance_type,
clip_predicted_image=clip_predicted_image,
) )
self.num_timesteps = int(timesteps) self.num_timesteps = int(timesteps)
self.clip_image = clip_predicted_image
if beta_schedule == "linear": if beta_schedule == "linear":
betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end) betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
...@@ -58,23 +61,23 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin): ...@@ -58,23 +61,23 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
alphas = 1.0 - betas alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, axis=0) alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
# TODO(PVP) - check how much of these is actually necessary!
# LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ...
# https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246
if variance_type == "fixed_small":
log_variance = torch.log(variance.clamp(min=1e-20))
elif variance_type == "fixed_large":
log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0))
self.register_buffer("betas", betas.to(torch.float32)) self.register_buffer("betas", betas.to(torch.float32))
self.register_buffer("alphas", alphas.to(torch.float32)) self.register_buffer("alphas", alphas.to(torch.float32))
self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32)) self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32))
self.register_buffer("log_variance", log_variance.to(torch.float32)) # alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
# TODO(PVP) - check how much of these is actually necessary!
# LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ...
# https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246
# variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
# if variance_type == "fixed_small":
# log_variance = torch.log(variance.clamp(min=1e-20))
# elif variance_type == "fixed_large":
# log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0))
#
#
# self.register_buffer("log_variance", log_variance.to(torch.float32))
def get_alpha(self, time_step): def get_alpha(self, time_step):
return self.alphas[time_step] return self.alphas[time_step]
...@@ -87,6 +90,43 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin): ...@@ -87,6 +90,43 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
return torch.tensor(1.0) return torch.tensor(1.0)
return self.alphas_cumprod[time_step] return self.alphas_cumprod[time_step]
def get_variance(self, t):
alpha_prod_t = self.get_alpha_prod(t)
alpha_prod_t_prev = self.get_alpha_prod(t - 1)
# For t > 0, compute predicted variance βt (see formala (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
# and sample from it to get previous image
# x_{t-1} ~ N(pred_prev_image, variance) == add variane to pred_image
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.get_beta(t).sqrt()
return variance
def predict_prev_image_step(self, residual, image, t, output_pred_x_0=False):
# 1. compute alphas, betas
alpha_prod_t = self.get_alpha_prod(t)
alpha_prod_t_prev = self.get_alpha_prod(t - 1)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
# 2. compute predicted original image from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_image = (image - beta_prod_t.sqrt() * residual) / alpha_prod_t.sqrt()
# 3. Clip "predicted x_0"
if self.clip_predicted_image:
pred_original_image = torch.clamp(pred_original_image, -1, 1)
# 4. Compute coefficients for pred_original_image x_0 and current image x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_image_coeff = (alpha_prod_t_prev.sqrt() * self.get_beta(t)) / beta_prod_t
current_image_coeff = self.get_alpha(t).sqrt() * beta_prod_t_prev / beta_prod_t
# 5. Compute predicted previous image µ_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_prev_image = pred_original_image_coeff * pred_original_image + current_image_coeff * image
return pred_prev_image
def sample_noise(self, shape, device, generator=None): def sample_noise(self, shape, device, generator=None):
# always sample on CPU to be deterministic # always sample on CPU to be deterministic
return torch.randn(shape, generator=generator).to(device) return torch.randn(shape, generator=generator).to(device)
......
...@@ -22,7 +22,7 @@ from distutils.util import strtobool ...@@ -22,7 +22,7 @@ from distutils.util import strtobool
import torch import torch
from diffusers import GaussianDDPMScheduler, UNetModel from diffusers import GaussianDDPMScheduler, UNetModel, DDIMScheduler
from diffusers.configuration_utils import ConfigMixin from diffusers.configuration_utils import ConfigMixin
from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipeline_utils import DiffusionPipeline
from models.vision.ddim.modeling_ddim import DDIM from models.vision.ddim.modeling_ddim import DDIM
...@@ -304,7 +304,10 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -304,7 +304,10 @@ class PipelineTesterMixin(unittest.TestCase):
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
model_id = "fusing/ddpm-cifar10" model_id = "fusing/ddpm-cifar10"
ddpm = DDPM.from_pretrained(model_id) unet = UNetModel.from_pretrained(model_id)
noise_scheduler = GaussianDDPMScheduler.from_config(model_id)
ddpm = DDPM(unet=unet, noise_scheduler=noise_scheduler)
image = ddpm(generator=generator) image = ddpm(generator=generator)
image_slice = image[0, -1, -3:, -3:].cpu() image_slice = image[0, -1, -3:, -3:].cpu()
...@@ -318,7 +321,10 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -318,7 +321,10 @@ class PipelineTesterMixin(unittest.TestCase):
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
model_id = "fusing/ddpm-cifar10" model_id = "fusing/ddpm-cifar10"
ddim = DDIM.from_pretrained(model_id) unet = UNetModel.from_pretrained(model_id)
noise_scheduler = DDIMScheduler()
ddim = DDIM(unet=unet, noise_scheduler=noise_scheduler)
image = ddim(generator=generator, eta=0.0) image = ddim(generator=generator, eta=0.0)
image_slice = image[0, -1, -3:, -3:].cpu() image_slice = image[0, -1, -3:, -3:].cpu()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment