Commit 2581b885 authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit

parents
Pipeline #3320 canceled with stages
# Openpose
# Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose
# 2nd Edited by https://github.com/Hzzone/pytorch-openpose
# 3rd Edited by ControlNet
# 4th Edited by ControlNet (added face and correct hands)
import os
import random
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
import torch
import numpy as np
from . import util
from .wholebody import Wholebody
def draw_pose(pose, H, W, draw_body=True):
bodies = pose["bodies"]
faces = pose["faces"]
hands = pose["hands"]
candidate = bodies["candidate"]
subset = bodies["subset"]
canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
if draw_body:
canvas = util.draw_bodypose(canvas, candidate, subset)
canvas = util.draw_handpose(canvas, hands)
canvas = util.draw_facepose(canvas, faces)
return canvas
def keypoint2bbox(keypoints):
valid_keypoints = keypoints[
keypoints[:, 0] >= 0
] # Ignore keypoints with confidence 0
if len(valid_keypoints) == 0:
return np.zeros(4)
x_min, y_min = np.min(valid_keypoints, axis=0)
x_max, y_max = np.max(valid_keypoints, axis=0)
return np.array([x_min, y_min, x_max, y_max])
def expand_bboxes(bboxes, expansion_rate=0.5, image_shape=(0, 0)):
expanded_bboxes = []
for bbox in bboxes:
x_min, y_min, x_max, y_max = map(int, bbox)
width = x_max - x_min
height = y_max - y_min
# 扩展宽度和高度
new_width = width * (1 + expansion_rate)
new_height = height * (1 + expansion_rate)
# 计算新的边界框坐标
x_min_new = max(0, x_min - (new_width - width) / 2)
x_max_new = min(image_shape[1], x_max + (new_width - width) / 2)
y_min_new = max(0, y_min - (new_height - height) / 2)
y_max_new = min(image_shape[0], y_max + (new_height - height) / 2)
expanded_bboxes.append([x_min_new, y_min_new, x_max_new, y_max_new])
return expanded_bboxes
def create_mask(image_width, image_height, bboxs):
mask = np.zeros((image_height, image_width), dtype=np.float32)
for bbox in bboxs:
x1, y1, x2, y2 = map(int, bbox)
mask[y1 : y2 + 1, x1 : x2 + 1] = 1.0
return mask
threshold = 0.4
class DWposeDetector:
def __init__(self):
self.pose_estimation = Wholebody()
def __call__(
self, oriImg, return_index=False, return_yolo=False, return_mask=False
):
oriImg = oriImg.copy()
H, W, C = oriImg.shape
with torch.no_grad():
candidate, subset = self.pose_estimation(oriImg)
candidate = (
np.zeros((1, 134, 2), dtype=np.float32)
if candidate is None
else candidate
)
subset = np.zeros((1, 134), dtype=np.float32) if subset is None else subset
nums, keys, locs = candidate.shape
candidate[..., 0] /= float(W)
candidate[..., 1] /= float(H)
# import pdb; pdb.set_trace()
if return_yolo:
candidate[subset < threshold] = -0.1
subset = np.expand_dims(subset >= threshold, axis=-1)
keypoint = np.concatenate([candidate, subset], axis=-1)
# return pose + hand
return np.concatenate([keypoint[:, :18], keypoint[:, 92:]], axis=1)
body = candidate[:, :18].copy()
body = body.reshape(nums * 18, locs)
score = subset[:, :18]
for i in range(len(score)):
for j in range(len(score[i])):
if score[i][j] > threshold:
score[i][j] = int(18 * i + j)
else:
score[i][j] = -1
un_visible = subset < threshold
candidate[un_visible] = -1
foot = candidate[:, 18:24]
faces = candidate[:, 24:92]
hands1 = candidate[:, 92:113]
hands2 = candidate[:, 113:]
hands = np.vstack([hands1, hands2])
# import pdb; pdb.set_trace()
hands_ = hands[hands.max(axis=(1, 2)) > 0]
if len(hands_) == 0:
bbox = np.array([0, 0, 0, 0]).astype(int)
else:
hand_random = random.choice(hands_)
bbox = (keypoint2bbox(hand_random) * H).astype(int) # [0, 1] -> [h, w]
bodies = dict(candidate=body, subset=score)
pose = dict(bodies=bodies, hands=hands, faces=faces)
if return_mask:
bbox = [(keypoint2bbox(hand) * H).astype(int) for hand in hands_]
# bbox = expand_bboxes(bbox, expansion_rate=0.5, image_shape=(H, W))
mask = create_mask(W, H, bbox)
return draw_pose(pose, H, W), mask
if return_index:
return pose
else:
return draw_pose(pose, H, W), bbox
import cv2
import numpy as np
import onnxruntime
def nms(boxes, scores, nms_thr):
"""Single class NMS implemented in Numpy."""
x1 = boxes[:, 0]
y1 = boxes[:, 1]
x2 = boxes[:, 2]
y2 = boxes[:, 3]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= nms_thr)[0]
order = order[inds + 1]
return keep
def multiclass_nms(boxes, scores, nms_thr, score_thr):
"""Multiclass NMS implemented in Numpy. Class-aware version."""
final_dets = []
num_classes = scores.shape[1]
for cls_ind in range(num_classes):
cls_scores = scores[:, cls_ind]
valid_score_mask = cls_scores > score_thr
if valid_score_mask.sum() == 0:
continue
else:
valid_scores = cls_scores[valid_score_mask]
valid_boxes = boxes[valid_score_mask]
keep = nms(valid_boxes, valid_scores, nms_thr)
if len(keep) > 0:
cls_inds = np.ones((len(keep), 1)) * cls_ind
dets = np.concatenate(
[valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
)
final_dets.append(dets)
if len(final_dets) == 0:
return None
return np.concatenate(final_dets, 0)
def demo_postprocess(outputs, img_size, p6=False):
grids = []
expanded_strides = []
strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
hsizes = [img_size[0] // stride for stride in strides]
wsizes = [img_size[1] // stride for stride in strides]
for hsize, wsize, stride in zip(hsizes, wsizes, strides):
xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
grids.append(grid)
shape = grid.shape[:2]
expanded_strides.append(np.full((*shape, 1), stride))
grids = np.concatenate(grids, 1)
expanded_strides = np.concatenate(expanded_strides, 1)
outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
return outputs
def preprocess(img, input_size, swap=(2, 0, 1)):
if len(img.shape) == 3:
padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
else:
padded_img = np.ones(input_size, dtype=np.uint8) * 114
r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
resized_img = cv2.resize(
img,
(int(img.shape[1] * r), int(img.shape[0] * r)),
interpolation=cv2.INTER_LINEAR,
).astype(np.uint8)
padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
padded_img = padded_img.transpose(swap)
padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
return padded_img, r
def inference_detector(session, oriImg):
input_shape = (640, 640)
img, ratio = preprocess(oriImg, input_shape)
ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
output = session.run(None, ort_inputs)
predictions = demo_postprocess(output[0], input_shape)[0]
boxes = predictions[:, :4]
scores = predictions[:, 4:5] * predictions[:, 5:]
boxes_xyxy = np.ones_like(boxes)
boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2] / 2.0
boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3] / 2.0
boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2] / 2.0
boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3] / 2.0
boxes_xyxy /= ratio
dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
if dets is not None:
final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
isscore = final_scores > 0.3
iscat = final_cls_inds == 0
isbbox = [i and j for (i, j) in zip(isscore, iscat)]
final_boxes = final_boxes[isbbox]
return final_boxes
else:
return None
from typing import List, Tuple
import cv2
import numpy as np
import onnxruntime as ort
def preprocess(
img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256)
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Do preprocessing for RTMPose model inference.
Args:
img (np.ndarray): Input image in shape.
input_size (tuple): Input image size in shape (w, h).
Returns:
tuple:
- resized_img (np.ndarray): Preprocessed image.
- center (np.ndarray): Center of image.
- scale (np.ndarray): Scale of image.
"""
# get shape of image
img_shape = img.shape[:2]
out_img, out_center, out_scale = [], [], []
if len(out_bbox) == 0:
out_bbox = [[0, 0, img_shape[1], img_shape[0]]]
for i in range(len(out_bbox)):
x0 = out_bbox[i][0]
y0 = out_bbox[i][1]
x1 = out_bbox[i][2]
y1 = out_bbox[i][3]
bbox = np.array([x0, y0, x1, y1])
# get center and scale
center, scale = bbox_xyxy2cs(bbox, padding=1.25)
# do affine transformation
resized_img, scale = top_down_affine(input_size, scale, center, img)
# normalize image
mean = np.array([123.675, 116.28, 103.53])
std = np.array([58.395, 57.12, 57.375])
resized_img = (resized_img - mean) / std
out_img.append(resized_img)
out_center.append(center)
out_scale.append(scale)
return out_img, out_center, out_scale
def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray:
"""Inference RTMPose model.
Args:
sess (ort.InferenceSession): ONNXRuntime session.
img (np.ndarray): Input image in shape.
Returns:
outputs (np.ndarray): Output of RTMPose model.
"""
all_out = []
# build input
for i in range(len(img)):
input = [img[i].transpose(2, 0, 1)]
# build output
sess_input = {sess.get_inputs()[0].name: input}
sess_output = []
for out in sess.get_outputs():
sess_output.append(out.name)
# run model
outputs = sess.run(sess_output, sess_input)
all_out.append(outputs)
return all_out
def postprocess(
outputs: List[np.ndarray],
model_input_size: Tuple[int, int],
center: Tuple[int, int],
scale: Tuple[int, int],
simcc_split_ratio: float = 2.0,
) -> Tuple[np.ndarray, np.ndarray]:
"""Postprocess for RTMPose model output.
Args:
outputs (np.ndarray): Output of RTMPose model.
model_input_size (tuple): RTMPose model Input image size.
center (tuple): Center of bbox in shape (x, y).
scale (tuple): Scale of bbox in shape (w, h).
simcc_split_ratio (float): Split ratio of simcc.
Returns:
tuple:
- keypoints (np.ndarray): Rescaled keypoints.
- scores (np.ndarray): Model predict scores.
"""
all_key = []
all_score = []
for i in range(len(outputs)):
# use simcc to decode
simcc_x, simcc_y = outputs[i]
keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio)
# rescale keypoints
keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2
all_key.append(keypoints[0])
all_score.append(scores[0])
return np.array(all_key), np.array(all_score)
def bbox_xyxy2cs(
bbox: np.ndarray, padding: float = 1.0
) -> Tuple[np.ndarray, np.ndarray]:
"""Transform the bbox format from (x,y,w,h) into (center, scale)
Args:
bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted
as (left, top, right, bottom)
padding (float): BBox padding factor that will be multilied to scale.
Default: 1.0
Returns:
tuple: A tuple containing center and scale.
- np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or
(n, 2)
- np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or
(n, 2)
"""
# convert single bbox from (4, ) to (1, 4)
dim = bbox.ndim
if dim == 1:
bbox = bbox[None, :]
# get bbox center and scale
x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3])
center = np.hstack([x1 + x2, y1 + y2]) * 0.5
scale = np.hstack([x2 - x1, y2 - y1]) * padding
if dim == 1:
center = center[0]
scale = scale[0]
return center, scale
def _fix_aspect_ratio(bbox_scale: np.ndarray, aspect_ratio: float) -> np.ndarray:
"""Extend the scale to match the given aspect ratio.
Args:
scale (np.ndarray): The image scale (w, h) in shape (2, )
aspect_ratio (float): The ratio of ``w/h``
Returns:
np.ndarray: The reshaped image scale in (2, )
"""
w, h = np.hsplit(bbox_scale, [1])
bbox_scale = np.where(
w > h * aspect_ratio,
np.hstack([w, w / aspect_ratio]),
np.hstack([h * aspect_ratio, h]),
)
return bbox_scale
def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray:
"""Rotate a point by an angle.
Args:
pt (np.ndarray): 2D point coordinates (x, y) in shape (2, )
angle_rad (float): rotation angle in radian
Returns:
np.ndarray: Rotated point in shape (2, )
"""
sn, cs = np.sin(angle_rad), np.cos(angle_rad)
rot_mat = np.array([[cs, -sn], [sn, cs]])
return rot_mat @ pt
def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray:
"""To calculate the affine matrix, three pairs of points are required. This
function is used to get the 3rd point, given 2D points a & b.
The 3rd point is defined by rotating vector `a - b` by 90 degrees
anticlockwise, using b as the rotation center.
Args:
a (np.ndarray): The 1st point (x,y) in shape (2, )
b (np.ndarray): The 2nd point (x,y) in shape (2, )
Returns:
np.ndarray: The 3rd point.
"""
direction = a - b
c = b + np.r_[-direction[1], direction[0]]
return c
def get_warp_matrix(
center: np.ndarray,
scale: np.ndarray,
rot: float,
output_size: Tuple[int, int],
shift: Tuple[float, float] = (0.0, 0.0),
inv: bool = False,
) -> np.ndarray:
"""Calculate the affine transformation matrix that can warp the bbox area
in the input image to the output size.
Args:
center (np.ndarray[2, ]): Center of the bounding box (x, y).
scale (np.ndarray[2, ]): Scale of the bounding box
wrt [width, height].
rot (float): Rotation angle (degree).
output_size (np.ndarray[2, ] | list(2,)): Size of the
destination heatmaps.
shift (0-100%): Shift translation ratio wrt the width/height.
Default (0., 0.).
inv (bool): Option to inverse the affine transform direction.
(inv=False: src->dst or inv=True: dst->src)
Returns:
np.ndarray: A 2x3 transformation matrix
"""
shift = np.array(shift)
src_w = scale[0]
dst_w = output_size[0]
dst_h = output_size[1]
# compute transformation matrix
rot_rad = np.deg2rad(rot)
src_dir = _rotate_point(np.array([0.0, src_w * -0.5]), rot_rad)
dst_dir = np.array([0.0, dst_w * -0.5])
# get four corners of the src rectangle in the original image
src = np.zeros((3, 2), dtype=np.float32)
src[0, :] = center + scale * shift
src[1, :] = center + src_dir + scale * shift
src[2, :] = _get_3rd_point(src[0, :], src[1, :])
# get four corners of the dst rectangle in the input image
dst = np.zeros((3, 2), dtype=np.float32)
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
if inv:
warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src))
else:
warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst))
return warp_mat
def top_down_affine(
input_size: dict, bbox_scale: dict, bbox_center: dict, img: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
"""Get the bbox image as the model input by affine transform.
Args:
input_size (dict): The input size of the model.
bbox_scale (dict): The bbox scale of the img.
bbox_center (dict): The bbox center of the img.
img (np.ndarray): The original image.
Returns:
tuple: A tuple containing center and scale.
- np.ndarray[float32]: img after affine transform.
- np.ndarray[float32]: bbox scale after affine transform.
"""
w, h = input_size
warp_size = (int(w), int(h))
# reshape bbox to fixed aspect ratio
bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h)
# get the affine matrix
center = bbox_center
scale = bbox_scale
rot = 0
warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h))
# do affine transform
img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR)
return img, bbox_scale
def get_simcc_maximum(
simcc_x: np.ndarray, simcc_y: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
"""Get maximum response location and value from simcc representations.
Note:
instance number: N
num_keypoints: K
heatmap height: H
heatmap width: W
Args:
simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy)
Returns:
tuple:
- locs (np.ndarray): locations of maximum heatmap responses in shape
(K, 2) or (N, K, 2)
- vals (np.ndarray): values of maximum heatmap responses in shape
(K,) or (N, K)
"""
N, K, Wx = simcc_x.shape
simcc_x = simcc_x.reshape(N * K, -1)
simcc_y = simcc_y.reshape(N * K, -1)
# get maximum value locations
x_locs = np.argmax(simcc_x, axis=1)
y_locs = np.argmax(simcc_y, axis=1)
locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
max_val_x = np.amax(simcc_x, axis=1)
max_val_y = np.amax(simcc_y, axis=1)
# get maximum value across x and y axis
mask = max_val_x > max_val_y
max_val_x[mask] = max_val_y[mask]
vals = max_val_x
locs[vals <= 0.0] = -1
# reshape
locs = locs.reshape(N, K, 2)
vals = vals.reshape(N, K)
return locs, vals
def decode(
simcc_x: np.ndarray, simcc_y: np.ndarray, simcc_split_ratio
) -> Tuple[np.ndarray, np.ndarray]:
"""Modulate simcc distribution with Gaussian.
Args:
simcc_x (np.ndarray[K, Wx]): model predicted simcc in x.
simcc_y (np.ndarray[K, Wy]): model predicted simcc in y.
simcc_split_ratio (int): The split ratio of simcc.
Returns:
tuple: A tuple containing center and scale.
- np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2)
- np.ndarray[float32]: scores in shape (K,) or (n, K)
"""
keypoints, scores = get_simcc_maximum(simcc_x, simcc_y)
keypoints /= simcc_split_ratio
return keypoints, scores
def inference_pose(session, out_bbox, oriImg):
h, w = session.get_inputs()[0].shape[2:]
model_input_size = (w, h)
resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size)
outputs = inference(session, resized_img)
keypoints, scores = postprocess(outputs, model_input_size, center, scale)
return keypoints, scores
import math
import numpy as np
import matplotlib
import cv2
eps = 0.01
def smart_resize(x, s):
Ht, Wt = s
if x.ndim == 2:
Ho, Wo = x.shape
Co = 1
else:
Ho, Wo, Co = x.shape
if Co == 3 or Co == 1:
k = float(Ht + Wt) / float(Ho + Wo)
return cv2.resize(
x,
(int(Wt), int(Ht)),
interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4,
)
else:
return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2)
def smart_resize_k(x, fx, fy):
if x.ndim == 2:
Ho, Wo = x.shape
Co = 1
else:
Ho, Wo, Co = x.shape
Ht, Wt = Ho * fy, Wo * fx
if Co == 3 or Co == 1:
k = float(Ht + Wt) / float(Ho + Wo)
return cv2.resize(
x,
(int(Wt), int(Ht)),
interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4,
)
else:
return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2)
def padRightDownCorner(img, stride, padValue):
h = img.shape[0]
w = img.shape[1]
pad = 4 * [None]
pad[0] = 0 # up
pad[1] = 0 # left
pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
img_padded = img
pad_up = np.tile(img_padded[0:1, :, :] * 0 + padValue, (pad[0], 1, 1))
img_padded = np.concatenate((pad_up, img_padded), axis=0)
pad_left = np.tile(img_padded[:, 0:1, :] * 0 + padValue, (1, pad[1], 1))
img_padded = np.concatenate((pad_left, img_padded), axis=1)
pad_down = np.tile(img_padded[-2:-1, :, :] * 0 + padValue, (pad[2], 1, 1))
img_padded = np.concatenate((img_padded, pad_down), axis=0)
pad_right = np.tile(img_padded[:, -2:-1, :] * 0 + padValue, (1, pad[3], 1))
img_padded = np.concatenate((img_padded, pad_right), axis=1)
return img_padded, pad
def transfer(model, model_weights):
transfered_model_weights = {}
for weights_name in model.state_dict().keys():
transfered_model_weights[weights_name] = model_weights[
".".join(weights_name.split(".")[1:])
]
return transfered_model_weights
def draw_bodypose(canvas, candidate, subset):
H, W, C = canvas.shape
candidate = np.array(candidate)
subset = np.array(subset)
stickwidth = 4
limbSeq = [
[2, 3],
[2, 6],
[3, 4],
[4, 5],
[6, 7],
[7, 8],
[2, 9],
[9, 10],
[10, 11],
[2, 12],
[12, 13],
[13, 14],
[2, 1],
[1, 15],
[15, 17],
[1, 16],
[16, 18],
[3, 17],
[6, 18],
]
colors = [
[255, 0, 0],
[255, 85, 0],
[255, 170, 0],
[255, 255, 0],
[170, 255, 0],
[85, 255, 0],
[0, 255, 0],
[0, 255, 85],
[0, 255, 170],
[0, 255, 255],
[0, 170, 255],
[0, 85, 255],
[0, 0, 255],
[85, 0, 255],
[170, 0, 255],
[255, 0, 255],
[255, 0, 170],
[255, 0, 85],
]
for i in range(17):
for n in range(len(subset)):
index = subset[n][np.array(limbSeq[i]) - 1]
if -1 in index:
continue
Y = candidate[index.astype(int), 0] * float(W)
X = candidate[index.astype(int), 1] * float(H)
mX = np.mean(X)
mY = np.mean(Y)
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
polygon = cv2.ellipse2Poly(
(int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1
)
# import pdb; pdb.set_trace()
cv2.fillConvexPoly(canvas, polygon, colors[i])
canvas = (canvas * 0.6).astype(np.uint8)
for i in range(18):
for n in range(len(subset)):
index = int(subset[n][i])
if index == -1:
continue
x, y = candidate[index][0:2]
x = int(x * W)
y = int(y * H)
cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
return canvas
def draw_handpose(canvas, all_hand_peaks):
H, W, C = canvas.shape
edges = [
[0, 1],
[1, 2],
[2, 3],
[3, 4],
[0, 5],
[5, 6],
[6, 7],
[7, 8],
[0, 9],
[9, 10],
[10, 11],
[11, 12],
[0, 13],
[13, 14],
[14, 15],
[15, 16],
[0, 17],
[17, 18],
[18, 19],
[19, 20],
]
for peaks in all_hand_peaks:
peaks = np.array(peaks)
for ie, e in enumerate(edges):
x1, y1 = peaks[e[0]]
x2, y2 = peaks[e[1]]
x1 = int(x1 * W)
y1 = int(y1 * H)
x2 = int(x2 * W)
y2 = int(y2 * H)
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
cv2.line(
canvas,
(x1, y1),
(x2, y2),
matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0])
* 255,
thickness=2,
)
for i, keyponit in enumerate(peaks):
x, y = keyponit
x = int(x * W)
y = int(y * H)
if x > eps and y > eps:
cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
return canvas
def draw_facepose(canvas, all_lmks):
H, W, C = canvas.shape
for lmks in all_lmks:
lmks = np.array(lmks)
for lmk in lmks:
x, y = lmk
x = int(x * W)
y = int(y * H)
if x > eps and y > eps:
cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1)
return canvas
# detect hand according to body pose keypoints
# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
def handDetect(candidate, subset, oriImg):
# right hand: wrist 4, elbow 3, shoulder 2
# left hand: wrist 7, elbow 6, shoulder 5
ratioWristElbow = 0.33
detect_result = []
image_height, image_width = oriImg.shape[0:2]
for person in subset.astype(int):
# if any of three not detected
has_left = np.sum(person[[5, 6, 7]] == -1) == 0
has_right = np.sum(person[[2, 3, 4]] == -1) == 0
if not (has_left or has_right):
continue
hands = []
# left hand
if has_left:
left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]]
x1, y1 = candidate[left_shoulder_index][:2]
x2, y2 = candidate[left_elbow_index][:2]
x3, y3 = candidate[left_wrist_index][:2]
hands.append([x1, y1, x2, y2, x3, y3, True])
# right hand
if has_right:
right_shoulder_index, right_elbow_index, right_wrist_index = person[
[2, 3, 4]
]
x1, y1 = candidate[right_shoulder_index][:2]
x2, y2 = candidate[right_elbow_index][:2]
x3, y3 = candidate[right_wrist_index][:2]
hands.append([x1, y1, x2, y2, x3, y3, False])
for x1, y1, x2, y2, x3, y3, is_left in hands:
# pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
# handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
# handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
# const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
# const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
# handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
x = x3 + ratioWristElbow * (x3 - x2)
y = y3 + ratioWristElbow * (y3 - y2)
distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
# x-y refers to the center --> offset to topLeft point
# handRectangle.x -= handRectangle.width / 2.f;
# handRectangle.y -= handRectangle.height / 2.f;
x -= width / 2
y -= width / 2 # width = height
# overflow the image
if x < 0:
x = 0
if y < 0:
y = 0
width1 = width
width2 = width
if x + width > image_width:
width1 = image_width - x
if y + width > image_height:
width2 = image_height - y
width = min(width1, width2)
# the max hand box value is 20 pixels
if width >= 20:
detect_result.append([int(x), int(y), int(width), is_left])
"""
return value: [[x, y, w, True if left hand else False]].
width=height since the network require squared input.
x, y is the coordinate of top left
"""
return detect_result
# Written by Lvmin
def faceDetect(candidate, subset, oriImg):
# left right eye ear 14 15 16 17
detect_result = []
image_height, image_width = oriImg.shape[0:2]
for person in subset.astype(int):
has_head = person[0] > -1
if not has_head:
continue
has_left_eye = person[14] > -1
has_right_eye = person[15] > -1
has_left_ear = person[16] > -1
has_right_ear = person[17] > -1
if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear):
continue
head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]]
width = 0.0
x0, y0 = candidate[head][:2]
if has_left_eye:
x1, y1 = candidate[left_eye][:2]
d = max(abs(x0 - x1), abs(y0 - y1))
width = max(width, d * 3.0)
if has_right_eye:
x1, y1 = candidate[right_eye][:2]
d = max(abs(x0 - x1), abs(y0 - y1))
width = max(width, d * 3.0)
if has_left_ear:
x1, y1 = candidate[left_ear][:2]
d = max(abs(x0 - x1), abs(y0 - y1))
width = max(width, d * 1.5)
if has_right_ear:
x1, y1 = candidate[right_ear][:2]
d = max(abs(x0 - x1), abs(y0 - y1))
width = max(width, d * 1.5)
x, y = x0, y0
x -= width
y -= width
if x < 0:
x = 0
if y < 0:
y = 0
width1 = width * 2
width2 = width * 2
if x + width > image_width:
width1 = image_width - x
if y + width > image_height:
width2 = image_height - y
width = min(width1, width2)
if width >= 20:
detect_result.append([int(x), int(y), int(width)])
return detect_result
# get max index of 2d array
def npmax(array):
arrayindex = array.argmax(1)
arrayvalue = array.max(1)
i = arrayvalue.argmax()
j = arrayindex[i]
return i, j
import os
import cv2
import numpy as np
import onnxruntime as ort
from .onnxdet import inference_detector
from .onnxpose import inference_pose
class Wholebody:
def __init__(self):
rank = int(os.getenv("LOCAL_RANK", "0"))
device = f"cuda:{rank}"
providers = (
["CPUExecutionProvider"]
if device == "cpu"
else [("CUDAExecutionProvider", {"device_id": rank})]
)
onnx_det = "hydit/annotator/ckpts/yolox_l.onnx"
onnx_pose = "hydit/annotator/ckpts/dw-ll_ucoco_384.onnx"
self.session_det = ort.InferenceSession(
path_or_bytes=onnx_det, providers=providers
)
self.session_pose = ort.InferenceSession(
path_or_bytes=onnx_pose, providers=providers
)
def __call__(self, oriImg):
det_result = inference_detector(self.session_det, oriImg)
if det_result is None:
return None, None
keypoints, scores = inference_pose(self.session_pose, det_result, oriImg)
keypoints_info = np.concatenate((keypoints, scores[..., None]), axis=-1)
# compute neck joint
neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
# neck score when visualizing pred
neck[:, 2:4] = np.logical_and(
keypoints_info[:, 5, 2:4] > 0.3, keypoints_info[:, 6, 2:4] > 0.3
).astype(int)
new_keypoints_info = np.insert(keypoints_info, 17, neck, axis=1)
mmpose_idx = [17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3]
openpose_idx = [1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17]
new_keypoints_info[:, openpose_idx] = new_keypoints_info[:, mmpose_idx]
keypoints_info = new_keypoints_info
keypoints, scores = keypoints_info[..., :2], keypoints_info[..., 2]
return keypoints, scores
def to(self, device):
self.session_det.set_providers([device])
self.session_pose.set_providers([device])
return self
# MIT License
# Copyright (c) 2023 AIGText
# https://github.com/AIGText/GlyphControl-release
from PIL import Image, ImageFont, ImageDraw
import random
import numpy as np
import cv2
# resize height to image_height first, then shrink or pad to image_width
def resize_and_pad_image(pil_image, image_size):
if isinstance(image_size, (tuple, list)) and len(image_size) == 2:
image_width, image_height = image_size
elif isinstance(image_size, int):
image_width = image_height = image_size
else:
raise ValueError(
f"Image size should be int or list/tuple of int not {image_size}"
)
while pil_image.size[1] >= 2 * image_height:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
)
scale = image_height / pil_image.size[1]
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)
# shrink
if pil_image.size[0] > image_width:
pil_image = pil_image.resize(
(image_width, image_height), resample=Image.BICUBIC
)
# padding
if pil_image.size[0] < image_width:
img = Image.new(
mode="RGBA", size=(image_width, image_height), color=(255, 255, 255, 0)
)
width, _ = pil_image.size
img.paste(pil_image, ((image_width - width) // 2, 0))
pil_image = img
return pil_image
def resize_and_pad_image2(pil_image, image_size):
if isinstance(image_size, (tuple, list)) and len(image_size) == 2:
image_width, image_height = image_size
elif isinstance(image_size, int):
image_width = image_height = image_size
else:
raise ValueError(
f"Image size should be int or list/tuple of int not {image_size}"
)
while pil_image.size[1] >= 2 * image_height:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
)
scale = image_height / pil_image.size[1]
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)
# shrink
if pil_image.size[0] > image_width:
pil_image = pil_image.resize(
(image_width, image_height), resample=Image.BICUBIC
)
# padding
if pil_image.size[0] < image_width:
img = Image.new(mode="RGB", size=(image_width, image_height), color="white")
width, _ = pil_image.size
img.paste(pil_image, ((image_width - width) // 2, 0))
pil_image = img
return pil_image
def draw_visual_text(
image_size, bboxes, rendered_txt_values, num_rows_values=None, align="center"
):
# aligns = ["center", "left", "right"]
"""Render text image based on the glyph instructions, i.e., the list of tuples (text, bbox, num_rows).
Currently we just use Calibri font to render glyph images.
"""
# print(image_size, bboxes, rendered_txt_values, num_rows_values, align)
background = Image.new("RGB", image_size, "white")
font = ImageFont.truetype("simfang.ttf", encoding="utf-8", size=512)
if num_rows_values is None:
num_rows_values = [1] * len(rendered_txt_values)
text_list = []
for text, bbox, num_rows in zip(rendered_txt_values, bboxes, num_rows_values):
if len(text) == 0:
continue
text = text.strip()
if num_rows != 1:
word_tokens = text.split()
num_tokens = len(word_tokens)
index_list = range(1, num_tokens + 1)
if num_tokens > num_rows:
index_list = random.sample(index_list, num_rows)
index_list.sort()
line_list = []
start_idx = 0
for index in index_list:
line_list.append(" ".join(word_tokens[start_idx:index]))
start_idx = index
text = "\n".join(line_list)
if "ratio" not in bbox or bbox["ratio"] == 0 or bbox["ratio"] < 1e-4:
image4ratio = Image.new("RGB", (512, 512), "white")
draw = ImageDraw.Draw(image4ratio)
_, _, w, h = draw.textbbox(xy=(0, 0), text=text, font=font)
ratio = w / h
else:
ratio = bbox["ratio"]
width = int(bbox["width"] * image_size[1])
height = int(width / ratio)
top_left_x = int(bbox["top_left_x"] * image_size[0])
top_left_y = int(bbox["top_left_y"] * image_size[1])
yaw = bbox["yaw"]
text_image = Image.new("RGB", (512, 512), "white")
draw = ImageDraw.Draw(text_image)
x, y, w, h = draw.textbbox(xy=(0, 0), text=text, font=font)
text_image = Image.new("RGBA", (w, h), (255, 255, 255, 0))
draw = ImageDraw.Draw(text_image)
draw.text((-x / 2, -y / 2), text, (0, 0, 0, 255), font=font, align=align)
text_image_ = resize_and_pad_image2(text_image.convert("RGB"), (288, 48))
# import pdb; pdb.set_trace()
text_list.append(np.array(text_image_))
text_image = resize_and_pad_image(text_image, (width, height))
text_image = text_image.rotate(
angle=-yaw, expand=True, fillcolor=(255, 255, 255, 0)
)
# image = Image.new("RGB", (w, h), "white")
# draw = ImageDraw.Draw(image)
background.paste(text_image, (top_left_x, top_left_y), mask=text_image)
return background, text_list
# [{'width': 0.1601562201976776, 'ratio': 81.99999451637203, 'yaw': 0.0, 'top_left_x': 0.712890625, 'top_left_y': 0.0},
# {'width': 0.134765625, 'ratio': 34.5, 'yaw': 0.0, 'top_left_x': 0.4453125, 'top_left_y': 0.0},
def insert_spaces(string, nSpace):
if nSpace == 0:
return string
new_string = ""
for char in string:
new_string += char + " " * nSpace
return new_string[:-nSpace]
def draw_glyph(text, font="simfang.ttf"):
if isinstance(font, str):
font = ImageFont.truetype(font, encoding="utf-8", size=512)
g_size = 50
W, H = (512, 80)
new_font = font.font_variant(size=g_size)
img = Image.new(mode="1", size=(W, H), color=0)
draw = ImageDraw.Draw(img)
left, top, right, bottom = new_font.getbbox(text)
text_width = max(right - left, 5)
text_height = max(bottom - top, 5)
ratio = min(W * 0.9 / text_width, H * 0.9 / text_height)
new_font = font.font_variant(size=int(g_size * ratio))
text_width, text_height = new_font.getsize(text)
offset_x, offset_y = new_font.getoffset(text)
x = (img.width - text_width) // 2
y = (img.height - text_height) // 2 - offset_y // 2
draw.text((x, y), text, font=new_font, fill="white")
img = np.expand_dims(np.array(img), axis=2).astype(np.float64)
return img
def draw_glyph2(
text,
polygon,
font="simfang.ttf",
vertAng=10,
scale=1,
width=1024,
height=1024,
add_space=True,
):
if isinstance(font, str):
font = ImageFont.truetype(font, encoding="utf-8", size=60)
enlarge_polygon = polygon * scale
rect = cv2.minAreaRect(enlarge_polygon)
box = cv2.boxPoints(rect)
box = np.int0(box)
w, h = rect[1]
angle = rect[2]
if angle < -45:
angle += 90
angle = -angle
if w < h:
angle += 90
vert = False
if abs(angle) % 90 < vertAng or abs(90 - abs(angle) % 90) % 90 < vertAng:
_w = max(box[:, 0]) - min(box[:, 0])
_h = max(box[:, 1]) - min(box[:, 1])
if _h >= _w:
vert = True
angle = 0
img = np.zeros((height * scale, width * scale, 3), np.uint8)
img = Image.fromarray(img)
# infer font size
image4ratio = Image.new("RGB", img.size, "white")
draw = ImageDraw.Draw(image4ratio)
_, _, _tw, _th = draw.textbbox(xy=(0, 0), text=text, font=font)
text_w = min(w, h) * (_tw / _th)
if text_w <= max(w, h):
# add space
if len(text) > 1 and not vert and add_space:
for i in range(1, 100):
text_space = insert_spaces(text, i)
_, _, _tw2, _th2 = draw.textbbox(xy=(0, 0), text=text_space, font=font)
if min(w, h) * (_tw2 / _th2) > max(w, h):
break
text = insert_spaces(text, i - 1)
font_size = min(w, h) * 0.80
else:
# shrink = 0.75 if vert else 0.85
shrink = 1.0
font_size = min(w, h) / (text_w / max(w, h)) * shrink
new_font = font.font_variant(size=int(font_size))
left, top, right, bottom = new_font.getbbox(text)
text_width = right - left
text_height = bottom - top
layer = Image.new("RGBA", img.size, (0, 0, 0, 0))
draw = ImageDraw.Draw(layer)
if not vert:
draw.text(
(rect[0][0] - text_width // 2, rect[0][1] - text_height // 2 - top),
text,
font=new_font,
fill=(255, 255, 255, 255),
)
else:
x_s = min(box[:, 0]) + _w // 2 - text_height // 2
y_s = min(box[:, 1])
for c in text:
draw.text((x_s, y_s), c, font=new_font, fill=(255, 255, 255, 255))
_, _t, _, _b = new_font.getbbox(c)
y_s += _b
rotated_layer = layer.rotate(angle, expand=1, center=(rect[0][0], rect[0][1]))
x_offset = int((img.width - rotated_layer.width) / 2)
y_offset = int((img.height - rotated_layer.height) / 2)
img.paste(rotated_layer, (x_offset, y_offset), rotated_layer)
img = np.expand_dims(np.array(img.convert("1")), axis=2).astype(np.float64)
return img
import random
import numpy as np
import cv2
import os
annotator_ckpts_path = os.path.join(os.path.dirname(__file__), "ckpts")
def HWC3(x):
assert x.dtype == np.uint8
if x.ndim == 2:
x = x[:, :, None]
assert x.ndim == 3
H, W, C = x.shape
assert C == 1 or C == 3 or C == 4
if C == 3:
return x
if C == 1:
return np.concatenate([x, x, x], axis=2)
if C == 4:
color = x[:, :, 0:3].astype(np.float32)
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
y = color * alpha + 255.0 * (1.0 - alpha)
y = y.clip(0, 255).astype(np.uint8)
return y
def resize_image(input_image, resolution):
H, W, C = input_image.shape
H = float(H)
W = float(W)
k = float(resolution) / min(H, W)
H *= k
W *= k
H = int(np.round(H / 64.0)) * 64
W = int(np.round(W / 64.0)) * 64
img = cv2.resize(
input_image,
(W, H),
interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA,
)
return img
def nms(x, t, s):
x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
y = np.zeros_like(x)
for f in [f1, f2, f3, f4]:
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
z = np.zeros_like(y, dtype=np.uint8)
z[y > t] = 255
return z
def make_noise_disk(H, W, C, F):
noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C))
noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC)
noise = noise[F : F + H, F : F + W]
noise -= np.min(noise)
noise /= np.max(noise)
if C == 1:
noise = noise[:, :, None]
return noise
def min_max_norm(x):
x -= np.min(x)
x /= np.maximum(np.max(x), 1e-5)
return x
def safe_step(x, step=2):
y = x.astype(np.float32) * float(step + 1)
y = y.astype(np.int32).astype(np.float32) / float(step)
return y
def img2mask(img, H, W, low=10, high=90):
assert img.ndim == 3 or img.ndim == 2
assert img.dtype == np.uint8
if img.ndim == 3:
y = img[:, :, random.randrange(0, img.shape[2])]
else:
y = img
y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC)
if random.uniform(0, 1) < 0.5:
y = 255 - y
return y < np.percentile(y, random.randrange(low, high))
import argparse
import deepspeed
from .constants import *
from .diffusion.gaussian_diffusion import ModelVarType
from .modules.models import HUNYUAN_DIT_CONFIG
def model_var_type(value):
try:
return ModelVarType[value]
except KeyError:
valid_choices = [v.name for v in ModelVarType]
raise ValueError(f"Invalid choice '{value}', valid choices are {valid_choices}")
def get_args(default_args=None):
parser = argparse.ArgumentParser()
parser.add_argument("--task-flag", type=str)
# General Setting
parser.add_argument("--batch-size", type=int, default=1, help="Per-GPU batch size")
parser.add_argument("--seed", type=int, default=42, help="A seed for all the prompts.")
parser.add_argument("--use-fp16", action="store_true", help="Use FP16 precision.")
parser.add_argument("--no-fp16", dest="use_fp16", action="store_false")
parser.set_defaults(use_fp16=True)
parser.add_argument("--extra-fp16", action="store_true", help="Use extra fp16 for vae and text_encoder.")
# HunYuan-DiT
parser.add_argument("--model", type=str, choices=list(HUNYUAN_DIT_CONFIG.keys()), default="DiT-g/2")
parser.add_argument("--image-size", type=int, nargs="+", default=[1024, 1024], help="Image size (h, w). If a single value is provided, the image will be treated to (value, value).")
parser.add_argument("--qk-norm", action="store_true", help="Query Key normalization. See http://arxiv.org/abs/2302.05442 for details.")
parser.set_defaults(qk_norm=True)
parser.add_argument("--norm", type=str, choices=["rms", "layer"], default="layer", help="Normalization layer type")
parser.add_argument("--text-states-dim", type=int, default=1024, help="Hidden size of CLIP text encoder.")
parser.add_argument("--text-len", type=int, default=77, help="Token length of CLIP text encoder output.")
parser.add_argument("--text-states-dim-t5", type=int, default=2048, help="Hidden size of T5 text encoder.")
parser.add_argument("--text-len-t5", type=int, default=256, help="Token length of T5 text encoder output.")
# LoRA config
parser.add_argument("--training-parts", type=str, default="all", choices=["all", "lora", "ipadapter"], help="Training parts")
parser.add_argument("--rank", type=int, default=64, help="Rank of LoRA")
parser.add_argument("--lora-ckpt", type=str, default=None, help="LoRA checkpoint")
parser.add_argument("--target-modules", type=str, nargs="+", default=["Wqkv", "q_proj", "kv_proj", "out_proj"], help="Target modules for LoRA fine tune")
parser.add_argument("--output-merge-path", type=str, default=None, help="Output path for merged model")
# controlnet config
parser.add_argument("--control-type", type=str, default="canny", choices=["canny", "depth", "pose"], help="Controlnet condition type")
parser.add_argument("--control-weight", type=str, default="1.0", help="Controlnet weight, You can use a float to specify the weight for all layers, or use a list to separately specify the weight for each layer, for example, '[1.0 * (0.825 ** float(19 - i)) for i in range(19)]'")
parser.add_argument("--condition-image-path", type=str, default=None, help="Inference condition image path")
# IP-Adapter config
parser.add_argument("--is-ipa", type=bool, default=False, help="inference with IP-Adapter or not")
parser.add_argument("--resume-ipa", type=bool, default=False, help="train with resume IP-Adapter model or not")
parser.add_argument("--resume-ipa-root", type=str, default=None, help="ipa model path")
parser.add_argument("--ref-image-path", type=str, default=None, help="Inference ref image path")
parser.add_argument("--i-scale", type=float, default=1.0, help="IP-Adapter weight")
# Diffusion
parser.add_argument("--learn-sigma", action="store_true", help="Learn extra channels for sigma.")
parser.add_argument("--no-learn-sigma", dest="learn_sigma", action="store_false")
parser.set_defaults(learn_sigma=True)
parser.add_argument("--predict-type", type=str, choices=list(PREDICT_TYPE), default="v_prediction", help="Diffusion predict type")
parser.add_argument("--noise-schedule", type=str, choices=list(NOISE_SCHEDULES), default="scaled_linear", help="Noise schedule")
parser.add_argument("--beta-start", type=float, default=0.00085, help="Beta start value")
parser.add_argument("--beta-end", type=float, default=0.02, help="Beta end value")
parser.add_argument("--sigma-small", action="store_true")
parser.add_argument("--mse-loss-weight-type", type=str, default="constant", help="Min-SNR-gamma. Can be constant or min_snr_<gamma> where gamma is a integer. 5 is recommended in the paper.")
parser.add_argument("--model-var-type", type=model_var_type, default=None, help="Specify the model variable type.")
parser.add_argument("--noise-offset", type=float, default=0.0, help="Add extra noise to the input image.")
# ========================================================================================================
# Inference
# ========================================================================================================
# Basic Setting
parser.add_argument("--prompt", type=str, default="一只小猫", help="The prompt for generating images.")
parser.add_argument("--model-root", type=str, default="ckpts", help="Root path of all the models, including t2i model and dialoggen model.")
parser.add_argument("--dit-weight", type=str, default=None, help="Path to the HunYuan-DiT model. If None, search the model in the args.model_root.")
parser.add_argument("--controlnet-weight", type=str, default=None, help="Path to the HunYuan-DiT controlnet model. If None, search the model in the args.model_root.")
# Model setting
parser.add_argument("--load-key", type=str, choices=["ema", "module", "distill", "merge"], default="ema", help="Load model key for HunYuanDiT checkpoint.")
parser.add_argument("--use-style-cond", action="store_true", help="Use style condition in hydit. Only for hydit version <= 1.1")
parser.add_argument("--size-cond", type=int, nargs="+", default=None, help="Size condition used in sampling. 2 values are required for height and width. If a single value is provided, the image will be treated to (value, value). Recommended values are [1024, 1024]. Only for hydit version <= 1.1")
parser.add_argument("--target-ratios", type=str, nargs="+", default=None, help="Target ratios for multi-resolution training.")
parser.add_argument("--cfg-scale", type=float, default=6.0, help="Guidance scale for classifier-free.")
parser.add_argument("--negative", type=str, default=None, help="Negative prompt.")
# Acceleration
parser.add_argument("--infer-mode", type=str, choices=["fa", "torch", "trt"], default="fa", help="Inference mode")
parser.add_argument("--onnx-workdir", type=str, default="onnx_model", help="Path to save ONNX model")
# Sampling
parser.add_argument("--sampler", type=str, choices=SAMPLER_FACTORY, default="ddpm", help="Diffusion sampler")
parser.add_argument("--infer-steps", type=int, default=100, help="Inference steps")
# Prompt enhancement
parser.add_argument("--enhance", action="store_true", help="Enhance prompt with mllm.")
parser.add_argument("--no-enhance", dest="enhance", action="store_false")
parser.add_argument("--load-4bit", help="load DialogGen model with 4bit quantization.", action="store_true")
parser.set_defaults(enhance=True)
# App
parser.add_argument("--lang", type=str, default="zh", choices=["zh", "en"], help="Language")
# ========================================================================================================
# Training
# ========================================================================================================
# Basic Setting
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--max-training-steps", type=int, default=10_000_000)
parser.add_argument("--gc-interval", type=int, default=40, help="To address the memory bottleneck encountered during the preprocessing of the dataset, memory fragments are reclaimed here by invoking the gc.collect() function.")
parser.add_argument("--log-every", type=int, default=100)
parser.add_argument("--ckpt-every", type=int, default=100_000, help="Create a ckpt every a few steps.")
parser.add_argument("--ckpt-latest-every", type=int, default=10_000, help="Create a ckpt named `latest.pt` every a few steps.")
parser.add_argument("--ckpt-every-n-epoch", type=int, default=0, help="Create a ckpt every a few epochs. If 0, do not create ckpt based on epoch. Default is 0.")
parser.add_argument("--num-workers", type=int, default=4)
parser.add_argument("--global-seed", type=int, default=1234)
parser.add_argument("--warmup-min-lr", type=float, default=1e-6)
parser.add_argument("--warmup-num-steps", type=float, default=0)
parser.add_argument("--weight-decay", type=float, default=0, help="weight-decay in optimizer")
parser.add_argument("--rope-img", type=str, default=None, choices=["extend", "base512", "base1024"], help="Extend or interpolate the positional embedding of the image.")
parser.add_argument("--rope-real", action="store_true", help="Use real part and imaginary part separately for RoPE.")
# Classifier-free
parser.add_argument("--uncond-p", type=float, default=0.2, help="The probability of dropping training text used for CLIP feature extraction")
parser.add_argument("--uncond-p-t5", type=float, default=0.2, help="The probability of dropping training text used for mT5 feature extraction")
parser.add_argument("--uncond-p-img", type=float, default=0.05, help="The probability of dropping training text used for mT5 feature extraction")
# Directory
parser.add_argument("--results-dir", type=str, default="results")
parser.add_argument("--resume", action="store_true")
parser.add_argument("--resume-module-root", type=str, default=None, help="Resume model states.")
parser.add_argument("--resume-ema-root", type=str, default=None, help="Resume ema states.")
parser.add_argument("--no-strict", dest="strict", action="store_false", help="Strict loading of checkpoint")
parser.set_defaults(strict=True)
# Dataset
parser.add_argument("--index-file", type=str, nargs="+", help="During training, provide a JSON file with data indices.")
parser.add_argument("--random-flip", action="store_true", help="Random flip image")
parser.add_argument("--reset-loader", action="store_true", help="Reset the data loader. It is useful when resuming from a checkpoint but switch to a new dataset.")
parser.add_argument("--multireso", action="store_true", help="Use multi-resolution training.")
parser.add_argument("--reso-step", type=int, default=None, help="Step size for multi-resolution training.")
# Additional condition
parser.add_argument("--random-shrink-size-cond", action="store_true", help="Randomly shrink the original size condition.")
parser.add_argument("--merge-src-cond", action="store_true", help="Merge the source condition into a single value.")
# EMA Model
parser.add_argument("--use-ema", action="store_true", help="Use EMA model")
parser.add_argument("--ema-dtype", type=str, choices=["fp16", "fp32", "none"], default="none", help="EMA data type. If none, use the same data type as the model.")
parser.add_argument("--ema-decay", type=float, default=None, help="EMA decay rate. If None, use the default value of the model.")
parser.add_argument("--ema-warmup", action="store_true", help="EMA warmup. If True, perform ema_decay warmup from 0 to ema_decay.")
parser.add_argument("--ema-warmup-power", type=float, default=None, help="EMA power. If None, use the default value of the model.")
parser.add_argument("--ema-reset-decay", action="store_true", help="Reset EMA decay to 0 and restart increasing the EMA decay. Only works when --ema-warmup is enabled.")
# Acceleration
parser.add_argument("--use-flash-attn", action="store_true", help="During training, flash attention is used to accelerate training.")
parser.add_argument("--no-flash-attn", dest="use_flash_attn", action="store_false", help="During training, flash attention is not used to accelerate training.")
parser.add_argument("--use-zero-stage", type=int, default=1, help="Use AngelPTM zero stage. Support 2 and 3")
parser.add_argument("--grad-accu-steps", type=int, default=1, help="Gradient accumulation steps.")
parser.add_argument("--gradient-checkpointing", action="store_true", help="Use gradient checkpointing.")
parser.add_argument("--gc-rate", default=1.0, type=float, help="set the rate of blocks with gradient checkpointing.")
parser.add_argument("--cpu-offloading", action="store_true", help="Use cpu offloading for parameters and optimizer states.")
parser.add_argument("--save-optimizer-state", action="store_true", help="Save optimizer state in the checkpoint.")
# ========================================================================================================
# Deepspeed config
# ========================================================================================================
parser = deepspeed.add_config_arguments(parser)
parser.add_argument("--local_rank", type=int, default=None, help="local rank passed from distributed launcher.")
parser.add_argument("--deepspeed-optimizer", action="store_true", help="Switching to the optimizers in DeepSpeed")
parser.add_argument("--remote-device", type=str, default="none", choices=["none", "cpu", "nvme"], help="Remote device for ZeRO-3 initialized parameters.")
parser.add_argument("--zero-stage", type=int, default=1)
# ========================================================================================================
# Gradio App config
# ========================================================================================================
parser.add_argument("--server_name", type=str, default="0.0.0.0")
parser.add_argument("--server_port", type=int, default=443)
parser.add_argument("--gradio_share", type=bool, default=False)
parser.add_argument("--model_path", type=str)
parser.add_argument("--output_img_path", type=str, default="app/output")
args = parser.parse_args(default_args)
return args
import torch
# =======================================================
NOISE_SCHEDULES = {
"linear",
"scaled_linear",
"squaredcos_cap_v2",
}
PREDICT_TYPE = {
"epsilon",
"sample",
"v_prediction",
}
# =======================================================
NEGATIVE_PROMPT = "错误的眼睛,糟糕的人脸,毁容,糟糕的艺术,变形,多余的肢体,模糊的颜色,模糊,重复,病态,残缺,"
# =======================================================
TRT_MAX_BATCH_SIZE = 1
TRT_MAX_WIDTH = 1280
TRT_MAX_HEIGHT = 1280
# =======================================================
# Constants about models
# =======================================================
VAE_EMA_PATH = "ckpts/t2i/sdxl-vae-fp16-fix"
TOKENIZER = "ckpts/t2i/tokenizer"
TEXT_ENCODER = "ckpts/t2i/clip_text_encoder"
T5_ENCODER = {
"MT5": "ckpts/t2i/mt5",
"attention_mask": True,
"layer_index": -1,
"attention_pool": True,
"torch_dtype": torch.float16,
"learnable_replace": True,
}
SAMPLER_FACTORY = {
"ddpm": {
"scheduler": "DDPMScheduler",
"name": "DDPM",
"kwargs": {
"steps_offset": 1,
"clip_sample": False,
"clip_sample_range": 1.0,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"beta_end": 0.02,
"prediction_type": "v_prediction",
},
},
"ddim": {
"scheduler": "DDIMScheduler",
"name": "DDIM",
"kwargs": {
"steps_offset": 1,
"clip_sample": False,
"clip_sample_range": 1.0,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"beta_end": 0.02,
"prediction_type": "v_prediction",
},
},
"dpmms": {
"scheduler": "DPMSolverMultistepScheduler",
"name": "DPMMS",
"kwargs": {
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"beta_end": 0.02,
"prediction_type": "v_prediction",
"trained_betas": None,
"solver_order": 2,
"algorithm_type": "dpmsolver++",
},
},
}
import pickle
import random
from pathlib import Path
import ast
import numpy as np
import re
import json
import time
from functools import partial
from PIL import Image
import torch
import torchvision.transforms as T
import torch.nn.functional as F
from torchvision.transforms import functional as TF
from torch.utils.data import Dataset
from IndexKits.index_kits import (
ArrowIndexV2,
MultiResolutionBucketIndexV2,
MultiIndexV2,
)
class TextImageArrowStream(Dataset):
def __init__(
self,
args,
resolution=512,
random_flip=None,
enable_CN=True,
log_fn=print,
index_file=None,
multireso=False,
batch_size=-1,
world_size=1,
random_shrink_size_cond=False,
merge_src_cond=False,
uncond_p=0.0,
uncond_p_img=0.0,
text_ctx_len=77,
tokenizer=None,
uncond_p_t5=0.0,
text_ctx_len_t5=256,
tokenizer_t5=None,
):
self.args = args
self.resolution = resolution
self.log_fn = lambda x: log_fn(f" {Path(__file__).stem} | " + x)
self.random_flip = random_flip
# If true, the Chinese prompt from the `text_zh` column will be taken from the arrow file;
# otherwise, the English prompt from the `text_en` column will be taken,
# provided that `text_zh` or `text_en` exists in the arrow file.
self.enable_CN = enable_CN
self.index_file = index_file
self.multireso = multireso
self.batch_size = batch_size
self.world_size = world_size
self.index_manager = self.load_index()
# clip params
self.uncond_p = uncond_p
self.text_ctx_len = text_ctx_len
self.tokenizer = tokenizer
self.uncond_p_img = uncond_p_img
# t5 params
self.uncond_p_t5 = uncond_p_t5
self.text_ctx_len_t5 = text_ctx_len_t5
self.tokenizer_t5 = tokenizer_t5
# size condition
self.random_shrink_size_cond = random_shrink_size_cond
self.merge_src_cond = merge_src_cond
self.is_ipa = args.is_ipa
assert isinstance(
resolution, int
), f"resolution must be an integer, got {resolution}"
self.flip_norm = T.Compose(
[
T.RandomHorizontalFlip() if self.random_flip else T.Lambda(lambda x: x),
T.ToTensor(),
T.Normalize([0.5], [0.5]),
]
)
self.ti2i_transform = T.Compose(
[
T.Resize((224, 224), interpolation=T.InterpolationMode.BICUBIC),
lambda x: x.convert("RGB"),
T.ToTensor(),
T.Normalize(
(0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711),
),
]
)
# show info
if self.merge_src_cond:
self.log_fn(
"Enable merging src condition: (oriW, oriH) --> ((WH)**0.5, (WH)**0.5)"
)
self.log_fn(
"Enable image_meta_size condition (original_size, target_size, crop_coords)"
)
self.log_fn(f"Image_transforms: {self.flip_norm}")
def load_index(self):
multireso = self.multireso
index_file = self.index_file
batch_size = self.batch_size
world_size = self.world_size
if multireso:
if isinstance(index_file, (list, tuple)):
if len(index_file) > 1:
raise ValueError(
f"When enabling multireso, index_file should be a single file, but got {index_file}"
)
index_file = index_file[0]
index_manager = MultiResolutionBucketIndexV2(
index_file, batch_size, world_size
)
self.log_fn(f"Using MultiResolutionBucketIndexV2: {len(index_manager):,}")
else:
if isinstance(index_file, str):
index_file = [index_file]
if len(index_file) == 1:
index_manager = ArrowIndexV2(index_file[0])
self.log_fn(f"Using ArrowIndexV2: {len(index_manager):,}")
else:
index_manager = MultiIndexV2(index_file)
self.log_fn(f"Using MultiIndexV2: {len(index_manager):,}")
return index_manager
def shuffle(self, seed, fast=False):
self.index_manager.shuffle(seed, fast=fast)
def get_raw_image(self, index, image_key="image"):
try:
ret = self.index_manager.get_image(index, image_key)
except Exception as e:
self.log_fn(f"get_raw_image | Error: {e}")
ret = Image.new("RGB", (256, 256), (255, 255, 255))
return ret
@staticmethod
def random_crop_image(image, origin_size, target_size):
aspect_ratio = float(origin_size[0]) / float(origin_size[1])
if origin_size[0] < origin_size[1]:
new_width = target_size[0]
new_height = int(new_width / aspect_ratio)
else:
new_height = target_size[1]
new_width = int(new_height * aspect_ratio)
image = image.resize((new_width, new_height), Image.LANCZOS)
if new_width > target_size[0]:
x_start = random.randint(0, new_width - target_size[0])
y_start = 0
else:
x_start = 0
y_start = random.randint(0, new_height - target_size[1])
image_crop = image.crop(
(x_start, y_start, x_start + target_size[0], y_start + target_size[1])
)
crops_coords_top_left = (x_start, y_start)
return image_crop, crops_coords_top_left
def get_style(self, index):
"Here we use a default learned embedder layer for future extension."
style = 0
return style
def get_image_with_hwxy(self, index, image_key="image"):
image = self.get_raw_image(index, image_key=image_key)
origin_size = image.size
if self.multireso:
target_size = self.index_manager.get_target_size(index)
image, crops_coords_top_left = self.index_manager.resize_and_crop(
image, target_size, resample=Image.LANCZOS, crop_type="random"
)
image_tensor = self.flip_norm(image)
if self.is_ipa:
img_for_clip_tensor = self.ti2i_transform(image)
else:
target_size = (self.resolution, self.resolution)
image_crop, crops_coords_top_left = self.random_crop_image(
image, origin_size, target_size
)
image_tensor = self.flip_norm(image_crop)
if self.is_ipa:
img_for_clip_tensor = self.ti2i_transform(image)
if self.random_shrink_size_cond:
origin_size = (
1024 if origin_size[0] < 1024 else origin_size[0],
1024 if origin_size[1] < 1024 else origin_size[1],
)
if self.merge_src_cond:
val = (origin_size[0] * origin_size[1]) ** 0.5
origin_size = (val, val)
image_meta_size = (
tuple(origin_size) + tuple(target_size) + tuple(crops_coords_top_left)
)
kwargs = {
"image_meta_size": image_meta_size,
}
style = self.get_style(index)
kwargs["style"] = style
if self.is_ipa:
return image_tensor, img_for_clip_tensor, kwargs
else:
return image_tensor, kwargs
def get_text_info_with_encoder(self, description):
pad_num = 0
text_inputs = self.tokenizer(
description,
padding="max_length",
max_length=self.text_ctx_len,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids[0]
attention_mask = text_inputs.attention_mask[0].bool()
if pad_num > 0:
attention_mask[1 : pad_num + 1] = False
return description, text_input_ids, attention_mask
def fill_t5_token_mask(self, fill_tensor, fill_number, setting_length):
fill_length = setting_length - fill_tensor.shape[1]
if fill_length > 0:
fill_tensor = torch.cat(
(fill_tensor, fill_number * torch.ones(1, fill_length)), dim=1
)
return fill_tensor
def get_text_info_with_encoder_t5(self, description_t5):
text_tokens_and_mask = self.tokenizer_t5(
description_t5,
max_length=self.text_ctx_len_t5,
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids_t5 = self.fill_t5_token_mask(
text_tokens_and_mask["input_ids"],
fill_number=1,
setting_length=self.text_ctx_len_t5,
).long()
attention_mask_t5 = self.fill_t5_token_mask(
text_tokens_and_mask["attention_mask"],
fill_number=0,
setting_length=self.text_ctx_len_t5,
).bool()
return description_t5, text_input_ids_t5, attention_mask_t5
def get_original_text(self, ind):
text = ""
try:
text = self.index_manager.get_attribute(
ind, "text_zh" if self.enable_CN else "text_en"
)
except:
self.log_fn(f"Warning! Fail get text_zh columns")
text = str(text).strip()
return text
def get_text(self, ind):
text = self.get_original_text(ind)
if text == "":
text = "随机生成一张图片"
return text
def __getitem__(self, ind):
# Get text
if random.random() < self.uncond_p:
description = ""
else:
description = self.get_text(ind)
# Get text for t5
if random.random() < self.uncond_p_t5:
description_t5 = ""
else:
description_t5 = self.get_text(ind)
if random.random() < self.uncond_p_img:
img_for_clip_tensor = torch.zeros_like(img_for_clip_tensor)
if self.is_ipa:
original_pil_image, img_for_clip_tensor, kwargs = self.get_image_with_hwxy(
ind
)
else:
original_pil_image, kwargs = self.get_image_with_hwxy(ind)
# Use encoder to embed tokens online
text, text_embedding, text_embedding_mask = self.get_text_info_with_encoder(
description
)
text_t5, text_embedding_t5, text_embedding_mask_t5 = (
self.get_text_info_with_encoder_t5(description_t5)
)
if self.is_ipa:
return (
original_pil_image,
text_embedding.clone().detach(),
text_embedding_mask.clone().detach(),
text_embedding_t5.clone().detach(),
text_embedding_mask_t5.clone().detach(),
img_for_clip_tensor.clone().detach(),
{
k: torch.tensor(np.array(v)).clone().detach()
for k, v in kwargs.items()
},
)
else:
return (
original_pil_image,
text_embedding.clone().detach(),
text_embedding_mask.clone().detach(),
text_embedding_t5.clone().detach(),
text_embedding_mask_t5.clone().detach(),
{
k: torch.tensor(np.array(v)).clone().detach()
for k, v in kwargs.items()
},
)
def __len__(self):
return len(self.index_manager)
# -*- coding: utf-8 -*-
import datetime
import gc
import os
import time
from multiprocessing import Pool
import subprocess
import pandas as pd
import pyarrow as pa
from tqdm import tqdm
import hashlib
from PIL import Image
import sys
def parse_data(data):
try:
img_path = data[0]
with open(img_path, "rb") as fp:
image = fp.read()
md5 = hashlib.md5(image).hexdigest()
with Image.open(img_path) as f:
width, height = f.size
return [data[1], md5, width, height, image]
except Exception as e:
print(f"error: {e}")
return
def make_arrow(csv_root, dataset_root, start_id=0, end_id=-1):
print(csv_root)
arrow_dir = dataset_root
print(arrow_dir)
if not os.path.exists(arrow_dir):
os.makedirs(arrow_dir)
data = pd.read_csv(csv_root)
data = data[["img_path", "text_zh"]]
columns_list = data.columns.tolist()
columns_list.append("image")
if end_id < 0:
end_id = len(data)
print(f"start_id:{start_id} end_id:{end_id}")
data = data[start_id:end_id]
num_slice = 5000
start_sub = int(start_id / num_slice)
sub_len = int(len(data) // num_slice) # if int(len(data) // num_slice) else 1
subs = list(range(sub_len + 1))
for sub in tqdm(subs):
arrow_path = os.path.join(
arrow_dir, "{}.arrow".format(str(sub + start_sub).zfill(5))
)
if os.path.exists(arrow_path):
continue
print(
f"{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')} start {sub + start_sub}"
)
sub_data = data[sub * num_slice : (sub + 1) * num_slice].values
bs = pool.map(parse_data, sub_data)
bs = [b for b in bs if b]
print(f"length of this arrow:{len(bs)}")
columns_list = ["text_zh", "md5", "width", "height", "image"]
dataframe = pd.DataFrame(bs, columns=columns_list)
table = pa.Table.from_pandas(dataframe)
os.makedirs(dataset_root, exist_ok=True)
with pa.OSFile(arrow_path, "wb") as sink:
with pa.RecordBatchFileWriter(sink, table.schema) as writer:
writer.write_table(table)
del dataframe
del table
del bs
gc.collect()
if __name__ == "__main__":
if len(sys.argv) != 4:
print(
"Usage: python hydit/data_loader/csv2arrow.py ${csv_root} ${output_arrow_data_path} ${pool_num}"
)
print(
"csv_root: The path to your created CSV file. For more details, see https://github.com/Tencent/HunyuanDiT?tab=readme-ov-file#truck-training"
)
print("output_arrow_data_path: The path for storing the created Arrow file")
print(
"pool_num: The number of processes, used for multiprocessing. If you encounter memory issues, you can set pool_num to 1"
)
sys.exit(1)
csv_root = sys.argv[1]
output_arrow_data_path = sys.argv[2]
pool_num = int(sys.argv[3])
pool = Pool(pool_num)
make_arrow(csv_root, output_arrow_data_path)
from . import gaussian_diffusion as gd
from .respace import SpacedDiffusion, space_timesteps
def create_diffusion(
*,
steps=1000,
learn_sigma=True,
sigma_small=False,
noise_schedule="linear",
use_kl=False,
predict_type="epsilon",
rescale_timesteps=False,
rescale_learned_sigmas=False,
timestep_respacing="",
mse_loss_weight_type="constant",
beta_start=0.0001,
beta_end=0.02,
noise_offset=0.0,
):
betas = gd.get_named_beta_schedule(noise_schedule, steps, beta_start, beta_end)
if use_kl:
loss_type = gd.LossType.RESCALED_KL
elif rescale_learned_sigmas:
loss_type = gd.LossType.RESCALED_MSE
else:
loss_type = gd.LossType.MSE
if timestep_respacing is None or timestep_respacing == "":
timestep_respacing = [steps]
mean_type = gd.predict_type_dict[predict_type]
return SpacedDiffusion(
use_timesteps=space_timesteps(steps, timestep_respacing),
betas=betas,
model_mean_type=mean_type,
model_var_type=(
(
gd.ModelVarType.FIXED_LARGE
if not sigma_small
else gd.ModelVarType.FIXED_SMALL
)
if not learn_sigma
else gd.ModelVarType.LEARNED_RANGE
),
loss_type=loss_type,
rescale_timesteps=rescale_timesteps,
mse_loss_weight_type=mse_loss_weight_type,
noise_offset=noise_offset,
)
import torch as th
import numpy as np
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, th.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for th.exp().
logvar1, logvar2 = [
x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
return 0.5 * (
-1.0
+ logvar2
- logvar1
+ th.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * th.exp(-logvar2)
)
def approx_standard_normal_cdf(x):
"""
A fast approximation of the cumulative distribution function of the
standard normal.
"""
return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
"""
Compute the log-likelihood of a Gaussian distribution discretizing to a
given image.
:param x: the target images. It is assumed that this was uint8 values,
rescaled to the range [-1, 1].
:param means: the Gaussian mean Tensor.
:param log_scales: the Gaussian log stddev Tensor.
:return: a tensor like x of log probabilities (in nats).
"""
assert x.shape == means.shape == log_scales.shape
centered_x = x - means
inv_stdv = th.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
cdf_plus = approx_standard_normal_cdf(plus_in)
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
cdf_min = approx_standard_normal_cdf(min_in)
log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = th.where(
x < -0.999,
log_cdf_plus,
th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
)
assert log_probs.shape == x.shape
return log_probs
import math
import numpy as np
import torch as th
import enum
from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
from ..utils.tools import assert_shape
def mean_flat(tensor):
"""
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
class ModelMeanType(enum.Enum):
"""
Which type of output the model predicts.
"""
PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
START_X = enum.auto() # the model predicts x_0
EPSILON = enum.auto() # the model predicts epsilon
VELOCITY = enum.auto() # the model predicts v
predict_type_dict = {
"epsilon": ModelMeanType.EPSILON,
"sample": ModelMeanType.START_X,
"v_prediction": ModelMeanType.VELOCITY,
}
class ModelVarType(enum.Enum):
"""
What is used as the model's output variance.
The LEARNED_RANGE option has been added to allow the model to predict
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
"""
LEARNED = enum.auto()
FIXED_SMALL = enum.auto()
FIXED_LARGE = enum.auto()
LEARNED_RANGE = enum.auto()
class LossType(enum.Enum):
MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
RESCALED_MSE = (
enum.auto()
) # use raw MSE loss (with RESCALED_KL when learning variances)
KL = enum.auto() # use the variational lower-bound
RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
def is_vb(self):
return self == LossType.KL or self == LossType.RESCALED_KL
def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
warmup_time = int(num_diffusion_timesteps * warmup_frac)
betas[:warmup_time] = np.linspace(
beta_start, beta_end, warmup_time, dtype=np.float64
)
return betas
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
"""
This is the deprecated API for creating beta schedules.
See get_named_beta_schedule() for the new library of schedules.
"""
if beta_schedule == "quad":
betas = (
np.linspace(
beta_start**0.5,
beta_end**0.5,
num_diffusion_timesteps,
dtype=np.float64,
)
** 2
)
elif beta_schedule == "linear":
betas = np.linspace(
beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
)
elif beta_schedule == "warmup10":
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
elif beta_schedule == "warmup50":
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
elif beta_schedule == "const":
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
betas = 1.0 / np.linspace(
num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
)
else:
raise NotImplementedError(beta_schedule)
assert_shape(betas, (num_diffusion_timesteps,))
return betas
def get_named_beta_schedule(
schedule_name, num_diffusion_timesteps, beta_start=0.0001, beta_end=0.02
):
"""
Get a pre-defined beta schedule for the given name.
The beta schedule library consists of beta schedules which remain similar
in the limit of num_diffusion_timesteps.
Beta schedules may be added, but should not be removed or changed once
they are committed to maintain backwards compatibility.
"""
if schedule_name == "linear":
# Linear schedule from Ho et al, extended to work for any number of
# diffusion steps.
scale = 1000 / num_diffusion_timesteps
return get_beta_schedule(
"linear",
beta_start=scale * beta_start, # DDPM
beta_end=scale * beta_end, # DDPM
num_diffusion_timesteps=num_diffusion_timesteps, # DDPM
)
elif schedule_name == "scaled_linear":
return get_beta_schedule(
"quad",
beta_start=beta_start, # StableDiffusion, should be 0.00085
beta_end=beta_end, # StableDiffusion, should be 0.012
num_diffusion_timesteps=num_diffusion_timesteps, # StableDiffusion
)
elif schedule_name == "squaredcos_cap_v2":
return betas_for_alpha_bar(
num_diffusion_timesteps,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
)
else:
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas)
class GaussianDiffusion:
"""
Utilities for training and sampling diffusion models.
Ported directly from here, and then adapted over time to further experimentation.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
:param betas: a 1-D numpy array of betas for each diffusion timestep,
starting at T and going to 1.
:param model_mean_type: a ModelMeanType determining what the model outputs.
:param model_var_type: a ModelVarType determining how variance is output.
:param loss_type: a LossType determining the loss function to use.
:param rescale_timesteps: if True, pass floating point timesteps into the
model so that they are always scaled like in the
original paper (0 to 1000).
"""
def __init__(
self,
*,
betas,
model_mean_type,
model_var_type,
loss_type,
rescale_timesteps=False,
mse_loss_weight_type="constant",
noise_offset=0.0,
):
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
self.loss_type = loss_type
self.rescale_timesteps = rescale_timesteps
self.mse_loss_weight_type = mse_loss_weight_type
self.noise_offset = noise_offset
# Use float64 for accuracy.
betas = np.array(betas, dtype=np.float64)
alphas = 1.0 - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
self.betas = betas
assert len(betas.shape) == 1, "betas must be 1-D"
assert (betas > 0).all() and (betas <= 1).all()
self.num_timesteps = int(betas.shape[0])
self.alphas_cumprod = alphas_cumprod
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
assert_shape(self.alphas_cumprod_prev, (self.num_timesteps,))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
# calculations for posterior q(x_{t-1} | x_t, x_0)
self.posterior_variance = (
betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
# log calculation clipped because the posterior variance is 0 at the
# beginning of the diffusion chain.
self.posterior_log_variance_clipped = (
np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:]))
if len(self.posterior_variance) > 1
else np.array([])
)
self.posterior_mean_coef1 = (
betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
self.posterior_mean_coef2 = (
(1.0 - self.alphas_cumprod_prev)
* np.sqrt(alphas)
/ (1.0 - self.alphas_cumprod)
)
self.sampler = {
"ddpm": self.p_sample_loop,
"ddim": self.ddim_sample_loop,
"plms": self.plms_sample_loop,
}
def q_mean_variance(self, x_start, t):
"""
Get the distribution q(x_t | x_0).
:param x_start: the [N x C x ...] tensor of noiseless inputs.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
"""
mean = (
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
)
variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
log_variance = _extract_into_tensor(
self.log_one_minus_alphas_cumprod, t, x_start.shape
)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data for a given number of diffusion steps.
In other words, sample from q(x_t | x_0).
:param x_start: the initial data batch.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:param noise: if specified, the split-out normal noise.
:return: A noisy version of x_start.
"""
if noise is None:
noise = th.randn_like(x_start)
assert_shape(noise, x_start)
return (
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
* noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior:
q(x_{t-1} | x_t, x_0)
"""
assert_shape(x_start, x_t)
posterior_mean = (
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = _extract_into_tensor(
self.posterior_log_variance_clipped, t, x_t.shape
)
assert_shape(
posterior_mean.shape[:1],
posterior_variance.shape[:1],
posterior_log_variance_clipped.shape[:1],
x_start.shape[:1],
)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(
self,
model,
x,
t,
clip_denoised=True,
denoised_fn=None,
model_kwargs=None,
model_var_type=None,
):
"""
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
the initial x, x_0.
:param model: the model, which takes a signal and a batch of timesteps
as input.
:param x: the [N x C x ...] tensor at time t.
:param t: a 1-D Tensor of timesteps.
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample. Applies before
clip_denoised.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:param model_var_type: if not None, overlap the default self.model_var_type.
It is useful when training with learned var but sampling with fixed var.
:return: a dict with the following keys:
- 'mean': the model mean output.
- 'variance': the model variance output.
- 'log_variance': the log of 'variance'.
- 'pred_xstart': the prediction for x_0.
"""
if model_kwargs is None:
model_kwargs = {}
if model_var_type is None:
model_var_type = self.model_var_type
B, C = x.shape[:2]
assert_shape(t, (B,))
out_dict = model(x, t, **model_kwargs)
model_output = out_dict["x"]
if len(out_dict) > 1:
extra = {k: v for k, v in out_dict.items() if k != "x"}
else:
extra = None
# self.model_var_type corresponds to model output
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
assert_shape(model_output, (B, C * 2, *x.shape[2:]))
model_output, model_var_values = th.split(model_output, C, dim=1)
# model_var_type corresponds to reverse diffusion process
if model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
if model_var_type == ModelVarType.LEARNED:
model_log_variance = model_var_values
model_variance = th.exp(model_log_variance)
else:
min_log = _extract_into_tensor(
self.posterior_log_variance_clipped, t, x.shape
)
max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
# The model_var_values is [-1, 1] for [min_var, max_var].
frac = (model_var_values + 1) / 2
model_log_variance = frac * max_log + (1 - frac) * min_log
model_variance = th.exp(model_log_variance)
else:
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so
# to get a better decoder log likelihood.
ModelVarType.FIXED_LARGE: (
np.append(self.posterior_variance[1], self.betas[1:]),
np.log(np.append(self.posterior_variance[1], self.betas[1:])),
),
ModelVarType.FIXED_SMALL: (
self.posterior_variance,
self.posterior_log_variance_clipped,
),
}[model_var_type]
model_variance = _extract_into_tensor(model_variance, t, x.shape)
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
def process_xstart(x):
if denoised_fn is not None:
x = denoised_fn(x)
if clip_denoised:
return x.clamp(-1, 1)
return x
if self.model_mean_type == ModelMeanType.PREVIOUS_X:
pred_xstart = process_xstart(
self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
)
model_mean = model_output
elif self.model_mean_type in [
ModelMeanType.START_X,
ModelMeanType.EPSILON,
ModelMeanType.VELOCITY,
]:
if self.model_mean_type == ModelMeanType.START_X:
pred_xstart = process_xstart(model_output)
elif self.model_mean_type == ModelMeanType.EPSILON:
pred_xstart = process_xstart(
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
)
else:
pred_xstart = process_xstart(
self._predict_xstart_from_v(x_t=x, t=t, v=model_output)
)
model_mean, _, _ = self.q_posterior_mean_variance(
x_start=pred_xstart, x_t=x, t=t
)
else:
raise NotImplementedError(self.model_mean_type)
assert_shape(model_mean, model_log_variance, pred_xstart, x)
return {
"mean": model_mean,
"variance": model_variance,
"log_variance": model_log_variance,
"pred_xstart": pred_xstart,
"extra": extra,
}
def _predict_xstart_from_eps(self, x_t, t, eps):
assert_shape(x_t, eps)
return (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
)
def _predict_xstart_from_v(self, x_t, t, v):
assert_shape(x_t, v)
return (
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t
- _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
)
def _predict_xstart_from_xprev(self, x_t, t, xprev):
assert_shape(x_t, xprev)
return ( # (xprev - coef2*x_t) / coef1
_extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
- _extract_into_tensor(
self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
)
* x_t
)
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
return (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- pred_xstart
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
def _velocity_from_xstart_and_noise(self, x_start, t, noise):
return (
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * noise
- _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
* x_start
)
def _vb_terms_bpd(
self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
):
"""
Get a term for the variational lower-bound.
The resulting units are bits (rather than nats, as one might expect).
This allows for comparison to other papers.
:return: a dict with the following keys:
- 'output': a shape [N] tensor of NLLs or KLs.
- 'pred_xstart': the x_0 predictions.
"""
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
x_start=x_start, x_t=x_t, t=t
)
out = self.p_mean_variance(
model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
)
kl = normal_kl(
true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
)
kl = mean_flat(kl) / np.log(2.0)
decoder_nll = -discretized_gaussian_log_likelihood(
x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
)
assert_shape(decoder_nll, x_start)
decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
# At the first timestep return the decoder NLL,
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
output = th.where((t == 0), decoder_nll, kl)
return {
"output": output,
"pred_xstart": out["pred_xstart"],
"extra": out["extra"],
}
def training_losses(
self, model, x_start, model_kwargs=None, controlnet=None, noise=None
):
"""
Compute training losses for a single timestep.
:param model: the model to evaluate loss on.
:param x_start: the [N x C x ...] tensor of inputs.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:param noise: if specified, the specific Gaussian noise to try to remove.
:return: a dict with the key "loss" containing a tensor of shape [N].
Some mean or variance settings may also have other keys.
"""
if model_kwargs is None:
model_kwargs = {}
# Time steps
t = th.randint(
0, self.num_timesteps, (x_start.shape[0],), device=x_start.device
)
# Noise
if noise is None:
noise = th.randn_like(x_start)
if self.noise_offset > 0:
# Add channel wise noise offset
# https://www.crosslabs.org/blog/diffusion-with-offset-noise
noise = noise + self.noise_offset * th.randn(
*x_start.shape[:2], 1, 1, device=x_start.device
)
x_t = self.q_sample(x_start, t, noise=noise)
terms = {}
if self.mse_loss_weight_type == "constant":
mse_loss_weight = th.ones_like(t)
elif self.mse_loss_weight_type.startswith("min_snr_"):
alpha = _extract_into_tensor(self.sqrt_alphas_cumprod, t, t.shape)
sigma = _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, t.shape)
snr = (alpha / sigma) ** 2
k = float(self.mse_loss_weight_type.split("min_snr_")[-1])
# min{snr, k}
mse_loss_weight = (
th.stack([snr, k * th.ones_like(t)], dim=1).min(dim=1)[0] / snr
)
else:
raise ValueError(self.mse_loss_weight_type)
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
out_dict = self._vb_terms_bpd(
model=model,
x_start=x_start,
x_t=x_t,
t=t,
clip_denoised=False,
model_kwargs=model_kwargs,
)
terms["loss"] = out_dict["output"]
if self.loss_type == LossType.RESCALED_KL:
terms["loss"] *= self.num_timesteps
extra = out_dict["extra"]
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
if controlnet != None:
controls = controlnet(x_t, t, **model_kwargs)
model_kwargs.pop("condition")
model_kwargs.update(controls)
out_dict = model(x_t, t, **model_kwargs)
model_output = out_dict["x"]
extra = {k: v for k, v in out_dict.items() if k != "x"}
if self.model_var_type in [
ModelVarType.LEARNED,
ModelVarType.LEARNED_RANGE,
]:
B, C = x_t.shape[:2]
assert_shape(model_output, (B, C * 2, *x_t.shape[2:]))
model_output, model_var_values = th.split(model_output, C, dim=1)
# Learn the variance using the variational bound, but don't let
# it affect our mean prediction.
frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
terms["vb"] = self._vb_terms_bpd(
model=lambda *args, r=frozen_out: dict(x=r),
x_start=x_start,
x_t=x_t,
t=t,
clip_denoised=False,
)["output"]
if self.loss_type == LossType.RESCALED_MSE:
# Divide by 1000 for equivalence with initial implementation.
# Without a factor of 1/1000, the VB term hurts the MSE term.
terms["vb"] *= self.num_timesteps / 1000.0
if self.model_mean_type == ModelMeanType.VELOCITY:
target = self._velocity_from_xstart_and_noise(x_start, t, noise)
else:
target = {
# ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
# x_start=x_start, x_t=x_t, t=t
# )[0],
ModelMeanType.START_X: x_start,
ModelMeanType.EPSILON: noise,
}[self.model_mean_type]
assert_shape(model_output, target, x_start)
raw_mse = mean_flat((target - model_output) ** 2).detach()
terms["mse"] = mse_loss_weight * mean_flat((target - model_output) ** 2)
if "vb" in terms:
terms["loss"] = terms["mse"] + terms["vb"]
terms["raw_loss"] = raw_mse + terms["vb"].detach()
else:
terms["loss"] = terms["mse"]
terms["raw_loss"] = raw_mse
else:
raise NotImplementedError(self.loss_type)
terms.update(extra)
return terms
def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
"""
Compute the mean for the previous step, given a function cond_fn that
computes the gradient of a conditional log probability with respect to
x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
condition on y.
This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
"""
gradient = cond_fn(x, t, **model_kwargs)
new_mean = (
p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
)
return new_mean
def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
"""
Compute what the p_mean_variance output would have been, should the
model's score function be conditioned by cond_fn.
See condition_mean() for details on cond_fn.
Unlike condition_mean(), this instead uses the conditioning strategy
from Song et al (2020).
"""
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
out = p_mean_var.copy()
out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
out["mean"], _, _ = self.q_posterior_mean_variance(
x_start=out["pred_xstart"], x_t=x, t=t
)
return out
def p_sample(
self,
model,
x,
t,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
model_var_type=None,
**kwargs,
):
"""
Sample x_{t-1} from the model at the given timestep.
:param model: the model to sample from.
:param x: the current tensor at x_{t-1}.
:param t: the value of t, starting at 0 for the first diffusion step.
:param clip_denoised: if True, clip the x_start prediction to [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample.
:param cond_fn: if not None, this is a gradient function that acts
similarly to the model.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:param model_var_type: if not None, overlap the default self.model_var_type.
It is useful when training with learned var but sampling with fixed var.
:return: a dict containing the following keys:
- 'sample': a random sample from the model.
- 'pred_xstart': a prediction of x_0.
"""
out = self.p_mean_variance(
model,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
model_var_type=model_var_type,
)
noise = th.randn_like(x)
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
) # no noise when t == 0
if cond_fn is not None:
out["mean"] = self.condition_mean(
cond_fn, out, x, t, model_kwargs=model_kwargs
)
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
def p_sample_loop(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
model_var_type=None,
device=None,
progress=False,
progress_leave=True,
**kwargs,
):
"""
Generate samples from the model.
:param model: the model module.
:param shape: the shape of the samples, (N, C, H, W).
:param noise: if specified, the noise from the encoder to sample.
Should be of the same shape as `shape`.
:param clip_denoised: if True, clip x_start predictions to [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample.
:param cond_fn: if not None, this is a gradient function that acts
similarly to the model.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:param model_var_type: if not None, overlap the default self.model_var_type.
It is useful when training with learned var but sampling with fixed var.
:param device: if specified, the device to create the samples on.
If not specified, use a model parameter's device.
:param progress: if True, show a tqdm progress bar.
:return: a non-differentiable batch of samples.
"""
final = None
for sample in self.p_sample_loop_progressive(
model,
shape,
noise=noise,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
model_var_type=model_var_type,
device=device,
progress=progress,
progress_leave=progress_leave,
**kwargs,
):
final = sample
return final["sample"]
def p_sample_loop_progressive(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
model_var_type=None,
device=None,
progress=False,
progress_leave=True,
**kwargs,
):
"""
Generate samples from the model and yield intermediate samples from
each timestep of diffusion.
Arguments are the same as p_sample_loop().
Returns a generator over dicts, where each dict is the return value of
p_sample().
"""
if device is None:
device = next(model.parameters()).device
assert isinstance(shape, (tuple, list))
if noise is not None:
img = noise
else:
img = th.randn(*shape, device=device)
indices = list(range(self.num_timesteps))[::-1]
if progress:
# Lazy import so that we don't depend on tqdm.
from tqdm.auto import tqdm
indices = tqdm(indices, leave=progress_leave)
for i in indices:
t = th.tensor([i] * shape[0], device=device)
with th.no_grad():
out = self.p_sample(
model,
img,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
model_var_type=model_var_type,
**kwargs,
)
yield out
img = out["sample"]
def ddim_sample(
self,
model,
x,
t,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
eta=0.0,
):
"""
Sample x_{t-1} from the model using DDIM.
Same usage as p_sample().
"""
out = self.p_mean_variance(
model,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
)
if cond_fn is not None:
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
# Usually our model outputs epsilon, but we re-derive it
# in case we used x_start or x_prev prediction.
eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
sigma = (
eta
* th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
* th.sqrt(1 - alpha_bar / alpha_bar_prev)
)
# Equation 12.
noise = th.randn_like(x)
mean_pred = (
out["pred_xstart"] * th.sqrt(alpha_bar_prev)
+ th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
)
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
) # no noise when t == 0
sample = mean_pred + nonzero_mask * sigma * noise
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
def ddim_reverse_sample(
self,
model,
x,
t,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
eta=0.0,
):
"""
Sample x_{t+1} from the model using DDIM reverse ODE.
"""
assert eta == 0.0, "Reverse ODE only for deterministic path"
out = self.p_mean_variance(
model,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
)
if cond_fn is not None:
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
# Usually our model outputs epsilon, but we re-derive it
# in case we used x_start or x_prev prediction.
eps = (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
- out["pred_xstart"]
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
# Equation 12. reversed
mean_pred = (
out["pred_xstart"] * th.sqrt(alpha_bar_next)
+ th.sqrt(1 - alpha_bar_next) * eps
)
return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
def ddim_sample_loop(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
progress_leave=True,
eta=0.0,
):
"""
Generate samples from the model using DDIM.
Same usage as p_sample_loop().
"""
final = None
for sample in self.ddim_sample_loop_progressive(
model,
shape,
noise=noise,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
device=device,
progress=progress,
progress_leave=progress_leave,
eta=eta,
):
final = sample
return final["sample"]
def ddim_sample_loop_progressive(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
progress_leave=True,
eta=0.0,
):
"""
Use DDIM to sample from the model and yield intermediate samples from
each timestep of DDIM.
Same usage as p_sample_loop_progressive().
"""
if device is None:
device = next(model.parameters()).device
assert isinstance(shape, (tuple, list))
if noise is not None:
img = noise
else:
img = th.randn(*shape, device=device)
indices = list(range(self.num_timesteps))[::-1]
if progress:
# Lazy import so that we don't depend on tqdm.
from tqdm.auto import tqdm
indices = tqdm(indices, leave=progress_leave)
for i in indices:
t = th.tensor([i] * shape[0], device=device)
with th.no_grad():
out = self.ddim_sample(
model,
img,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
eta=eta,
)
yield out
img = out["sample"]
def _prior_bpd(self, x_start):
"""
Get the prior KL term for the variational lower-bound, measured in
bits-per-dim.
This term can't be optimized, as it only depends on the encoder.
:param x_start: the [N x C x ...] tensor of inputs.
:return: a batch of [N] KL values (in bits), one per batch element.
"""
batch_size = x_start.shape[0]
t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
kl_prior = normal_kl(
mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
)
return mean_flat(kl_prior) / np.log(2.0)
def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
"""
Compute the entire variational lower-bound, measured in bits-per-dim,
as well as other related quantities.
:param model: the model to evaluate loss on.
:param x_start: the [N x C x ...] tensor of inputs.
:param clip_denoised: if True, clip denoised samples.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict containing the following keys:
- total_bpd: the total variational lower-bound, per batch element.
- prior_bpd: the prior term in the lower-bound.
- vb: an [N x T] tensor of terms in the lower-bound.
- xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
- mse: an [N x T] tensor of epsilon MSEs for each timestep.
"""
device = x_start.device
batch_size = x_start.shape[0]
vb = []
xstart_mse = []
mse = []
for t in list(range(self.num_timesteps))[::-1]:
t_batch = th.tensor([t] * batch_size, device=device)
noise = th.randn_like(x_start)
x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
# Calculate VLB term at the current timestep
with th.no_grad():
out = self._vb_terms_bpd(
model,
x_start=x_start,
x_t=x_t,
t=t_batch,
clip_denoised=clip_denoised,
model_kwargs=model_kwargs,
)
vb.append(out["output"])
xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
mse.append(mean_flat((eps - noise) ** 2))
vb = th.stack(vb, dim=1)
xstart_mse = th.stack(xstart_mse, dim=1)
mse = th.stack(mse, dim=1)
prior_bpd = self._prior_bpd(x_start)
total_bpd = vb.sum(dim=1) + prior_bpd
return {
"total_bpd": total_bpd,
"prior_bpd": prior_bpd,
"vb": vb,
"xstart_mse": xstart_mse,
"mse": mse,
}
def get_eps(
self,
model,
x,
t,
model_kwargs,
cond_fn=None,
):
model_output = model(x, t, **model_kwargs)["x"]
if isinstance(model_output, tuple):
model_output, _ = model_output
eps = model_output[:, :4]
if cond_fn is not None:
alpha_bar = _extract_into_tensor_lerp(self.alphas_cumprod, t, x.shape)
eps = eps - th.sqrt(1 - alpha_bar) * cond_fn(x, t, **model_kwargs)
return eps
def eps_to_pred_xstart(
self,
x,
eps,
t,
):
alpha_bar = _extract_into_tensor_lerp(self.alphas_cumprod, t, x.shape)
return (x - eps * th.sqrt(1 - alpha_bar)) / th.sqrt(alpha_bar)
def pndm_transfer(
self,
x,
eps,
t_1,
t_2,
):
pred_xstart = self.eps_to_pred_xstart(x, eps, t_1)
alpha_bar_prev = _extract_into_tensor_lerp(self.alphas_cumprod, t_2, x.shape)
return pred_xstart * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps
def prk_sample_loop(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
):
"""
Generate samples from the model using PRK.
Same usage as p_sample_loop().
"""
final = None
for sample in self.prk_sample_loop_progressive(
model,
shape,
noise=noise,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
device=device,
progress=progress,
):
final = sample
return final["sample"]
def prk_sample_loop_progressive(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
):
"""
Use PRK to sample from the model and yield intermediate samples from
each timestep of PRK.
Same usage as p_sample_loop_progressive().
"""
if device is None:
device = next(model.parameters()).device
assert isinstance(shape, (tuple, list))
if noise is not None:
img = noise
else:
img = th.randn(*shape, device=device)
indices = list(range(self.num_timesteps))[::-1][1:-1]
if progress:
# Lazy import so that we don't depend on tqdm.
from tqdm.auto import tqdm
indices = tqdm(indices, leave=False)
for i in indices:
t = th.tensor([i] * shape[0], device=device)
with th.no_grad():
out = self.prk_sample(
model,
img,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
)
yield out
img = out["sample"]
def prk_sample(
self,
model,
x,
t,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
):
"""
Sample x_{t-1} from the model using fourth-order Pseudo Runge-Kutta
(https://openreview.net/forum?id=PlKWVd2yBkY).
Same usage as p_sample().
"""
if model_kwargs is None:
model_kwargs = {}
def process_xstart(x):
if denoised_fn is not None:
x = denoised_fn(x)
if clip_denoised:
return x.clamp(-1, 1)
return x
eps_1 = self.get_eps(model, x, t, model_kwargs, cond_fn)
x_1 = self.pndm_transfer(x, eps_1, t, t - 0.5)
eps_2 = self.get_eps(model, x_1, t - 0.5, model_kwargs, cond_fn)
x_2 = self.pndm_transfer(x, eps_2, t, t - 0.5)
eps_3 = self.get_eps(model, x_2, t - 0.5, model_kwargs, cond_fn)
x_3 = self.pndm_transfer(x, eps_3, t, t - 1)
eps_4 = self.get_eps(model, x_3, t - 1, model_kwargs, cond_fn)
eps_prime = (eps_1 + 2 * eps_2 + 2 * eps_3 + eps_4) / 6
sample = self.pndm_transfer(x, eps_prime, t, t - 1)
pred_xstart = self.eps_to_pred_xstart(x, eps_prime, t)
pred_xstart = process_xstart(pred_xstart)
return {"sample": sample, "pred_xstart": pred_xstart, "eps": eps_prime}
def plms_sample_loop(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
progress_leave=True,
):
"""
Generate samples from the model using PLMS.
Same usage as p_sample_loop().
"""
assert (
self.model_mean_type == ModelMeanType.EPSILON
), "plms_sample only support model_mean_type == ModelMeanType.EPSILON"
final = None
for sample in self.plms_sample_loop_progressive(
model,
shape,
noise=noise,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
device=device,
progress=progress,
progress_leave=progress_leave,
):
final = sample
return final["sample"]
def plms_sample_loop_progressive(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
progress_leave=True,
):
"""
Use PLMS to sample from the model and yield intermediate samples from
each timestep of PLMS.
Same usage as p_sample_loop_progressive().
"""
if device is None:
device = next(model.parameters()).device
assert isinstance(shape, (tuple, list))
if noise is not None:
img = noise
else:
img = th.randn(*shape, device=device)
indices = list(range(self.num_timesteps))[::-1][1:-1]
if progress:
# Lazy import so that we don't depend on tqdm.
from tqdm.auto import tqdm
indices = tqdm(indices, leave=progress_leave)
old_eps = []
for i in indices:
t = th.tensor([i] * shape[0], device=device)
with th.no_grad():
if len(old_eps) < 3:
out = self.prk_sample(
model,
img,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
)
else:
out = self.plms_sample(
model,
img,
old_eps,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
)
old_eps.pop(0)
old_eps.append(out["eps"])
yield out
img = out["sample"]
def plms_sample(
self,
model,
x,
old_eps,
t,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
):
"""
Sample x_{t-1} from the model using fourth-order Pseudo Linear Multistep
(https://openreview.net/forum?id=PlKWVd2yBkY).
"""
if model_kwargs is None:
model_kwargs = {}
def process_xstart(x):
if denoised_fn is not None:
x = denoised_fn(x)
if clip_denoised:
return x.clamp(-1, 1)
return x
eps = self.get_eps(model, x, t, model_kwargs, cond_fn)
eps_prime = (
55 * eps - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]
) / 24
sample = self.pndm_transfer(x, eps_prime, t, t - 1)
pred_xstart = self.eps_to_pred_xstart(x, eps, t)
pred_xstart = process_xstart(pred_xstart)
return {"sample": sample, "pred_xstart": pred_xstart, "eps": eps}
def _extract_into_tensor(arr, timesteps, broadcast_shape):
"""
Extract values from a 1-D numpy array for a batch of indices.
:param arr: the 1-D numpy array.
:param timesteps: a tensor of indices into the array to extract.
:param broadcast_shape: a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res.expand(broadcast_shape)
def _extract_into_tensor_lerp(arr, timesteps, broadcast_shape):
"""
Extract values from arr with fractional time steps
"""
timesteps = timesteps.float()
frac = timesteps.frac()
while len(frac.shape) < len(broadcast_shape):
frac = frac[..., None]
res_1 = _extract_into_tensor(arr, timesteps.floor().long(), broadcast_shape)
res_2 = _extract_into_tensor(arr, timesteps.ceil().long(), broadcast_shape)
return th.lerp(res_1, res_2, frac)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from diffusers.configuration_utils import FrozenDict
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import (
FromSingleFileMixin,
LoraLoaderMixin,
TextualInversionLoaderMixin,
)
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import (
deprecate,
logging,
replace_example_docstring,
)
from diffusers.utils.torch_utils import randn_tensor
from transformers import BertModel, BertTokenizer
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ..modules.models import HunYuanDiT
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import StableDiffusionPipeline
>>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
>>> pipe = pipe.to("cuda")
>>> prompt = "a photo of an astronaut riding a horse on mars"
>>> image = pipe(prompt).images[0]
```
"""
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
"""
std_text = noise_pred_text.std(
dim=list(range(1, noise_pred_text.ndim)), keepdim=True
)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = (
guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
)
return noise_cfg
class StableDiffusionPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
):
r"""
Pipeline for text-to-image generation using Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]):
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]):
A `BertTokenizer` or `CLIPTokenizer` to tokenize text.
unet (Optional[`HunYuanDiT`, `UNet2DConditionModel`]):
A `HunYuanDiT` or `UNet2DConditionModel` to denoise the encoded image latents.
Notice: Here we still keep the word `unet` for compatibility with the previous version of the pipeline.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
def __init__(
self,
vae: AutoencoderKL,
text_encoder: Union[BertModel, CLIPTextModel],
tokenizer: Union[BertTokenizer, CLIPTokenizer],
unet: Union[HunYuanDiT, UNet2DConditionModel],
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
progress_bar_config: Dict[str, Any] = None,
embedder_t5=None,
infer_mode="torch",
):
super().__init__()
# ========================================================
self.embedder_t5 = embedder_t5
self.infer_mode = infer_mode
# ========================================================
if progress_bar_config is None:
progress_bar_config = {}
if not hasattr(self, "_progress_bar_config"):
self._progress_bar_config = {}
self._progress_bar_config.update(progress_bar_config)
# ========================================================
if (
hasattr(scheduler.config, "steps_offset")
and scheduler.config.steps_offset != 1
):
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate(
"steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False
)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if (
hasattr(scheduler.config, "clip_sample")
and scheduler.config.clip_sample is True
):
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate(
"clip_sample not set", "1.0.0", deprecation_message, standard_warn=False
)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
):
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
prompt_embeds_tuple = self.encode_prompt(
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=lora_scale,
)
# concatenate for backwards comp
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
return prompt_embeds
def encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
embedder=None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
embedder:
T5 embedder (including text encoder and tokenizer)
"""
if embedder is None:
text_encoder = self.text_encoder
tokenizer = self.tokenizer
max_length = self.tokenizer.model_max_length
else:
text_encoder = embedder.model
tokenizer = embedder.tokenizer
max_length = embedder.max_length
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, tokenizer)
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_attention_mask=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(
prompt, padding="longest", return_tensors="pt"
).input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[
-1
] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = tokenizer.batch_decode(
untruncated_ids[:, tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer.model_max_length} tokens: {removed_text}"
)
attention_mask = text_inputs.attention_mask.to(device)
prompt_embeds = text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
attention_mask = attention_mask.repeat(num_images_per_prompt, 1)
else:
attention_mask = None
if text_encoder is not None:
prompt_embeds_dtype = text_encoder.dtype
elif self.unet is not None:
prompt_embeds_dtype = self.unet.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(
bs_embed * num_images_per_prompt, seq_len, -1
)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
uncond_attention_mask = uncond_input.attention_mask.to(device)
negative_prompt_embeds = text_encoder(
uncond_input.input_ids.to(device),
attention_mask=uncond_attention_mask,
)
negative_prompt_embeds = negative_prompt_embeds[0]
uncond_attention_mask = uncond_attention_mask.repeat(
num_images_per_prompt, 1
)
else:
uncond_attention_mask = None
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(
dtype=prompt_embeds_dtype, device=device
)
negative_prompt_embeds = negative_prompt_embeds.repeat(
1, num_images_per_prompt, 1
)
negative_prompt_embeds = negative_prompt_embeds.view(
batch_size * num_images_per_prompt, seq_len, -1
)
return (
prompt_embeds,
negative_prompt_embeds,
attention_mask,
uncond_attention_mask,
)
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(
image, output_type="pil"
)
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(
feature_extractor_input, return_tensors="pt"
).to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept
def decode_latents(self, latents):
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(
self,
prompt,
height,
width,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(
f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
)
if (callback_steps is None) or (
callback_steps is not None
and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (
not isinstance(prompt, str) and not isinstance(prompt, list)
):
raise ValueError(
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
)
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
shape = (
batch_size,
num_channels_latents,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(
shape, generator=generator, device=device, dtype=dtype
)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
height: int,
width: int,
prompt: Union[str, List[str]] = None,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_embeds_t5: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds_t5: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[
Callable[[int, int, torch.FloatTensor, torch.FloatTensor], None]
] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
image_meta_size: Optional[torch.LongTensor] = None,
style: Optional[torch.LongTensor] = None,
progress: bool = True,
use_fp16: bool = False,
freqs_cis_img: Optional[tuple] = None,
learn_sigma: bool = True,
):
r"""
The call function to the pipeline for generation.
Args:
height (`int`):
The height in pixels of the generated image.
width (`int`):
The width in pixels of the generated image.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
latents as `image`, but if passing latents directly it is not encoded again.
strength (`float`, *optional*, defaults to 1.0):
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
essentially ignores `image`.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. This parameter is modulated by `strength`.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor,
pred_x0: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
Examples:
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list with the generated images and the
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None)
if cross_attention_kwargs is not None
else None
)
prompt_embeds, negative_prompt_embeds, attention_mask, uncond_attention_mask = (
self.encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
)
)
(
prompt_embeds_t5,
negative_prompt_embeds_t5,
attention_mask_t5,
uncond_attention_mask_t5,
) = self.encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds_t5,
negative_prompt_embeds=negative_prompt_embeds_t5,
lora_scale=text_encoder_lora_scale,
embedder=self.embedder_t5,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
attention_mask = torch.cat([uncond_attention_mask, attention_mask])
prompt_embeds_t5 = torch.cat([negative_prompt_embeds_t5, prompt_embeds_t5])
attention_mask_t5 = torch.cat([uncond_attention_mask_t5, attention_mask_t5])
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 6. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
)
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
)
# expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
t_expand = torch.tensor(
[t] * latent_model_input.shape[0], device=latent_model_input.device
)
if use_fp16:
latent_model_input = latent_model_input.half()
t_expand = t_expand.half()
prompt_embeds = prompt_embeds.half()
ims = (
image_meta_size.half() if image_meta_size is not None else None
)
else:
ims = image_meta_size if image_meta_size is not None else None
# predict the noise residual
if self.infer_mode in ["fa", "torch"]:
noise_pred = self.unet(
latent_model_input,
t_expand,
encoder_hidden_states=prompt_embeds,
text_embedding_mask=attention_mask,
encoder_hidden_states_t5=prompt_embeds_t5,
text_embedding_mask_t5=attention_mask_t5,
image_meta_size=ims,
style=style,
cos_cis_img=freqs_cis_img[0],
sin_cis_img=freqs_cis_img[1],
return_dict=False,
)
elif self.infer_mode == "trt":
noise_pred = self.unet(
x=latent_model_input.contiguous(),
t_emb=t_expand.contiguous(),
context=prompt_embeds.contiguous(),
image_meta_size=ims,
style=style,
freqs_cis_img0=freqs_cis_img[0].to(device).contiguous(),
freqs_cis_img1=freqs_cis_img[1].to(device).contiguous(),
text_embedding_mask=attention_mask.contiguous(),
encoder_hidden_states_t5=prompt_embeds_t5.contiguous(),
text_embedding_mask_t5=attention_mask_t5.contiguous(),
)
else:
raise ValueError("Unknown infer_mode: {self.infer_mode}")
if learn_sigma:
noise_pred, _ = noise_pred.chunk(2, dim=1)
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
if do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(
noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
)
# compute the previous noisy sample x_t -> x_t-1
results = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs, return_dict=True
)
latents = results.prev_sample
pred_x0 = (
results.pred_original_sample
if hasattr(results, "pred_original_sample")
else None
)
# call the callback, if provided
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents, pred_x0)
if not output_type == "latent":
image = self.vae.decode(
latents / self.vae.config.scaling_factor, return_dict=False
)[0]
image, has_nsfw_concept = self.run_safety_checker(
image, device, prompt_embeds.dtype
)
else:
image = latents
has_nsfw_concept = None
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.image_processor.postprocess(
image, output_type=output_type, do_denormalize=do_denormalize
)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(
images=image, nsfw_content_detected=has_nsfw_concept
)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from diffusers.configuration_utils import FrozenDict
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import (
FromSingleFileMixin,
LoraLoaderMixin,
TextualInversionLoaderMixin,
)
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import (
deprecate,
logging,
replace_example_docstring,
)
from diffusers.utils.torch_utils import randn_tensor
from transformers import BertModel, BertTokenizer
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ..modules.models import HunYuanDiT
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import StableDiffusionPipeline
>>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
>>> pipe = pipe.to("cuda")
>>> prompt = "a photo of an astronaut riding a horse on mars"
>>> image = pipe(prompt).images[0]
```
"""
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
"""
std_text = noise_pred_text.std(
dim=list(range(1, noise_pred_text.ndim)), keepdim=True
)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = (
guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
)
return noise_cfg
class StableDiffusionControlNetPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
):
r"""
Pipeline for text-to-image generation using Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]):
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]):
A `BertTokenizer` or `CLIPTokenizer` to tokenize text.
unet (Optional[`HunYuanDiT`, `UNet2DConditionModel`]):
A `HunYuanDiT` or `UNet2DConditionModel` to denoise the encoded image latents.
Notice: Here we still keep the word `unet` for compatibility with the previous version of the pipeline.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
def __init__(
self,
vae: AutoencoderKL,
text_encoder: Union[BertModel, CLIPTextModel],
tokenizer: Union[BertTokenizer, CLIPTokenizer],
unet: Union[HunYuanDiT, UNet2DConditionModel],
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
progress_bar_config: Dict[str, Any] = None,
embedder_t5=None,
infer_mode="torch",
controlnet=None,
):
super().__init__()
# ========================================================
self.embedder_t5 = embedder_t5
self.infer_mode = infer_mode
# ========================================================
if progress_bar_config is None:
progress_bar_config = {}
if not hasattr(self, "_progress_bar_config"):
self._progress_bar_config = {}
self._progress_bar_config.update(progress_bar_config)
# ========================================================
if (
hasattr(scheduler.config, "steps_offset")
and scheduler.config.steps_offset != 1
):
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate(
"steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False
)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if (
hasattr(scheduler.config, "clip_sample")
and scheduler.config.clip_sample is True
):
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate(
"clip_sample not set", "1.0.0", deprecation_message, standard_warn=False
)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
controlnet=controlnet,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
):
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
prompt_embeds_tuple = self.encode_prompt(
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=lora_scale,
)
# concatenate for backwards comp
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
return prompt_embeds
def encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
embedder=None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
embedder:
T5 embedder (including text encoder and tokenizer)
"""
if embedder is None:
text_encoder = self.text_encoder
tokenizer = self.tokenizer
max_length = self.tokenizer.model_max_length
else:
text_encoder = embedder.model
tokenizer = embedder.tokenizer
max_length = embedder.max_length
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, tokenizer)
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_attention_mask=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(
prompt, padding="longest", return_tensors="pt"
).input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[
-1
] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = tokenizer.batch_decode(
untruncated_ids[:, tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer.model_max_length} tokens: {removed_text}"
)
attention_mask = text_inputs.attention_mask.to(device)
prompt_embeds = text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
attention_mask = attention_mask.repeat(num_images_per_prompt, 1)
else:
attention_mask = None
if text_encoder is not None:
prompt_embeds_dtype = text_encoder.dtype
elif self.unet is not None:
prompt_embeds_dtype = self.unet.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(
bs_embed * num_images_per_prompt, seq_len, -1
)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
uncond_attention_mask = uncond_input.attention_mask.to(device)
negative_prompt_embeds = text_encoder(
uncond_input.input_ids.to(device),
attention_mask=uncond_attention_mask,
)
negative_prompt_embeds = negative_prompt_embeds[0]
uncond_attention_mask = uncond_attention_mask.repeat(
num_images_per_prompt, 1
)
else:
uncond_attention_mask = None
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(
dtype=prompt_embeds_dtype, device=device
)
negative_prompt_embeds = negative_prompt_embeds.repeat(
1, num_images_per_prompt, 1
)
negative_prompt_embeds = negative_prompt_embeds.view(
batch_size * num_images_per_prompt, seq_len, -1
)
return (
prompt_embeds,
negative_prompt_embeds,
attention_mask,
uncond_attention_mask,
)
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(
image, output_type="pil"
)
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(
feature_extractor_input, return_tensors="pt"
).to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept
def decode_latents(self, latents):
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(
self,
prompt,
height,
width,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(
f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
)
if (callback_steps is None) or (
callback_steps is not None
and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (
not isinstance(prompt, str) and not isinstance(prompt, list)
):
raise ValueError(
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
)
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
shape = (
batch_size,
num_channels_latents,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(
shape, generator=generator, device=device, dtype=dtype
)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
height: int,
width: int,
prompt: Union[str, List[str]] = None,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_embeds_t5: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds_t5: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[
Callable[[int, int, torch.FloatTensor, torch.FloatTensor], None]
] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
image_meta_size: Optional[torch.LongTensor] = None,
style: Optional[torch.LongTensor] = None,
progress: bool = True,
use_fp16: bool = False,
freqs_cis_img: Optional[tuple] = None,
learn_sigma: bool = True,
image=None,
control_weight=1.0,
):
r"""
The call function to the pipeline for generation.
Args:
height (`int`):
The height in pixels of the generated image.
width (`int`):
The width in pixels of the generated image.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
latents as `image`, but if passing latents directly it is not encoded again.
strength (`float`, *optional*, defaults to 1.0):
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
essentially ignores `image`.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. This parameter is modulated by `strength`.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor,
pred_x0: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
Examples:
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list with the generated images and the
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None)
if cross_attention_kwargs is not None
else None
)
prompt_embeds, negative_prompt_embeds, attention_mask, uncond_attention_mask = (
self.encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
)
)
(
prompt_embeds_t5,
negative_prompt_embeds_t5,
attention_mask_t5,
uncond_attention_mask_t5,
) = self.encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds_t5,
negative_prompt_embeds=negative_prompt_embeds_t5,
lora_scale=text_encoder_lora_scale,
embedder=self.embedder_t5,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
attention_mask = torch.cat([uncond_attention_mask, attention_mask])
prompt_embeds_t5 = torch.cat([negative_prompt_embeds_t5, prompt_embeds_t5])
attention_mask_t5 = torch.cat([uncond_attention_mask_t5, attention_mask_t5])
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 6. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
condition = (
self.vae.encode(image.float())
.latent_dist.sample(generator)
.mul_(self.vae.config.scaling_factor)
.half()
)
condition = (
torch.cat([condition] * 2) if do_classifier_free_guidance else condition
)
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
)
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
)
# expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
t_expand = torch.tensor(
[t] * latent_model_input.shape[0], device=latent_model_input.device
)
if use_fp16:
latent_model_input = latent_model_input.half()
t_expand = t_expand.half()
prompt_embeds = prompt_embeds.half()
ims = (
image_meta_size.half() if image_meta_size is not None else None
)
else:
ims = image_meta_size if image_meta_size is not None else None
# predict the noise residual
if self.infer_mode in ["fa", "torch"]:
controls = self.controlnet(
latent_model_input,
t_expand,
condition,
encoder_hidden_states=prompt_embeds,
text_embedding_mask=attention_mask,
encoder_hidden_states_t5=prompt_embeds_t5,
text_embedding_mask_t5=attention_mask_t5,
image_meta_size=ims,
style=style,
cos_cis_img=freqs_cis_img[0],
sin_cis_img=freqs_cis_img[1],
return_dict=False,
)
if isinstance(control_weight, list):
assert len(control_weight) == len(controls)
controls = [
control * weight
for control, weight in zip(controls, control_weight)
]
else:
controls = [control * control_weight for control in controls]
noise_pred = self.unet(
latent_model_input,
t_expand,
encoder_hidden_states=prompt_embeds,
text_embedding_mask=attention_mask,
encoder_hidden_states_t5=prompt_embeds_t5,
text_embedding_mask_t5=attention_mask_t5,
image_meta_size=ims,
style=style,
cos_cis_img=freqs_cis_img[0],
sin_cis_img=freqs_cis_img[1],
return_dict=False,
controls=controls,
)
elif self.infer_mode == "trt":
noise_pred = self.unet(
x=latent_model_input.contiguous(),
t_emb=t_expand.contiguous(),
context=prompt_embeds.contiguous(),
image_meta_size=ims.contiguous(),
style=style.contiguous(),
freqs_cis_img0=freqs_cis_img[0].to(device).contiguous(),
freqs_cis_img1=freqs_cis_img[1].to(device).contiguous(),
text_embedding_mask=attention_mask.contiguous(),
encoder_hidden_states_t5=prompt_embeds_t5.contiguous(),
text_embedding_mask_t5=attention_mask_t5.contiguous(),
)
else:
raise ValueError("Unknown infer_mode: {self.infer_mode}")
if learn_sigma:
noise_pred, _ = noise_pred.chunk(2, dim=1)
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
if do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(
noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
)
# compute the previous noisy sample x_t -> x_t-1
results = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs, return_dict=True
)
latents = results.prev_sample
pred_x0 = (
results.pred_original_sample
if hasattr(results, "pred_original_sample")
else None
)
# call the callback, if provided
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents, pred_x0)
if not output_type == "latent":
image = self.vae.decode(
latents / self.vae.config.scaling_factor, return_dict=False
)[0]
image, has_nsfw_concept = self.run_safety_checker(
image, device, prompt_embeds.dtype
)
else:
image = latents
has_nsfw_concept = None
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.image_processor.postprocess(
image, output_type=output_type, do_denormalize=do_denormalize
)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(
images=image, nsfw_content_detected=has_nsfw_concept
)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from diffusers.configuration_utils import FrozenDict
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import (
FromSingleFileMixin,
LoraLoaderMixin,
TextualInversionLoaderMixin,
)
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import (
deprecate,
logging,
replace_example_docstring,
)
from diffusers.utils.torch_utils import randn_tensor
from transformers import BertModel, BertTokenizer
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ..modules.models import HunYuanDiT
import torchvision.transforms as T
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import StableDiffusionPipeline
>>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
>>> pipe = pipe.to("cuda")
>>> prompt = "a photo of an astronaut riding a horse on mars"
>>> image = pipe(prompt).images[0]
```
"""
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
"""
std_text = noise_pred_text.std(
dim=list(range(1, noise_pred_text.ndim)), keepdim=True
)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = (
guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
)
return noise_cfg
class StableDiffusionIPAPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
):
r"""
Pipeline for text-to-image generation using Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]):
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]):
A `BertTokenizer` or `CLIPTokenizer` to tokenize text.
unet (Optional[`HunYuanDiT`, `UNet2DConditionModel`]):
A `HunYuanDiT` or `UNet2DConditionModel` to denoise the encoded image latents.
Notice: Here we still keep the word `unet` for compatibility with the previous version of the pipeline.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
def __init__(
self,
vae: AutoencoderKL,
text_encoder: Union[BertModel, CLIPTextModel],
img_encoder,
tokenizer: Union[BertTokenizer, CLIPTokenizer],
unet: Union[HunYuanDiT, UNet2DConditionModel],
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
progress_bar_config: Dict[str, Any] = None,
embedder_t5=None,
infer_mode="torch",
):
super().__init__()
# ========================================================
self.embedder_t5 = embedder_t5
self.infer_mode = infer_mode
# ========================================================
if progress_bar_config is None:
progress_bar_config = {}
if not hasattr(self, "_progress_bar_config"):
self._progress_bar_config = {}
self._progress_bar_config.update(progress_bar_config)
# ========================================================
if (
hasattr(scheduler.config, "steps_offset")
and scheduler.config.steps_offset != 1
):
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate(
"steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False
)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if (
hasattr(scheduler.config, "clip_sample")
and scheduler.config.clip_sample is True
):
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate(
"clip_sample not set", "1.0.0", deprecation_message, standard_warn=False
)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
img_encoder=img_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
self.image_transform = T.Compose(
[
T.Resize((224, 224), interpolation=T.InterpolationMode.BICUBIC),
lambda x: x.convert("RGB"),
T.ToTensor(),
T.Normalize(
(0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711),
),
]
)
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
):
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
prompt_embeds_tuple = self.encode_prompt(
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=lora_scale,
)
# concatenate for backwards comp
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
return prompt_embeds
def encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
embedder=None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
embedder:
T5 embedder (including text encoder and tokenizer)
"""
if embedder is None:
text_encoder = self.text_encoder
tokenizer = self.tokenizer
max_length = self.tokenizer.model_max_length
else:
text_encoder = embedder.model
tokenizer = embedder.tokenizer
max_length = embedder.max_length
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, tokenizer)
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_attention_mask=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(
prompt, padding="longest", return_tensors="pt"
).input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[
-1
] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = tokenizer.batch_decode(
untruncated_ids[:, tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer.model_max_length} tokens: {removed_text}"
)
attention_mask = text_inputs.attention_mask.to(device)
prompt_embeds = text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
attention_mask = attention_mask.repeat(num_images_per_prompt, 1)
else:
attention_mask = None
if text_encoder is not None:
prompt_embeds_dtype = text_encoder.dtype
elif self.unet is not None:
prompt_embeds_dtype = self.unet.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(
bs_embed * num_images_per_prompt, seq_len, -1
)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
uncond_attention_mask = uncond_input.attention_mask.to(device)
negative_prompt_embeds = text_encoder(
uncond_input.input_ids.to(device),
attention_mask=uncond_attention_mask,
)
negative_prompt_embeds = negative_prompt_embeds[0]
uncond_attention_mask = uncond_attention_mask.repeat(
num_images_per_prompt, 1
)
else:
uncond_attention_mask = None
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(
dtype=prompt_embeds_dtype, device=device
)
negative_prompt_embeds = negative_prompt_embeds.repeat(
1, num_images_per_prompt, 1
)
negative_prompt_embeds = negative_prompt_embeds.view(
batch_size * num_images_per_prompt, seq_len, -1
)
return (
prompt_embeds,
negative_prompt_embeds,
attention_mask,
uncond_attention_mask,
)
def encode_img(self, img, device, do_classifier_free_guidance):
# print('len', len(img))
# print('img', img.size)
# img = img[0] # TODO: support batch processing
image_preprocess = self.image_transform
img_for_clip = image_preprocess(img)
# print('img_for_clip', img_for_clip.shape)
img_for_clip = img_for_clip.unsqueeze(0)
img_clip_embedding = self.img_encoder(img_for_clip.to(device)).to(
dtype=torch.float16
)
# print('img_clip_embedding_1_type', img_clip_embedding.dtype)
if do_classifier_free_guidance:
negative_img_clip_embedding = torch.zeros_like(img_clip_embedding)
return img_clip_embedding, negative_img_clip_embedding
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(
image, output_type="pil"
)
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(
feature_extractor_input, return_tensors="pt"
).to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept
def decode_latents(self, latents):
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(
self,
prompt,
height,
width,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(
f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
)
if (callback_steps is None) or (
callback_steps is not None
and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (
not isinstance(prompt, str) and not isinstance(prompt, list)
):
raise ValueError(
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
)
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
shape = (
batch_size,
num_channels_latents,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(
shape, generator=generator, device=device, dtype=dtype
)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
height: int,
width: int,
prompt: Union[str, List[str]] = None,
image=None,
t_scale=1,
i_scale=1,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_embeds_t5: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds_t5: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[
Callable[[int, int, torch.FloatTensor, torch.FloatTensor], None]
] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
image_meta_size: Optional[torch.LongTensor] = None,
style: Optional[torch.LongTensor] = None,
progress: bool = True,
use_fp16: bool = False,
freqs_cis_img: Optional[tuple] = None,
learn_sigma: bool = True,
):
r"""
The call function to the pipeline for generation.
Args:
height (`int`):
The height in pixels of the generated image.
width (`int`):
The width in pixels of the generated image.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
latents as `image`, but if passing latents directly it is not encoded again.
strength (`float`, *optional*, defaults to 1.0):
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
essentially ignores `image`.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. This parameter is modulated by `strength`.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor,
pred_x0: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
Examples:
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list with the generated images and the
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
img_embeds, negative_img_embeds = self.encode_img(
image, device, do_classifier_free_guidance
)
# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None)
if cross_attention_kwargs is not None
else None
)
prompt_embeds, negative_prompt_embeds, attention_mask, uncond_attention_mask = (
self.encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
)
)
(
prompt_embeds_t5,
negative_prompt_embeds_t5,
attention_mask_t5,
uncond_attention_mask_t5,
) = self.encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds_t5,
negative_prompt_embeds=negative_prompt_embeds_t5,
lora_scale=text_encoder_lora_scale,
embedder=self.embedder_t5,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
attention_mask = torch.cat([uncond_attention_mask, attention_mask])
prompt_embeds_t5 = torch.cat([negative_prompt_embeds_t5, prompt_embeds_t5])
attention_mask_t5 = torch.cat([uncond_attention_mask_t5, attention_mask_t5])
img_embeds = torch.cat([negative_img_embeds, img_embeds])
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 6. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
)
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
)
# expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
t_expand = torch.tensor(
[t] * latent_model_input.shape[0], device=latent_model_input.device
)
if use_fp16:
latent_model_input = latent_model_input.half()
t_expand = t_expand.half()
prompt_embeds = prompt_embeds.half()
ims = (
image_meta_size.half() if image_meta_size is not None else None
)
else:
ims = image_meta_size if image_meta_size is not None else None
# predict the noise residual
if self.infer_mode in ["fa", "torch"]:
noise_pred = self.unet(
latent_model_input,
t_expand,
t_scale=t_scale,
i_scale=i_scale,
encoder_hidden_states=prompt_embeds,
text_embedding_mask=attention_mask,
img_clip_embedding=img_embeds,
encoder_hidden_states_t5=prompt_embeds_t5,
text_embedding_mask_t5=attention_mask_t5,
image_meta_size=ims,
style=style,
cos_cis_img=freqs_cis_img[0],
sin_cis_img=freqs_cis_img[1],
return_dict=False,
)
elif self.infer_mode == "trt":
noise_pred = self.unet(
x=latent_model_input.contiguous(),
t_emb=t_expand.contiguous(),
context=prompt_embeds.contiguous(),
image_meta_size=ims.contiguous(),
style=style.contiguous(),
freqs_cis_img0=freqs_cis_img[0].to(device).contiguous(),
freqs_cis_img1=freqs_cis_img[1].to(device).contiguous(),
text_embedding_mask=attention_mask.contiguous(),
encoder_hidden_states_t5=prompt_embeds_t5.contiguous(),
text_embedding_mask_t5=attention_mask_t5.contiguous(),
)
else:
raise ValueError("Unknown infer_mode: {self.infer_mode}")
if learn_sigma:
noise_pred, _ = noise_pred.chunk(2, dim=1)
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
if do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(
noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
)
# compute the previous noisy sample x_t -> x_t-1
results = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs, return_dict=True
)
latents = results.prev_sample
pred_x0 = (
results.pred_original_sample
if hasattr(results, "pred_original_sample")
else None
)
# call the callback, if provided
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents, pred_x0)
if not output_type == "latent":
image = self.vae.decode(
latents / self.vae.config.scaling_factor, return_dict=False
)[0]
image, has_nsfw_concept = self.run_safety_checker(
image, device, prompt_embeds.dtype
)
else:
image = latents
has_nsfw_concept = None
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.image_processor.postprocess(
image, output_type=output_type, do_denormalize=do_denormalize
)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(
images=image, nsfw_content_detected=has_nsfw_concept
)
import numpy as np
import torch as th
from .gaussian_diffusion import GaussianDiffusion
def space_timesteps(num_timesteps, section_counts):
"""
Create a list of timesteps to use from an original diffusion process,
given the number of timesteps we want to take from equally-sized portions
of the original process.
For example, if there's 300 timesteps and the section counts are [10,15,20]
then the first 100 timesteps are strided to be 10 timesteps, the second 100
are strided to be 15 timesteps, and the final 100 are strided to be 20.
If the stride is a string starting with "ddim", then the fixed striding
from the DDIM paper is used, and only one section is allowed.
:param num_timesteps: the number of diffusion steps in the original
process to divide up.
:param section_counts: either a list of numbers, or a string containing
comma-separated numbers, indicating the step count
per section. As a special case, use "ddimN" where N
is a number of steps to use the striding from the
DDIM paper.
:return: a set of diffusion steps from the original process to use.
"""
if isinstance(section_counts, str):
if section_counts.startswith("ddim"):
desired_count = int(section_counts[len("ddim") :])
for i in range(1, num_timesteps):
if len(range(0, num_timesteps, i)) == desired_count:
return set(range(0, num_timesteps, i))
raise ValueError(
f"cannot create exactly {num_timesteps} steps with an integer stride"
)
section_counts = [int(x) for x in section_counts.split(",")]
size_per = num_timesteps // len(section_counts)
extra = num_timesteps % len(section_counts)
start_idx = 0
all_steps = []
for i, section_count in enumerate(section_counts):
size = size_per + (1 if i < extra else 0)
if size < section_count:
raise ValueError(
f"cannot divide section of {size} steps into {section_count}"
)
if section_count <= 1:
frac_stride = 1
else:
frac_stride = (size - 1) / (section_count - 1)
cur_idx = 0.0
taken_steps = []
for _ in range(section_count):
taken_steps.append(start_idx + round(cur_idx))
cur_idx += frac_stride
all_steps += taken_steps
start_idx += size
return set(all_steps)
class SpacedDiffusion(GaussianDiffusion):
"""
Improved DDPM
A diffusion process which can skip steps in a base diffusion process.
:param use_timesteps: a collection (sequence or set) of timesteps from the
original diffusion process to retain.
:param kwargs: the kwargs to create the base diffusion process.
"""
def __init__(self, use_timesteps, **kwargs):
self.use_timesteps = set(use_timesteps)
self.timestep_map = []
self.original_num_steps = len(kwargs["betas"])
base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
last_alpha_cumprod = 1.0
new_betas = []
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
if i in self.use_timesteps:
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
last_alpha_cumprod = alpha_cumprod
self.timestep_map.append(i)
kwargs["betas"] = np.array(new_betas)
super().__init__(**kwargs)
def p_mean_variance(
self, model, *args, **kwargs
): # pylint: disable=signature-differs
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
def training_losses(
self, model, controlnet=None, *args, **kwargs
): # pylint: disable=signature-differs
if controlnet != None:
return super().training_losses(
self._wrap_model(model),
controlnet=self._wrap_model(controlnet),
*args,
**kwargs,
)
else:
return super().training_losses(self._wrap_model(model), *args, **kwargs)
def condition_mean(self, cond_fn, *args, **kwargs):
return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
def condition_score(self, cond_fn, *args, **kwargs):
return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
def get_eps(self, model, *args, **kwargs):
return super().get_eps(self._wrap_model(model), *args, **kwargs)
def _wrap_model(self, model):
if isinstance(model, _WrappedModel):
return model
return _WrappedModel(model, self.timestep_map, self.original_num_steps)
def _scale_timesteps(self, t):
# Scaling is done by the wrapped model.
return t
class _WrappedModel:
"""
Improved DDPM
When using a subsequent timesteps (e.g., 250), we must wrap the model
for mapping the timesteps 1-250 with step 1 to 1-1000 with step 4
"""
def __init__(self, model, timestep_map, original_num_steps):
self.model = model
self.timestep_map = timestep_map
# self.rescale_timesteps = rescale_timesteps
self.original_num_steps = original_num_steps
def __call__(self, x, ts, **kwargs):
"""
Here we must make a interpolation because `ts` maybe a float (e.g., 4.5)
in the PLMS/PNDM sampler.
"""
ts = ts.float()
frac = ts.frac()
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
new_ts_1 = map_tensor[ts.floor().long()]
new_ts_2 = map_tensor[ts.ceil().long()]
new_ts = th.lerp(new_ts_1, new_ts_2, frac)
return self.model(x, new_ts, **kwargs)
# -*- coding: utf-8 -*-
import os
def deepspeed_config_from_args(args, global_batch_size):
if args.use_zero_stage == 2:
deepspeed_config = {
"zero_allow_untested_optimizer": True,
"train_batch_size": global_batch_size,
"train_micro_batch_size_per_gpu": args.batch_size,
"gradient_accumulation_steps": args.grad_accu_steps,
"steps_per_print": args.log_every,
"optimizer": {
"type": "AdamW",
"params": {
"lr": args.lr,
"betas": [0.9, 0.999],
"eps": 1e-08,
"weight_decay": args.weight_decay,
},
},
"zero_optimization": {
"stage": 2,
"reduce_scatter": False,
"reduce_bucket_size": 1e9,
},
"gradient_clipping": 1.0,
"prescale_gradients": True,
"fp16": {
"enabled": args.use_fp16,
"loss_scale": 0,
"loss_scale_window": 500,
"hysteresis": 2,
"min_loss_scale": 1e-3,
"initial_scale_power": 15,
},
"bf16": {"enabled": False},
"wall_clock_breakdown": False,
}
if args.cpu_offloading == True:
deepspeed_config["zero_optimization"]["offload_optimizer"] = {
"device": "cpu",
"pin_memory": True,
}
deepspeed_config["zero_optimization"]["offload_parameter"] = {
"device": "cpu",
"pin_memory": True,
}
elif args.use_zero_stage == 3:
deepspeed_config = {
"train_batch_size": args.global_batch_size,
# "train_micro_batch_size_per_gpu": args.batch_size,
"gradient_accumulation_steps": args.grad_accu_steps,
"steps_per_print": args.log_every,
"optimizer": {
"type": "AdamW",
"params": {
"lr": args.lr,
"betas": [0.9, 0.999],
"eps": 1e-08,
"weight_decay": args.weight_decay,
},
},
"zero_optimization": {
"stage": 3,
"allgather_partitions": True,
"overlap_comm": True,
"reduce_scatter": True,
"contiguous_gradients": True,
"stage3_prefetch_bucket_size": 5e8,
"stage3_max_live_parameters": 6e8,
"reduce_bucket_size": 1.2e9,
"sub_group_size": 1e9,
"sub_group_buffer_num": 10,
"pipeline_optimizer": True,
"max_contigous_event_size": 0,
"cache_sub_group_rate": 0.0,
"prefetch_cache_sub_group_rate": 1.0,
"max_contigous_params_size": -1,
"max_param_reduce_events": 0,
"stage3_param_persistence_threshold": 9e9,
"is_communication_time_profiling": False,
"save_large_model_multi_slice": True,
"use_fused_op_with_grad_norm_overflow": False,
},
"gradient_clipping": 1.0,
"prescale_gradients": False,
"fp16": {
"enabled": True,
"loss_scale": 0,
"loss_scale_window": 500,
"hysteresis": 2,
"min_loss_scale": 1,
"initial_scale_power": 15,
},
"bf16": {"enabled": False},
"wall_clock_breakdown": False,
"mem_chunk": {
"default_chunk_size": 536870911,
"use_fake_dist": False,
"client": {
"mem_tracer": {
"use_async_mem_monitor": True,
"warmup_gpu_chunk_mem_ratio": 0.8,
"overall_gpu_mem_ratio": 0.8,
"overall_cpu_mem_ratio": 1.0,
"margin_use_ratio": 0.8,
"use_fake_dist": False,
},
"opts": {"with_mem_cache": True, "with_async_move": True},
},
},
}
if args.cpu_offloading == True:
deepspeed_config["zero_optimization"]["offload_optimizer"] = {
"device": "cpu",
"pin_memory": True,
}
deepspeed_config["zero_optimization"]["offload_parameter"] = {
"device": "cpu",
"pin_memory": True,
}
else:
raise ValueError
return deepspeed_config
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment