Commit 4a457188 authored by raojy's avatar raojy
Browse files

fix

parent a570aeea
import numpy as np
import torch as th
def expand_t_like_x(t, x):
"""Function to reshape time t to broadcastable dimension of x
Args:
t: [batch_dim,], time vector
x: [batch_dim,...], data point
"""
dims = [1] * len(x[0].size())
t = t.view(t.size(0), *dims)
return t
#################### Coupling Plans ####################
class ICPlan:
"""Linear Coupling Plan"""
def __init__(self, sigma=0.0):
self.sigma = sigma
def compute_alpha_t(self, t):
"""Compute the data coefficient along the path"""
return t, 1
def compute_sigma_t(self, t):
"""Compute the noise coefficient along the path"""
return 1 - t, -1
def compute_d_alpha_alpha_ratio_t(self, t):
"""Compute the ratio between d_alpha and alpha"""
return 1 / t
def compute_drift(self, x, t):
"""We always output sde according to score parametrization;"""
t = expand_t_like_x(t, x)
alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t)
sigma_t, d_sigma_t = self.compute_sigma_t(t)
drift = alpha_ratio * x
diffusion = alpha_ratio * (sigma_t**2) - sigma_t * d_sigma_t
return -drift, diffusion
def compute_diffusion(self, x, t, form="constant", norm=1.0):
"""Compute the diffusion term of the SDE
Args:
x: [batch_dim, ...], data point
t: [batch_dim,], time vector
form: str, form of the diffusion term
norm: float, norm of the diffusion term
"""
t = expand_t_like_x(t, x)
choices = {
"constant": norm,
"SBDM": norm * self.compute_drift(x, t)[1],
"sigma": norm * self.compute_sigma_t(t)[0],
"linear": norm * (1 - t),
"decreasing": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2,
"inccreasing-decreasing": norm * th.sin(np.pi * t) ** 2,
}
try:
diffusion = choices[form]
except KeyError:
raise NotImplementedError(f"Diffusion form {form} not implemented")
return diffusion
def get_score_from_velocity(self, velocity, x, t):
"""Wrapper function: transfrom velocity prediction model to score
Args:
velocity: [batch_dim, ...] shaped tensor; velocity model output
x: [batch_dim, ...] shaped tensor; x_t data point
t: [batch_dim,] time tensor
"""
t = expand_t_like_x(t, x)
alpha_t, d_alpha_t = self.compute_alpha_t(t)
sigma_t, d_sigma_t = self.compute_sigma_t(t)
mean = x
reverse_alpha_ratio = alpha_t / d_alpha_t
var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
score = (reverse_alpha_ratio * velocity - mean) / var
return score
def get_noise_from_velocity(self, velocity, x, t):
"""Wrapper function: transfrom velocity prediction model to denoiser
Args:
velocity: [batch_dim, ...] shaped tensor; velocity model output
x: [batch_dim, ...] shaped tensor; x_t data point
t: [batch_dim,] time tensor
"""
t = expand_t_like_x(t, x)
alpha_t, d_alpha_t = self.compute_alpha_t(t)
sigma_t, d_sigma_t = self.compute_sigma_t(t)
mean = x
reverse_alpha_ratio = alpha_t / d_alpha_t
var = reverse_alpha_ratio * d_sigma_t - sigma_t
noise = (reverse_alpha_ratio * velocity - mean) / var
return noise
def get_velocity_from_score(self, score, x, t):
"""Wrapper function: transfrom score prediction model to velocity
Args:
score: [batch_dim, ...] shaped tensor; score model output
x: [batch_dim, ...] shaped tensor; x_t data point
t: [batch_dim,] time tensor
"""
t = expand_t_like_x(t, x)
drift, var = self.compute_drift(x, t)
velocity = var * score - drift
return velocity
def compute_mu_t(self, t, x0, x1):
"""Compute the mean of time-dependent density p_t"""
t = expand_t_like_x(t, x1)
alpha_t, _ = self.compute_alpha_t(t)
sigma_t, _ = self.compute_sigma_t(t)
if isinstance(x1, (list, tuple)):
return [alpha_t[i] * x1[i] + sigma_t[i] * x0[i] for i in range(len(x1))]
else:
return alpha_t * x1 + sigma_t * x0
def compute_xt(self, t, x0, x1):
"""Sample xt from time-dependent density p_t; rng is required"""
xt = self.compute_mu_t(t, x0, x1)
return xt
def compute_ut(self, t, x0, x1, xt):
"""Compute the vector field corresponding to p_t"""
t = expand_t_like_x(t, x1)
_, d_alpha_t = self.compute_alpha_t(t)
_, d_sigma_t = self.compute_sigma_t(t)
if isinstance(x1, (list, tuple)):
return [d_alpha_t * x1[i] + d_sigma_t * x0[i] for i in range(len(x1))]
else:
return d_alpha_t * x1 + d_sigma_t * x0
def plan(self, t, x0, x1):
xt = self.compute_xt(t, x0, x1)
ut = self.compute_ut(t, x0, x1, xt)
return t, xt, ut
class VPCPlan(ICPlan):
"""class for VP path flow matching"""
def __init__(self, sigma_min=0.1, sigma_max=20.0):
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.log_mean_coeff = (
lambda t: -0.25 * ((1 - t) ** 2) * (self.sigma_max - self.sigma_min) - 0.5 * (1 - t) * self.sigma_min
)
self.d_log_mean_coeff = lambda t: 0.5 * (1 - t) * (self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min
def compute_alpha_t(self, t):
"""Compute coefficient of x1"""
alpha_t = self.log_mean_coeff(t)
alpha_t = th.exp(alpha_t)
d_alpha_t = alpha_t * self.d_log_mean_coeff(t)
return alpha_t, d_alpha_t
def compute_sigma_t(self, t):
"""Compute coefficient of x0"""
p_sigma_t = 2 * self.log_mean_coeff(t)
sigma_t = th.sqrt(1 - th.exp(p_sigma_t))
d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t)
return sigma_t, d_sigma_t
def compute_d_alpha_alpha_ratio_t(self, t):
"""Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
return self.d_log_mean_coeff(t)
def compute_drift(self, x, t):
"""Compute the drift term of the SDE"""
t = expand_t_like_x(t, x)
beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min)
return -0.5 * beta_t * x, beta_t / 2
class GVPCPlan(ICPlan):
def __init__(self, sigma=0.0):
super().__init__(sigma)
def compute_alpha_t(self, t):
"""Compute coefficient of x1"""
alpha_t = th.sin(t * np.pi / 2)
d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2)
return alpha_t, d_alpha_t
def compute_sigma_t(self, t):
"""Compute coefficient of x0"""
sigma_t = th.cos(t * np.pi / 2)
d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2)
return sigma_t, d_sigma_t
def compute_d_alpha_alpha_ratio_t(self, t):
"""Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
return np.pi / (2 * th.tan(t * np.pi / 2))
import enum
import math
from typing import Callable
import numpy as np
import torch as th
from . import path
from .integrators import ode, sde
from .utils import mean_flat, expand_dims
from .dpm_solver import NoiseScheduleFlow, model_wrapper, DPM_Solver
class ModelType(enum.Enum):
"""
Which type of output the model predicts.
"""
NOISE = enum.auto() # the model predicts epsilon
SCORE = enum.auto() # the model predicts \nabla \log p(x)
VELOCITY = enum.auto() # the model predicts v(x)
class PathType(enum.Enum):
"""
Which type of path to use.
"""
LINEAR = enum.auto()
GVP = enum.auto()
VP = enum.auto()
class WeightType(enum.Enum):
"""
Which type of weighting to use.
"""
NONE = enum.auto()
VELOCITY = enum.auto()
LIKELIHOOD = enum.auto()
class Transport:
def __init__(self, *, model_type, path_type, loss_type, train_eps, sample_eps, snr_type, do_shift, seq_len):
path_options = {
PathType.LINEAR: path.ICPlan,
PathType.GVP: path.GVPCPlan,
PathType.VP: path.VPCPlan,
}
self.loss_type = loss_type
self.model_type = model_type
self.path_sampler = path_options[path_type]()
self.train_eps = train_eps
self.sample_eps = sample_eps
self.snr_type = snr_type
self.do_shift = do_shift
self.seq_len = seq_len
def prior_logp(self, z):
"""
Standard multivariate normal prior
Assume z is batched
"""
shape = th.tensor(z.size())
N = th.prod(shape[1:])
_fn = lambda x: -N / 2.0 * np.log(2 * np.pi) - th.sum(x**2) / 2.0
return th.vmap(_fn)(z)
def check_interval(
self,
train_eps,
sample_eps,
*,
diffusion_form="SBDM",
sde=False,
reverse=False,
eval=False,
last_step_size=0.0,
):
t0 = 0
t1 = 1
eps = train_eps if not eval else sample_eps
if type(self.path_sampler) in [path.VPCPlan]:
t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
elif (type(self.path_sampler) in [path.ICPlan, path.GVPCPlan]) and (
self.model_type != ModelType.VELOCITY or sde
): # avoid numerical issue by taking a first semi-implicit step
t0 = eps if (diffusion_form == "SBDM" and sde) or self.model_type != ModelType.VELOCITY else 0
t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
if reverse:
t0, t1 = 1 - t0, 1 - t1
return t0, t1
def sample(self, x1):
"""Sampling x0 & t based on shape of x1 (if needed)
Args:
x1 - data point; [batch, *dim]
"""
if isinstance(x1, (list, tuple)):
x0 = [th.randn_like(img_start) for img_start in x1]
else:
x0 = th.randn_like(x1)
t0, t1 = self.check_interval(self.train_eps, self.sample_eps)
if self.snr_type.startswith("uniform"):
assert t0 == 0.0 and t1 == 1.0, "not implemented."
if "_" in self.snr_type:
_, t0, t1 = self.snr_type.split("_")
t0, t1 = float(t0), float(t1)
t = th.rand((len(x1),)) * (t1 - t0) + t0
elif self.snr_type == "lognorm":
u = th.normal(mean=0.0, std=1.0, size=(len(x1),))
t = 1 / (1 + th.exp(-u)) * (t1 - t0) + t0
else:
raise NotImplementedError("Not implemented snr_type %s" % self.snr_type)
if self.do_shift:
base_shift: float = 0.5
max_shift: float = 1.15
mu = self.get_lin_function(y1=base_shift, y2=max_shift)(self.seq_len)
t = self.time_shift(mu, 1.0, t)
t = t.to(x1[0])
return t, x0, x1
def time_shift(self, mu: float, sigma: float, t: th.Tensor):
# the following implementation was original for t=0: clean / t=1: noise
# Since we adopt the reverse, the 1-t operations are needed
t = 1 - t
t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
t = 1 - t
return t
def get_lin_function(
self, x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
) -> Callable[[float], float]:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
def training_losses(self, model, x1, model_kwargs=None):
"""Loss for training the score model
Args:
- model: backbone model; could be score, noise, or velocity
- x1: datapoint
- model_kwargs: additional arguments for the model
"""
if model_kwargs == None:
model_kwargs = {}
t, x0, x1 = self.sample(x1)
t, xt, ut = self.path_sampler.plan(t, x0, x1)
if "cond" in model_kwargs:
conds = model_kwargs.pop("cond")
xt = [th.cat([x, cond], dim=0) if cond is not None else x for x, cond in zip(xt, conds)]
model_output = model(xt, t, **model_kwargs)
# Unwrap model output
if hasattr(model_output, 'sample'):
model_output = model_output.sample
elif isinstance(model_output, tuple):
model_output = model_output[0]
B = len(x0)
terms = {}
# terms['pred'] = model_output
if self.model_type == ModelType.VELOCITY:
if isinstance(x1, (list, tuple)):
assert len(model_output) == len(ut) == len(x1)
for i in range(B):
assert (
model_output[i].shape == ut[i].shape == x1[i].shape
), f"{model_output[i].shape} {ut[i].shape} {x1[i].shape}"
terms["task_loss"] = th.stack(
[((ut[i] - model_output[i]) ** 2).mean() for i in range(B)],
dim=0,
)
else:
terms["task_loss"] = mean_flat(((model_output - ut) ** 2))
else:
raise NotImplementedError
terms["loss"] = terms["task_loss"]
terms["task_loss"] = terms["task_loss"].clone().detach()
terms["t"] = t
return terms
def get_drift(self):
"""member function for obtaining the drift of the probability flow ODE"""
def score_ode(x, t, model, **model_kwargs):
drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
model_output = model(x, t, **model_kwargs)
return -drift_mean + drift_var * model_output # by change of variable
def noise_ode(x, t, model, **model_kwargs):
drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))
model_output = model(x, t, **model_kwargs)
score = model_output / -sigma_t
return -drift_mean + drift_var * score
def velocity_ode(x, t, model, **model_kwargs):
model_output = model(x, t, **model_kwargs)
return model_output
if self.model_type == ModelType.NOISE:
drift_fn = noise_ode
elif self.model_type == ModelType.SCORE:
drift_fn = score_ode
else:
drift_fn = velocity_ode
def body_fn(x, t, model, **model_kwargs):
model_output = drift_fn(x, t, model, **model_kwargs)
assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape"
return model_output
return body_fn
def get_score(
self,
):
"""member function for obtaining score of
x_t = alpha_t * x + sigma_t * eps"""
if self.model_type == ModelType.NOISE:
score_fn = (
lambda x, t, model, **kwargs: model(x, t, **kwargs)
/ -self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0]
)
elif self.model_type == ModelType.SCORE:
score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs)
elif self.model_type == ModelType.VELOCITY:
score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(
model(x, t, **kwargs), x, t
)
else:
raise NotImplementedError()
return score_fn
class Sampler:
"""Sampler class for the transport model"""
def __init__(
self,
transport,
):
"""Constructor for a general sampler; supporting different sampling methods
Args:
- transport: an tranport object specify model prediction & interpolant type
"""
self.transport = transport
self.drift = self.transport.get_drift()
self.score = self.transport.get_score()
def __get_sde_diffusion_and_drift(
self,
*,
diffusion_form="SBDM",
diffusion_norm=1.0,
):
def diffusion_fn(x, t):
diffusion = self.transport.path_sampler.compute_diffusion(x, t, form=diffusion_form, norm=diffusion_norm)
return diffusion
sde_drift = lambda x, t, model, **kwargs: self.drift(x, t, model, **kwargs) + diffusion_fn(x, t) * self.score(
x, t, model, **kwargs
)
sde_diffusion = diffusion_fn
return sde_drift, sde_diffusion
def __get_last_step(
self,
sde_drift,
*,
last_step,
last_step_size,
):
"""Get the last step function of the SDE solver"""
if last_step is None:
last_step_fn = lambda x, t, model, **model_kwargs: x
elif last_step == "Mean":
last_step_fn = (
lambda x, t, model, **model_kwargs: x + sde_drift(x, t, model, **model_kwargs) * last_step_size
)
elif last_step == "Tweedie":
alpha = self.transport.path_sampler.compute_alpha_t # simple aliasing; the original name was too long
sigma = self.transport.path_sampler.compute_sigma_t
last_step_fn = lambda x, t, model, **model_kwargs: x / alpha(t)[0][0] + (sigma(t)[0][0] ** 2) / alpha(t)[0][
0
] * self.score(x, t, model, **model_kwargs)
elif last_step == "Euler":
last_step_fn = (
lambda x, t, model, **model_kwargs: x + self.drift(x, t, model, **model_kwargs) * last_step_size
)
else:
raise NotImplementedError()
return last_step_fn
def sample_sde(
self,
*,
sampling_method="Euler",
diffusion_form="SBDM",
diffusion_norm=1.0,
last_step="Mean",
last_step_size=0.04,
num_steps=250,
):
"""returns a sampling function with given SDE settings
Args:
- sampling_method: type of sampler used in solving the SDE; default to be Euler-Maruyama
- diffusion_form: function form of diffusion coefficient; default to be matching SBDM
- diffusion_norm: function magnitude of diffusion coefficient; default to 1
- last_step: type of the last step; default to identity
- last_step_size: size of the last step; default to match the stride of 250 steps over [0,1]
- num_steps: total integration step of SDE
"""
if last_step is None:
last_step_size = 0.0
sde_drift, sde_diffusion = self.__get_sde_diffusion_and_drift(
diffusion_form=diffusion_form,
diffusion_norm=diffusion_norm,
)
t0, t1 = self.transport.check_interval(
self.transport.train_eps,
self.transport.sample_eps,
diffusion_form=diffusion_form,
sde=True,
eval=True,
reverse=False,
last_step_size=last_step_size,
)
_sde = sde(
sde_drift,
sde_diffusion,
t0=t0,
t1=t1,
num_steps=num_steps,
sampler_type=sampling_method,
)
last_step_fn = self.__get_last_step(sde_drift, last_step=last_step, last_step_size=last_step_size)
def _sample(init, model, **model_kwargs):
xs = _sde.sample(init, model, **model_kwargs)
ts = th.ones(init.size(0), device=init.device) * t1
x = last_step_fn(xs[-1], ts, model, **model_kwargs)
xs.append(x)
assert len(xs) == num_steps, "Samples does not match the number of steps"
return xs
return _sample
def sample_dpm(
self,
model,
model_kwargs=None,
):
noise_schedule = NoiseScheduleFlow(schedule="discrete_flow")
def noise_pred_fn(x, t_continuous):
output = model(x, 1 - t_continuous, **model_kwargs)
_, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
try:
noise = x - (1 - expand_dims(sigma_t, x.dim()).to(x)) * output
except:
noise = x - (1 - expand_dims(sigma_t, x.dim()).to(x)) * output[0]
return noise
return DPM_Solver(noise_pred_fn, noise_schedule, algorithm_type="dpmsolver++").sample
def sample_ode(
self,
*,
sampling_method="dopri5",
num_steps=50,
atol=1e-6,
rtol=1e-3,
reverse=False,
do_shift=False,
time_shifting_factor=None,
stochast_ratio=0.0, # 新增参数,0.0=纯ODE,1.0=完全重加噪
):
if stochast_ratio == 0.0:
# 原有逻辑不变
drift = lambda x, t, model, **kwargs: self.drift(x, t, model, **kwargs)
t0, t1 = self.transport.check_interval(
self.transport.train_eps,
self.transport.sample_eps,
sde=False,
eval=True,
reverse=reverse,
last_step_size=0.0,
)
_ode = ode(
drift=drift,
t0=t0,
t1=t1,
sampler_type=sampling_method,
num_steps=num_steps,
atol=atol,
rtol=rtol,
do_shift=do_shift,
time_shifting_factor=time_shifting_factor,
)
return _ode.sample
else:
# 新增:DDPM风格重加噪采样
t0, t1 = self.transport.check_interval(
self.transport.train_eps,
self.transport.sample_eps,
sde=False,
eval=True,
reverse=reverse,
last_step_size=0.0,
)
path_sampler = self.transport.path_sampler
def _sample(init, model, **model_kwargs):
# t0→t1: noise(t=0) → data(t=1)
t_steps = th.linspace(t0, t1, num_steps + 1, dtype=th.float64).to(init)
x_cur = init.to(th.float64)
for t_cur, t_next in zip(t_steps[:-1], t_steps[1:]):
t_batch = th.ones(x_cur.size(0), device=x_cur.device, dtype=x_cur.dtype) * t_cur
# 1. 模型预测 velocity
v = model(x_cur, t_batch, **model_kwargs)
# 2. 直接从流匹配公式还原 x̂₁ 和 x̂₀,避免除以 alpha_t 的奇点
# 联立 x_t = alpha_t*x1 + sigma_t*x0 与 v = d_alpha_t*x1 + d_sigma_t*x0
t_exp = expand_dims(t_batch, x_cur.dim())
alpha_t, d_alpha_t = path_sampler.compute_alpha_t(t_exp)
sigma_t, d_sigma_t = path_sampler.compute_sigma_t(t_exp)
denom = sigma_t * d_alpha_t - d_sigma_t * alpha_t # =1 for ICPlan
x1_hat = (sigma_t * v - d_sigma_t * x_cur) / denom
x0_hat = (d_alpha_t * x_cur - alpha_t * v) / denom
# 3. 按 t_next 重加噪
t_next_batch = th.ones_like(t_batch) * t_next
t_next_exp = expand_dims(t_next_batch, x_cur.dim())
alpha_next, _ = path_sampler.compute_alpha_t(t_next_exp)
sigma_next, _ = path_sampler.compute_sigma_t(t_next_exp)
noi = th.randn_like(x_cur)
x_cur = alpha_next * x1_hat + sigma_next * (
x0_hat * ((1 - stochast_ratio) ** 0.5)
+ noi * (stochast_ratio ** 0.5)
)
return [x_cur]
return _sample
def sample_ode_likelihood(
self,
*,
sampling_method="dopri5",
num_steps=50,
atol=1e-6,
rtol=1e-3,
):
"""returns a sampling function for calculating likelihood with given ODE settings
Args:
- sampling_method: type of sampler used in solving the ODE; default to be Dopri5
- num_steps:
- fixed solver (Euler, Heun): the actual number of integration steps performed
- adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
- atol: absolute error tolerance for the solver
- rtol: relative error tolerance for the solver
"""
def _likelihood_drift(x, t, model, **model_kwargs):
x, _ = x
eps = th.randint(2, x.size(), dtype=th.float, device=x.device) * 2 - 1
t = th.ones_like(t) * (1 - t)
with th.enable_grad():
x.requires_grad = True
grad = th.autograd.grad(th.sum(self.drift(x, t, model, **model_kwargs) * eps), x)[0]
logp_grad = th.sum(grad * eps, dim=tuple(range(1, len(x.size()))))
drift = self.drift(x, t, model, **model_kwargs)
return (-drift, logp_grad)
t0, t1 = self.transport.check_interval(
self.transport.train_eps,
self.transport.sample_eps,
sde=False,
eval=True,
reverse=False,
last_step_size=0.0,
)
_ode = ode(
drift=_likelihood_drift,
t0=t0,
t1=t1,
sampler_type=sampling_method,
num_steps=num_steps,
atol=atol,
rtol=rtol,
)
def _sample_fn(x, model, **model_kwargs):
init_logp = th.zeros(x.size(0)).to(x)
input = (x, init_logp)
drift, delta_logp = _ode.sample(input, model, **model_kwargs)
drift, delta_logp = drift[-1], delta_logp[-1]
prior_logp = self.transport.prior_logp(drift)
logp = prior_logp - delta_logp
return logp, drift
return _sample_fn
import torch as th
import math
class EasyDict:
def __init__(self, sub_dict):
for k, v in sub_dict.items():
setattr(self, k, v)
def __getitem__(self, key):
return getattr(self, key)
def mean_flat(x):
"""
Take the mean over all non-batch dimensions.
"""
return th.mean(x, dim=list(range(1, len(x.size()))))
def log_state(state):
result = []
sorted_state = dict(sorted(state.items()))
for key, value in sorted_state.items():
# Check if the value is an instance of a class
if "<object" in str(value) or "object at" in str(value):
result.append(f"{key}: [{value.__class__.__name__}]")
else:
result.append(f"{key}: {value}")
return "\n".join(result)
def time_shift(mu: float, sigma: float, t: th.Tensor):
# the following implementation was original for t=0: clean / t=1: noise
# Since we adopt the reverse, the 1-t operations are needed
t = 1 - t
t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
t = 1 - t
return t
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15):
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
def expand_dims(v, dims):
"""
Expand the tensor `v` to the dim `dims`.
Args:
`v`: a PyTorch tensor with shape [N].
`dim`: a `int`.
Returns:
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
"""
return v[(...,) + (None,) * (dims - 1)]
\ No newline at end of file
"""Utility functions for image preprocessing."""
import random
from PIL import Image
def center_crop(pil_image, crop_size):
cw, ch = crop_size
w, h = pil_image.size
left = max(0, (w - cw) // 2)
top = max(0, (h - ch) // 2)
return pil_image.crop((left, top, left + cw, top + ch)).resize((cw, ch), Image.LANCZOS)
def var_center_crop(pil_image, crop_size_list, random_top_k=1):
w, h = pil_image.size
rem_percent = [min(cw / w, ch / h) / max(cw / w, ch / h) for cw, ch in crop_size_list]
crop_size = random.choice(
sorted(((x, y) for x, y in zip(rem_percent, crop_size_list)), reverse=True)[:random_top_k]
)[1]
return center_crop(pil_image, crop_size)
def generate_crop_size_list(num_patches, patch_size, max_ratio=4.0):
assert max_ratio >= 1.0
crop_size_list = []
wp, hp = num_patches, 1
while wp > 0:
if max(wp, hp) / min(wp, hp) <= max_ratio:
crop_size_list.append((wp * patch_size, hp * patch_size))
if (hp + 1) * wp <= num_patches:
hp += 1
else:
wp -= 1
return crop_size_list
from .image_tokenizer import ImageTokenizer
# Copyright 2025 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Image tokenizer for LLaDA-2.0-Uni.
Converts PIL images into discrete VQ token IDs via a vision encoder + VQVAE.
"""
import json
from pathlib import Path
from types import SimpleNamespace
from typing import Optional
import numpy as np
import PIL.Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms.v2 import functional as tvF
# ============================================================
# Config loading
# ============================================================
def load_configs(model_dir: str | Path) -> dict:
with open(Path(model_dir) / "config.json", "r") as f:
return json.load(f)
def make_vision_config(raw: dict) -> SimpleNamespace:
vc = raw.get("vision_config", raw)
# Determine best attention implementation
attn_impl = "eager"
try:
from flash_attn import flash_attn_varlen_func
attn_impl = "flash_attention_2"
except ImportError:
try:
import torch
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
attn_impl = "sdpa"
except:
pass
return SimpleNamespace(
hidden_size=vc["hidden_size"], intermediate_size=vc["intermediate_size"],
num_heads=vc["num_heads"], depth=vc["depth"],
patch_size=vc["patch_size"], image_size=vc["image_size"],
in_channels=vc.get("in_channels", 3), hidden_act=vc.get("hidden_act", "gelu"),
attention_bias=vc.get("attention_bias", True), attention_dropout=vc.get("attention_dropout", 0.0),
layer_norm_eps=vc.get("layer_norm_eps", 1e-6),
spatial_merge_size=vc.get("spatial_merge_size", 1),
_attn_implementation=attn_impl,
)
def make_vq_config(raw: dict) -> SimpleNamespace:
vq = raw.get("vq_config", raw)
return SimpleNamespace(
num_embeddings=vq["num_embeddings"], embed_dim=vq["embed_dim"],
latent_channels=vq["latent_channels"], beta=vq.get("beta", 0.25),
)
# ============================================================
# Image preprocessing
# ============================================================
OPENAI_CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
OPENAI_CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
class ImagePreprocessor:
"""Image preprocessor: rescale + normalize. Resizing/cropping is handled externally."""
def __init__(self, config_path: str | Path):
config_path = Path(config_path)
if config_path.is_dir():
config_path = config_path / "preprocessor_config.json"
with open(config_path, "r") as f:
config = json.load(f)
self.do_rescale = config.get("do_rescale", True)
self.do_normalize = config.get("do_normalize", True)
self.rescale_factor = config.get("rescale_factor", 1.0 / 255.0)
self.image_mean = config.get("image_mean", OPENAI_CLIP_MEAN)
self.image_std = config.get("image_std", OPENAI_CLIP_STD)
self.patch_size = config.get("patch_size", 14)
self.temporal_patch_size = config.get("temporal_patch_size", 2)
self.merge_size = config.get("merge_size", 2)
self.factor = self.patch_size * self.merge_size
def _pil_to_tensor(self, image):
return tvF.to_dtype(tvF.to_image(image), dtype=torch.float32, scale=False)
def _rescale_and_normalize(self, images):
if self.do_rescale:
images = images * self.rescale_factor
if self.do_normalize:
mean = torch.tensor(self.image_mean, dtype=images.dtype, device=images.device).view(-1, 1, 1)
std = torch.tensor(self.image_std, dtype=images.dtype, device=images.device).view(-1, 1, 1)
images = (images - mean) / std
return images
def __call__(self, images):
if isinstance(images, PIL.Image.Image):
images = [images]
all_patches, all_grids = [], []
for img in images:
if img.mode != "RGB":
img = img.convert("RGB")
img_tensor = self._pil_to_tensor(img)
height, width = img_tensor.shape[-2:]
# Assume images are already cropped/resized by the caller
# (e.g. via decoder.utils.var_center_crop + generate_crop_size_list)
rh, rw = height, width
patches = self._rescale_and_normalize(img_tensor)
if patches.ndim == 3:
patches = patches.unsqueeze(0)
# Temporal padding
if patches.shape[0] % self.temporal_patch_size != 0:
repeats = patches[-1:].repeat(self.temporal_patch_size - 1, 1, 1, 1)
patches = torch.cat([patches, repeats], dim=0)
grid_t = patches.shape[0] // self.temporal_patch_size
grid_h = rh // self.patch_size
grid_w = rw // self.patch_size
channel = patches.shape[1]
# Reshape into patch tokens
patches = patches.unsqueeze(0).view(
1, grid_t, self.temporal_patch_size, channel,
grid_h // self.merge_size, self.merge_size, self.patch_size,
grid_w // self.merge_size, self.merge_size, self.patch_size,
)
patches = patches.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9)
flatten_patches = patches.reshape(
1, grid_t * grid_h * grid_w,
channel * self.temporal_patch_size * self.patch_size * self.patch_size,
)
all_patches.append(flatten_patches.squeeze(0))
all_grids.append([grid_t, grid_h, grid_w])
return {
"pixel_values": torch.cat(all_patches, dim=0),
"image_grid_thw": torch.tensor(all_grids, dtype=torch.long),
}
# ============================================================
# Vision model components
# ============================================================
def _get_act_fn(name):
mapping = {"gelu": nn.GELU(), "relu": nn.ReLU(), "silu": nn.SiLU(),
"quick_gelu": lambda x: x * torch.sigmoid(1.702 * x)}
if name in mapping:
return mapping[name]
from transformers.activations import ACT2FN
return ACT2FN[name]
class VisionMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.activation_fn = _get_act_fn(config.hidden_act)
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, x):
return self.fc2(self.activation_fn(self.fc1(x)))
class VisionAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.dim = config.hidden_size
self.num_heads = config.num_heads
self.head_dim = self.dim // self.num_heads
self.num_key_value_groups = 1
self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias)
self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
self.scaling = self.head_dim ** -0.5
self.config = config
self.attention_dropout = config.attention_dropout
self.is_causal = False
def forward(self, hidden_states, cu_seqlens, **kwargs):
seq_length = hidden_states.shape[0]
query_states, key_states, value_states = (
self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
)
query_states = query_states.transpose(0, 1).unsqueeze(0)
key_states = key_states.transpose(0, 1).unsqueeze(0)
value_states = value_states.transpose(0, 1).unsqueeze(0)
# Try to use the HF attention dispatch (flash_attention_2 / sdpa / eager)
try:
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
attn_impl = getattr(self.config, '_attn_implementation', 'eager')
if attn_impl != 'eager' and attn_impl in ALL_ATTENTION_FUNCTIONS:
attention_interface = ALL_ATTENTION_FUNCTIONS[attn_impl]
if 'flash' in attn_impl:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
attn_output, _ = attention_interface(
self, query_states, key_states, value_states,
attention_mask=None, scaling=self.scaling,
dropout=0.0 if not self.training else self.attention_dropout,
cu_seq_lens_q=cu_seqlens, cu_seq_lens_k=cu_seqlens,
max_length_q=max_seqlen, max_length_k=max_seqlen,
is_causal=False, **kwargs,
)
else:
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
splits = [torch.split(t, lengths.tolist(), dim=2) for t in (query_states, key_states, value_states)]
attn_output = torch.cat([
attention_interface(
self, q, k, v, attention_mask=None, scaling=self.scaling,
dropout=0.0 if not self.training else self.attention_dropout,
is_causal=False, **kwargs,
)[0] for q, k, v in zip(*splits)
], dim=1)
attn_output = attn_output.reshape(seq_length, -1).contiguous()
return self.proj(attn_output)
except (ImportError, KeyError, AttributeError):
pass
# Fallback: try flash_attn directly
try:
from flash_attn import flash_attn_varlen_func
q = query_states.squeeze(0).transpose(0, 1) # (seq, heads, dim)
k = key_states.squeeze(0).transpose(0, 1)
v = value_states.squeeze(0).transpose(0, 1)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen)
attn_output = attn_output.reshape(seq_length, -1).contiguous()
return self.proj(attn_output)
except ImportError:
pass
# Final fallback: manual eager attention (chunk by chunk)
q = query_states.squeeze(0) # (heads, seq, dim)
k = key_states.squeeze(0)
v = value_states.squeeze(0)
lengths = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
outputs = []
for qc, kc, vc in zip(torch.split(q, lengths, dim=1), torch.split(k, lengths, dim=1), torch.split(v, lengths, dim=1)):
attn = F.softmax(torch.matmul(qc, kc.transpose(-2, -1)) * self.scaling, dim=-1, dtype=torch.float32).to(qc.dtype)
outputs.append(torch.matmul(attn, vc))
attn_output = torch.cat(outputs, dim=1).transpose(0, 1).reshape(seq_length, -1).contiguous()
return self.proj(attn_output)
class VisionPatchEmbed(nn.Module):
def __init__(self, config):
super().__init__()
self.patch_size = config.patch_size
self.in_channels = config.in_channels
self.embed_dim = config.hidden_size
self.proj = nn.Conv2d(self.in_channels, self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size)
def forward(self, x):
target_dtype = self.proj.weight.dtype
x = x.view(-1, self.in_channels, self.patch_size, self.patch_size)
return self.proj(x.to(dtype=target_dtype)).view(-1, self.embed_dim)
class VisionEmbeddings(nn.Module):
def __init__(self, config):
super().__init__()
self.embed_dim = config.hidden_size
num_patches = (config.image_size // config.patch_size) ** 2
self.position_embedding = nn.Embedding(num_patches, self.embed_dim)
def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords):
pos_w = self.position_embedding.weight
hidden_size = pos_w.shape[1]
device = pos_w.device
if isinstance(lengths, list):
lengths = torch.tensor(lengths, device=device, dtype=torch.long)
orig_size = int(pos_w.shape[0] ** 0.5)
pos_2d = pos_w.view(orig_size, orig_size, hidden_size).permute(2, 0, 1).unsqueeze(0).float()
target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to(device=device, dtype=torch.float32)
target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]).to(device=device, dtype=torch.float32)
norm_w = ((w_coords + 0.5) / target_w) * 2 - 1
norm_h = ((h_coords + 0.5) / target_h) * 2 - 1
grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2)
adapted = F.grid_sample(pos_2d, grid, mode="bilinear", align_corners=False, padding_mode="border")
adapted = adapted.squeeze(0).squeeze(-1).permute(1, 0).to(pos_w.dtype).to(embeddings.device)
return embeddings + adapted
class VisionBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attn = VisionAttention(config)
self.mlp = VisionMLP(config)
def forward(self, hidden_states, cu_seqlens, **kwargs):
hidden_states = hidden_states + self.attn(self.norm1(hidden_states), cu_seqlens=cu_seqlens)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states
class VisionEncoder(nn.Module):
"""Vision transformer encoder that produces per-patch features."""
def __init__(self, config):
super().__init__()
self.spatial_merge_size = config.spatial_merge_size
self.patch_size = config.patch_size
self.embeddings = VisionEmbeddings(config)
self.patch_embed = VisionPatchEmbed(config)
self.blocks = nn.ModuleList([VisionBlock(config) for _ in range(config.depth)])
@property
def dtype(self):
return self.patch_embed.proj.weight.dtype
def rot_pos_emb(self, grid_thw):
pos_ids = []
for t, h, w in grid_thw:
hpos = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos = hpos.reshape(h // self.spatial_merge_size, self.spatial_merge_size,
w // self.spatial_merge_size, self.spatial_merge_size)
hpos = hpos.permute(0, 2, 1, 3).flatten()
wpos = torch.arange(w).unsqueeze(0).expand(h, -1)
wpos = wpos.reshape(h // self.spatial_merge_size, self.spatial_merge_size,
w // self.spatial_merge_size, self.spatial_merge_size)
wpos = wpos.permute(0, 2, 1, 3).flatten()
pos_ids.append(torch.stack([hpos, wpos], dim=-1).repeat(t, 1))
return torch.cat(pos_ids, dim=0)
def forward(self, pixel_values, grid_thw):
hidden_states = self.patch_embed(pixel_values)
image_type_ids = self.rot_pos_emb(grid_thw)
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0])
cu_seqlens = F.pad(cu_seqlens.cumsum(0, dtype=torch.int32), (1, 0), value=0)
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
hidden_states = self.embeddings(
hidden_states, seqlens, grid_thw,
image_type_ids[:, 0].to(hidden_states.device),
image_type_ids[:, 1].to(hidden_states.device),
)
for blk in self.blocks:
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens)
return hidden_states
# ============================================================
# VQVAE quantizer
# ============================================================
class VQVAEVectorQuantizer(nn.Module):
def __init__(self, config):
super().__init__()
self.num_embeddings = config.num_embeddings
self.embedding_dim = config.embed_dim
self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
def forward(self, hidden_state):
hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
flat = hidden_state.view(-1, self.embedding_dim)
flat = F.normalize(flat, p=2, dim=-1)
emb = F.normalize(self.embedding.weight, p=2, dim=-1)
distances = (torch.sum(flat ** 2, dim=1, keepdim=True)
+ torch.sum(emb ** 2, dim=1)
- 2 * torch.einsum("bd,dn->bn", flat, emb.t()))
return torch.argmin(distances, dim=1)
class VQVAE(nn.Module):
def __init__(self, config):
super().__init__()
self.quantize = VQVAEVectorQuantizer(config)
self.quant_conv = nn.Conv2d(config.latent_channels, config.embed_dim, 1)
self.post_quant_conv = nn.Conv2d(config.embed_dim, config.latent_channels, 1)
def encode(self, hidden_states):
return self.quantize(self.quant_conv(hidden_states))
# ============================================================
# Weight loading
# ============================================================
def _load_weights(model_dir, visual, vqmodel):
from safetensors.torch import load_file
model_path = Path(model_dir)
index_file = model_path / "model.safetensors.index.json"
if index_file.exists():
with open(index_file) as f:
weight_map = json.load(f)["weight_map"]
needed = {fn for k, fn in weight_map.items() if k.startswith(("model.visual.", "model.vqmodel."))}
else:
needed = {f.name for f in model_path.glob("*.safetensors")}
visual_sd, vq_sd = {}, {}
for filename in sorted(needed):
filepath = model_path / filename
if not filepath.exists():
continue
shard = load_file(str(filepath), device="cpu")
for key, value in shard.items():
if key.startswith("model.visual."):
visual_sd[key[len("model.visual."):]] = value
elif key.startswith("model.vqmodel."):
vq_sd[key[len("model.vqmodel."):]] = value
del shard
visual.load_state_dict(visual_sd, strict=False)
vqmodel.load_state_dict(vq_sd, strict=False)
del visual_sd, vq_sd
# ============================================================
# Main tokenizer class
# ============================================================
class ImageTokenizer:
"""
Standalone image tokenizer that converts PIL images to discrete VQ token IDs.
Expects the following layout under ``model_path``::
model_path/
└── image_tokenizer/
├── config.json # vision_config + vq_config
├── preprocessor_config.json
└── *.safetensors # visual + vqmodel weights
Args:
model_path: Root path of the model directory (parent of image_tokenizer/).
device: Torch device.
dtype: Model dtype (default: bfloat16).
"""
def __init__(self, model_path, device="cuda", dtype=torch.bfloat16):
self.device = torch.device(device)
self.dtype = dtype
tokenizer_dir = Path(model_path) / "image_tokenizer"
self.image_processor = ImagePreprocessor(tokenizer_dir)
raw_config = load_configs(tokenizer_dir)
vision_cfg = make_vision_config(raw_config)
vq_cfg = make_vq_config(raw_config)
self.visual = VisionEncoder(vision_cfg).to(self.device, self.dtype)
self.vqmodel = VQVAE(vq_cfg).to(self.device, self.dtype)
_load_weights(str(tokenizer_dir), self.visual, self.vqmodel)
self.visual.eval()
self.vqmodel.eval()
self.spatial_merge_size = vision_cfg.spatial_merge_size
@staticmethod
def _whiten_transparency(img):
if img.mode == "RGBA":
canvas = PIL.Image.new("RGBA", img.size, (255, 255, 255, 255))
canvas.alpha_composite(img)
return canvas.convert("RGB")
return img if img.mode == "RGB" else img.convert("RGB")
def _extract_features(self, pixel_values, image_grid_thw):
with torch.no_grad():
hidden = self.visual(pixel_values.to(self.device, self.dtype),
grid_thw=image_grid_thw.to(self.device))
split_sizes = (image_grid_thw.prod(-1) // self.spatial_merge_size ** 2).tolist()
return list(torch.split(hidden, split_sizes))
def _quantize(self, hidden_states, image_grid_thw):
hidden_size = hidden_states.shape[-1]
split_sizes = image_grid_thw.prod(dim=-1).tolist()
all_tokens = []
with torch.no_grad():
for i, hs in enumerate(torch.split(hidden_states, split_sizes)):
gt, gh, gw = image_grid_thw[i].tolist()
hs = hs.view(gt, gh, gw, hidden_size).permute(0, 3, 1, 2).contiguous()
all_tokens.append(self.vqmodel.encode(hs))
return torch.cat(all_tokens, dim=0)
@torch.no_grad()
def encode(self, image: PIL.Image.Image) -> list[int]:
"""Encode a single image to VQ token IDs."""
image = self._whiten_transparency(image)
inputs = self.image_processor([image])
embeds = self._extract_features(inputs["pixel_values"], inputs["image_grid_thw"])
tokens = self._quantize(torch.cat(embeds, dim=0), inputs["image_grid_thw"])
return tokens.flatten().tolist()
@torch.no_grad()
def encode_batch(self, images: list[PIL.Image.Image]) -> list[list[int]]:
"""Encode a batch of images to VQ token IDs."""
images = [self._whiten_transparency(img) for img in images]
inputs = self.image_processor(images)
pv, grid = inputs["pixel_values"], inputs["image_grid_thw"]
embeds = self._extract_features(pv, grid)
return [self._quantize(e, grid[i:i+1]).flatten().tolist() for i, e in enumerate(embeds)]
@torch.no_grad()
def encode_with_info(self, image: PIL.Image.Image) -> dict:
"""Encode image and return token IDs with metadata."""
image = self._whiten_transparency(image)
w, h = image.size
inputs = self.image_processor([image])
pv, grid = inputs["pixel_values"], inputs["image_grid_thw"]
embeds = self._extract_features(pv, grid)
tl = self._quantize(torch.cat(embeds, dim=0), grid).flatten().tolist()
return {"pixel_values": pv, "token_ids": tl, "grid_thw": tuple(grid[0].tolist()),
"num_tokens": len(tl), "image_size": (w, h)}
@property
def codebook_size(self):
return self.vqmodel.quantize.num_embeddings
@property
def embed_dim(self):
return self.vqmodel.quantize.embedding_dim
"""
LLaDA-2.0-Uni — Image Editing
Usage:
python image_edit.py --model_path /path/to/LLaDA-2.0-Uni --image input.jpg --instruction "Change the background to a beach."
python image_edit.py --model_path /path/to/LLaDA-2.0-Uni --image_token input.pt --instruction "Make it a watercolor painting."
"""
import os, sys, gc, argparse, torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from decoder import decode_vq_tokens
def parse_args():
p = argparse.ArgumentParser(description="LLaDA-2.0-Uni Image Editing")
p.add_argument("--model_path", type=str, required=True,
help="Root model dir containing LLM weights, image_tokenizer/, decoder/, vae/")
p.add_argument("--image", type=str, default=None)
p.add_argument("--image_token", type=str, default=None)
p.add_argument("--instruction", type=str, required=True)
p.add_argument("--steps", type=int, default=8)
p.add_argument("--block_length", type=int, default=32)
p.add_argument("--cfg_text_scale", type=float, default=4.0)
p.add_argument("--cfg_image_scale", type=float, default=0.0)
p.add_argument("--decoder_steps", type=int, default=50)
p.add_argument("--resolution_multiplier", type=int, default=2)
p.add_argument("--output", type=str, default="edited.png")
p.add_argument("--seed", type=int, default=42)
return p.parse_args()
def _get_image_token_offset(model_path):
"""Read image_token_offset from model config."""
import json
with open(os.path.join(model_path, "config.json")) as f:
return json.load(f).get("image_token_offset", 157184)
def encode_image_from_pt(pt_path, offset):
data = torch.load(pt_path, map_location="cpu", weights_only=False)
token_ids = (data["semantic_token_ids"] + offset).tolist()
w, h = data["metadata"]["processed_size"]
return token_ids, h // 16, w // 16
def encode_image_from_pil(image_path, model_path, device, offset):
from encoder.image_tokenizer import ImageTokenizer
from decoder.utils import generate_crop_size_list, var_center_crop
image_tokenizer = ImageTokenizer(
model_path=model_path, device=device, dtype=torch.bfloat16,
)
crop_size_list = generate_crop_size_list((512 // 32) ** 2, 32)
pil_image = var_center_crop(Image.open(image_path).convert("RGB"), crop_size_list=crop_size_list)
info = image_tokenizer.encode_with_info(pil_image)
_, h, w = info["grid_thw"]
token_ids = [x + offset for x in info["token_ids"]]
del image_tokenizer; torch.cuda.empty_cache()
return token_ids, h, w
def main():
args = parse_args()
torch.manual_seed(args.seed)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Encode source image
offset = _get_image_token_offset(args.model_path)
if args.image_token:
print(f"Loading pre-tokenized image: {args.image_token}")
image_tokens, image_h, image_w = encode_image_from_pt(args.image_token, offset)
elif args.image:
print(f"Encoding image: {args.image}")
image_tokens, image_h, image_w = encode_image_from_pil(args.image, args.model_path, device, offset)
else:
raise ValueError("Provide --image or --image_token")
print(f"Image grid: {image_h}x{image_w}, instruction: {args.instruction}")
# Phase 1: generate edited VQ tokens
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
args.model_path, device_map={"": device}, trust_remote_code=True
).to(torch.bfloat16).eval()
model.tokenizer = tokenizer
result = model.edit_image(
image_tokens, image_h, image_w, args.instruction,
steps=args.steps, block_length=args.block_length,
cfg_text_scale=args.cfg_text_scale, cfg_image_scale=args.cfg_image_scale,
)
del model; gc.collect(); torch.cuda.empty_cache()
print("Model unloaded.\n")
# Phase 2: decode to image
print("Decoding edited image...")
img = decode_vq_tokens(result["token_ids"], result["h"], result["w"],
args.model_path, device,
resolution_multiplier=args.resolution_multiplier, num_steps=args.decoder_steps)
img.save(args.output)
print(f"\n✅ Saved: {args.output}")
if __name__ == "__main__":
main()
"""
LLaDA-2.0-Uni — Image Understanding (Multimodal Understanding)
Usage:
python mmu_understand.py --model_path /path/to/LLaDA-2.0-Uni --image photo.jpg
python mmu_understand.py --model_path /path/to/LLaDA-2.0-Uni --image_token photo.pt
python mmu_understand.py --model_path /path/to/LLaDA-2.0-Uni --image photo.jpg --question "Describe this image."
"""
import os, sys, argparse, torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def parse_args():
p = argparse.ArgumentParser(description="LLaDA-2.0-Uni Image Understanding")
p.add_argument("--model_path", type=str, required=True,
help="Root model dir containing LLM weights and image_tokenizer/")
p.add_argument("--image", type=str, default=None, help="Path to input image (jpg/png)")
p.add_argument("--image_token", type=str, default=None, help="Path to pre-tokenized .pt file")
p.add_argument("--question", type=str, default="", help="Optional question/prefix")
p.add_argument("--steps", type=int, default=32)
p.add_argument("--block_length", type=int, default=32)
p.add_argument("--gen_length", type=int, default=2048)
return p.parse_args()
def _get_image_token_offset(model_path):
"""Read image_token_offset from model config."""
import json
with open(os.path.join(model_path, "config.json")) as f:
return json.load(f).get("image_token_offset", 157184)
def encode_image_from_pt(pt_path, offset):
data = torch.load(pt_path, map_location="cpu", weights_only=False)
token_ids = (data["semantic_token_ids"] + offset).tolist()
w, h = data["metadata"]["processed_size"]
return token_ids, h // 16, w // 16
def encode_image_from_pil(image_path, model_path, device, offset):
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from encoder.image_tokenizer import ImageTokenizer
from decoder.smart_img_process import smart_resize_images
image_tokenizer = ImageTokenizer(
model_path=model_path, device=device, dtype=torch.bfloat16,
)
pil_image = smart_resize_images([image_path])[0]
info = image_tokenizer.encode_with_info(pil_image)
_, h, w = info["grid_thw"]
token_ids = [x + offset for x in info["token_ids"]]
del image_tokenizer; torch.cuda.empty_cache()
return token_ids, h, w
def main():
args = parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
# Encode image
offset = _get_image_token_offset(args.model_path)
if args.image_token:
print(f"Loading pre-tokenized image: {args.image_token}")
image_tokens, image_h, image_w = encode_image_from_pt(args.image_token, offset)
elif args.image:
print(f"Encoding image: {args.image}")
image_tokens, image_h, image_w = encode_image_from_pil(args.image, args.model_path, device, offset)
else:
raise ValueError("Provide --image or --image_token")
print(f"Image grid: {image_h}x{image_w}, tokens: {len(image_tokens)}")
# Load model and use high-level API
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
args.model_path, device_map=device, trust_remote_code=True
).to(torch.bfloat16).eval()
model.tokenizer = tokenizer
print("Generating...")
response = model.understand_image(
image_tokens, image_h, image_w,
question=args.question, steps=args.steps,
block_length=args.block_length, gen_length=args.gen_length,
)
print(f"\n{'='*60}")
print(f"Question: {args.question or '(none)'}")
print(f"{'='*60}")
print(f"Response:\n{response}")
print(f"{'='*60}")
if __name__ == "__main__":
main()
"""
LLaDA-2.0-Uni — Text-to-Image Generation
Usage:
python t2i_generate.py --model_path /path/to/LLaDA-2.0-Uni --prompt "A cat on a table"
python t2i_generate.py --model_path /path/to/LLaDA-2.0-Uni --prompts_file prompts.txt
"""
import os, sys, gc, argparse, torch
from transformers import AutoModelForCausalLM, AutoTokenizer
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from decoder import decode_vq_tokens
def parse_args():
p = argparse.ArgumentParser(description="LLaDA-2.0-Uni Text-to-Image Generation")
p.add_argument("--model_path", type=str, required=True,
help="Root model dir containing LLM weights, image_tokenizer/, decoder/, vae/")
p.add_argument("--prompt", type=str, default=None)
p.add_argument("--prompts_file", type=str, default=None, help="One prompt per line")
p.add_argument("--steps", type=int, default=16)
p.add_argument("--cfg_scale", type=float, default=4.0)
p.add_argument("--image_h", type=int, default=512)
p.add_argument("--image_w", type=int, default=512)
p.add_argument("--decoder_steps", type=int, default=50)
p.add_argument("--resolution_multiplier", type=int, default=2)
p.add_argument("--output_dir", type=str, default="./t2i_output")
p.add_argument("--output", type=str, default=None, help="Output path for single prompt")
p.add_argument("--seed", type=int, default=42)
return p.parse_args()
def main():
args = parse_args()
torch.manual_seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
prompts = []
if args.prompt: prompts = [args.prompt]
elif args.prompts_file:
with open(args.prompts_file) as f: prompts = [l.strip() for l in f if l.strip()]
else: raise ValueError("--prompt or --prompts_file required")
os.makedirs(args.output_dir, exist_ok=True)
# Phase 1: generate VQ tokens
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(args.model_path, device_map={"": device}, trust_remote_code=True)
model = model.to(torch.bfloat16).eval()
model.tokenizer = tokenizer
results = []
for i, prompt in enumerate(prompts):
print(f"[{i+1}/{len(prompts)}] {prompt[:80]}")
res = model.generate_image(prompt, image_h=args.image_h, image_w=args.image_w,
steps=args.steps, cfg_scale=args.cfg_scale)
results.append({"prompt": prompt, **res})
del model; gc.collect(); torch.cuda.empty_cache()
print("Model unloaded.\n")
# Phase 2: decode to images
for i, res in enumerate(results):
if args.output and len(prompts) == 1:
out = args.output
else:
safe = res["prompt"][:40].replace(" ", "_").replace("/", "")
out = os.path.join(args.output_dir, f"{i:04d}_{safe}.png")
print(f"[{i+1}/{len(results)}] Decoding → {out}")
img = decode_vq_tokens(res["token_ids"], res["h"], res["w"], args.model_path, device,
resolution_multiplier=args.resolution_multiplier, num_steps=args.decoder_steps)
img.save(out)
print(f"\n🏁 Done! {len(results)} images generated.")
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment