Commit 1da75ff3 authored by mashun1's avatar mashun1
Browse files

hyi2v

parents
Pipeline #2556 failed with stages
in 0 seconds
[English](./README.md)
# 混元视频特征提取工具
本项目提供了一个高效的工具,用于从视频中提取潜在特征,为后续的视频生成和处理任务做准备。
## 功能特点
- 支持各种视频格式和分辨率
- 多GPU并行处理,提高效率
- 支持多种宽高比
- 高性能VAE模型用于特征提取
- 自动跳过已处理的视频,支持断点续传功能
## 使用方法
### 1. 配置文件
## 输入数据集格式
输入的视频元数据文件(meta_file.list)应为JSON文件路径的列表,每个JSON文件包含以下字段:
meta_file.list的格式(例如,./assets/demo/i2v_lora/train_dataset/meta_file.list)如下:
```
/path/to/0.json
/path/to/1.json
/path/to/2.json
...
```
/path/to/0.json的格式(例如,./assets/demo/i2v_lora/train_dataset/meta_data.json)如下:
```json
{
"video_path": "/path/to/video.mp4",
"raw_caption": {
"long caption": "视频的详细描述文本"
}
}
```
`hyvideo/hyvae_extract/vae.yaml`中配置参数:
```yaml
vae_path: "./ckpts/hunyuan-video-i2v-720p/vae" # VAE模型路径
video_url_files: "/path/to/meta_file.list" # 视频元数据文件列表
output_base_dir: "/path/to/output/directory" # 输出目录
sample_n_frames: 129 # 采样帧数
target_size: # 目标尺寸
- bucket_size
- bucket_size
enable_multi_aspect_ratio: True # 启用多种宽高比
use_stride: True # 使用步长采样
```
#### 分辨率桶大小参考
`target_size`参数定义了分辨率桶大小。以下是不同质量级别的推荐值:
| 质量 | 桶大小 | 典型分辨率 |
|---------|-------------|-------------------|
| 720p | 960 | 1280×720或类似 |
| 540p | 720 | 960×540或类似 |
| 360p | 480 | 640×360或类似 |
`enable_multi_aspect_ratio`设置为`True`时,系统将使用这些桶大小作为基础来生成多种宽高比的桶。为了获得最佳性能,请根据您的硬件能力选择平衡质量和内存使用的桶大小。
### 2. 运行提取
```bash
# 设置环境变量
export HOST_GPU_NUM=8 # 设置要使用的GPU数量
# 运行提取脚本
cd HunyuanVideo-I2V
bash hyvideo/hyvae_extract/start.sh
```
### 3. 单GPU运行
```bash
cd HunyuanVideo-I2V
export PYTHONPATH=${PYTHONPATH}:`pwd`
export HOST_GPU_NUM=1
CUDA_VISIBLE_DEVICES=0 python3 -u hyvideo/hyvae_extract/run.py --local_rank 0 --config 'hyvideo/hyvae_extract/vae.yaml'
```
## 输出文件
程序在指定的输出目录中生成以下文件:
1. `{video_id}.npy` - 视频的潜在特征数组
2. `json_path/{video_id}.json` - 包含视频元数据的JSON文件,包括:
- video_id: 视频ID
- latent_shape: 潜在特征的形状
- video_path: 原始视频路径
- prompt: 视频描述/提示
- npy_save_path: 保存潜在特征的路径
```
output_base_dir/
├── {video_id_1}.npy # 视频1的潜在特征数组
├── {video_id_2}.npy # 视频2的潜在特征数组
├── {video_id_3}.npy # 视频3的潜在特征数组
│ ...
├── {video_id_n}.npy # 视频n的潜在特征数组
└── json_path/ # 包含元数据JSON文件的目录
├── {video_id_1}.json # 视频1的元数据
├── {video_id_2}.json # 视频2的元数据
├── {video_id_3}.json # 视频3的元数据
│ ...
└── {video_id_n}.json # 视频n的元数据
```
## 高级配置
### 多宽高比处理
`enable_multi_aspect_ratio`设置为`True`时,系统会选择最接近视频原始宽高比的目标尺寸,而不是强制将其裁剪为固定尺寸。这有助于保持视频内容的完整性。
### 步长采样
`use_stride`设置为`True`时,系统会根据视频的帧率自动调整采样步长:
- 当帧率 >= 50fps时,步长为2
- 当帧率 < 50fps时,步长为1
\ No newline at end of file
from typing import Tuple, List
from decord import VideoReader
import urllib
import io
import os
import csv
import numpy as np
import torch
from torch.utils.data import Dataset, IterableDataset
import torchvision.transforms as transforms
from torchvision.transforms.functional import crop
from pathlib import Path
import sys
import json
def split_video_urls(meta_files: str, global_rank: int, world_size: int):
meta_paths = []
meta_paths.extend([line.strip() for line in open(meta_files, 'r').readlines()])
num_videos = len(meta_paths)
num_videos_per_rank = num_videos // world_size
remainder = num_videos % world_size
# Calculate start and end indices
start = num_videos_per_rank * global_rank + min(global_rank, remainder)
end = start + num_videos_per_rank + (1 if global_rank < remainder else 0)
return start, end, meta_paths[start:end]
class MultiBucketDataset(IterableDataset):
def __init__(self, source: Dataset, batch_size: int, max_buf = 64):
super().__init__()
self.source = source
self.batch_size = batch_size
self.buffer = {}
self.max_buf = max_buf
self.size = 0
@staticmethod
def collate_fn(samples):
pixel_values = torch.stack([sample["pixel_values"] for sample in samples]).contiguous()
videoid = [sample["videoid"] for sample in samples]
valid = [sample["valid"] for sample in samples]
batch = {"pixel_values": pixel_values, "videoid": videoid, "valid": valid}
return batch
def __iter__(self):
# split dataset
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
iter_start = 0
iter_end = len(self.source)
else:
worker_id = int(worker_info.id)
per_worker = len(self.source) // int(worker_info.num_workers)
per_worker += int(worker_id < len(self.source) % int(worker_info.num_workers))
if worker_id >= len(self.source) % int(worker_info.num_workers):
iter_start = worker_id * per_worker + len(self.source) % int(worker_info.num_workers)
else:
iter_start = worker_id * per_worker
iter_end = iter_start + per_worker
# bucketing
for i in range(iter_start, iter_end):
sample = self.source[i]
if sample["valid"] is False:
continue
T, C, H, W = sample["pixel_values"].shape
if (T, H, W) not in self.buffer:
self.buffer[(T, H, W)] = []
self.buffer[(T, H, W)].append(sample)
self.size += 1
if len(self.buffer[(T, H, W)]) == self.batch_size:
yield self.buffer[(T, H, W)]
self.size -= self.batch_size
self.buffer[(T, H, W)] = []
if self.size > self.max_buf and (len(self.buffer[(T, H, W)]) > 0):
self.size -= len(self.buffer[(T, H, W)])
yield self.buffer[(T, H, W)]
self.buffer[(T, H, W)] = []
# yield the remaining batch
for bucket, samples in self.buffer.items():
if len(samples) > 0:
yield samples
class VideoDataset(Dataset):
def __init__(
self,
meta_files: List[str],
latent_cache_dir: str,
sample_size: Tuple[int, int],
sample_n_frames: int,
is_center_crop: bool = True,
enable_multi_aspect_ratio: bool = False,
vae_time_compression_ratio: int = 4,
use_stride: bool = False,
):
if not Path(latent_cache_dir).exists():
Path(latent_cache_dir).mkdir(parents=True, exist_ok=True)
self.latent_cache_dir = latent_cache_dir
self.sample_n_frames = sample_n_frames
self.sample_size = tuple(sample_size)
self.is_center_crop = is_center_crop
self.vae_time_compression_ratio = vae_time_compression_ratio
self.enable_multi_aspect_ratio = enable_multi_aspect_ratio
self.dataset = meta_files
self.length = len(self.dataset)
self.use_stride = use_stride
# multi-aspect-ratio buckets
if enable_multi_aspect_ratio:
assert self.sample_size[0] == self.sample_size[1]
if self.sample_size[0] < 540:
self.buckets = self.generate_crop_size_list(base_size=self.sample_size[0])
else:
self.buckets = self.generate_crop_size_list(base_size=self.sample_size[0], patch_size=32)
self.aspect_ratios = np.array([float(w) / float(h) for w, h in self.buckets])
print(f"Multi-aspect-ratio bucket num: {len(self.buckets)}")
# image preprocess
if not enable_multi_aspect_ratio:
self.train_crop = transforms.CenterCrop(self.sample_size) if self.is_center_crop else transforms.RandomCrop(self.sample_size)
def request_ceph_data(self, path):
try:
video_reader = VideoReader(path)
except Exception as e:
print(f"Error: {e}")
raise
return video_reader
def preprocess_url(self, data_json_path):
with open(data_json_path, "r") as f:
data_dict = json.load(f)
video_path = data_dict['video_path']
video_id = video_path.split('/')[-1].split('.')[0]
prompt = data_dict['raw_caption']["long caption"]
item = {"video_path": video_path, "videoid": video_id, "prompt": prompt}
return item
def get_item(self, idx):
# Create Video Reader
data_json_path = self.dataset[idx]
video_item = self.preprocess_url(data_json_path)
# Skip if exists
latent_save_path = Path(self.latent_cache_dir) / f"{video_item['videoid']}.npy"
if latent_save_path.exists():
return None, None, False
video_reader = self.request_ceph_data(video_item["video_path"])
fps = video_reader.get_avg_fps()
stride = 1
if self.use_stride:
if int(fps) >= 50:
stride = 2
else:
stride = 1
else:
stride = 1
video_length = len(video_reader)
if video_length < self.sample_n_frames*stride:
sample_n_frames = video_length - (video_length - 1) % (self.vae_time_compression_ratio*stride) # 4n+1/8n+1
else:
sample_n_frames = self.sample_n_frames*stride
start_idx = 0
batch_index = list(range(start_idx, start_idx + sample_n_frames, stride))
if len(batch_index) == 0:
print("get video len=0, skip")
return None, None, None, False
# Read frames
try:
video_images = video_reader.get_batch(batch_index).asnumpy()
except Exception as e:
print(f'Error: {e}, video_path: {video_item["video_path"]}')
raise
pixel_values = torch.from_numpy(video_images).permute(0, 3, 1, 2).contiguous()
del video_reader
return pixel_values, video_item["videoid"], video_item["video_path"], video_item["prompt"], True
def preprocess_train(self, frames):
height, width = frames.shape[-2:]
# Resize & Crop
if self.enable_multi_aspect_ratio:
bw, bh = self.get_closest_ratio(width=width, height=height, ratios=self.aspect_ratios, buckets=self.buckets)
sample_size = bh, bw
target_size = self.get_target_size(frames, sample_size)
train_crop = transforms.CenterCrop(sample_size) if self.is_center_crop else transforms.RandomCrop(sample_size)
else:
sample_size = self.sample_size
target_size = self.get_target_size(frames, sample_size)
train_crop = self.train_crop
frames = transforms.Resize(target_size, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True)(frames)
if self.is_center_crop:
y1 = max(0, int(round((height - sample_size[0]) / 2.0)))
x1 = max(0, int(round((width - sample_size[1]) / 2.0)))
frames = train_crop(frames)
else:
y1, x1, h, w = train_crop.get_params(frames, sample_size)
frames = crop(frames, y1, x1, h, w)
return frames
@staticmethod
def get_closest_ratio(width: float, height: float, ratios: list, buckets: list):
aspect_ratio = float(width) / float(height)
closest_ratio_id = np.abs(ratios - aspect_ratio).argmin()
return buckets[closest_ratio_id]
@staticmethod
def generate_crop_size_list(base_size=256, patch_size=16, max_ratio=4.0):
num_patches = round((base_size / patch_size) ** 2)
assert max_ratio >= 1.
crop_size_list = []
wp, hp = num_patches, 1
while wp > 0:
if max(wp, hp) / min(wp, hp) <= max_ratio:
crop_size_list.append((wp * patch_size, hp * patch_size))
if (hp + 1) * wp <= num_patches:
hp += 1
else:
wp -= 1
return crop_size_list
def get_target_size(self, frames, target_size):
T, C, H, W = frames.shape
th, tw = target_size
r = max(th / H, tw / W)
target_size = int(H * r), int(W * r)
return target_size
def __len__(self):
return self.length
def __getitem__(self, idx):
try:
pixel, videoid, video_path, prompt, valid = self.get_item(idx)
if pixel is not None and valid:
pixel = self.preprocess_train(pixel)
sample = dict(pixel_values=pixel, videoid=videoid, video_path=video_path, prompt=prompt,valid=valid)
return sample
except Exception as e:
print(e)
return dict(pixel_values=None, videoid=None, video_path=None, prompt=None, valid=False)
\ No newline at end of file
from typing import Tuple, List, Dict
import sys
from pathlib import Path
import argparse
import time
import os
import traceback
import random
import numpy as np
from einops import rearrange
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from dataset import VideoDataset, MultiBucketDataset, split_video_urls
import json
import glob
from omegaconf import OmegaConf
from hyvideo.vae import load_vae
DEVICE = "cuda"
DTYPE = torch.float16
def seed_everything(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
@torch.no_grad()
def extract(
vae: torch.nn.Module,
meta_files: List[str],
output_base_dir: str,
sample_n_frames: int,
target_size: Tuple[int, int],
enable_multi_aspect_ratio: bool = False,
use_stride: bool = False,
batch_size=None,
):
dataset = VideoDataset(
meta_files=meta_files,
latent_cache_dir=output_base_dir,
sample_size=target_size,
sample_n_frames=sample_n_frames,
is_center_crop=True,
enable_multi_aspect_ratio=enable_multi_aspect_ratio,
vae_time_compression_ratio=vae.time_compression_ratio,
use_stride=use_stride
)
if batch_size is not None:
dataset = MultiBucketDataset(dataset, batch_size=batch_size)
dataloader = DataLoader(
dataset,
batch_size=None,
collate_fn=dataset.collate_fn if batch_size is not None else None,
shuffle=False,
num_workers=8,
prefetch_factor=4,
pin_memory=False,
)
normalize_fn = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
save_json_path = Path(output_base_dir) / "json_path"
if not os.path.exists(save_json_path):
os.makedirs(save_json_path, exist_ok=True)
for i, item in enumerate(dataloader):
print(f"processing video latent extraction {i}")
if batch_size is None:
if item.get("valid", True) is False:
continue
item["videoid"] = [item["videoid"]]
item["valid"] = [item["valid"]]
item["prompt"] = [item["prompt"]]
try:
pixel_values = item["pixel_values"]
pixel_values = pixel_values.to(device=vae.device, dtype=vae.dtype)
pixel_values = pixel_values / 255.
pixel_values = normalize_fn(pixel_values)
if pixel_values.ndim == 4:
pixel_values = pixel_values.unsqueeze(0)
pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w")
z = vae.encode(pixel_values).latent_dist.mode()
z = z.detach().to(DTYPE).cpu().numpy()
assert z.shape[0] == len(item["videoid"])
for k in range(z.shape[0]):
save_path = Path(output_base_dir) / f"{item['videoid'][k]}.npy"
np.save(save_path, z[k][None, ...])
data = {"video_id": item["videoid"][k],
"latent_shape": z[k][None,...].shape,
"video_path": item["video_path"][k],
"prompt": item["prompt"][k],
"npy_save_path": str(save_path)}
with open(save_json_path / f"{item['videoid'][k]}.json", "w", encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False)
except Exception as e:
traceback.print_exc()
def main(
local_rank: int,
vae_path: str,
meta_files: str,
output_base_dir: str,
sample_n_frames: int,
target_size: Tuple[int, int],
enable_multi_aspect_ratio: bool = False,
use_stride: bool = False,
seed: int = 42,
):
seed_everything(seed)
global_rank = local_rank
world_size = int(os.environ["HOST_GPU_NUM"])
print(f"split video urls")
start, end, meta_files = split_video_urls(meta_files, global_rank, world_size)
print(f"Load VAE")
vae, vae_path, spatial_compression_ratio, time_compression_ratio = load_vae(
vae_type="884-16c-hy",
vae_precision='fp16',
vae_path=vae_path,
device=DEVICE,
)
# vae.enable_temporal_tiling()
vae.enable_spatial_tiling()
vae.eval()
print(f"processing video latent extraction")
extract(vae, meta_files, output_base_dir, sample_n_frames, target_size, enable_multi_aspect_ratio, use_stride)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int, required=True)
parser.add_argument("--config", default='./vae.yaml', type=str)
args = parser.parse_args()
config = OmegaConf.load(args.config)
vae_path = config.vae_path
sample_n_frames = config.sample_n_frames
target_size = [config.target_size[0], config.target_size[1]]
enable_multi_aspect_ratio = config.enable_multi_aspect_ratio
output_base_dir = config.output_base_dir
use_stride = config.use_stride
meta_files = config.video_url_files
main(args.local_rank, vae_path, meta_files, output_base_dir, sample_n_frames, target_size, enable_multi_aspect_ratio, use_stride)
\ No newline at end of file
export PYTHONPATH=${PYTHONPATH}:`pwd`
for ((i=0;i<$HOST_GPU_NUM;++i)); do
CUDA_VISIBLE_DEVICES=$i python3 -u hyvideo/hyvae_extract/run.py --local_rank $i --config 'hyvideo/hyvae_extract/vae.yaml'&
done
# CUDA_VISIBLE_DEVICES=0 python3 -u hyvideo/hyvae_extract/run.py --local_rank 0 --config 'hyvideo/hyvae_extract/vae.yaml'&
wait
echo "Finished."
vae_path: "./ckpts/hunyuan-video-i2v-720p/vae"
video_url_files: "/path/to/meta_file.list"
output_base_dir: "/path/to/output/directory"
sample_n_frames: 129
target_size:
- 480
- 480
enable_multi_aspect_ratio: True
use_stride: True
\ No newline at end of file
This diff is collapsed.
from .models import HYVideoDiffusionTransformer, HUNYUAN_VIDEO_CONFIG
def load_model(args, in_channels, out_channels, factor_kwargs):
"""load hunyuan video model
Args:
args (dict): model args
in_channels (int): input channels number
out_channels (int): output channels number
factor_kwargs (dict): factor kwargs
Returns:
model (nn.Module): The hunyuan video model
"""
if args.model in HUNYUAN_VIDEO_CONFIG.keys():
model = HYVideoDiffusionTransformer(
args,
in_channels=in_channels,
out_channels=out_channels,
**HUNYUAN_VIDEO_CONFIG[args.model],
**factor_kwargs,
)
return model
else:
raise NotImplementedError()
import torch.nn as nn
def get_activation_layer(act_type):
"""get activation layer
Args:
act_type (str): the activation type
Returns:
torch.nn.functional: the activation layer
"""
if act_type == "gelu":
return lambda: nn.GELU()
elif act_type == "gelu_tanh":
# Approximate `tanh` requires torch >= 1.13
return lambda: nn.GELU(approximate="tanh")
elif act_type == "relu":
return nn.ReLU
elif act_type == "silu":
return nn.SiLU
else:
raise ValueError(f"Unknown activation type: {act_type}")
import importlib.metadata
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
import flash_attn
from flash_attn.flash_attn_interface import _flash_attn_forward
from flash_attn.flash_attn_interface import flash_attn_varlen_func
except ImportError:
flash_attn = None
flash_attn_varlen_func = None
_flash_attn_forward = None
MEMORY_LAYOUT = {
"flash": (
lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
lambda x: x,
),
"torch": (
lambda x: x.transpose(1, 2),
lambda x: x.transpose(1, 2),
),
"vanilla": (
lambda x: x.transpose(1, 2),
lambda x: x.transpose(1, 2),
),
}
def get_cu_seqlens(text_mask, img_len):
"""Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len
Args:
text_mask (torch.Tensor): the mask of text
img_len (int): the length of image
Returns:
torch.Tensor: the calculated cu_seqlens for flash attention
"""
batch_size = text_mask.shape[0]
text_len = text_mask.sum(dim=1)
max_len = text_mask.shape[1] + img_len
cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
for i in range(batch_size):
s = text_len[i] + img_len
s1 = i * max_len + s
s2 = (i + 1) * max_len
cu_seqlens[2 * i + 1] = s1
cu_seqlens[2 * i + 2] = s2
return cu_seqlens
def attention(
q,
k,
v,
mode="flash",
drop_rate=0,
attn_mask=None,
causal=False,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
batch_size=1,
):
"""
Perform QKV self attention.
Args:
q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
k (torch.Tensor): Key tensor with shape [b, s1, a, d]
v (torch.Tensor): Value tensor with shape [b, s1, a, d]
mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
drop_rate (float): Dropout rate in attention map. (default: 0)
attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
(default: None)
causal (bool): Whether to use causal attention. (default: False)
cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
used to index into q.
cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
used to index into kv.
max_seqlen_q (int): The maximum sequence length in the batch of q.
max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
Returns:
torch.Tensor: Output tensor after self attention with shape [b, s, ad]
"""
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
q = pre_attn_layout(q)
k = pre_attn_layout(k)
v = pre_attn_layout(v)
if mode == "torch":
if attn_mask is not None and attn_mask.dtype != torch.bool:
attn_mask = attn_mask.to(q.dtype)
x = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
)
elif mode == "flash":
x = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
)
# x with shape [(bxs), a, d]
x = x.view(
batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]
) # reshape x to [b, s, a, d]
elif mode == "vanilla":
scale_factor = 1 / math.sqrt(q.size(-1))
b, a, s, _ = q.shape
s1 = k.size(2)
attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
if causal:
# Only applied to self attention
assert (
attn_mask is None
), "Causal mask and attn_mask cannot be used together"
temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(
diagonal=0
)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(q.dtype)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
# TODO: Maybe force q and k to be float32 to avoid numerical overflow
attn = (q @ k.transpose(-2, -1)) * scale_factor
attn += attn_bias
attn = attn.softmax(dim=-1)
attn = torch.dropout(attn, p=drop_rate, train=True)
x = attn @ v
else:
raise NotImplementedError(f"Unsupported attention mode: {mode}")
x = post_attn_layout(x)
b, s, a, d = x.shape
out = x.reshape(b, s, -1)
return out
def parallel_attention(
hybrid_seq_parallel_attn,
q,
k,
v,
img_q_len,
img_kv_len,
cu_seqlens_q,
cu_seqlens_kv
):
attn1 = hybrid_seq_parallel_attn(
None,
q[:, :img_q_len, :, :],
k[:, :img_kv_len, :, :],
v[:, :img_kv_len, :, :],
dropout_p=0.0,
causal=False,
joint_tensor_query=q[:,img_q_len:cu_seqlens_q[1]],
joint_tensor_key=k[:,img_kv_len:cu_seqlens_kv[1]],
joint_tensor_value=v[:,img_kv_len:cu_seqlens_kv[1]],
joint_strategy="rear",
)
if flash_attn.__version__ >= '2.7.0':
attn2, *_ = _flash_attn_forward(
q[:,cu_seqlens_q[1]:],
k[:,cu_seqlens_kv[1]:],
v[:,cu_seqlens_kv[1]:],
dropout_p=0.0,
softmax_scale=q.shape[-1] ** (-0.5),
causal=False,
window_size_left=-1,
window_size_right=-1,
softcap=0.0,
alibi_slopes=None,
return_softmax=False,
)
else:
attn2, *_ = _flash_attn_forward(
q[:,cu_seqlens_q[1]:],
k[:,cu_seqlens_kv[1]:],
v[:,cu_seqlens_kv[1]:],
dropout_p=0.0,
softmax_scale=q.shape[-1] ** (-0.5),
causal=False,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
return_softmax=False,
)
attn = torch.cat([attn1, attn2], dim=1)
b, s, a, d = attn.shape
attn = attn.reshape(b, s, -1)
return attn
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment