Commit 9ff47a7e authored by mashun1's avatar mashun1
Browse files

latte

parents
Pipeline #792 canceled with stages
import os
import json
import torch
import decord
import torchvision
import numpy as np
import random
from PIL import Image
from einops import rearrange
from typing import Dict, List, Tuple
from torchvision import transforms
import traceback
class_labels_map = None
cls_sample_cnt = None
def temporal_sampling(frames, start_idx, end_idx, num_samples):
"""
Given the start and end frame index, sample num_samples frames between
the start and end with equal interval.
Args:
frames (tensor): a tensor of video frames, dimension is
`num video frames` x `channel` x `height` x `width`.
start_idx (int): the index of the start frame.
end_idx (int): the index of the end frame.
num_samples (int): number of frames to sample.
Returns:
frames (tersor): a tensor of temporal sampled video frames, dimension is
`num clip frames` x `channel` x `height` x `width`.
"""
index = torch.linspace(start_idx, end_idx, num_samples)
index = torch.clamp(index, 0, frames.shape[0] - 1).long()
frames = torch.index_select(frames, 0, index)
return frames
def numpy2tensor(x):
return torch.from_numpy(x)
def get_filelist(file_path):
Filelist = []
for home, dirs, files in os.walk(file_path):
for filename in files:
# 文件名列表,包含完整路径
Filelist.append(os.path.join(home, filename))
# # 文件名列表,只包含文件名
# Filelist.append( filename)
return Filelist
def load_annotation_data(data_file_path):
with open(data_file_path, 'r') as data_file:
return json.load(data_file)
def get_class_labels(num_class, anno_pth='./k400_classmap.json'):
global class_labels_map, cls_sample_cnt
if class_labels_map is not None:
return class_labels_map, cls_sample_cnt
else:
cls_sample_cnt = {}
class_labels_map = load_annotation_data(anno_pth)
for cls in class_labels_map:
cls_sample_cnt[cls] = 0
return class_labels_map, cls_sample_cnt
def load_annotations(ann_file, num_class, num_samples_per_cls):
dataset = []
class_to_idx, cls_sample_cnt = get_class_labels(num_class)
with open(ann_file, 'r') as fin:
for line in fin:
line_split = line.strip().split('\t')
sample = {}
idx = 0
# idx for frame_dir
frame_dir = line_split[idx]
sample['video'] = frame_dir
idx += 1
# idx for label[s]
label = [x for x in line_split[idx:]]
assert label, f'missing label in line: {line}'
assert len(label) == 1
class_name = label[0]
class_index = int(class_to_idx[class_name])
# choose a class subset of whole dataset
if class_index < num_class:
sample['label'] = class_index
if cls_sample_cnt[class_name] < num_samples_per_cls:
dataset.append(sample)
cls_sample_cnt[class_name]+=1
return dataset
class DecordInit(object):
"""Using Decord(https://github.com/dmlc/decord) to initialize the video_reader."""
def __init__(self, num_threads=1, **kwargs):
self.num_threads = num_threads
self.ctx = decord.cpu(0)
self.kwargs = kwargs
def __call__(self, filename):
"""Perform the Decord initialization.
Args:
results (dict): The resulting dict to be modified and passed
to the next transform in pipeline.
"""
reader = decord.VideoReader(filename,
ctx=self.ctx,
num_threads=self.num_threads)
return reader
def __repr__(self):
repr_str = (f'{self.__class__.__name__}('
f'sr={self.sr},'
f'num_threads={self.num_threads})')
return repr_str
class FaceForensicsImages(torch.utils.data.Dataset):
"""Load the FaceForensics video files
Args:
target_video_len (int): the number of video frames will be load.
align_transform (callable): Align different videos in a specified size.
temporal_sample (callable): Sample the target length of a video.
"""
def __init__(self,
configs,
transform=None,
temporal_sample=None):
self.configs = configs
self.data_path = configs.data_path
self.video_lists = get_filelist(configs.data_path)
self.transform = transform
self.temporal_sample = temporal_sample
self.target_video_len = self.configs.num_frames
self.v_decoder = DecordInit()
self.video_length = len(self.video_lists)
# ffs video frames
self.video_frame_path = configs.frame_data_path
self.video_frame_txt = configs.frame_data_txt
self.video_frame_files = [frame_file.strip() for frame_file in open(self.video_frame_txt)]
random.shuffle(self.video_frame_files)
self.use_image_num = configs.use_image_num
self.image_tranform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])
def __getitem__(self, index):
video_index = index % self.video_length
path = self.video_lists[video_index]
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW')
total_frames = len(vframes)
# Sampling video frames
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
assert end_frame_ind - start_frame_ind >= self.target_video_len
frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, self.target_video_len, dtype=int)
video = vframes[frame_indice]
# videotransformer data proprecess
video = self.transform(video) # T C H W
# get video frames
images = []
for i in range(self.use_image_num):
while True:
try:
image = Image.open(os.path.join(self.video_frame_path, self.video_frame_files[index+i])).convert("RGB")
image = self.image_tranform(image).unsqueeze(0)
images.append(image)
break
except Exception as e:
traceback.print_exc()
index = random.randint(0, len(self.video_frame_files) - self.use_image_num)
images = torch.cat(images, dim=0)
assert len(images) == self.use_image_num
video_cat = torch.cat([video, images], dim=0)
return {'video': video_cat, 'video_name': 1}
def __len__(self):
return len(self.video_frame_files)
if __name__ == '__main__':
import argparse
import torchvision
import video_transforms
import torch.utils.data as Data
import torchvision.transforms as transform
from PIL import Image
parser = argparse.ArgumentParser()
parser.add_argument("--num_frames", type=int, default=16)
parser.add_argument("--use-image-num", type=int, default=5)
parser.add_argument("--frame_interval", type=int, default=3)
parser.add_argument("--dataset", type=str, default='webvideo10m')
parser.add_argument("--test-run", type=bool, default='')
parser.add_argument("--data-path", type=str, default="/path/to/datasets/preprocessed_ffs/train/videos/")
parser.add_argument("--frame-data-path", type=str, default="/path/to/datasets/preprocessed_ffs/train/images/")
parser.add_argument("--frame-data-txt", type=str, default="/path/to/datasets/faceForensics_v1/train_list.txt")
config = parser.parse_args()
temporal_sample = video_transforms.TemporalRandomCrop(config.num_frames * config.frame_interval)
transform_webvideo = transform.Compose([
video_transforms.ToTensorVideo(),
transform.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
dataset = FaceForensicsImages(config, transform=transform_webvideo, temporal_sample=temporal_sample)
dataloader = Data.DataLoader(dataset=dataset, batch_size=1, shuffle=True, num_workers=4)
for i, video_data in enumerate(dataloader):
video, video_label = video_data['video'], video_data['video_name']
# print(video_label)
# print(image_label)
print(video.shape)
print(video_label)
# video_ = ((video[0] * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1)
# print(video_.shape)
# try:
# torchvision.io.write_video(f'./test/{i:03d}_{video_label}.mp4', video_[:16], fps=8)
# except:
# pass
# if i % 100 == 0 and i != 0:
# break
print('Done!')
\ No newline at end of file
import os
import torch
import random
import torch.utils.data as data
import numpy as np
from PIL import Image
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG']
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
class Sky(data.Dataset):
def __init__(self, configs, transform, temporal_sample=None, train=True):
self.configs = configs
self.data_path = configs.data_path
self.transform = transform
self.temporal_sample = temporal_sample
self.target_video_len = self.configs.num_frames
self.frame_interval = self.configs.frame_interval
self.data_all = self.load_video_frames(self.data_path)
def __getitem__(self, index):
vframes = self.data_all[index]
total_frames = len(vframes)
# Sampling video frames
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
assert end_frame_ind - start_frame_ind >= self.target_video_len
frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, num=self.target_video_len, dtype=int) # start, stop, num=50
select_video_frames = vframes[frame_indice[0]: frame_indice[-1]+1: self.frame_interval]
video_frames = []
for path in select_video_frames:
video_frame = torch.as_tensor(np.array(Image.open(path), dtype=np.uint8, copy=True)).unsqueeze(0)
video_frames.append(video_frame)
video_clip = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2)
video_clip = self.transform(video_clip)
return {'video': video_clip, 'video_name': 1}
def __len__(self):
return self.video_num
def load_video_frames(self, dataroot):
data_all = []
frame_list = os.walk(dataroot)
for _, meta in enumerate(frame_list):
root = meta[0]
try:
frames = sorted(meta[2], key=lambda item: int(item.split('.')[0].split('_')[-1]))
except:
print(meta[0]) # root
print(meta[2]) # files
frames = [os.path.join(root, item) for item in frames if is_image_file(item)]
if len(frames) > max(0, self.target_video_len * self.frame_interval): # need all > (16 * frame-interval) videos
# if len(frames) >= max(0, self.target_video_len): # need all > 16 frames videos
data_all.append(frames)
self.video_num = len(data_all)
return data_all
if __name__ == '__main__':
import argparse
import torchvision
import video_transforms
import torch.utils.data as data
from torchvision import transforms
from torchvision.utils import save_image
parser = argparse.ArgumentParser()
parser.add_argument("--num_frames", type=int, default=16)
parser.add_argument("--frame_interval", type=int, default=4)
parser.add_argument("--data-path", type=str, default="/path/to/datasets/sky_timelapse/sky_train/")
config = parser.parse_args()
target_video_len = config.num_frames
temporal_sample = video_transforms.TemporalRandomCrop(target_video_len * config.frame_interval)
trans = transforms.Compose([
video_transforms.ToTensorVideo(),
# video_transforms.CenterCropVideo(256),
video_transforms.CenterCropResizeVideo(256),
# video_transforms.RandomHorizontalFlipVideo(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])
taichi_dataset = Sky(config, transform=trans, temporal_sample=temporal_sample)
print(len(taichi_dataset))
taichi_dataloader = data.DataLoader(dataset=taichi_dataset, batch_size=1, shuffle=False, num_workers=1)
for i, video_data in enumerate(taichi_dataloader):
print(video_data['video'].shape)
# print(video_data.dtype)
# for i in range(target_video_len):
# save_image(video_data[0][i], os.path.join('./test_data', '%04d.png' % i), normalize=True, value_range=(-1, 1))
# video_ = ((video_data[0] * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1)
# torchvision.io.write_video('./test_data' + 'test.mp4', video_, fps=8)
# exit()
\ No newline at end of file
import os
import torch
import random
import torch.utils.data as data
import numpy as np
import copy
from PIL import Image
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG']
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
class SkyImages(data.Dataset):
def __init__(self, configs, transform, temporal_sample=None, train=True):
self.configs = configs
self.data_path = configs.data_path
self.transform = transform
self.temporal_sample = temporal_sample
self.target_video_len = self.configs.num_frames
self.frame_interval = self.configs.frame_interval
self.data_all, self.video_frame_all = self.load_video_frames(self.data_path)
# sky video frames
random.shuffle(self.video_frame_all)
self.use_image_num = configs.use_image_num
def __getitem__(self, index):
video_index = index % self.video_num
vframes = self.data_all[video_index]
total_frames = len(vframes)
# Sampling video frames
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
assert end_frame_ind - start_frame_ind >= self.target_video_len
frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, num=self.target_video_len, dtype=int) # start, stop, num=50
select_video_frames = vframes[frame_indice[0]: frame_indice[-1]+1: self.frame_interval]
video_frames = []
for path in select_video_frames:
video_frame = torch.as_tensor(np.array(Image.open(path), dtype=np.uint8, copy=True)).unsqueeze(0)
video_frames.append(video_frame)
video_clip = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2)
video_clip = self.transform(video_clip)
# get video frames
images = []
for i in range(self.use_image_num):
while True:
try:
video_frame_path = self.video_frame_all[index+i]
image = torch.as_tensor(np.array(Image.open(video_frame_path), dtype=np.uint8, copy=True)).unsqueeze(0)
images.append(image)
break
except Exception as e:
index = random.randint(0, self.video_frame_num - self.use_image_num)
images = torch.cat(images, dim=0).permute(0, 3, 1, 2)
images = self.transform(images)
assert len(images) == self.use_image_num
video_cat = torch.cat([video_clip, images], dim=0)
return {'video': video_cat, 'video_name': 1}
def __len__(self):
return self.video_frame_num
def load_video_frames(self, dataroot):
data_all = []
frames_all = []
frame_list = os.walk(dataroot)
for _, meta in enumerate(frame_list):
root = meta[0]
try:
frames = sorted(meta[2], key=lambda item: int(item.split('.')[0].split('_')[-1]))
except:
print(meta[0]) # root
print(meta[2]) # files
frames = [os.path.join(root, item) for item in frames if is_image_file(item)]
if len(frames) > max(0, self.target_video_len * self.frame_interval): # need all > (16 * frame-interval) videos
# if len(frames) >= max(0, self.target_video_len): # need all > 16 frames videos
data_all.append(frames)
for frame in frames:
frames_all.append(frame)
self.video_num = len(data_all)
self.video_frame_num = len(frames_all)
return data_all, frames_all
if __name__ == '__main__':
import argparse
import torchvision
import video_transforms
import torch.utils.data as data
from torchvision import transforms
from torchvision.utils import save_image
parser = argparse.ArgumentParser()
parser.add_argument("--num_frames", type=int, default=16)
parser.add_argument("--frame_interval", type=int, default=3)
parser.add_argument("--data-path", type=str, default="/path/to/datasets/sky_timelapse/sky_train/")
parser.add_argument("--use-image-num", type=int, default=5)
config = parser.parse_args()
target_video_len = config.num_frames
temporal_sample = video_transforms.TemporalRandomCrop(target_video_len * config.frame_interval)
trans = transforms.Compose([
video_transforms.ToTensorVideo(),
# video_transforms.CenterCropVideo(256),
video_transforms.CenterCropResizeVideo(256),
# video_transforms.RandomHorizontalFlipVideo(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])
taichi_dataset = SkyImages(config, transform=trans, temporal_sample=temporal_sample)
print(len(taichi_dataset))
taichi_dataloader = data.DataLoader(dataset=taichi_dataset, batch_size=1, shuffle=False, num_workers=1)
for i, video_data in enumerate(taichi_dataloader):
print(video_data['video'].shape)
# print(video_data.dtype)
# for i in range(target_video_len):
# save_image(video_data[0][i], os.path.join('./test_data', '%04d.png' % i), normalize=True, value_range=(-1, 1))
# video_ = ((video_data[0] * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1)
# torchvision.io.write_video('./test_data' + 'test.mp4', video_, fps=8)
# exit()
\ No newline at end of file
import os
import torch
import random
import torch.utils.data as data
import numpy as np
import io
import json
from PIL import Image
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG']
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
class Taichi(data.Dataset):
def __init__(self, configs, transform, temporal_sample=None, train=True):
self.configs = configs
self.data_path = configs.data_path
self.transform = transform
self.temporal_sample = temporal_sample
self.target_video_len = self.configs.num_frames
self.frame_interval = self.configs.frame_interval
self.data_all = self.load_video_frames(self.data_path)
self.video_num = len(self.data_all)
def __getitem__(self, index):
vframes = self.data_all[index]
total_frames = len(vframes)
# Sampling video frames
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
assert end_frame_ind - start_frame_ind >= self.target_video_len
frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, self.target_video_len, dtype=int)
select_video_frames = vframes[frame_indice[0]: frame_indice[-1]+1: self.frame_interval]
video_frames = []
for path in select_video_frames:
image = Image.open(path).convert('RGB')
video_frame = torch.as_tensor(np.array(image, dtype=np.uint8, copy=True)).unsqueeze(0)
video_frames.append(video_frame)
video_clip = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2)
video_clip = self.transform(video_clip)
# return video_clip, 1
return {'video': video_clip, 'video_name': 1}
def __len__(self):
return self.video_num
def load_video_frames(self, dataroot):
data_all = []
frame_list = os.walk(dataroot)
for _, meta in enumerate(frame_list):
root = meta[0]
try:
frames = sorted(meta[2], key=lambda item: int(item.split('.')[0].split('_')[-1]))
except:
print(meta[0], meta[2])
frames = [os.path.join(root, item) for item in frames if is_image_file(item)]
# if len(frames) > max(0, self.sequence_length * self.sample_every_n_frames):
if len(frames) != 0:
data_all.append(frames)
# self.video_num = len(data_all)
return data_all
if __name__ == '__main__':
import argparse
import torchvision
import video_transforms
import torch.utils.data as data
from torchvision import transforms
from torchvision.utils import save_image
parser = argparse.ArgumentParser()
parser.add_argument("--num_frames", type=int, default=16)
parser.add_argument("--frame_interval", type=int, default=4)
parser.add_argument("--load_fron_ceph", type=bool, default=True)
parser.add_argument("--data-path", type=str, default="/path/to/datasets/taichi/taichi-256/frames/train")
config = parser.parse_args()
target_video_len = config.num_frames
temporal_sample = video_transforms.TemporalRandomCrop(target_video_len * config.frame_interval)
trans = transforms.Compose([
video_transforms.ToTensorVideo(),
video_transforms.RandomHorizontalFlipVideo(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])
taichi_dataset = Taichi(config, transform=trans, temporal_sample=temporal_sample)
taichi_dataloader = data.DataLoader(dataset=taichi_dataset, batch_size=1, shuffle=False, num_workers=1)
for i, video_data in enumerate(taichi_dataloader):
print(video_data['video'].shape)
# print(video_data.dtype)
# for i in range(target_video_len):
# save_image(video_data[0][i], os.path.join('./test_data', '%04d.png' % i), normalize=True, value_range=(-1, 1))
# video_ = ((video_data[0] * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1)
# torchvision.io.write_video('./test_data' + 'test.mp4', video_, fps=8)
# exit()
\ No newline at end of file
import os
import torch
import random
import torch.utils.data as data
import numpy as np
import io
import json
from PIL import Image
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG']
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
class TaichiImages(data.Dataset):
def __init__(self, configs, transform, temporal_sample=None, train=True):
self.configs = configs
self.data_path = configs.data_path
self.transform = transform
self.temporal_sample = temporal_sample
self.target_video_len = self.configs.num_frames
self.frame_interval = self.configs.frame_interval
self.data_all, self.video_frame_all = self.load_video_frames(self.data_path)
self.video_num = len(self.data_all)
self.video_frame_num = len(self.video_frame_all)
# sky video frames
random.shuffle(self.video_frame_all)
self.use_image_num = configs.use_image_num
def __getitem__(self, index):
video_index = index % self.video_num
vframes = self.data_all[video_index]
total_frames = len(vframes)
# Sampling video frames
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
assert end_frame_ind - start_frame_ind >= self.target_video_len
frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, self.target_video_len, dtype=int)
# print(frame_indice)
select_video_frames = vframes[frame_indice[0]: frame_indice[-1]+1: self.frame_interval]
video_frames = []
for path in select_video_frames:
image = Image.open(path).convert('RGB')
video_frame = torch.as_tensor(np.array(image, dtype=np.uint8, copy=True)).unsqueeze(0)
video_frames.append(video_frame)
video_clip = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2)
video_clip = self.transform(video_clip)
# get video frames
images = []
for i in range(self.use_image_num):
while True:
try:
video_frame_path = self.video_frame_all[index+i]
image_path = os.path.join(self.data_path, video_frame_path)
image = Image.open(image_path).convert('RGB')
image = torch.as_tensor(np.array(image, dtype=np.uint8, copy=True)).unsqueeze(0)
images.append(image)
break
except Exception as e:
index = random.randint(0, self.video_frame_num - self.use_image_num)
images = torch.cat(images, dim=0).permute(0, 3, 1, 2)
images = self.transform(images)
assert len(images) == self.use_image_num
video_cat = torch.cat([video_clip, images], dim=0)
return {'video': video_cat, 'video_name': 1}
def __len__(self):
return self.video_frame_num
def load_video_frames(self, dataroot):
data_all = []
frames_all = []
frame_list = os.walk(dataroot)
for _, meta in enumerate(frame_list):
root = meta[0]
try:
frames = sorted(meta[2], key=lambda item: int(item.split('.')[0].split('_')[-1]))
except:
print(meta[0], meta[2])
frames = [os.path.join(root, item) for item in frames if is_image_file(item)]
# if len(frames) > max(0, self.sequence_length * self.sample_every_n_frames):
if len(frames) != 0:
data_all.append(frames)
for frame in frames:
frames_all.append(frame)
# self.video_num = len(data_all)
return data_all, frames_all
if __name__ == '__main__':
import argparse
import torchvision
import video_transforms
import torch.utils.data as data
from torchvision import transforms
from torchvision.utils import save_image
parser = argparse.ArgumentParser()
parser.add_argument("--num_frames", type=int, default=16)
parser.add_argument("--frame_interval", type=int, default=4)
parser.add_argument("--load_from_ceph", type=bool, default=True)
parser.add_argument("--data-path", type=str, default="/path/to/datasets/taichi/taichi-256/frames/train")
parser.add_argument("--use-image-num", type=int, default=5)
config = parser.parse_args()
target_video_len = config.num_frames
temporal_sample = video_transforms.TemporalRandomCrop(target_video_len * config.frame_interval)
trans = transforms.Compose([
video_transforms.ToTensorVideo(),
video_transforms.RandomHorizontalFlipVideo(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])
taichi_dataset = TaichiImages(config, transform=trans, temporal_sample=temporal_sample)
print(len(taichi_dataset))
taichi_dataloader = data.DataLoader(dataset=taichi_dataset, batch_size=1, shuffle=False, num_workers=1)
for i, video_data in enumerate(taichi_dataloader):
print(video_data['video'].shape)
# print(video_data.dtype)
# for i in range(target_video_len):
# save_image(video_data[0][i], os.path.join('./test_data', '%04d.png' % i), normalize=True, value_range=(-1, 1))
video_ = ((video_data[0] * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1)
torchvision.io.write_video('./test_data' + 'test.mp4', video_, fps=8)
exit()
\ No newline at end of file
import os
import re
import json
import torch
import decord
import torchvision
import numpy as np
from PIL import Image
from einops import rearrange
from typing import Dict, List, Tuple
class_labels_map = None
cls_sample_cnt = None
class_labels_map = None
cls_sample_cnt = None
def temporal_sampling(frames, start_idx, end_idx, num_samples):
"""
Given the start and end frame index, sample num_samples frames between
the start and end with equal interval.
Args:
frames (tensor): a tensor of video frames, dimension is
`num video frames` x `channel` x `height` x `width`.
start_idx (int): the index of the start frame.
end_idx (int): the index of the end frame.
num_samples (int): number of frames to sample.
Returns:
frames (tersor): a tensor of temporal sampled video frames, dimension is
`num clip frames` x `channel` x `height` x `width`.
"""
index = torch.linspace(start_idx, end_idx, num_samples)
index = torch.clamp(index, 0, frames.shape[0] - 1).long()
frames = torch.index_select(frames, 0, index)
return frames
def get_filelist(file_path):
Filelist = []
for home, dirs, files in os.walk(file_path):
for filename in files:
# 文件名列表,包含完整路径
Filelist.append(os.path.join(home, filename))
# # 文件名列表,只包含文件名
# Filelist.append( filename)
return Filelist
def load_annotation_data(data_file_path):
with open(data_file_path, 'r') as data_file:
return json.load(data_file)
def get_class_labels(num_class, anno_pth='./k400_classmap.json'):
global class_labels_map, cls_sample_cnt
if class_labels_map is not None:
return class_labels_map, cls_sample_cnt
else:
cls_sample_cnt = {}
class_labels_map = load_annotation_data(anno_pth)
for cls in class_labels_map:
cls_sample_cnt[cls] = 0
return class_labels_map, cls_sample_cnt
def load_annotations(ann_file, num_class, num_samples_per_cls):
dataset = []
class_to_idx, cls_sample_cnt = get_class_labels(num_class)
with open(ann_file, 'r') as fin:
for line in fin:
line_split = line.strip().split('\t')
sample = {}
idx = 0
# idx for frame_dir
frame_dir = line_split[idx]
sample['video'] = frame_dir
idx += 1
# idx for label[s]
label = [x for x in line_split[idx:]]
assert label, f'missing label in line: {line}'
assert len(label) == 1
class_name = label[0]
class_index = int(class_to_idx[class_name])
# choose a class subset of whole dataset
if class_index < num_class:
sample['label'] = class_index
if cls_sample_cnt[class_name] < num_samples_per_cls:
dataset.append(sample)
cls_sample_cnt[class_name]+=1
return dataset
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
"""Finds the class folders in a dataset.
See :class:`DatasetFolder` for details.
"""
classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
if not classes:
raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
class DecordInit(object):
"""Using Decord(https://github.com/dmlc/decord) to initialize the video_reader."""
def __init__(self, num_threads=1):
self.num_threads = num_threads
self.ctx = decord.cpu(0)
def __call__(self, filename):
"""Perform the Decord initialization.
Args:
results (dict): The resulting dict to be modified and passed
to the next transform in pipeline.
"""
reader = decord.VideoReader(filename,
ctx=self.ctx,
num_threads=self.num_threads)
return reader
def __repr__(self):
repr_str = (f'{self.__class__.__name__}('
f'sr={self.sr},'
f'num_threads={self.num_threads})')
return repr_str
class UCF101(torch.utils.data.Dataset):
"""Load the UCF101 video files
Args:
target_video_len (int): the number of video frames will be load.
align_transform (callable): Align different videos in a specified size.
temporal_sample (callable): Sample the target length of a video.
"""
def __init__(self,
configs,
transform=None,
temporal_sample=None):
self.configs = configs
self.data_path = configs.data_path
self.video_lists = get_filelist(configs.data_path)
self.transform = transform
self.temporal_sample = temporal_sample
self.target_video_len = self.configs.num_frames
self.v_decoder = DecordInit()
self.classes, self.class_to_idx = find_classes(self.data_path)
# print(self.class_to_idx)
# exit()
def __getitem__(self, index):
path = self.video_lists[index]
class_name = path.split('/')[-2]
class_index = self.class_to_idx[class_name]
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW')
total_frames = len(vframes)
# Sampling video frames
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
assert end_frame_ind - start_frame_ind >= self.target_video_len
frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, self.target_video_len, dtype=int)
# print(frame_indice)
video = vframes[frame_indice] #
video = self.transform(video) # T C H W
return {'video': video, 'video_name': class_index}
def __len__(self):
return len(self.video_lists)
if __name__ == '__main__':
import argparse
import video_transforms
import torch.utils.data as Data
import torchvision.transforms as transforms
from PIL import Image
parser = argparse.ArgumentParser()
parser.add_argument("--num_frames", type=int, default=16)
parser.add_argument("--frame_interval", type=int, default=1)
# parser.add_argument("--data-path", type=str, default="/nvme/share_data/datasets/UCF101/videos")
parser.add_argument("--data-path", type=str, default="/path/to/datasets/UCF101/videos/")
config = parser.parse_args()
temporal_sample = video_transforms.TemporalRandomCrop(config.num_frames * config.frame_interval)
transform_ucf101 = transforms.Compose([
video_transforms.ToTensorVideo(), # TCHW
video_transforms.RandomHorizontalFlipVideo(),
video_transforms.UCFCenterCropVideo(256),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])
ffs_dataset = UCF101(config, transform=transform_ucf101, temporal_sample=temporal_sample)
ffs_dataloader = Data.DataLoader(dataset=ffs_dataset, batch_size=6, shuffle=False, num_workers=1)
# for i, video_data in enumerate(ffs_dataloader):
for video_data in ffs_dataloader:
print(type(video_data))
video = video_data['video']
video_name = video_data['video_name']
print(video.shape)
print(video_name)
# print(video_data[2])
# for i in range(16):
# img0 = rearrange(video_data[0][0][i], 'c h w -> h w c')
# print('Label: {}'.format(video_data[1]))
# print(img0.shape)
# img0 = Image.fromarray(np.uint8(img0 * 255))
# img0.save('./img{}.jpg'.format(i))
exit()
\ No newline at end of file
import os, io
import re
import json
import torch
import decord
import torchvision
import numpy as np
from PIL import Image
from einops import rearrange
from typing import Dict, List, Tuple
from torchvision import transforms
import random
class_labels_map = None
cls_sample_cnt = None
class_labels_map = None
cls_sample_cnt = None
def temporal_sampling(frames, start_idx, end_idx, num_samples):
"""
Given the start and end frame index, sample num_samples frames between
the start and end with equal interval.
Args:
frames (tensor): a tensor of video frames, dimension is
`num video frames` x `channel` x `height` x `width`.
start_idx (int): the index of the start frame.
end_idx (int): the index of the end frame.
num_samples (int): number of frames to sample.
Returns:
frames (tersor): a tensor of temporal sampled video frames, dimension is
`num clip frames` x `channel` x `height` x `width`.
"""
index = torch.linspace(start_idx, end_idx, num_samples)
index = torch.clamp(index, 0, frames.shape[0] - 1).long()
frames = torch.index_select(frames, 0, index)
return frames
def get_filelist(file_path):
Filelist = []
for home, dirs, files in os.walk(file_path):
for filename in files:
Filelist.append(os.path.join(home, filename))
# Filelist.append( filename)
return Filelist
def load_annotation_data(data_file_path):
with open(data_file_path, 'r') as data_file:
return json.load(data_file)
def get_class_labels(num_class, anno_pth='./k400_classmap.json'):
global class_labels_map, cls_sample_cnt
if class_labels_map is not None:
return class_labels_map, cls_sample_cnt
else:
cls_sample_cnt = {}
class_labels_map = load_annotation_data(anno_pth)
for cls in class_labels_map:
cls_sample_cnt[cls] = 0
return class_labels_map, cls_sample_cnt
def load_annotations(ann_file, num_class, num_samples_per_cls):
dataset = []
class_to_idx, cls_sample_cnt = get_class_labels(num_class)
with open(ann_file, 'r') as fin:
for line in fin:
line_split = line.strip().split('\t')
sample = {}
idx = 0
# idx for frame_dir
frame_dir = line_split[idx]
sample['video'] = frame_dir
idx += 1
# idx for label[s]
label = [x for x in line_split[idx:]]
assert label, f'missing label in line: {line}'
assert len(label) == 1
class_name = label[0]
class_index = int(class_to_idx[class_name])
# choose a class subset of whole dataset
if class_index < num_class:
sample['label'] = class_index
if cls_sample_cnt[class_name] < num_samples_per_cls:
dataset.append(sample)
cls_sample_cnt[class_name]+=1
return dataset
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
"""Finds the class folders in a dataset.
See :class:`DatasetFolder` for details.
"""
classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
if not classes:
raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
class DecordInit(object):
"""Using Decord(https://github.com/dmlc/decord) to initialize the video_reader."""
def __init__(self, num_threads=1):
self.num_threads = num_threads
self.ctx = decord.cpu(0)
def __call__(self, filename):
"""Perform the Decord initialization.
Args:
results (dict): The resulting dict to be modified and passed
to the next transform in pipeline.
"""
reader = decord.VideoReader(filename,
ctx=self.ctx,
num_threads=self.num_threads)
return reader
def __repr__(self):
repr_str = (f'{self.__class__.__name__}('
f'sr={self.sr},'
f'num_threads={self.num_threads})')
return repr_str
class UCF101Images(torch.utils.data.Dataset):
"""Load the UCF101 video files
Args:
target_video_len (int): the number of video frames will be load.
align_transform (callable): Align different videos in a specified size.
temporal_sample (callable): Sample the target length of a video.
"""
def __init__(self,
configs,
transform=None,
temporal_sample=None):
self.configs = configs
self.data_path = configs.data_path
self.video_lists = get_filelist(configs.data_path)
self.transform = transform
self.temporal_sample = temporal_sample
self.target_video_len = self.configs.num_frames
self.v_decoder = DecordInit()
self.classes, self.class_to_idx = find_classes(self.data_path)
self.video_num = len(self.video_lists)
# ucf101 video frames
self.video_frame_txt = configs.frame_data_txt
self.video_frame_files = [frame_file.strip() for frame_file in open(self.video_frame_txt)]
random.shuffle(self.video_frame_files)
self.use_image_num = configs.use_image_num
self.image_tranform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])
self.video_frame_num = len(self.video_frame_files)
def __getitem__(self, index):
# start_time = time.perf_counter()
video_index = index % self.video_num
path = self.video_lists[video_index]
class_name = path.split('/')[-2]
class_index = self.class_to_idx[class_name]
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW')
total_frames = len(vframes)
# Sampling video frames
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
# print(start_frame_ind)
# print(end_frame_ind)
assert end_frame_ind - start_frame_ind >= self.target_video_len
frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, self.target_video_len, dtype=int)
# print(frame_indice)
video = vframes[frame_indice] # 这里没有根据步长取视频帧
# print(type(video))
# videotransformer data proprecess
video = self.transform(video) # T C H W
images = []
image_names = []
for i in range(self.use_image_num):
while True:
try:
video_frame_path = self.video_frame_files[index+i]
image_class_name = video_frame_path.split('_')[1]
image_class_index = self.class_to_idx[image_class_name]
video_frame_path = os.path.join(self.frame_data_path, video_frame_path)
image = Image.open(video_frame_path).convert('RGB')
image = self.image_tranform(image).unsqueeze(0)
images.append(image)
image_names.append(str(image_class_index))
break
except Exception as e:
index = random.randint(0, self.video_frame_num - self.use_image_num)
images = torch.cat(images, dim=0)
assert len(images) == self.use_image_num
assert len(image_names) == self.use_image_num
image_names = '====='.join(image_names)
video_cat = torch.cat([video, images], dim=0)
return {'video': video_cat,
'video_name': class_index,
'image_name': image_names}
def __len__(self):
return self.video_frame_num
if __name__ == '__main__':
import argparse
import video_transforms
import torch.utils.data as Data
import torchvision.transforms as transforms
from PIL import Image
parser = argparse.ArgumentParser()
parser.add_argument("--num_frames", type=int, default=16)
parser.add_argument("--frame_interval", type=int, default=3)
parser.add_argument("--use-image-num", type=int, default=5)
parser.add_argument("--data-path", type=str, default="/path/to/datasets/UCF101/videos/")
parser.add_argument("--frame-data-path", type=str, default="/path/to/datasets/preprocessed_ffs/train/images/")
parser.add_argument("--frame-data-txt", type=str, default="/path/to/datasets/UCF101/train_256_list.txt")
config = parser.parse_args()
temporal_sample = video_transforms.TemporalRandomCrop(config.num_frames * config.frame_interval)
transform_ucf101 = transforms.Compose([
video_transforms.ToTensorVideo(), # TCHW
video_transforms.RandomHorizontalFlipVideo(),
video_transforms.UCFCenterCropVideo(256),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])
ffs_dataset = UCF101Images(config, transform=transform_ucf101, temporal_sample=temporal_sample)
ffs_dataloader = Data.DataLoader(dataset=ffs_dataset, batch_size=6, shuffle=False, num_workers=1)
# for i, video_data in enumerate(ffs_dataloader):
for video_data in ffs_dataloader:
# print(type(video_data))
video = video_data['video']
# video_name = video_data['video_name']
print(video.shape)
print(video_data['image_name'])
image_name = video_data['image_name']
image_names = []
for caption in image_name:
single_caption = [int(item) for item in caption.split('=====')]
image_names.append(torch.as_tensor(single_caption))
print(image_names)
# print(video_name)
# print(video_data[2])
# for i in range(16):
# img0 = rearrange(video_data[0][0][i], 'c h w -> h w c')
# print('Label: {}'.format(video_data[1]))
# print(img0.shape)
# img0 = Image.fromarray(np.uint8(img0 * 255))
# img0.save('./img{}.jpg'.format(i))
\ No newline at end of file
import torch
import random
import numbers
from torchvision.transforms import RandomCrop, RandomResizedCrop
def _is_tensor_video_clip(clip):
if not torch.is_tensor(clip):
raise TypeError("clip should be Tensor. Got %s" % type(clip))
if not clip.ndimension() == 4:
raise ValueError("clip should be 4D. Got %dD" % clip.dim())
return True
def center_crop_arr(pil_image, image_size):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
)
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
def crop(clip, i, j, h, w):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
"""
if len(clip.size()) != 4:
raise ValueError("clip should be a 4D tensor")
return clip[..., i : i + h, j : j + w]
def resize(clip, target_size, interpolation_mode):
if len(target_size) != 2:
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
def resize_scale(clip, target_size, interpolation_mode):
if len(target_size) != 2:
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
H, W = clip.size(-2), clip.size(-1)
scale_ = target_size[0] / min(H, W)
return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
"""
Do spatial cropping and resizing to the video clip
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
i (int): i in (i,j) i.e coordinates of the upper left corner.
j (int): j in (i,j) i.e coordinates of the upper left corner.
h (int): Height of the cropped region.
w (int): Width of the cropped region.
size (tuple(int, int)): height and width of resized clip
Returns:
clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
clip = crop(clip, i, j, h, w)
clip = resize(clip, size, interpolation_mode)
return clip
def center_crop(clip, crop_size):
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
th, tw = crop_size
if h < th or w < tw:
raise ValueError("height and width must be no smaller than crop_size")
i = int(round((h - th) / 2.0))
j = int(round((w - tw) / 2.0))
return crop(clip, i, j, th, tw)
def center_crop_using_short_edge(clip):
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
if h < w:
th, tw = h, h
i = 0
j = int(round((w - tw) / 2.0))
else:
th, tw = w, w
i = int(round((h - th) / 2.0))
j = 0
return crop(clip, i, j, th, tw)
def random_shift_crop(clip):
'''
Slide along the long edge, with the short edge as crop size
'''
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
if h <= w:
long_edge = w
short_edge = h
else:
long_edge = h
short_edge =w
th, tw = short_edge, short_edge
i = torch.randint(0, h - th + 1, size=(1,)).item()
j = torch.randint(0, w - tw + 1, size=(1,)).item()
return crop(clip, i, j, th, tw)
def to_tensor(clip):
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimensions of clip tensor
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
Return:
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
"""
_is_tensor_video_clip(clip)
if not clip.dtype == torch.uint8:
raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
# return clip.float().permute(3, 0, 1, 2) / 255.0
return clip.float() / 255.0
def normalize(clip, mean, std, inplace=False):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
mean (tuple): pixel RGB mean. Size is (3)
std (tuple): pixel standard deviation. Size is (3)
Returns:
normalized clip (torch.tensor): Size is (T, C, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
if not inplace:
clip = clip.clone()
mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
# print(mean)
std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
return clip
def hflip(clip):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
Returns:
flipped clip (torch.tensor): Size is (T, C, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
return clip.flip(-1)
class RandomCropVideo:
def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: randomly cropped video clip.
size is (T, C, OH, OW)
"""
i, j, h, w = self.get_params(clip)
return crop(clip, i, j, h, w)
def get_params(self, clip):
h, w = clip.shape[-2:]
th, tw = self.size
if h < th or w < tw:
raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
if w == tw and h == th:
return 0, 0, h, w
i = torch.randint(0, h - th + 1, size=(1,)).item()
j = torch.randint(0, w - tw + 1, size=(1,)).item()
return i, j, th, tw
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size})"
class CenterCropResizeVideo:
'''
First use the short side for cropping length,
center crop video, then resize to the specified size
'''
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: scale resized / center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
clip_center_crop = center_crop_using_short_edge(clip)
clip_center_crop_resize = resize(clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode)
return clip_center_crop_resize
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class UCFCenterCropVideo:
'''
First scale to the specified size in equal proportion to the short edge,
then center cropping
'''
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: scale resized / center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
clip_center_crop = center_crop(clip_resize, self.size)
return clip_center_crop
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class KineticsRandomCropResizeVideo:
'''
Slide along the long edge, with the short edge as crop size. And resie to the desired size.
'''
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
clip_random_crop = random_shift_crop(clip)
clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode)
return clip_resize
class CenterCropVideo:
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
clip_center_crop = center_crop(clip, self.size)
return clip_center_crop
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class NormalizeVideo:
"""
Normalize the video clip by mean subtraction and division by standard deviation
Args:
mean (3-tuple): pixel RGB mean
std (3-tuple): pixel RGB standard deviation
inplace (boolean): whether do in-place normalization
"""
def __init__(self, mean, std, inplace=False):
self.mean = mean
self.std = std
self.inplace = inplace
def __call__(self, clip):
"""
Args:
clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
"""
return normalize(clip, self.mean, self.std, self.inplace)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
class ToTensorVideo:
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimensions of clip tensor
"""
def __init__(self):
pass
def __call__(self, clip):
"""
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
Return:
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
"""
return to_tensor(clip)
def __repr__(self) -> str:
return self.__class__.__name__
class RandomHorizontalFlipVideo:
"""
Flip the video clip along the horizontal direction with a given probability
Args:
p (float): probability of the clip being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
self.p = p
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Size is (T, C, H, W)
Return:
clip (torch.tensor): Size is (T, C, H, W)
"""
if random.random() < self.p:
clip = hflip(clip)
return clip
def __repr__(self) -> str:
return f"{self.__class__.__name__}(p={self.p})"
# ------------------------------------------------------------
# --------------------- Sampling ---------------------------
# ------------------------------------------------------------
class TemporalRandomCrop(object):
"""Temporally crop the given frame indices at a random location.
Args:
size (int): Desired length of frames will be seen in the model.
"""
def __init__(self, size):
self.size = size
def __call__(self, total_frames):
rand_end = max(0, total_frames - self.size - 1)
begin_index = random.randint(0, rand_end)
end_index = min(begin_index + self.size, total_frames)
return begin_index, end_index
if __name__ == '__main__':
from torchvision import transforms
import torchvision.io as io
import numpy as np
from torchvision.utils import save_image
import os
vframes, aframes, info = io.read_video(
filename='./v_Archery_g01_c03.avi',
pts_unit='sec',
output_format='TCHW'
)
trans = transforms.Compose([
ToTensorVideo(),
RandomHorizontalFlipVideo(),
UCFCenterCropVideo(512),
# NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])
target_video_len = 32
frame_interval = 1
total_frames = len(vframes)
print(total_frames)
temporal_sample = TemporalRandomCrop(target_video_len * frame_interval)
# Sampling video frames
start_frame_ind, end_frame_ind = temporal_sample(total_frames)
# print(start_frame_ind)
# print(end_frame_ind)
assert end_frame_ind - start_frame_ind >= target_video_len
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int)
print(frame_indice)
select_vframes = vframes[frame_indice]
print(select_vframes.shape)
print(select_vframes.dtype)
select_vframes_trans = trans(select_vframes)
print(select_vframes_trans.shape)
print(select_vframes_trans.dtype)
select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8)
print(select_vframes_trans_int.dtype)
print(select_vframes_trans_int.permute(0, 2, 3, 1).shape)
io.write_video('./test.avi', select_vframes_trans_int.permute(0, 2, 3, 1), fps=8)
for i in range(target_video_len):
save_image(select_vframes_trans[i], os.path.join('./test000', '%04d.png' % i), normalize=True, value_range=(-1, 1))
\ No newline at end of file
# Modified from OpenAI's diffusion repos
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
from . import gaussian_diffusion as gd
from .respace import SpacedDiffusion, space_timesteps
def create_diffusion(
timestep_respacing,
noise_schedule="linear",
use_kl=False,
sigma_small=False,
predict_xstart=False,
learn_sigma=True,
# learn_sigma=False,
rescale_learned_sigmas=False,
diffusion_steps=1000
):
betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
if use_kl:
loss_type = gd.LossType.RESCALED_KL
elif rescale_learned_sigmas:
loss_type = gd.LossType.RESCALED_MSE
else:
loss_type = gd.LossType.MSE
if timestep_respacing is None or timestep_respacing == "":
timestep_respacing = [diffusion_steps]
return SpacedDiffusion(
use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
betas=betas,
model_mean_type=(
gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
),
model_var_type=(
(
gd.ModelVarType.FIXED_LARGE
if not sigma_small
else gd.ModelVarType.FIXED_SMALL
)
if not learn_sigma
else gd.ModelVarType.LEARNED_RANGE
),
loss_type=loss_type
# rescale_timesteps=rescale_timesteps,
)
# Modified from OpenAI's diffusion repos
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
import torch as th
import numpy as np
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, th.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for th.exp().
logvar1, logvar2 = [
x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
return 0.5 * (
-1.0
+ logvar2
- logvar1
+ th.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * th.exp(-logvar2)
)
def approx_standard_normal_cdf(x):
"""
A fast approximation of the cumulative distribution function of the
standard normal.
"""
return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
def continuous_gaussian_log_likelihood(x, *, means, log_scales):
"""
Compute the log-likelihood of a continuous Gaussian distribution.
:param x: the targets
:param means: the Gaussian mean Tensor.
:param log_scales: the Gaussian log stddev Tensor.
:return: a tensor like x of log probabilities (in nats).
"""
centered_x = x - means
inv_stdv = th.exp(-log_scales)
normalized_x = centered_x * inv_stdv
log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
return log_probs
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
"""
Compute the log-likelihood of a Gaussian distribution discretizing to a
given image.
:param x: the target images. It is assumed that this was uint8 values,
rescaled to the range [-1, 1].
:param means: the Gaussian mean Tensor.
:param log_scales: the Gaussian log stddev Tensor.
:return: a tensor like x of log probabilities (in nats).
"""
assert x.shape == means.shape == log_scales.shape
centered_x = x - means
inv_stdv = th.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
cdf_plus = approx_standard_normal_cdf(plus_in)
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
cdf_min = approx_standard_normal_cdf(min_in)
log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = th.where(
x < -0.999,
log_cdf_plus,
th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
)
assert log_probs.shape == x.shape
return log_probs
This diff is collapsed.
# Modified from OpenAI's diffusion repos
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
import torch
import numpy as np
import torch as th
from .gaussian_diffusion import GaussianDiffusion
def space_timesteps(num_timesteps, section_counts):
"""
Create a list of timesteps to use from an original diffusion process,
given the number of timesteps we want to take from equally-sized portions
of the original process.
For example, if there's 300 timesteps and the section counts are [10,15,20]
then the first 100 timesteps are strided to be 10 timesteps, the second 100
are strided to be 15 timesteps, and the final 100 are strided to be 20.
If the stride is a string starting with "ddim", then the fixed striding
from the DDIM paper is used, and only one section is allowed.
:param num_timesteps: the number of diffusion steps in the original
process to divide up.
:param section_counts: either a list of numbers, or a string containing
comma-separated numbers, indicating the step count
per section. As a special case, use "ddimN" where N
is a number of steps to use the striding from the
DDIM paper.
:return: a set of diffusion steps from the original process to use.
"""
if isinstance(section_counts, str):
if section_counts.startswith("ddim"):
desired_count = int(section_counts[len("ddim") :])
for i in range(1, num_timesteps):
if len(range(0, num_timesteps, i)) == desired_count:
return set(range(0, num_timesteps, i))
raise ValueError(
f"cannot create exactly {num_timesteps} steps with an integer stride"
)
section_counts = [int(x) for x in section_counts.split(",")]
size_per = num_timesteps // len(section_counts)
extra = num_timesteps % len(section_counts)
start_idx = 0
all_steps = []
for i, section_count in enumerate(section_counts):
size = size_per + (1 if i < extra else 0)
if size < section_count:
raise ValueError(
f"cannot divide section of {size} steps into {section_count}"
)
if section_count <= 1:
frac_stride = 1
else:
frac_stride = (size - 1) / (section_count - 1)
cur_idx = 0.0
taken_steps = []
for _ in range(section_count):
taken_steps.append(start_idx + round(cur_idx))
cur_idx += frac_stride
all_steps += taken_steps
start_idx += size
return set(all_steps)
class SpacedDiffusion(GaussianDiffusion):
"""
A diffusion process which can skip steps in a base diffusion process.
:param use_timesteps: a collection (sequence or set) of timesteps from the
original diffusion process to retain.
:param kwargs: the kwargs to create the base diffusion process.
"""
def __init__(self, use_timesteps, **kwargs):
self.use_timesteps = set(use_timesteps)
self.timestep_map = []
self.original_num_steps = len(kwargs["betas"])
base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
last_alpha_cumprod = 1.0
new_betas = []
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
if i in self.use_timesteps:
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
last_alpha_cumprod = alpha_cumprod
self.timestep_map.append(i)
kwargs["betas"] = np.array(new_betas)
super().__init__(**kwargs)
def p_mean_variance(
self, model, *args, **kwargs
): # pylint: disable=signature-differs
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
# @torch.compile
def training_losses(
self, model, *args, **kwargs
): # pylint: disable=signature-differs
return super().training_losses(self._wrap_model(model), *args, **kwargs)
def condition_mean(self, cond_fn, *args, **kwargs):
return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
def condition_score(self, cond_fn, *args, **kwargs):
return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
def _wrap_model(self, model):
if isinstance(model, _WrappedModel):
return model
return _WrappedModel(
model, self.timestep_map, self.original_num_steps
)
def _scale_timesteps(self, t):
# Scaling is done by the wrapped model.
return t
class _WrappedModel:
def __init__(self, model, timestep_map, original_num_steps):
self.model = model
self.timestep_map = timestep_map
# self.rescale_timesteps = rescale_timesteps
self.original_num_steps = original_num_steps
def __call__(self, x, ts, **kwargs):
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
new_ts = map_tensor[ts]
# if self.rescale_timesteps:
# new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
return self.model(x, new_ts, **kwargs)
# Modified from OpenAI's diffusion repos
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
from abc import ABC, abstractmethod
import numpy as np
import torch as th
import torch.distributed as dist
def create_named_schedule_sampler(name, diffusion):
"""
Create a ScheduleSampler from a library of pre-defined samplers.
:param name: the name of the sampler.
:param diffusion: the diffusion object to sample for.
"""
if name == "uniform":
return UniformSampler(diffusion)
elif name == "loss-second-moment":
return LossSecondMomentResampler(diffusion)
else:
raise NotImplementedError(f"unknown schedule sampler: {name}")
class ScheduleSampler(ABC):
"""
A distribution over timesteps in the diffusion process, intended to reduce
variance of the objective.
By default, samplers perform unbiased importance sampling, in which the
objective's mean is unchanged.
However, subclasses may override sample() to change how the resampled
terms are reweighted, allowing for actual changes in the objective.
"""
@abstractmethod
def weights(self):
"""
Get a numpy array of weights, one per diffusion step.
The weights needn't be normalized, but must be positive.
"""
def sample(self, batch_size, device):
"""
Importance-sample timesteps for a batch.
:param batch_size: the number of timesteps.
:param device: the torch device to save to.
:return: a tuple (timesteps, weights):
- timesteps: a tensor of timestep indices.
- weights: a tensor of weights to scale the resulting losses.
"""
w = self.weights()
p = w / np.sum(w)
indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
indices = th.from_numpy(indices_np).long().to(device)
weights_np = 1 / (len(p) * p[indices_np])
weights = th.from_numpy(weights_np).float().to(device)
return indices, weights
class UniformSampler(ScheduleSampler):
def __init__(self, diffusion):
self.diffusion = diffusion
self._weights = np.ones([diffusion.num_timesteps])
def weights(self):
return self._weights
class LossAwareSampler(ScheduleSampler):
def update_with_local_losses(self, local_ts, local_losses):
"""
Update the reweighting using losses from a model.
Call this method from each rank with a batch of timesteps and the
corresponding losses for each of those timesteps.
This method will perform synchronization to make sure all of the ranks
maintain the exact same reweighting.
:param local_ts: an integer Tensor of timesteps.
:param local_losses: a 1D Tensor of losses.
"""
batch_sizes = [
th.tensor([0], dtype=th.int32, device=local_ts.device)
for _ in range(dist.get_world_size())
]
dist.all_gather(
batch_sizes,
th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
)
# Pad all_gather batches to be the maximum batch size.
batch_sizes = [x.item() for x in batch_sizes]
max_bs = max(batch_sizes)
timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
dist.all_gather(timestep_batches, local_ts)
dist.all_gather(loss_batches, local_losses)
timesteps = [
x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
]
losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
self.update_with_all_losses(timesteps, losses)
@abstractmethod
def update_with_all_losses(self, ts, losses):
"""
Update the reweighting using losses from a model.
Sub-classes should override this method to update the reweighting
using losses from the model.
This method directly updates the reweighting without synchronizing
between workers. It is called by update_with_local_losses from all
ranks with identical arguments. Thus, it should have deterministic
behavior to maintain state across workers.
:param ts: a list of int timesteps.
:param losses: a list of float losses, one per timestep.
"""
class LossSecondMomentResampler(LossAwareSampler):
def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
self.diffusion = diffusion
self.history_per_term = history_per_term
self.uniform_prob = uniform_prob
self._loss_history = np.zeros(
[diffusion.num_timesteps, history_per_term], dtype=np.float64
)
self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
def weights(self):
if not self._warmed_up():
return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
weights /= np.sum(weights)
weights *= 1 - self.uniform_prob
weights += self.uniform_prob / len(weights)
return weights
def update_with_all_losses(self, ts, losses):
for t, loss in zip(ts, losses):
if self._loss_counts[t] == self.history_per_term:
# Shift out the oldest loss term.
self._loss_history[t, :-1] = self._loss_history[t, 1:]
self._loss_history[t, -1] = loss
else:
self._loss_history[t, self._loss_counts[t]] = loss
self._loss_counts[t] += 1
def _warmed_up(self):
return (self._loss_counts == self.history_per_term).all()
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import os
def find_model(model_name):
"""
Finds a pre-trained Latte model, downloading it if necessary. Alternatively, loads a model from a local path.
"""
assert os.path.isfile(model_name), f'Could not find Latte checkpoint at {model_name}'
checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage)
if "ema" in checkpoint: # supports checkpoints from train.py
print('Using Ema!')
checkpoint = checkpoint["ema"]
else:
print('Using model!')
checkpoint = checkpoint['model']
return checkpoint
\ No newline at end of file
name: latte
channels:
- pytorch
- nvidia
dependencies:
- python >= 3.10
- pytorch > 2.0.0
- torchvision
- pytorch-cuda=11.7
- pip:
- timm
- diffusers[torch]
- accelerate
- tensorboard
- einops
- transformers
- av
- scikit-image
- decord
- pandas
import os
import sys
sys.path.append(os.path.split(sys.path[0])[0])
from .latte import Latte_models
from .latte_img import LatteIMG_models
from .latte_t2v import LatteT2V
from torch.optim.lr_scheduler import LambdaLR
def customized_lr_scheduler(optimizer, warmup_steps=5000): # 5000 from u-vit
from torch.optim.lr_scheduler import LambdaLR
def fn(step):
if warmup_steps > 0:
return min(step / warmup_steps, 1)
else:
return 1
return LambdaLR(optimizer, fn)
def get_lr_scheduler(optimizer, name, **kwargs):
if name == 'warmup':
return customized_lr_scheduler(optimizer, **kwargs)
elif name == 'cosine':
from torch.optim.lr_scheduler import CosineAnnealingLR
return CosineAnnealingLR(optimizer, **kwargs)
else:
raise NotImplementedError(name)
def get_models(args):
if 'LatteIMG' in args.model:
return LatteIMG_models[args.model](
input_size=args.latent_size,
num_classes=args.num_classes,
num_frames=args.num_frames,
learn_sigma=args.learn_sigma,
extras=args.extras
)
elif 'LatteT2V' in args.model:
pretrained_model_path = args.pretrained_model_path
return LatteT2V.from_pretrained_2d(pretrained_model_path, subfolder="transformer")
elif 'Latte' in args.model:
return Latte_models[args.model](
input_size=args.latent_size,
num_classes=args.num_classes,
num_frames=args.num_frames,
learn_sigma=args.learn_sigma,
extras=args.extras
)
else:
raise '{} Model Not Supported!'.format(args.model)
\ No newline at end of file
import numpy
import torch.nn as nn
from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor
import transformers
transformers.logging.set_verbosity_error()
"""
Will encounter following warning:
- This IS expected if you are initializing CLIPTextModel from the checkpoint of a model trained on another task
or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing CLIPTextModel from the checkpoint of a model
that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
https://github.com/CompVis/stable-diffusion/issues/97
according to this issue, this warning is safe.
This is expected since the vision backbone of the CLIP model is not needed to run Stable Diffusion.
You can safely ignore the warning, it is not an error.
This clip usage is from U-ViT and same with Stable Diffusion.
"""
class AbstractEncoder(nn.Module):
def __init__(self):
super().__init__()
def encode(self, *args, **kwargs):
raise NotImplementedError
class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
# def __init__(self, version="openai/clip-vit-huge-patch14", device="cuda", max_length=77):
def __init__(self, path, device="cuda", max_length=77):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(path, subfolder="tokenizer")
self.transformer = CLIPTextModel.from_pretrained(path, subfolder='text_encoder')
self.device = device
self.max_length = max_length
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
pooled_z = outputs.pooler_output
return z, pooled_z
def encode(self, text):
return self(text)
class TextEmbedder(nn.Module):
"""
Embeds text prompt into vector representations. Also handles text dropout for classifier-free guidance.
"""
def __init__(self, path, dropout_prob=0.1):
super().__init__()
self.text_encodder = FrozenCLIPEmbedder(path=path)
self.dropout_prob = dropout_prob
def token_drop(self, text_prompts, force_drop_ids=None):
"""
Drops text to enable classifier-free guidance.
"""
if force_drop_ids is None:
drop_ids = numpy.random.uniform(0, 1, len(text_prompts)) < self.dropout_prob
else:
# TODO
drop_ids = force_drop_ids == 1
labels = list(numpy.where(drop_ids, "", text_prompts))
# print(labels)
return labels
def forward(self, text_prompts, train, force_drop_ids=None):
use_dropout = self.dropout_prob > 0
if (train and use_dropout) or (force_drop_ids is not None):
text_prompts = self.token_drop(text_prompts, force_drop_ids)
embeddings, pooled_embeddings = self.text_encodder(text_prompts)
# return embeddings, pooled_embeddings
return pooled_embeddings
if __name__ == '__main__':
r"""
Returns:
Examples from CLIPTextModel:
```python
>>> from transformers import AutoTokenizer, CLIPTextModel
>>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states
```"""
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
text_encoder = TextEmbedder(path='/mnt/petrelfs/maxin/work/pretrained/stable-diffusion-2-1-base',
dropout_prob=0.00001).to(device)
text_prompt = [["a photo of a cat", "a photo of a cat"], ["a photo of a dog", "a photo of a cat"], ['a photo of a dog human', "a photo of a cat"]]
# text_prompt = ('None', 'None', 'None')
output, pooled_output = text_encoder(text_prompts=text_prompt, train=False)
# print(output)
print(output.shape)
print(pooled_output.shape)
# print(output.shape)
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment