"vscode:/vscode.git/clone" did not exist on "f3c89cc696cb77a3d453d93f06b637b202316af1"
Commit 78bae405 authored by mashun1's avatar mashun1
Browse files

open_sora_inference

parents
from .stdit import STDiT
This diff is collapsed.
from .classes import ClassEncoder
from .clip import ClipEncoder
from .t5 import T5Encoder
import torch
from opensora.registry import MODELS
@MODELS.register_module("classes")
class ClassEncoder:
def __init__(self, num_classes, model_max_length=None, device="cuda", dtype=torch.float):
self.num_classes = num_classes
self.y_embedder = None
self.model_max_length = model_max_length
self.output_dim = None
self.device = device
def encode(self, text):
return dict(y=torch.tensor([int(t) for t in text]).to(self.device))
def null(self, n):
return torch.tensor([self.num_classes] * n).to(self.device)
# Copyright 2024 Vchitect/Latte
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.# Modified from Latte
#
# This file is adapted from the Latte project.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# Latte: https://github.com/Vchitect/Latte
# DiT: https://github.com/facebookresearch/DiT/tree/main
# --------------------------------------------------------
import torch
import torch.nn as nn
import transformers
from transformers import CLIPTextModel, CLIPTokenizer
from opensora.registry import MODELS
transformers.logging.set_verbosity_error()
class AbstractEncoder(nn.Module):
def __init__(self):
super().__init__()
def encode(self, *args, **kwargs):
raise NotImplementedError
class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
def __init__(self, path="openai/clip-vit-huge-patch14", device="cuda", max_length=77):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(path)
self.transformer = CLIPTextModel.from_pretrained(path)
self.device = device
self.max_length = max_length
self._freeze()
def _freeze(self):
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
pooled_z = outputs.pooler_output
return z, pooled_z
def encode(self, text):
return self(text)
@MODELS.register_module("clip")
class ClipEncoder:
"""
Embeds text prompt into vector representations. Also handles text dropout for classifier-free guidance.
"""
def __init__(
self,
from_pretrained,
model_max_length=77,
device="cuda",
dtype=torch.float,
):
super().__init__()
assert from_pretrained is not None, "Please specify the path to the T5 model"
self.text_encoder = FrozenCLIPEmbedder(path=from_pretrained, max_length=model_max_length).to(device, dtype)
self.y_embedder = None
self.model_max_length = model_max_length
self.output_dim = self.text_encoder.transformer.config.hidden_size
def encode(self, text):
_, pooled_embeddings = self.text_encoder.encode(text)
y = pooled_embeddings.unsqueeze(1).unsqueeze(1)
return dict(y=y)
def null(self, n):
null_y = self.y_embedder.y_embedding[None].repeat(n, 1, 1)[:, None]
return null_y
def to(self, dtype):
self.text_encoder = self.text_encoder.to(dtype)
return self
This diff is collapsed.
from .vae import VideoAutoencoderKL, VideoAutoencoderKLTemporalDecoder
import torch
import torch.nn as nn
from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder
from einops import rearrange
from opensora.registry import MODELS
@MODELS.register_module()
class VideoAutoencoderKL(nn.Module):
def __init__(self, from_pretrained=None, micro_batch_size=None):
super().__init__()
self.module = AutoencoderKL.from_pretrained(from_pretrained)
self.out_channels = self.module.config.latent_channels
self.patch_size = (1, 8, 8)
self.micro_batch_size = micro_batch_size
def encode(self, x):
# x: (B, C, T, H, W)
B = x.shape[0]
x = rearrange(x, "B C T H W -> (B T) C H W")
if self.micro_batch_size is None:
x = self.module.encode(x).latent_dist.sample().mul_(0.18215)
else:
bs = self.micro_batch_size
x_out = []
for i in range(0, x.shape[0], bs):
x_bs = x[i : i + bs]
x_bs = self.module.encode(x_bs).latent_dist.sample().mul_(0.18215)
x_out.append(x_bs)
x = torch.cat(x_out, dim=0)
x = rearrange(x, "(B T) C H W -> B C T H W", B=B)
return x
def decode(self, x):
# x: (B, C, T, H, W)
B = x.shape[0]
x = rearrange(x, "B C T H W -> (B T) C H W")
if self.micro_batch_size is None:
x = self.module.decode(x / 0.18215).sample
else:
bs = self.micro_batch_size
x_out = []
for i in range(0, x.shape[0], bs):
x_bs = x[i : i + bs]
x_bs = self.module.decode(x_bs / 0.18215).sample
x_out.append(x_bs)
x = torch.cat(x_out, dim=0)
x = rearrange(x, "(B T) C H W -> B C T H W", B=B)
return x
def get_latent_size(self, input_size):
for i in range(3):
assert input_size[i] % self.patch_size[i] == 0, "Input size must be divisible by patch size"
input_size = [input_size[i] // self.patch_size[i] for i in range(3)]
return input_size
@MODELS.register_module()
class VideoAutoencoderKLTemporalDecoder(nn.Module):
def __init__(self, from_pretrained=None):
super().__init__()
self.module = AutoencoderKLTemporalDecoder.from_pretrained(from_pretrained)
self.out_channels = self.module.config.latent_channels
self.patch_size = (1, 8, 8)
def encode(self, x):
raise NotImplementedError
def decode(self, x):
B, _, T = x.shape[:3]
x = rearrange(x, "B C T H W -> (B T) C H W")
x = self.module.decode(x / 0.18215, num_frames=T).sample
x = rearrange(x, "(B T) C H W -> B C T H W", B=B)
return x
def get_latent_size(self, input_size):
for i in range(3):
assert input_size[i] % self.patch_size[i] == 0, "Input size must be divisible by patch size"
input_size = [input_size[i] // self.patch_size[i] for i in range(3)]
return input_size
from copy import deepcopy
import torch.nn as nn
from mmengine.registry import Registry
def build_module(module, builder, **kwargs):
"""Build module from config or return the module itself.
Args:
module (Union[dict, nn.Module]): The module to build.
builder (Registry): The registry to build module.
*args, **kwargs: Arguments passed to build function.
Returns:
Any: The built module.
"""
if isinstance(module, dict):
cfg = deepcopy(module)
for k, v in kwargs.items():
cfg[k] = v
return builder.build(cfg)
elif isinstance(module, nn.Module):
return module
elif module is None:
return None
else:
raise TypeError(f"Only support dict and nn.Module, but got {type(module)}.")
MODELS = Registry(
"model",
locations=["opensora.models"],
)
SCHEDULERS = Registry(
"scheduler",
locations=["opensora.schedulers"],
)
from .dpms import DPMS
from .iddpm import IDDPM
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment