"examples/production_monitoring/grafana.json" did not exist on "d0d93b92b190f420e2628350ec69921bede691d4"
Commit e2364931 authored by mashun1's avatar mashun1
Browse files

pixart-alpha

parents
Pipeline #861 canceled with stages
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: check https://arxiv.org/abs/2309.05019
# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
import math
from typing import List, Optional, Tuple, Union, Callable
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils.torch_utils import randn_tensor
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
Choose from `cosine` or `exp`
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
if alpha_transform_type == "cosine":
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
class SASolverScheduler(SchedulerMixin, ConfigMixin):
"""
`SASolverScheduler` is a fast dedicated high-order solver for diffusion SDEs.
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
beta_start (`float`, defaults to 0.0001):
The starting `beta` value of inference.
beta_end (`float`, defaults to 0.02):
The final `beta` value.
beta_schedule (`str`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
predictor_order (`int`, defaults to 2):
The predictor order which can be `1` or `2` or `3` or '4'. It is recommended to use `predictor_order=2` for guided
sampling, and `predictor_order=3` for unconditional sampling.
corrector_order (`int`, defaults to 2):
The corrector order which can be `1` or `2` or `3` or '4'. It is recommended to use `corrector_order=2` for guided
sampling, and `corrector_order=3` for unconditional sampling.
predictor_corrector_mode (`str`, defaults to `PEC`):
The predictor-corrector mode can be `PEC` or 'PECE'. It is recommended to use `PEC` mode for fast
sampling, and `PECE` for high-quality sampling (PECE needs around twice model evaluations as PEC).
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
Video](https://imagen.research.google/video/paper.pdf) paper).
thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion.
dynamic_thresholding_ratio (`float`, defaults to 0.995):
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
sample_max_value (`float`, defaults to 1.0):
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
`algorithm_type="dpmsolver++"`.
algorithm_type (`str`, defaults to `data_prediction`):
Algorithm type for the solver; can be `data_prediction` or `noise_prediction`. It is recommended to use `data_prediction`
with `solver_order=2` for guided sampling like in Stable Diffusion.
lower_order_final (`bool`, defaults to `True`):
Whether to use lower-order solvers in the final steps. Default = True.
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
lambda_min_clipped (`float`, defaults to `-inf`):
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
cosine (`squaredcos_cap_v2`) noise schedule.
variance_type (`str`, *optional*):
Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
contains the predicted Gaussian variance.
timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
steps_offset (`int`, defaults to 0):
An offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
Diffusion.
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
predictor_order: int = 2,
corrector_order: int = 2,
predictor_corrector_mode: str = 'PEC',
prediction_type: str = "epsilon",
tau_func: Callable = lambda t: 1 if t >= 200 and t <= 800 else 0,
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
algorithm_type: str = "data_prediction",
lower_order_final: bool = True,
use_karras_sigmas: Optional[bool] = False,
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = (
torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_train_timesteps, dtype=torch.float32) ** 2
)
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# Currently we only support VP-type noise schedule
self.alpha_t = torch.sqrt(self.alphas_cumprod)
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
if algorithm_type not in ["data_prediction", "noise_prediction"]:
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
# setable values
self.num_inference_steps = None
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
self.timesteps = torch.from_numpy(timesteps)
self.timestep_list = [None] * max(predictor_order, corrector_order - 1)
self.model_outputs = [None] * max(predictor_order, corrector_order - 1)
self.tau_func = tau_func
self.predict_x0 = algorithm_type == "data_prediction"
self.lower_order_nums = 0
self.last_sample = None
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
# Clipping the minimum of all lambda(t) for numerical stability.
# This is critical for cosine (squaredcos_cap_v2) noise schedule.
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item()
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64)
)
elif self.config.timestep_spacing == "leading":
step_ratio = last_timestep // (num_inference_steps + 1)
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64)
timesteps -= 1
else:
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
if self.config.use_karras_sigmas:
log_sigmas = np.log(sigmas)
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
timesteps = np.flip(timesteps).copy().astype(np.int64)
self.sigmas = torch.from_numpy(sigmas)
# when num_inference_steps == num_train_timesteps, we can end up with
# duplicates in timesteps.
_, unique_indices = np.unique(timesteps, return_index=True)
timesteps = timesteps[np.sort(unique_indices)]
self.timesteps = torch.from_numpy(timesteps).to(device)
self.num_inference_steps = len(timesteps)
self.model_outputs = [
None,
] * max(self.config.predictor_order, self.config.corrector_order - 1)
self.lower_order_nums = 0
self.last_sample = None
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
"""
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://arxiv.org/abs/2205.11487
"""
dtype = sample.dtype
batch_size, channels, height, width = sample.shape
if dtype not in (torch.float32, torch.float64):
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
sample = sample.reshape(batch_size, channels * height * width)
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
sample = sample.reshape(batch_size, channels, height, width)
sample = sample.to(dtype)
return sample
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
log_sigma = np.log(sigma)
# get distribution
dists = log_sigma - log_sigmas[:, np.newaxis]
# get sigmas range
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
low = log_sigmas[low_idx]
high = log_sigmas[high_idx]
# interpolate sigmas
w = (low - log_sigma) / (low - high)
w = np.clip(w, 0, 1)
# transform interpolation to time range
t = (1 - w) * low_idx + w * high_idx
t = t.reshape(sigma.shape)
return t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
return (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
def convert_model_output(
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
) -> torch.FloatTensor:
"""
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
integral of the data prediction model.
<Tip>
The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
prediction and data prediction models.
</Tip>
Args:
model_output (`torch.FloatTensor`):
The direct output from the learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
Returns:
`torch.FloatTensor`:
The converted model output.
"""
# SA-Solver_data_prediction needs to solve an integral of the data prediction model.
if self.config.algorithm_type in ["data_prediction"]:
if self.config.prediction_type == "epsilon":
# SA-Solver only needs the "mean" output.
if self.config.variance_type in ["learned", "learned_range"]:
model_output = model_output[:, :3]
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = (sample - sigma_t * model_output) / alpha_t
elif self.config.prediction_type == "sample":
x0_pred = model_output
elif self.config.prediction_type == "v_prediction":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = alpha_t * sample - sigma_t * model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction` for the SASolverScheduler."
)
if self.config.thresholding:
x0_pred = self._threshold_sample(x0_pred)
return x0_pred
# SA-Solver_noise_prediction needs to solve an integral of the noise prediction model.
elif self.config.algorithm_type in ["noise_prediction"]:
if self.config.prediction_type == "epsilon":
# SA-Solver only needs the "mean" output.
if self.config.variance_type in ["learned", "learned_range"]:
epsilon = model_output[:, :3]
else:
epsilon = model_output
elif self.config.prediction_type == "sample":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
epsilon = (sample - alpha_t * model_output) / sigma_t
elif self.config.prediction_type == "v_prediction":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
epsilon = alpha_t * model_output + sigma_t * sample
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction` for the SASolverScheduler."
)
if self.config.thresholding:
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = (sample - sigma_t * epsilon) / alpha_t
x0_pred = self._threshold_sample(x0_pred)
epsilon = (sample - alpha_t * x0_pred) / sigma_t
return epsilon
def get_coefficients_exponential_negative(self, order, interval_start, interval_end):
"""
Calculate the integral of exp(-x) * x^order dx from interval_start to interval_end
"""
assert order in [0, 1, 2, 3], "order is only supported for 0, 1, 2 and 3"
if order == 0:
return torch.exp(-interval_end) * (torch.exp(interval_end - interval_start) - 1)
elif order == 1:
return torch.exp(-interval_end) * (
(interval_start + 1) * torch.exp(interval_end - interval_start) - (interval_end + 1))
elif order == 2:
return torch.exp(-interval_end) * (
(interval_start ** 2 + 2 * interval_start + 2) * torch.exp(interval_end - interval_start) - (
interval_end ** 2 + 2 * interval_end + 2))
elif order == 3:
return torch.exp(-interval_end) * (
(interval_start ** 3 + 3 * interval_start ** 2 + 6 * interval_start + 6) * torch.exp(
interval_end - interval_start) - (interval_end ** 3 + 3 * interval_end ** 2 + 6 * interval_end + 6))
def get_coefficients_exponential_positive(self, order, interval_start, interval_end, tau):
"""
Calculate the integral of exp(x(1+tau^2)) * x^order dx from interval_start to interval_end
"""
assert order in [0, 1, 2, 3], "order is only supported for 0, 1, 2 and 3"
# after change of variable(cov)
interval_end_cov = (1 + tau ** 2) * interval_end
interval_start_cov = (1 + tau ** 2) * interval_start
if order == 0:
return torch.exp(interval_end_cov) * (1 - torch.exp(-(interval_end_cov - interval_start_cov))) / (
(1 + tau ** 2))
elif order == 1:
return torch.exp(interval_end_cov) * ((interval_end_cov - 1) - (interval_start_cov - 1) * torch.exp(
-(interval_end_cov - interval_start_cov))) / ((1 + tau ** 2) ** 2)
elif order == 2:
return torch.exp(interval_end_cov) * ((interval_end_cov ** 2 - 2 * interval_end_cov + 2) - (
interval_start_cov ** 2 - 2 * interval_start_cov + 2) * torch.exp(
-(interval_end_cov - interval_start_cov))) / ((1 + tau ** 2) ** 3)
elif order == 3:
return torch.exp(interval_end_cov) * (
(interval_end_cov ** 3 - 3 * interval_end_cov ** 2 + 6 * interval_end_cov - 6) - (
interval_start_cov ** 3 - 3 * interval_start_cov ** 2 + 6 * interval_start_cov - 6) * torch.exp(
-(interval_end_cov - interval_start_cov))) / ((1 + tau ** 2) ** 4)
def lagrange_polynomial_coefficient(self, order, lambda_list):
"""
Calculate the coefficient of lagrange polynomial
"""
assert order in [0, 1, 2, 3]
assert order == len(lambda_list) - 1
if order == 0:
return [[1]]
elif order == 1:
return [[1 / (lambda_list[0] - lambda_list[1]), -lambda_list[1] / (lambda_list[0] - lambda_list[1])],
[1 / (lambda_list[1] - lambda_list[0]), -lambda_list[0] / (lambda_list[1] - lambda_list[0])]]
elif order == 2:
denominator1 = (lambda_list[0] - lambda_list[1]) * (lambda_list[0] - lambda_list[2])
denominator2 = (lambda_list[1] - lambda_list[0]) * (lambda_list[1] - lambda_list[2])
denominator3 = (lambda_list[2] - lambda_list[0]) * (lambda_list[2] - lambda_list[1])
return [[1 / denominator1,
(-lambda_list[1] - lambda_list[2]) / denominator1,
lambda_list[1] * lambda_list[2] / denominator1],
[1 / denominator2,
(-lambda_list[0] - lambda_list[2]) / denominator2,
lambda_list[0] * lambda_list[2] / denominator2],
[1 / denominator3,
(-lambda_list[0] - lambda_list[1]) / denominator3,
lambda_list[0] * lambda_list[1] / denominator3]
]
elif order == 3:
denominator1 = (lambda_list[0] - lambda_list[1]) * (lambda_list[0] - lambda_list[2]) * (
lambda_list[0] - lambda_list[3])
denominator2 = (lambda_list[1] - lambda_list[0]) * (lambda_list[1] - lambda_list[2]) * (
lambda_list[1] - lambda_list[3])
denominator3 = (lambda_list[2] - lambda_list[0]) * (lambda_list[2] - lambda_list[1]) * (
lambda_list[2] - lambda_list[3])
denominator4 = (lambda_list[3] - lambda_list[0]) * (lambda_list[3] - lambda_list[1]) * (
lambda_list[3] - lambda_list[2])
return [[1 / denominator1,
(-lambda_list[1] - lambda_list[2] - lambda_list[3]) / denominator1,
(lambda_list[1] * lambda_list[2] + lambda_list[1] * lambda_list[3] + lambda_list[2] * lambda_list[
3]) / denominator1,
(-lambda_list[1] * lambda_list[2] * lambda_list[3]) / denominator1],
[1 / denominator2,
(-lambda_list[0] - lambda_list[2] - lambda_list[3]) / denominator2,
(lambda_list[0] * lambda_list[2] + lambda_list[0] * lambda_list[3] + lambda_list[2] * lambda_list[
3]) / denominator2,
(-lambda_list[0] * lambda_list[2] * lambda_list[3]) / denominator2],
[1 / denominator3,
(-lambda_list[0] - lambda_list[1] - lambda_list[3]) / denominator3,
(lambda_list[0] * lambda_list[1] + lambda_list[0] * lambda_list[3] + lambda_list[1] * lambda_list[
3]) / denominator3,
(-lambda_list[0] * lambda_list[1] * lambda_list[3]) / denominator3],
[1 / denominator4,
(-lambda_list[0] - lambda_list[1] - lambda_list[2]) / denominator4,
(lambda_list[0] * lambda_list[1] + lambda_list[0] * lambda_list[2] + lambda_list[1] * lambda_list[
2]) / denominator4,
(-lambda_list[0] * lambda_list[1] * lambda_list[2]) / denominator4]
]
def get_coefficients_fn(self, order, interval_start, interval_end, lambda_list, tau):
assert order in [1, 2, 3, 4]
assert order == len(lambda_list), 'the length of lambda list must be equal to the order'
coefficients = []
lagrange_coefficient = self.lagrange_polynomial_coefficient(order - 1, lambda_list)
for i in range(order):
coefficient = sum(
lagrange_coefficient[i][j]
* self.get_coefficients_exponential_positive(
order - 1 - j, interval_start, interval_end, tau
)
if self.predict_x0
else lagrange_coefficient[i][j]
* self.get_coefficients_exponential_negative(
order - 1 - j, interval_start, interval_end
)
for j in range(order)
)
coefficients.append(coefficient)
assert len(coefficients) == order, 'the length of coefficients does not match the order'
return coefficients
def stochastic_adams_bashforth_update(
self,
model_output: torch.FloatTensor,
prev_timestep: int,
sample: torch.FloatTensor,
noise: torch.FloatTensor,
order: int,
tau: torch.FloatTensor,
) -> torch.FloatTensor:
"""
One step for the SA-Predictor.
Args:
model_output (`torch.FloatTensor`):
The direct output from the learned diffusion model at the current timestep.
prev_timestep (`int`):
The previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
order (`int`):
The order of SA-Predictor at this timestep.
Returns:
`torch.FloatTensor`:
The sample tensor at the previous timestep.
"""
assert noise is not None
timestep_list = self.timestep_list
model_output_list = self.model_outputs
s0, t = self.timestep_list[-1], prev_timestep
lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
gradient_part = torch.zeros_like(sample)
h = lambda_t - lambda_s0
lambda_list = [self.lambda_t[timestep_list[-(i + 1)]] for i in range(order)]
gradient_coefficients = self.get_coefficients_fn(order, lambda_s0, lambda_t, lambda_list, tau)
x = sample
if self.predict_x0 and order == 2:
gradient_coefficients[0] += 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * (
h ** 2 / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / (
(1 + tau ** 2) ** 2)) / (self.lambda_t[timestep_list[-1]] - self.lambda_t[
timestep_list[-2]])
gradient_coefficients[1] -= 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * (
h ** 2 / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / (
(1 + tau ** 2) ** 2)) / (self.lambda_t[timestep_list[-1]] - self.lambda_t[
timestep_list[-2]])
for i in range(order):
if self.predict_x0:
gradient_part += (1 + tau ** 2) * sigma_t * torch.exp(- tau ** 2 * lambda_t) * gradient_coefficients[
i] * model_output_list[-(i + 1)]
else:
gradient_part += -(1 + tau ** 2) * alpha_t * gradient_coefficients[i] * model_output_list[-(i + 1)]
if self.predict_x0:
noise_part = sigma_t * torch.sqrt(1 - torch.exp(-2 * tau ** 2 * h)) * noise
else:
noise_part = tau * sigma_t * torch.sqrt(torch.exp(2 * h) - 1) * noise
if self.predict_x0:
x_t = torch.exp(-tau ** 2 * h) * (sigma_t / sigma_s0) * x + gradient_part + noise_part
else:
x_t = (alpha_t / alpha_s0) * x + gradient_part + noise_part
x_t = x_t.to(x.dtype)
return x_t
def stochastic_adams_moulton_update(
self,
this_model_output: torch.FloatTensor,
this_timestep: int,
last_sample: torch.FloatTensor,
last_noise: torch.FloatTensor,
this_sample: torch.FloatTensor,
order: int,
tau: torch.FloatTensor,
) -> torch.FloatTensor:
"""
One step for the SA-Corrector.
Args:
this_model_output (`torch.FloatTensor`):
The model outputs at `x_t`.
this_timestep (`int`):
The current timestep `t`.
last_sample (`torch.FloatTensor`):
The generated sample before the last predictor `x_{t-1}`.
this_sample (`torch.FloatTensor`):
The generated sample after the last predictor `x_{t}`.
order (`int`):
The order of SA-Corrector at this step.
Returns:
`torch.FloatTensor`:
The corrected sample tensor at the current timestep.
"""
assert last_noise is not None
timestep_list = self.timestep_list
model_output_list = self.model_outputs
s0, t = self.timestep_list[-1], this_timestep
lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
gradient_part = torch.zeros_like(this_sample)
h = lambda_t - lambda_s0
t_list = timestep_list + [this_timestep]
lambda_list = [self.lambda_t[t_list[-(i + 1)]] for i in range(order)]
model_prev_list = model_output_list + [this_model_output]
gradient_coefficients = self.get_coefficients_fn(order, lambda_s0, lambda_t, lambda_list, tau)
x = last_sample
if self.predict_x0 and order == 2:
gradient_coefficients[0] += 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * (
h / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / (
(1 + tau ** 2) ** 2 * h))
gradient_coefficients[1] -= 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * (
h / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / (
(1 + tau ** 2) ** 2 * h))
for i in range(order):
if self.predict_x0:
gradient_part += (1 + tau ** 2) * sigma_t * torch.exp(- tau ** 2 * lambda_t) * gradient_coefficients[
i] * model_prev_list[-(i + 1)]
else:
gradient_part += -(1 + tau ** 2) * alpha_t * gradient_coefficients[i] * model_prev_list[-(i + 1)]
if self.predict_x0:
noise_part = sigma_t * torch.sqrt(1 - torch.exp(-2 * tau ** 2 * h)) * last_noise
else:
noise_part = tau * sigma_t * torch.sqrt(torch.exp(2 * h) - 1) * last_noise
if self.predict_x0:
x_t = torch.exp(-tau ** 2 * h) * (sigma_t / sigma_s0) * x + gradient_part + noise_part
else:
x_t = (alpha_t / alpha_s0) * x + gradient_part + noise_part
x_t = x_t.to(x.dtype)
return x_t
def step(
self,
model_output: torch.FloatTensor,
timestep: int,
sample: torch.FloatTensor,
generator=None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
the SA-Solver.
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
step_index = (self.timesteps == timestep).nonzero()
if len(step_index) == 0:
step_index = len(self.timesteps) - 1
else:
step_index = step_index.item()
use_corrector = (
step_index > 0 and self.last_sample is not None
)
model_output_convert = self.convert_model_output(model_output, timestep, sample)
if use_corrector:
current_tau = self.tau_func(self.timestep_list[-1])
sample = self.stochastic_adams_moulton_update(
this_model_output=model_output_convert,
this_timestep=timestep,
last_sample=self.last_sample,
last_noise=self.last_noise,
this_sample=sample,
order=self.this_corrector_order,
tau=current_tau,
)
prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
for i in range(max(self.config.predictor_order, self.config.corrector_order - 1) - 1):
self.model_outputs[i] = self.model_outputs[i + 1]
self.timestep_list[i] = self.timestep_list[i + 1]
self.model_outputs[-1] = model_output_convert
self.timestep_list[-1] = timestep
noise = randn_tensor(
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
)
if self.config.lower_order_final:
this_predictor_order = min(self.config.predictor_order, len(self.timesteps) - step_index)
this_corrector_order = min(self.config.corrector_order, len(self.timesteps) - step_index + 1)
else:
this_predictor_order = self.config.predictor_order
this_corrector_order = self.config.corrector_order
self.this_predictor_order = min(this_predictor_order, self.lower_order_nums + 1) # warmup for multistep
self.this_corrector_order = min(this_corrector_order, self.lower_order_nums + 2) # warmup for multistep
assert self.this_predictor_order > 0
assert self.this_corrector_order > 0
self.last_sample = sample
self.last_noise = noise
current_tau = self.tau_func(self.timestep_list[-1])
prev_sample = self.stochastic_adams_bashforth_update(
model_output=model_output_convert,
prev_timestep=prev_timestep,
sample=sample,
noise=noise,
order=self.this_predictor_order,
tau=current_tau,
)
if self.lower_order_nums < max(self.config.predictor_order, self.config.corrector_order - 1):
self.lower_order_nums += 1
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`):
The input sample.
Returns:
`torch.FloatTensor`:
A scaled input sample.
"""
return sample
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
return sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
def __len__(self):
return self.config.num_train_timesteps
\ No newline at end of file
import os
import re
import torch
from diffusion.utils.logger import get_root_logger
def save_checkpoint(work_dir,
epoch,
model,
model_ema=None,
optimizer=None,
lr_scheduler=None,
keep_last=False,
step=None,
):
os.makedirs(work_dir, exist_ok=True)
state_dict = dict(state_dict=model.state_dict())
if model_ema is not None:
state_dict['state_dict_ema'] = model_ema.state_dict()
if optimizer is not None:
state_dict['optimizer'] = optimizer.state_dict()
if lr_scheduler is not None:
state_dict['scheduler'] = lr_scheduler.state_dict()
if epoch is not None:
state_dict['epoch'] = epoch
file_path = os.path.join(work_dir, f"epoch_{epoch}.pth")
if step is not None:
file_path = file_path.split('.pth')[0] + f"_step_{step}.pth"
logger = get_root_logger()
torch.save(state_dict, file_path)
logger.info(f'Saved checkpoint of epoch {epoch} to {file_path.format(epoch)}.')
if keep_last:
for i in range(epoch):
previous_ckgt = file_path.format(i)
if os.path.exists(previous_ckgt):
os.remove(previous_ckgt)
def load_checkpoint(checkpoint,
model,
model_ema=None,
optimizer=None,
lr_scheduler=None,
load_ema=False,
resume_optimizer=True,
resume_lr_scheduler=True
):
assert isinstance(checkpoint, str)
ckpt_file = checkpoint
checkpoint = torch.load(ckpt_file, map_location="cpu")
state_dict_keys = ['pos_embed', 'base_model.pos_embed', 'model.pos_embed']
for key in state_dict_keys:
if key in checkpoint['state_dict']:
del checkpoint['state_dict'][key]
if 'state_dict_ema' in checkpoint and key in checkpoint['state_dict_ema']:
del checkpoint['state_dict_ema'][key]
break
if load_ema:
state_dict = checkpoint['state_dict_ema']
else:
state_dict = checkpoint.get('state_dict', checkpoint) # to be compatible with the official checkpoint
# model.load_state_dict(state_dict)
missing, unexpect = model.load_state_dict(state_dict, strict=False)
if model_ema is not None:
model_ema.load_state_dict(checkpoint['state_dict_ema'], strict=False)
if optimizer is not None and resume_optimizer:
optimizer.load_state_dict(checkpoint['optimizer'])
if lr_scheduler is not None and resume_lr_scheduler:
lr_scheduler.load_state_dict(checkpoint['scheduler'])
logger = get_root_logger()
if optimizer is not None:
epoch = checkpoint.get('epoch', re.match(r'.*epoch_(\d*).*.pth', ckpt_file).group()[0])
logger.info(f'Resume checkpoint of epoch {epoch} from {ckpt_file}. Load ema: {load_ema}, '
f'resume optimizer: {resume_optimizer}, resume lr scheduler: {resume_lr_scheduler}.')
return epoch, missing, unexpect
logger.info(f'Load checkpoint from {ckpt_file}. Load ema: {load_ema}.')
return missing, unexpect
# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import Sequence
from torch.utils.data import BatchSampler, Sampler, Dataset
from random import shuffle, choice
from copy import deepcopy
from diffusion.utils.logger import get_root_logger
class AspectRatioBatchSampler(BatchSampler):
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch.
Args:
sampler (Sampler): Base sampler.
dataset (Dataset): Dataset providing data information.
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``.
aspect_ratios (dict): The predefined aspect ratios.
"""
def __init__(self,
sampler: Sampler,
dataset: Dataset,
batch_size: int,
aspect_ratios: dict,
drop_last: bool = False,
config=None,
valid_num=0, # take as valid aspect-ratio when sample number >= valid_num
**kwargs) -> None:
if not isinstance(sampler, Sampler):
raise TypeError('sampler should be an instance of ``Sampler``, '
f'but got {sampler}')
if not isinstance(batch_size, int) or batch_size <= 0:
raise ValueError('batch_size should be a positive integer value, '
f'but got batch_size={batch_size}')
self.sampler = sampler
self.dataset = dataset
self.batch_size = batch_size
self.aspect_ratios = aspect_ratios
self.drop_last = drop_last
self.ratio_nums_gt = kwargs.get('ratio_nums', None)
self.config = config
assert self.ratio_nums_gt
# buckets for each aspect ratio
self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
self.current_available_bucket_keys = [str(k) for k, v in self.ratio_nums_gt.items() if v >= valid_num]
logger = get_root_logger() if config is None else get_root_logger(os.path.join(config.work_dir, 'train_log.log'))
logger.warning(f"Using valid_num={valid_num} in config file. Available {len(self.current_available_bucket_keys)} aspect_ratios: {self.current_available_bucket_keys}")
def __iter__(self) -> Sequence[int]:
for idx in self.sampler:
data_info = self.dataset.get_data_info(idx)
height, width = data_info['height'], data_info['width']
ratio = height / width
# find the closest aspect ratio
closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
if closest_ratio not in self.current_available_bucket_keys:
continue
bucket = self._aspect_ratio_buckets[closest_ratio]
bucket.append(idx)
# yield a batch of indices in the same aspect ratio group
if len(bucket) == self.batch_size:
yield bucket[:]
del bucket[:]
# yield the rest data and reset the buckets
for bucket in self._aspect_ratio_buckets.values():
while len(bucket) > 0:
if len(bucket) <= self.batch_size:
if not self.drop_last:
yield bucket[:]
bucket = []
else:
yield bucket[:self.batch_size]
bucket = bucket[self.batch_size:]
class BalancedAspectRatioBatchSampler(AspectRatioBatchSampler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Assign samples to each bucket
self.ratio_nums_gt = kwargs.get('ratio_nums', None)
assert self.ratio_nums_gt
self._aspect_ratio_buckets = {float(ratio): [] for ratio in self.aspect_ratios.keys()}
self.original_buckets = {}
self.current_available_bucket_keys = [k for k, v in self.ratio_nums_gt.items() if v >= 3000]
self.all_available_keys = deepcopy(self.current_available_bucket_keys)
self.exhausted_bucket_keys = []
self.total_batches = len(self.sampler) // self.batch_size
self._aspect_ratio_count = {}
for k in self.all_available_keys:
self._aspect_ratio_count[float(k)] = 0
self.original_buckets[float(k)] = []
logger = get_root_logger(os.path.join(self.config.work_dir, 'train_log.log'))
logger.warning(f"Available {len(self.current_available_bucket_keys)} aspect_ratios: {self.current_available_bucket_keys}")
def __iter__(self) -> Sequence[int]:
i = 0
for idx in self.sampler:
data_info = self.dataset.get_data_info(idx)
height, width = data_info['height'], data_info['width']
ratio = height / width
closest_ratio = float(min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio)))
if closest_ratio not in self.all_available_keys:
continue
if self._aspect_ratio_count[closest_ratio] < self.ratio_nums_gt[closest_ratio]:
self._aspect_ratio_count[closest_ratio] += 1
self._aspect_ratio_buckets[closest_ratio].append(idx)
self.original_buckets[closest_ratio].append(idx) # Save the original samples for each bucket
if not self.current_available_bucket_keys:
self.current_available_bucket_keys, self.exhausted_bucket_keys = self.exhausted_bucket_keys, []
if closest_ratio not in self.current_available_bucket_keys:
continue
key = closest_ratio
bucket = self._aspect_ratio_buckets[key]
if len(bucket) == self.batch_size:
yield bucket[:self.batch_size]
del bucket[:self.batch_size]
i += 1
self.exhausted_bucket_keys.append(key)
self.current_available_bucket_keys.remove(key)
for _ in range(self.total_batches - i):
key = choice(self.all_available_keys)
bucket = self._aspect_ratio_buckets[key]
if len(bucket) >= self.batch_size:
yield bucket[:self.batch_size]
del bucket[:self.batch_size]
# If a bucket is exhausted
if not bucket:
self._aspect_ratio_buckets[key] = deepcopy(self.original_buckets[key][:])
shuffle(self._aspect_ratio_buckets[key])
else:
self._aspect_ratio_buckets[key] = deepcopy(self.original_buckets[key][:])
shuffle(self._aspect_ratio_buckets[key])
"""
This file contains primitives for multi-gpu communication.
This is useful when doing distributed training.
"""
import os
import pickle
import shutil
import gc
import mmcv
import torch
import torch.distributed as dist
from mmcv.runner import get_dist_info
def is_distributed():
return get_world_size() > 1
def get_world_size():
if not dist.is_available():
return 1
return dist.get_world_size() if dist.is_initialized() else 1
def get_rank():
if not dist.is_available():
return 0
return dist.get_rank() if dist.is_initialized() else 0
def get_local_rank():
if not dist.is_available():
return 0
return int(os.getenv('LOCAL_RANK', 0)) if dist.is_initialized() else 0
def is_master():
return get_rank() == 0
def is_local_master():
return get_local_rank() == 0
def get_local_proc_group(group_size=8):
world_size = get_world_size()
if world_size <= group_size or group_size == 1:
return None
assert world_size % group_size == 0, f'world size ({world_size}) should be evenly divided by group size ({group_size}).'
process_groups = getattr(get_local_proc_group, 'process_groups', {})
if group_size not in process_groups:
num_groups = dist.get_world_size() // group_size
groups = [list(range(i * group_size, (i + 1) * group_size)) for i in range(num_groups)]
process_groups.update({group_size: [torch.distributed.new_group(group) for group in groups]})
get_local_proc_group.process_groups = process_groups
group_idx = get_rank() // group_size
return get_local_proc_group.process_groups.get(group_size)[group_idx]
def synchronize():
"""
Helper function to synchronize (barrier) among all processes when
using distributed training
"""
if not dist.is_available():
return
if not dist.is_initialized():
return
world_size = dist.get_world_size()
if world_size == 1:
return
dist.barrier()
def all_gather(data):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors)
Args:
data: any picklable object
Returns:
list[data]: list of data gathered from each rank
"""
to_device = torch.device("cuda")
# to_device = torch.device("cpu")
world_size = get_world_size()
if world_size == 1:
return [data]
# serialized to a Tensor
buffer = pickle.dumps(data)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to(to_device)
# obtain Tensor size of each rank
local_size = torch.LongTensor([tensor.numel()]).to(to_device)
size_list = [torch.LongTensor([0]).to(to_device) for _ in range(world_size)]
dist.all_gather(size_list, local_size)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
tensor_list = [
torch.ByteTensor(size=(max_size,)).to(to_device) for _ in size_list
]
if local_size != max_size:
padding = torch.ByteTensor(size=(max_size - local_size,)).to(to_device)
tensor = torch.cat((tensor, padding), dim=0)
dist.all_gather(tensor_list, tensor)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
def reduce_dict(input_dict, average=True):
"""
Args:
input_dict (dict): all the values will be reduced
average (bool): whether to do average or sum
Reduce the values in the dictionary from all processes so that process with rank
0 has the averaged results. Returns a dict with the same fields as
input_dict, after reduction.
"""
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
reduced_dict = _extracted_from_reduce_dict_14(input_dict, average, world_size)
return reduced_dict
# TODO Rename this here and in `reduce_dict`
def _extracted_from_reduce_dict_14(input_dict, average, world_size):
names = []
values = []
# sort the keys so that they are consistent across processes
for k in sorted(input_dict.keys()):
names.append(k)
values.append(input_dict[k])
values = torch.stack(values, dim=0)
dist.reduce(values, dst=0)
if dist.get_rank() == 0 and average:
# only main process gets accumulated, so only divide by
# world_size in this case
values /= world_size
return dict(zip(names, values))
def broadcast(data, **kwargs):
if get_world_size() == 1:
return data
data = [data]
dist.broadcast_object_list(data, **kwargs)
return data[0]
def all_gather_cpu(result_part, tmpdir=None, collect_by_master=True):
rank, world_size = get_dist_info()
if tmpdir is None:
tmpdir = './tmp'
if rank == 0:
mmcv.mkdir_or_exist(tmpdir)
synchronize()
# dump the part result to the dir
mmcv.dump(result_part, os.path.join(tmpdir, f'part_{rank}.pkl'))
synchronize()
if collect_by_master and rank != 0:
return None
# load results of all parts from tmp dir
results = []
for i in range(world_size):
part_file = os.path.join(tmpdir, f'part_{i}.pkl')
results.append(mmcv.load(part_file))
if not collect_by_master:
synchronize()
# remove tmp dir
if rank == 0:
shutil.rmtree(tmpdir)
return results
def all_gather_tensor(tensor, group_size=None, group=None):
if group_size is None:
group_size = get_world_size()
if group_size == 1:
output = [tensor]
else:
output = [torch.zeros_like(tensor) for _ in range(group_size)]
dist.all_gather(output, tensor, group=group)
return output
def gather_difflen_tensor(feat, num_samples_list, concat=True, group=None, group_size=None):
world_size = get_world_size()
if world_size == 1:
return feat if concat else [feat]
num_samples, *feat_dim = feat.size()
# padding to max number of samples
feat_padding = feat.new_zeros((max(num_samples_list), *feat_dim))
feat_padding[:num_samples] = feat
# gather
feat_gather = all_gather_tensor(feat_padding, group=group, group_size=group_size)
for r, num in enumerate(num_samples_list):
feat_gather[r] = feat_gather[r][:num]
if concat:
feat_gather = torch.cat(feat_gather)
return feat_gather
class GatherLayer(torch.autograd.Function):
'''Gather tensors from all process, supporting backward propagation.
'''
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
num_samples = torch.tensor(input.size(0), dtype=torch.long, device=input.device)
ctx.num_samples_list = all_gather_tensor(num_samples)
output = gather_difflen_tensor(input, ctx.num_samples_list, concat=False)
return tuple(output)
@staticmethod
def backward(ctx, *grads): # tuple(output)'s grad
input, = ctx.saved_tensors
num_samples_list = ctx.num_samples_list
rank = get_rank()
start, end = sum(num_samples_list[:rank]), sum(num_samples_list[:rank + 1])
grads = torch.cat(grads)
if is_distributed():
dist.all_reduce(grads)
grad_out = torch.zeros_like(input)
grad_out[:] = grads[start:end]
return grad_out, None, None
class GatherLayerWithGroup(torch.autograd.Function):
'''Gather tensors from all process, supporting backward propagation.
'''
@staticmethod
def forward(ctx, input, group, group_size):
ctx.save_for_backward(input)
ctx.group_size = group_size
output = all_gather_tensor(input, group=group, group_size=group_size)
return tuple(output)
@staticmethod
def backward(ctx, *grads): # tuple(output)'s grad
input, = ctx.saved_tensors
grads = torch.stack(grads)
if is_distributed():
dist.all_reduce(grads)
grad_out = torch.zeros_like(input)
grad_out[:] = grads[get_rank() % ctx.group_size]
return grad_out, None, None
def gather_layer_with_group(data, group=None, group_size=None):
if group_size is None:
group_size = get_world_size()
return GatherLayer.apply(data, group, group_size)
from typing import Union
import math
# from torch.distributed.fsdp.fully_sharded_data_parallel import TrainingState_, _calc_grad_norm
@torch.no_grad()
def clip_grad_norm_(
self, max_norm: Union[float, int], norm_type: Union[float, int] = 2.0
) -> None:
self._lazy_init()
self._wait_for_previous_optim_step()
assert self._is_root, "clip_grad_norm should only be called on the root (parent) instance"
self._assert_state(TrainingState_.IDLE)
max_norm = float(max_norm)
norm_type = float(norm_type)
# Computes the max norm for this shard's gradients and sync's across workers
local_norm = _calc_grad_norm(self.params_with_grad, norm_type).cuda() # type: ignore[arg-type]
if norm_type == math.inf:
total_norm = local_norm
dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=self.process_group)
else:
total_norm = local_norm ** norm_type
dist.all_reduce(total_norm, group=self.process_group)
total_norm = total_norm ** (1.0 / norm_type)
clip_coef = torch.tensor(max_norm, dtype=total_norm.dtype, device=total_norm.device) / (total_norm + 1e-6)
if clip_coef < 1:
# multiply by clip_coef, aka, (max_norm/total_norm).
for p in self.params_with_grad:
assert p.grad is not None
p.grad.detach().mul_(clip_coef.to(p.grad.device))
return total_norm
def flush():
gc.collect()
torch.cuda.empty_cache()
import logging
import os
import torch.distributed as dist
from datetime import datetime
from .dist_utils import is_local_master
from mmcv.utils.logging import logger_initialized
def get_root_logger(log_file=None, log_level=logging.INFO, name='PixArt'):
"""Get root logger.
Args:
log_file (str, optional): File path of log. Defaults to None.
log_level (int, optional): The level of logger.
Defaults to logging.INFO.
name (str): logger name
Returns:
:obj:`logging.Logger`: The obtained logger
"""
if log_file is None:
log_file = '/dev/null'
return get_logger(name=name, log_file=log_file, log_level=log_level)
def get_logger(name, log_file=None, log_level=logging.INFO):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added. If `log_file` is specified and the process rank is 0, a FileHandler
will also be added.
Args:
name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level. Note that only the process of
rank 0 is affected, and other processes will set the level to
"Error" thus be silent most of the time.
Returns:
logging.Logger: The expected logger.
"""
logger = logging.getLogger(name)
logger.propagate = False # disable root logger to avoid duplicate logging
if name in logger_initialized:
return logger
# handle hierarchical names
# e.g., logger "a" is initialized, then logger "a.b" will skip the
# initialization since it is a child of "a".
for logger_name in logger_initialized:
if name.startswith(logger_name):
return logger
stream_handler = logging.StreamHandler()
handlers = [stream_handler]
rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0
# only rank 0 will add a FileHandler
if rank == 0 and log_file is not None:
file_handler = logging.FileHandler(log_file, 'w')
handlers.append(file_handler)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
for handler in handlers:
handler.setFormatter(formatter)
handler.setLevel(log_level)
logger.addHandler(handler)
# only rank0 for each node will print logs
log_level = log_level if is_local_master() else logging.ERROR
logger.setLevel(log_level)
logger_initialized[name] = True
return logger
def rename_file_with_creation_time(file_path):
# 获取文件的创建时间
creation_time = os.path.getctime(file_path)
creation_time_str = datetime.fromtimestamp(creation_time).strftime('%Y-%m-%d_%H-%M-%S')
# 构建新的文件名
dir_name, file_name = os.path.split(file_path)
name, ext = os.path.splitext(file_name)
new_file_name = f"{name}_{creation_time_str}{ext}"
new_file_path = os.path.join(dir_name, new_file_name)
# 重命名文件
os.rename(file_path, new_file_path)
print(f"File renamed to: {new_file_path}")
from diffusers import get_cosine_schedule_with_warmup, get_constant_schedule_with_warmup
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR
import math
from diffusion.utils.logger import get_root_logger
def build_lr_scheduler(config, optimizer, train_dataloader, lr_scale_ratio):
if not config.get('lr_schedule_args', None):
config.lr_schedule_args = {}
if config.get('lr_warmup_steps', None):
config['num_warmup_steps'] = config.get('lr_warmup_steps') # for compatibility with old version
logger = get_root_logger()
logger.info(
f'Lr schedule: {config.lr_schedule}, ' + ",".join(
[f"{key}:{value}" for key, value in config.lr_schedule_args.items()]) + '.')
if config.lr_schedule == 'cosine':
lr_scheduler = get_cosine_schedule_with_warmup(
optimizer=optimizer,
**config.lr_schedule_args,
num_training_steps=(len(train_dataloader) * config.num_epochs),
)
elif config.lr_schedule == 'constant':
lr_scheduler = get_constant_schedule_with_warmup(
optimizer=optimizer,
**config.lr_schedule_args,
)
elif config.lr_schedule == 'cosine_decay_to_constant':
assert lr_scale_ratio >= 1
lr_scheduler = get_cosine_decay_to_constant_with_warmup(
optimizer=optimizer,
**config.lr_schedule_args,
final_lr=1 / lr_scale_ratio,
num_training_steps=(len(train_dataloader) * config.num_epochs),
)
else:
raise RuntimeError(f'Unrecognized lr schedule {config.lr_schedule}.')
return lr_scheduler
def get_cosine_decay_to_constant_with_warmup(optimizer: Optimizer,
num_warmup_steps: int,
num_training_steps: int,
final_lr: float = 0.0,
num_decay: float = 0.667,
num_cycles: float = 0.5,
last_epoch: int = -1
):
"""
Create a schedule with a cosine annealing lr followed by a constant lr.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The number of total training steps.
final_lr (`int`):
The final constant lr after cosine decay.
num_decay (`int`):
The
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
num_decay_steps = int(num_training_steps * num_decay)
if current_step > num_decay_steps:
return final_lr
progress = float(current_step - num_warmup_steps) / float(max(1, num_decay_steps - num_warmup_steps))
return (
max(
0.0,
0.5 * (1.0 + math.cos(math.pi * num_cycles * 2.0 * progress)),
)
* (1 - final_lr)
) + final_lr
return LambdaLR(optimizer, lr_lambda, last_epoch)
import collections
import datetime
import os
import random
import subprocess
import time
from multiprocessing import JoinableQueue, Process
import numpy as np
import torch
import torch.distributed as dist
from mmcv import Config
from mmcv.runner import get_dist_info
from diffusion.utils.logger import get_root_logger
os.environ["MOX_SILENT_MODE"] = "1" # mute moxing log
def read_config(file):
# solve config loading conflict when multi-processes
import time
while True:
config = Config.fromfile(file)
if len(config) == 0:
time.sleep(0.1)
continue
break
return config
def init_random_seed(seed=None, device='cuda'):
"""Initialize random seed.
If the seed is not set, the seed will be automatically randomized,
and then broadcast to all processes to prevent some potential bugs.
Args:
seed (int, Optional): The seed. Default to None.
device (str): The device where the seed will be put on.
Default to 'cuda'.
Returns:
int: Seed to be used.
"""
if seed is not None:
return seed
# Make sure all ranks share the same random seed to prevent
# some potential bugs. Please refer to
# https://github.com/open-mmlab/mmdetection/issues/6339
rank, world_size = get_dist_info()
seed = np.random.randint(2 ** 31)
if world_size == 1:
return seed
if rank == 0:
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
else:
random_num = torch.tensor(0, dtype=torch.int32, device=device)
dist.broadcast(random_num, src=0)
return random_num.item()
def set_random_seed(seed, deterministic=False):
"""Set random seed.
Args:
seed (int): Seed to be used.
deterministic (bool): Whether to set the deterministic option for
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
to True and `torch.backends.cudnn.benchmark` to False.
Default: False.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
class SimpleTimer:
def __init__(self, num_tasks, log_interval=1, desc="Process"):
self.num_tasks = num_tasks
self.desc = desc
self.count = 0
self.log_interval = log_interval
self.start_time = time.time()
self.logger = get_root_logger()
def log(self):
self.count += 1
if (self.count % self.log_interval) == 0 or self.count == self.num_tasks:
time_elapsed = time.time() - self.start_time
avg_time = time_elapsed / self.count
eta_sec = avg_time * (self.num_tasks - self.count)
eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
elapsed_str = str(datetime.timedelta(seconds=int(time_elapsed)))
log_info = f"{self.desc} [{self.count}/{self.num_tasks}], elapsed_time:{elapsed_str}," \
f" avg_time: {avg_time}, eta: {eta_str}."
self.logger.info(log_info)
class DebugUnderflowOverflow:
"""
This debug class helps detect and understand where the model starts getting very large or very small, and more
importantly `nan` or `inf` weight and activation elements.
There are 2 working modes:
1. Underflow/overflow detection (default)
2. Specific batch absolute min/max tracing without detection
Mode 1: Underflow/overflow detection
To activate the underflow/overflow detection, initialize the object with the model :
```python
debug_overflow = DebugUnderflowOverflow(model)
```
then run the training as normal and if `nan` or `inf` gets detected in at least one of the weight, input or
output elements this module will throw an exception and will print `max_frames_to_save` frames that lead to this
event, each frame reporting
1. the fully qualified module name plus the class name whose `forward` was run
2. the absolute min and max value of all elements for each module weights, and the inputs and output
For example, here is the header and the last few frames in detection report for `google/mt5-small` run in fp16 mixed precision :
```
Detected inf/nan during batch_number=0
Last 21 forward frames:
abs min abs max metadata
[...]
encoder.block.2.layer.1.DenseReluDense.wi_0 Linear
2.17e-07 4.50e+00 weight
1.79e-06 4.65e+00 input[0]
2.68e-06 3.70e+01 output
encoder.block.2.layer.1.DenseReluDense.wi_1 Linear
8.08e-07 2.66e+01 weight
1.79e-06 4.65e+00 input[0]
1.27e-04 2.37e+02 output
encoder.block.2.layer.1.DenseReluDense.wo Linear
1.01e-06 6.44e+00 weight
0.00e+00 9.74e+03 input[0]
3.18e-04 6.27e+04 output
encoder.block.2.layer.1.DenseReluDense T5DenseGatedGeluDense
1.79e-06 4.65e+00 input[0]
3.18e-04 6.27e+04 output
encoder.block.2.layer.1.dropout Dropout
3.18e-04 6.27e+04 input[0]
0.00e+00 inf output
```
You can see here, that `T5DenseGatedGeluDense.forward` resulted in output activations, whose absolute max value
was around 62.7K, which is very close to fp16's top limit of 64K. In the next frame we have `Dropout` which
renormalizes the weights, after it zeroed some of the elements, which pushes the absolute max value to more than
64K, and we get an overlow.
As you can see it's the previous frames that we need to look into when the numbers start going into very large for
fp16 numbers.
The tracking is done in a forward hook, which gets invoked immediately after `forward` has completed.
By default the last 21 frames are printed. You can change the default to adjust for your needs. For example :
```python
debug_overflow = DebugUnderflowOverflow(model, max_frames_to_save=100)
```
To validate that you have set up this debugging feature correctly, and you intend to use it in a training that may
take hours to complete, first run it with normal tracing enabled for one of a few batches as explained in the next
section.
Mode 2. Specific batch absolute min/max tracing without detection
The second work mode is per-batch tracing with the underflow/overflow detection feature turned off.
Let's say you want to watch the absolute min and max values for all the ingredients of each `forward` call of a
given batch, and only do that for batches 1 and 3. Then you instantiate this class as :
```python
debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1,3])
```
And now full batches 1 and 3 will be traced using the same format as explained above. Batches are 0-indexed.
This is helpful if you know that the program starts misbehaving after a certain batch number, so you can
fast-forward right to that area.
Early stopping:
You can also specify the batch number after which to stop the training, with :
```python
debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1,3], abort_after_batch_num=3)
```
This feature is mainly useful in the tracing mode, but you can use it for any mode.
**Performance**:
As this module measures absolute `min`/``max` of each weight of the model on every forward it'll slow the
training down. Therefore remember to turn it off once the debugging needs have been met.
Args:
model (`nn.Module`):
The model to debug.
max_frames_to_save (`int`, *optional*, defaults to 21):
How many frames back to record
trace_batch_nums(`List[int]`, *optional*, defaults to `[]`):
Which batch numbers to trace (turns detection off)
abort_after_batch_num (`int``, *optional*):
Whether to abort after a certain batch number has finished
"""
def __init__(self, model, max_frames_to_save=21, trace_batch_nums=None, abort_after_batch_num=None):
if trace_batch_nums is None:
trace_batch_nums = []
self.model = model
self.trace_batch_nums = trace_batch_nums
self.abort_after_batch_num = abort_after_batch_num
# keep a LIFO buffer of frames to dump as soon as inf/nan is encountered to give context to the problem emergence
self.frames = collections.deque([], max_frames_to_save)
self.frame = []
self.batch_number = 0
self.total_calls = 0
self.detected_overflow = False
self.prefix = " "
self.analyse_model()
self.register_forward_hook()
def save_frame(self, frame=None):
if frame is not None:
self.expand_frame(frame)
self.frames.append("\n".join(self.frame))
self.frame = [] # start a new frame
def expand_frame(self, line):
self.frame.append(line)
def trace_frames(self):
print("\n".join(self.frames))
self.frames = []
def reset_saved_frames(self):
self.frames = []
def dump_saved_frames(self):
print(f"\nDetected inf/nan during batch_number={self.batch_number} "
f"Last {len(self.frames)} forward frames:"
f"{'abs min':8} {'abs max':8} metadata"
f"'\n'.join(self.frames)"
f"\n\n")
self.frames = []
def analyse_model(self):
# extract the fully qualified module names, to be able to report at run time. e.g.:
# encoder.block.2.layer.0.SelfAttention.o
#
# for shared weights only the first shared module name will be registered
self.module_names = {m: name for name, m in self.model.named_modules()}
# self.longest_module_name = max(len(v) for v in self.module_names.values())
def analyse_variable(self, var, ctx):
if torch.is_tensor(var):
self.expand_frame(self.get_abs_min_max(var, ctx))
if self.detect_overflow(var, ctx):
self.detected_overflow = True
elif var is None:
self.expand_frame(f"{'None':>17} {ctx}")
else:
self.expand_frame(f"{'not a tensor':>17} {ctx}")
def batch_start_frame(self):
self.expand_frame(f"\n\n{self.prefix} *** Starting batch number={self.batch_number} ***")
self.expand_frame(f"{'abs min':8} {'abs max':8} metadata")
def batch_end_frame(self):
self.expand_frame(f"{self.prefix} *** Finished batch number={self.batch_number - 1} ***\n\n")
def create_frame(self, module, input, output):
self.expand_frame(f"{self.prefix} {self.module_names[module]} {module.__class__.__name__}")
# params
for name, p in module.named_parameters(recurse=False):
self.analyse_variable(p, name)
# inputs
if isinstance(input, tuple):
for i, x in enumerate(input):
self.analyse_variable(x, f"input[{i}]")
else:
self.analyse_variable(input, "input")
# outputs
if isinstance(output, tuple):
for i, x in enumerate(output):
# possibly a tuple of tuples
if isinstance(x, tuple):
for j, y in enumerate(x):
self.analyse_variable(y, f"output[{i}][{j}]")
else:
self.analyse_variable(x, f"output[{i}]")
else:
self.analyse_variable(output, "output")
self.save_frame()
def register_forward_hook(self):
self.model.apply(self._register_forward_hook)
def _register_forward_hook(self, module):
module.register_forward_hook(self.forward_hook)
def forward_hook(self, module, input, output):
# - input is a tuple of packed inputs (could be non-Tensors)
# - output could be a Tensor or a tuple of Tensors and non-Tensors
last_frame_of_batch = False
trace_mode = self.batch_number in self.trace_batch_nums
if trace_mode:
self.reset_saved_frames()
if self.total_calls == 0:
self.batch_start_frame()
self.total_calls += 1
# count batch numbers - the very first forward hook of the batch will be called when the
# batch completes - i.e. it gets called very last - we know this batch has finished
if module == self.model:
self.batch_number += 1
last_frame_of_batch = True
self.create_frame(module, input, output)
# if last_frame_of_batch:
# self.batch_end_frame()
if trace_mode:
self.trace_frames()
if last_frame_of_batch:
self.batch_start_frame()
if self.detected_overflow and not trace_mode:
self.dump_saved_frames()
# now we can abort, as it's pointless to continue running
raise ValueError(
"DebugUnderflowOverflow: inf/nan detected, aborting as there is no point running further. "
"Please scroll up above this traceback to see the activation values prior to this event."
)
# abort after certain batch if requested to do so
if self.abort_after_batch_num is not None and self.batch_number > self.abort_after_batch_num:
raise ValueError(
f"DebugUnderflowOverflow: aborting after {self.batch_number} batches due to `abort_after_batch_num={self.abort_after_batch_num}` arg"
)
@staticmethod
def get_abs_min_max(var, ctx):
abs_var = var.abs()
return f"{abs_var.min():8.2e} {abs_var.max():8.2e} {ctx}"
@staticmethod
def detect_overflow(var, ctx):
"""
Report whether the tensor contains any `nan` or `inf` entries.
This is useful for detecting overflows/underflows and best to call right after the function that did some math that
modified the tensor in question.
This function contains a few other helper features that you can enable and tweak directly if you want to track
various other things.
Args:
var: the tensor variable to check
ctx: the message to print as a context
Return:
`True` if `inf` or `nan` was detected, `False` otherwise
"""
detected = False
if torch.isnan(var).any().item():
detected = True
print(f"{ctx} has nans")
if torch.isinf(var).any().item():
detected = True
print(f"{ctx} has infs")
if var.dtype == torch.float32 and torch.ge(var.abs(), 65535).any().item():
detected = True
print(f"{ctx} has overflow values {var.abs().max().item()}.")
return detected
import math
from mmcv import Config
from mmcv.runner import build_optimizer as mm_build_optimizer, OPTIMIZER_BUILDERS, DefaultOptimizerConstructor, \
OPTIMIZERS
from mmcv.utils import _BatchNorm, _InstanceNorm
from torch.nn import GroupNorm, LayerNorm
from .logger import get_root_logger
from typing import Tuple, Optional, Callable
import torch
from torch.optim.optimizer import Optimizer
def auto_scale_lr(effective_bs, optimizer_cfg, rule='linear', base_batch_size=256):
assert rule in ['linear', 'sqrt']
logger = get_root_logger()
# scale by world size
if rule == 'sqrt':
scale_ratio = math.sqrt(effective_bs / base_batch_size)
elif rule == 'linear':
scale_ratio = effective_bs / base_batch_size
optimizer_cfg['lr'] *= scale_ratio
logger.info(f'Automatically adapt lr to {optimizer_cfg["lr"]:.7f} (using {rule} scaling rule).')
return scale_ratio
@OPTIMIZER_BUILDERS.register_module()
class MyOptimizerConstructor(DefaultOptimizerConstructor):
def add_params(self, params, module, prefix='', is_dcn_module=None):
"""Add all parameters of module to the params list.
The parameters of the given module will be added to the list of param
groups, with specific rules defined by paramwise_cfg.
Args:
params (list[dict]): A list of param groups, it will be modified
in place.
module (nn.Module): The module to be added.
prefix (str): The prefix of the module
"""
# get param-wise options
custom_keys = self.paramwise_cfg.get('custom_keys', {})
# first sort with alphabet order and then sort with reversed len of str
# sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True)
bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', 1.)
bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', 1.)
norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', 1.)
bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False)
# special rules for norm layers and depth-wise conv layers
is_norm = isinstance(module,
(_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm))
for name, param in module.named_parameters(recurse=False):
base_lr = self.base_lr
if name == 'bias' and not is_norm and not is_dcn_module:
base_lr *= bias_lr_mult
# apply weight decay policies
base_wd = self.base_wd
# norm decay
if is_norm:
if self.base_wd is not None:
base_wd *= norm_decay_mult
elif name == 'bias' and not is_dcn_module:
if self.base_wd is not None:
# TODO: current bias_decay_mult will have affect on DCN
base_wd *= bias_decay_mult
param_group = {'params': [param]}
if not param.requires_grad:
param_group['requires_grad'] = False
params.append(param_group)
continue
if bypass_duplicate and self._is_in(param_group, params):
logger = get_root_logger()
logger.warn(f'{prefix} is duplicate. It is skipped since '
f'bypass_duplicate={bypass_duplicate}')
continue
# if the parameter match one of the custom keys, ignore other rules
is_custom = False
for key in custom_keys:
scope, key_name = key if isinstance(key, tuple) else (None, key)
if scope is not None and scope not in f'{prefix}':
continue
if key_name in f'{prefix}.{name}':
is_custom = True
if 'lr_mult' in custom_keys[key]:
# if 'base_classes' in f'{prefix}.{name}' or 'attn_base' in f'{prefix}.{name}':
# param_group['lr'] = self.base_lr
# else:
param_group['lr'] = self.base_lr * custom_keys[key]['lr_mult']
elif 'lr' not in param_group:
param_group['lr'] = base_lr
if self.base_wd is not None:
if 'decay_mult' in custom_keys[key]:
param_group['weight_decay'] = self.base_wd * custom_keys[key]['decay_mult']
elif 'weight_decay' not in param_group:
param_group['weight_decay'] = base_wd
if not is_custom:
# bias_lr_mult affects all bias parameters
# except for norm.bias dcn.conv_offset.bias
if base_lr != self.base_lr:
param_group['lr'] = base_lr
if base_wd != self.base_wd:
param_group['weight_decay'] = base_wd
params.append(param_group)
for child_name, child_mod in module.named_children():
child_prefix = f'{prefix}.{child_name}' if prefix else child_name
self.add_params(
params,
child_mod,
prefix=child_prefix,
is_dcn_module=is_dcn_module)
def build_optimizer(model, optimizer_cfg):
# default parameter-wise config
logger = get_root_logger()
if hasattr(model, 'module'):
model = model.module
# set optimizer constructor
optimizer_cfg.setdefault('constructor', 'MyOptimizerConstructor')
# parameter-wise setting: cancel weight decay for some specific modules
custom_keys = dict()
for name, module in model.named_modules():
if hasattr(module, 'zero_weight_decay'):
custom_keys |= {
(name, key): dict(decay_mult=0)
for key in module.zero_weight_decay
}
paramwise_cfg = Config(dict(cfg=dict(custom_keys=custom_keys)))
if given_cfg := optimizer_cfg.get('paramwise_cfg'):
paramwise_cfg.merge_from_dict(dict(cfg=given_cfg))
optimizer_cfg['paramwise_cfg'] = paramwise_cfg.cfg
# build optimizer
optimizer = mm_build_optimizer(model, optimizer_cfg)
weight_decay_groups = dict()
lr_groups = dict()
for group in optimizer.param_groups:
if not group.get('requires_grad', True): continue
lr_groups.setdefault(group['lr'], []).append(group)
weight_decay_groups.setdefault(group['weight_decay'], []).append(group)
learnable_count, fix_count = 0, 0
for p in model.parameters():
if p.requires_grad:
learnable_count += 1
else:
fix_count += 1
fix_info = f"{learnable_count} are learnable, {fix_count} are fix"
lr_info = "Lr group: " + ", ".join([f'{len(group)} params with lr {lr:.5f}' for lr, group in lr_groups.items()])
wd_info = "Weight decay group: " + ", ".join(
[f'{len(group)} params with weight decay {wd}' for wd, group in weight_decay_groups.items()])
opt_info = f"Optimizer: total {len(optimizer.param_groups)} param groups, {fix_info}. {lr_info}; {wd_info}."
logger.info(opt_info)
return optimizer
@OPTIMIZERS.register_module()
class Lion(Optimizer):
def __init__(
self,
params,
lr: float = 1e-4,
betas: Tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0.0,
):
assert lr > 0.
assert all(0. <= beta <= 1. for beta in betas)
defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
super().__init__(params, defaults)
@staticmethod
def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2):
# stepweight decay
p.data.mul_(1 - lr * wd)
# weight update
update = exp_avg.clone().lerp_(grad, 1 - beta1).sign_()
p.add_(update, alpha=-lr)
# decay the momentum running average coefficient
exp_avg.lerp_(grad, 1 - beta2)
@staticmethod
def exists(val):
return val is not None
@torch.no_grad()
def step(
self,
closure: Optional[Callable] = None
):
loss = None
if self.exists(closure):
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in filter(lambda p: self.exists(p.grad), group['params']):
grad, lr, wd, beta1, beta2, state = p.grad, group['lr'], group['weight_decay'], *group['betas'], \
self.state[p]
# init state - exponential moving average of gradient values
if len(state) == 0:
state['exp_avg'] = torch.zeros_like(p)
exp_avg = state['exp_avg']
self.update_fn(
p,
grad,
exp_avg,
lr,
wd,
beta1,
beta2
)
return loss
name: PixArt
channels:
- pytorch
- nvidia
dependencies:
- python >= 3.8
- pytorch >= 1.13
- torchvision
- pytorch-cuda=11.7
- pip:
- timm==0.6.12
- diffusers
- accelerate
- mmcv==1.7.0
- diffusers
- accelerate==0.15.0
- tensorboard
- transformers==4.26.1
- sentencepiece~=0.1.97
- ftfy~=6.1.1
- beautifulsoup4~=4.11.1
- opencv-python
- bs4
- einops
- xformers
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment