Commit e08c4f90 authored by sandy's avatar sandy Committed by GitHub
Browse files

Merge branch 'main' into audio_r2v

parents 12bfd120 6d07a72e
# 步数蒸馏
xxx
步数蒸馏是 LightX2V 中的一项重要优化技术,通过训练蒸馏模型将推理步数从原始的 40-50 步大幅减少到 **4 步**,在保持视频质量的同时显著提升推理速度。LightX2V 在实现步数蒸馏的同时也加入了 CFG 蒸馏,进一步提升推理速度。
## 🔍 技术原理
步数蒸馏通过 [Self-Forcing](https://github.com/guandeh17/Self-Forcing) 技术实现。Self-Forcing 针对 1.3B 的自回归模型进行步数蒸馏、CFG蒸馏。LightX2V 在其基础上,进行了一系列扩展:
1. **更大的模型**:支持 14B 模型的步数蒸馏训练;
2. **更多的模型**:支持标准的双向模型,以及 I2V 模型的步数蒸馏训练;
具体实现可参考 [Self-Forcing-Plus](https://github.com/GoatWu/Self-Forcing-Plus)
## 🎯 技术特性
- **推理加速**:推理步数从 40-50 步减少到 4 步且无需 CFG,速度提升约 **20-24x**
- **质量保持**:通过蒸馏技术保持原有的视频生成质量
- **兼容性强**:支持 T2V 和 I2V 任务
- **使用灵活**:支持加载完整步数蒸馏模型,或者在原生模型的基础上加载步数蒸馏LoRA
## 🛠️ 配置文件说明
### 基础配置文件
[configs/distill/](https://github.com/ModelTC/lightx2v/tree/main/configs/distill) 目录下提供了多种配置选项:
| 配置文件 | 用途 | 模型地址 |
|----------|------|------------|
| [wan_t2v_distill_4step_cfg.json](https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_t2v_distill_4step_cfg.json) | 加载 T2V 4步蒸馏完整模型 | TODO |
| [wan_i2v_distill_4step_cfg.json](https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_i2v_distill_4step_cfg.json) | 加载 I2V 4步蒸馏完整模型 | TODO |
| [wan_t2v_distill_4step_cfg_lora.json](https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_t2v_distill_4step_cfg_lora.json) | 加载 Wan-T2V 模型和步数蒸馏 LoRA | TODO |
| [wan_i2v_distill_4step_cfg_lora.json](https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_i2v_distill_4step_cfg_lora.json) | 加载 Wan-I2V 模型和步数蒸馏 LoRA | TODO |
### 关键配置参数
```json
{
"infer_steps": 4, // 推理步数
"denoising_step_list": [999, 750, 500, 250], // 去噪时间步列表
"enable_cfg": false, // 关闭CFG以提升速度
"lora_configs": [ // LoRA权重路径(可选)
{
"path": "path/to/distill_lora.safetensors",
"strength": 1.0
}
]
}
```
## 📜 使用方法
### 模型准备
**完整模型:**
将下载好的模型(`distill_model.pt` 或者 `distill_model.safetensors`)放到 Wan 模型根目录的 `distill_models/` 文件夹下即可
- 对于 T2V:`Wan2.1-T2V-14B/distill_models/`
- 对于 I2V-480P:`Wan2.1-I2V-14B-480P/distill_models/`
**LoRA:**
1. 将下载好的 LoRA 放到任意位置
2. 修改配置文件中的 `lora_path` 参数为 LoRA 存放路径即可
### 推理脚本
**T2V 完整模型:**
```bash
bash scripts/wan/run_wan_t2v_distill_4step_cfg.sh
```
**I2V 完整模型:**
```bash
bash scripts/wan/run_wan_i2v_distill_4step_cfg.sh
```
### 步数蒸馏 LoRA 推理脚本
**T2V LoRA:**
```bash
bash scripts/wan/run_wan_t2v_distill_4step_cfg_lora.sh
```
**I2V LoRA:**
```bash
bash scripts/wan/run_wan_i2v_distill_4step_cfg_lora.sh
```
## 🔧 服务化部署
### 启动蒸馏模型服务
[scripts/server/start_server.sh](https://github.com/ModelTC/lightx2v/blob/main/scripts/server/start_server.sh) 中的启动命令进行修改:
```bash
python -m lightx2v.api_server \
--model_cls wan2.1_distill \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/distill/wan_t2v_distill_4step_cfg.json \
--port 8000 \
--nproc_per_node 1
```
运行服务启动脚本:
```bash
scripts/server/start_server.sh
```
更多详细信息见[服务化部署](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/deploy_service.html)
### 在 Gradio 界面中使用
[Gradio 文档](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/deploy_gradio.html)
......@@ -36,6 +36,7 @@ def main():
choices=[
"wan2.1",
"hunyuan",
"wan2.1_distill",
"wan2.1_causvid",
"wan2.1_skyreels_v2_df",
"wan2.1_audio",
......@@ -48,6 +49,7 @@ def main():
parser.add_argument("--split", action="store_true")
parser.add_argument("--lora_path", type=str, required=False, default=None)
parser.add_argument("--lora_strength", type=float, default=1.0, help="The strength for the lora (default: 1.0)")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--nproc_per_node", type=int, default=1, help="Number of processes per node for distributed inference")
......@@ -55,7 +57,7 @@ def main():
args = parser.parse_args()
logger.info(f"args: {args}")
cache_dir = Path(__file__).parent.parent / ".cache"
cache_dir = Path(__file__).parent.parent / "server_cache"
inference_service = DistributedInferenceService()
api_server = ApiServer()
......
......@@ -121,7 +121,6 @@ if __name__ == "__main__":
with ProfilingContext("Init Server Cost"):
config = set_config(args)
config["mode"] = "split_server"
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
runner = DiTRunner(config)
......
......@@ -116,7 +116,6 @@ if __name__ == "__main__":
with ProfilingContext("Init Server Cost"):
config = set_config(args)
config["mode"] = "split_server"
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
runner = ImageEncoderRunner(config)
......
......@@ -119,7 +119,6 @@ if __name__ == "__main__":
with ProfilingContext("Init Server Cost"):
config = set_config(args)
config["mode"] = "split_server"
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
runner = TextEncoderRunner(config)
......
......@@ -168,7 +168,6 @@ if __name__ == "__main__":
with ProfilingContext("Init Server Cost"):
config = set_config(args)
config["mode"] = "split_server"
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
runner = VAERunner(config)
......
......@@ -15,6 +15,7 @@ class WeightAsyncStreamManager(object):
self.cuda_load_stream = torch.cuda.Stream(priority=0)
self.offload_block_num = int(offload_ratio * blocks_num)
self.phases_num = phases_num
self.block_nums = blocks_num
self.offload_phases_num = blocks_num * phases_num * offload_ratio
def prefetch_weights(self, block_idx, blocks_weights):
......@@ -121,12 +122,16 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
except Exception as e:
logger.error(f"Disk worker thread error: {e}")
def _async_prefetch_block(self, weights):
next_block_idx = self.pin_memory_buffer.get_max_block_index()
def _async_prefetch_block(self, blocks, next_block_idx=None):
if next_block_idx is None:
next_block_idx = self.pin_memory_buffer.get_max_block_index()
if next_block_idx < 0:
next_block_idx = 0
if next_block_idx == self.block_nums:
return
if self.offload_gra == "phase":
for phase_idx in range(self.phases_num):
obj_key = (next_block_idx, phase_idx)
......@@ -137,7 +142,7 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
with self.task_lock:
self.pending_tasks[obj_key] = True
phase = weights.blocks[next_block_idx].compute_phases[phase_idx]
phase = blocks[next_block_idx].compute_phases[phase_idx]
priority_key = (next_block_idx, phase_idx)
self.disk_task_queue.put((priority_key, (next_block_idx, phase_idx, phase)))
......@@ -149,32 +154,34 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
with self.task_lock:
self.pending_tasks[obj_key] = True
block = weights.blocks[next_block_idx]
block = blocks[next_block_idx]
self.disk_task_queue.put((obj_key, (next_block_idx, block)))
def _sync_prefetch_block(self, weights):
def _sync_prefetch_block(self, blocks):
block_idx = 0
while not self.pin_memory_buffer.is_nearly_full():
if self.offload_gra == "phase":
for phase_idx in range(self.phases_num):
phase = weights.blocks[block_idx].compute_phases[phase_idx]
phase = blocks[block_idx].compute_phases[phase_idx]
logger.info(f"Synchronous loading: block={block_idx}, phase={phase_idx}")
phase.load_from_disk()
self.pin_memory_buffer.push((block_idx, phase_idx), phase)
else:
block = weights.blocks[block_idx]
block = blocks[block_idx]
logger.info(f"Synchronous loading: block={block_idx}")
for phase in block.compute_phases:
phase.load_from_disk()
self.pin_memory_buffer.push(block_idx, block)
block_idx += 1
if block_idx == self.block_nums:
break
def prefetch_weights_from_disk(self, weights):
def prefetch_weights_from_disk(self, blocks):
if self.initial_prefetch_done:
return
self._sync_prefetch_block(weights)
self._sync_prefetch_block(blocks)
self.initial_prefetch_done = True
def prefetch_weights(self, block_idx, blocks):
......@@ -193,7 +200,15 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
if time.time() - start_time > 5:
raise TimeoutError(f"Load timeout: block={block_idx}")
else:
logger.info("Not find prefetch block={block_idx} task. This is a bug.")
logger.info("Not find prefetch block={block_idx} task.")
logger.info("Sync prefetch block={block_idx}.")
self._async_prefetch_block(blocks, block_idx)
start_time = time.time()
for phase_idx in self.phases_num:
while not self.pin_memory_buffer.exists((block_idx, phase_idx)):
time.sleep(0.001)
if time.time() - start_time > 15:
raise TimeoutError(f"Load timeout: block={block_idx}, phase={phase_idx}")
with torch.cuda.stream(self.cuda_load_stream):
block = self.pin_memory_buffer.get(obj_key)
......@@ -224,7 +239,14 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
if time.time() - start_time > 5:
raise TimeoutError(f"Load timeout: block={block_idx}, phase={phase_idx}")
else:
logger.info("Not find prefetch block={block_idx}, phase={phase_idx} task. This is a bug.")
logger.info(f"Not find block={block_idx}, phase={phase_idx} task.")
logger.info(f"Sync prefetch block={block_idx}, phase={phase_idx}.")
self._async_prefetch_block(blocks, block_idx)
start_time = time.time()
while not self.pin_memory_buffer.exists((block_idx, phase_idx)):
time.sleep(0.001)
if time.time() - start_time > 5:
raise TimeoutError(f"Load timeout: block={block_idx}, phase={phase_idx}")
with torch.cuda.stream(self.cuda_load_stream):
phase = self.pin_memory_buffer.get(obj_key)
......
......@@ -23,7 +23,7 @@ except ImportError:
logger.info("flash_attn_varlen_func_v3 not found, please install flash_attn3 first")
flash_attn_varlen_func_v3 = None
if torch.cuda.get_device_capability(0)[0] <= 8 and torch.cuda.get_device_capability(0)[1] <= 9:
if torch.cuda.get_device_capability(0) == (8, 9):
try:
from sageattention import sageattn_qk_int8_pv_fp16_triton as sageattn
except ImportError:
......
......@@ -56,3 +56,10 @@ class Conv2dWeight(Conv2dWeightTemplate):
if self.bias is not None:
destination[self.bias_name] = self.bias.cpu().detach().clone()
return destination
def clear(self):
attrs = ["weight", "bias", "pinned_weight", "pinned_bias"]
for attr in attrs:
if hasattr(self, attr):
delattr(self, attr)
setattr(self, attr, None)
......@@ -66,3 +66,10 @@ class Conv3dWeight(Conv3dWeightTemplate):
if self.bias is not None:
destination[self.bias_name] = self.bias.cpu().detach().clone()
return destination
def clear(self):
attrs = ["weight", "bias", "pinned_weight", "pinned_bias"]
for attr in attrs:
if hasattr(self, attr):
delattr(self, attr)
setattr(self, attr, None)
......@@ -145,7 +145,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.pinned_weight = self.pinned_weight.t()
def clear(self):
attrs = ["weight", "weight_scale", "bias"]
attrs = ["weight", "weight_scale", "bias", "pinned_weight", "pinned_weight_scale", "pinned_bias"]
for attr in attrs:
if hasattr(self, attr):
delattr(self, attr)
......
......@@ -34,9 +34,11 @@ class LNWeightTemplate(metaclass=ABCMeta):
return self.weight.numel() * self.weight.element_size()
def clear(self):
del self.weight
if self.bias is not None:
del self.bias
attrs = ["weight", "bias", "pinned_weight", "pinned_bias"]
for attr in attrs:
if hasattr(self, attr):
delattr(self, attr)
setattr(self, attr, None)
@abstractmethod
def apply(self, input_tensor):
......
......@@ -23,7 +23,11 @@ class RMSWeightTemplate(metaclass=ABCMeta):
self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
def clear(self):
del self.weight
attrs = ["weight", "pinned_weight"]
for attr in attrs:
if hasattr(self, attr):
delattr(self, attr)
setattr(self, attr, None)
@abstractmethod
def apply(self, input_tensor):
......
......@@ -22,7 +22,11 @@ class DefaultTensor:
self.pinned_tensor = torch.empty(self.tensor.shape, pin_memory=True, dtype=self.tensor.dtype)
def clear(self):
del self.tensor
attrs = ["tensor", "pinned_tensor"]
for attr in attrs:
if hasattr(self, attr):
delattr(self, attr)
setattr(self, attr, None)
def _calculate_size(self):
return self.tensor.numel() * self.tensor.element_size()
......
import asyncio
import argparse
import torch
import torch.distributed as dist
......@@ -40,11 +39,12 @@ def init_runner(config):
return runner
async def main():
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_distill", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox", "wan2.1_audio"], default="hunyuan"
"--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_distill", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox", "wan2.1_audio"], default="wan2.1"
)
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True)
......@@ -52,35 +52,27 @@ async def main():
parser.add_argument("--prompt", type=str, default="", help="The input prompt for text-to-video generation")
parser.add_argument("--negative_prompt", type=str, default="")
parser.add_argument("--lora_path", type=str, default="", help="The lora file path")
parser.add_argument("--prompt_path", type=str, default="", help="The path to input prompt file")
parser.add_argument("--audio_path", type=str, default="", help="The path to input audio file")
parser.add_argument("--image_path", type=str, default="", help="The path to input image file or path for image-to-video (i2v) task")
parser.add_argument("--image_path", type=str, default="", help="The path to input image file for image-to-video (i2v) task")
parser.add_argument("--audio_path", type=str, default="", help="The path to input audio file for audio-to-video (a2v) task")
parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file")
args = parser.parse_args()
if args.prompt_path:
try:
with open(args.prompt_path, "r", encoding="utf-8") as f:
args.prompt = f.read().strip()
logger.info(f"从文件 {args.prompt_path} 读取到prompt: {args.prompt}")
except FileNotFoundError:
logger.error(f"找不到prompt文件: {args.prompt_path}")
raise
except Exception as e:
logger.error(f"读取prompt文件时出错: {e}")
raise
logger.info(f"args: {args}")
with ProfilingContext("Total Cost"):
config = set_config(args)
config["mode"] = "infer"
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
runner = init_runner(config)
await runner.run_pipeline()
runner.run_pipeline()
# Clean up distributed process group
if dist.is_initialized():
dist.destroy_process_group()
logger.info("Distributed process group cleaned up")
if __name__ == "__main__":
asyncio.run(main())
main()
......@@ -151,12 +151,3 @@ class TextEncoderHFLlavaModel:
if config.cpu_offload:
self.to_cpu()
return last_hidden_state, attention_mask
if __name__ == "__main__":
model = TextEncoderHFLlavaModel("/mtc/yongyang/models/x2v_models/hunyuan/lightx2v_format/i2v/text_encoder_i2v", torch.device("cuda"))
text = "An Asian man with short hair in black tactical uniform and white clothes waves a firework stick."
img_path = "/mtc/yongyang/projects/lightx2v/assets/inputs/imgs/img_1.jpg"
img = Image.open(img_path).convert("RGB")
outputs = model.infer(text, img, None)
logger.info(outputs)
......@@ -2,14 +2,9 @@ import torch
import torch.nn as nn
from vllm import _custom_ops as ops
try:
import q8_kernels.functional as Q8F
except ImportError:
Q8F = None
class QuantLinearInt8(nn.Module):
def __init__(self, in_features, out_features, bias=True):
def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16):
super().__init__()
self.in_features = in_features
self.out_features = out_features
......@@ -18,7 +13,7 @@ class QuantLinearInt8(nn.Module):
self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float32))
if bias:
self.register_buffer("bias", torch.empty(out_features, dtype=torch.float32))
self.register_buffer("bias", torch.empty(out_features, dtype=dtype))
else:
self.register_buffer("bias", None)
......@@ -44,18 +39,30 @@ class QuantLinearInt8(nn.Module):
)
return output_tensor.unsqueeze(0)
def _apply(self, fn):
for module in self.children():
module._apply(fn)
def maybe_cast(t):
if t is not None and t.device != fn(t).device:
return fn(t)
return t
self.weight = maybe_cast(self.weight)
self.weight_scale = maybe_cast(self.weight_scale)
self.bias = maybe_cast(self.bias)
return self
class QuantLinearFp8(nn.Module):
def __init__(self, in_features, out_features, bias=True):
def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.float8_e4m3fn))
self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float32))
if bias:
self.register_buffer("bias", torch.empty(out_features, dtype=torch.float32))
self.register_buffer("bias", torch.empty(out_features, dtype=dtype))
else:
self.register_buffer("bias", None)
......@@ -65,7 +72,6 @@ class QuantLinearFp8(nn.Module):
def forward(self, input_tensor):
input_tensor = input_tensor.squeeze(0)
self.weight = self.weight.to(torch.float8_e4m3fn)
shape = (input_tensor.shape[0], self.weight.shape[0])
dtype = input_tensor.dtype
device = input_tensor.device
......@@ -79,4 +85,19 @@ class QuantLinearFp8(nn.Module):
self.weight_scale.float(),
self.bias,
)
return output_tensor.unsqueeze(0)
def _apply(self, fn):
for module in self.children():
module._apply(fn)
def maybe_cast(t):
if t is not None and t.device != fn(t).device:
return fn(t)
return t
self.weight = maybe_cast(self.weight)
self.weight_scale = maybe_cast(self.weight_scale)
self.bias = maybe_cast(self.bias)
return self
......@@ -27,6 +27,14 @@ def fp16_clamp(x):
return x
def optimize_memory_usage():
if torch.cuda.is_available():
torch.cuda.empty_cache()
import gc
gc.collect()
def init_weights(m):
if isinstance(m, T5LayerNorm):
nn.init.ones_(m.weight)
......@@ -51,11 +59,11 @@ class GELU(nn.Module):
class T5LayerNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
def __init__(self, dim, eps=1e-6, dtype=torch.float16):
super(T5LayerNorm, self).__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
self.weight = nn.Parameter(torch.ones(dim, dtype=dtype))
def forward(self, x):
x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps)
......@@ -65,7 +73,7 @@ class T5LayerNorm(nn.Module):
class T5Attention(nn.Module):
def __init__(self, dim, dim_attn, num_heads, dropout=0.1, quantized=False, quant_scheme=None):
def __init__(self, dim, dim_attn, num_heads, dropout=0.1, quantized=False, quant_scheme=None, dtype=torch.bfloat16):
assert dim_attn % num_heads == 0
super(T5Attention, self).__init__()
self.dim = dim
......@@ -82,10 +90,10 @@ class T5Attention(nn.Module):
linear_cls = nn.Linear
# layers
self.q = linear_cls(dim, dim_attn, bias=False)
self.k = linear_cls(dim, dim_attn, bias=False)
self.v = linear_cls(dim, dim_attn, bias=False)
self.o = linear_cls(dim_attn, dim, bias=False)
self.q = linear_cls(dim, dim_attn, bias=False, dtype=dtype)
self.k = linear_cls(dim, dim_attn, bias=False, dtype=dtype)
self.v = linear_cls(dim, dim_attn, bias=False, dtype=dtype)
self.o = linear_cls(dim_attn, dim, bias=False, dtype=dtype)
self.dropout = nn.Dropout(dropout)
def forward(self, x, context=None, mask=None, pos_bias=None):
......@@ -114,10 +122,14 @@ class T5Attention(nn.Module):
# compute attention (T5 does not use scaling)
attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias
if hasattr(self, "cpu_offload") and self.cpu_offload:
del attn_bias
attn = F.softmax(attn.float(), dim=-1).to(torch.bfloat16)
x = torch.einsum("bnij,bjnc->binc", attn, v)
# output
if hasattr(self, "cpu_offload") and self.cpu_offload:
del attn
x = x.reshape(b, -1, n * c)
x = self.o(x)
x = self.dropout(x)
......@@ -125,7 +137,7 @@ class T5Attention(nn.Module):
class T5FeedForward(nn.Module):
def __init__(self, dim, dim_ffn, dropout=0.1, quantized=False, quant_scheme=None):
def __init__(self, dim, dim_ffn, dropout=0.1, quantized=False, quant_scheme=None, dtype=torch.bfloat16):
super(T5FeedForward, self).__init__()
self.dim = dim
self.dim_ffn = dim_ffn
......@@ -138,13 +150,20 @@ class T5FeedForward(nn.Module):
else:
linear_cls = nn.Linear
# layers
self.gate = nn.Sequential(linear_cls(dim, dim_ffn, bias=False), GELU())
self.fc1 = linear_cls(dim, dim_ffn, bias=False)
self.fc2 = linear_cls(dim_ffn, dim, bias=False)
self.gate = nn.Sequential(linear_cls(dim, dim_ffn, bias=False, dtype=dtype), GELU())
self.fc1 = linear_cls(dim, dim_ffn, bias=False, dtype=dtype)
self.fc2 = linear_cls(dim_ffn, dim, bias=False, dtype=dtype)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.fc1(x) * self.gate(x)
if hasattr(self, "cpu_offload") and self.cpu_offload:
gate_out = self.gate(x)
fc1_out = self.fc1(x)
x = fc1_out * gate_out
del gate_out, fc1_out
else:
x = self.fc1(x) * self.gate(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
......@@ -152,7 +171,7 @@ class T5FeedForward(nn.Module):
class T5SelfAttention(nn.Module):
def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1, quantized=False, quant_scheme=None):
def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1, quantized=False, quant_scheme=None, dtype=torch.bfloat16):
super(T5SelfAttention, self).__init__()
self.dim = dim
self.dim_attn = dim_attn
......@@ -162,16 +181,27 @@ class T5SelfAttention(nn.Module):
self.shared_pos = shared_pos
# layers
self.norm1 = T5LayerNorm(dim)
self.attn = T5Attention(dim, dim_attn, num_heads, dropout, quantized, quant_scheme)
self.norm2 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn, dropout, quantized, quant_scheme)
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
self.norm1 = T5LayerNorm(dim, dtype=dtype)
self.attn = T5Attention(dim, dim_attn, num_heads, dropout, quantized, quant_scheme, dtype)
self.norm2 = T5LayerNorm(dim, dtype=dtype)
self.ffn = T5FeedForward(dim, dim_ffn, dropout, quantized, quant_scheme, dtype=dtype)
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True, dtype=dtype)
def forward(self, x, mask=None, pos_bias=None):
e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
x = fp16_clamp(x + self.ffn(self.norm2(x)))
if hasattr(self, "cpu_offload") and self.cpu_offload:
attn_out = self.attn(self.norm1(x), mask=mask, pos_bias=e)
x = fp16_clamp(x + attn_out)
del attn_out
ffn_out = self.ffn(self.norm2(x))
x = fp16_clamp(x + ffn_out)
del ffn_out
else:
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
x = fp16_clamp(x + self.ffn(self.norm2(x)))
return x
......@@ -212,7 +242,7 @@ class T5CrossAttention(nn.Module):
class T5RelativeEmbedding(nn.Module):
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
def __init__(self, num_buckets, num_heads, bidirectional, dtype=torch.bfloat16, max_dist=128):
super(T5RelativeEmbedding, self).__init__()
self.num_buckets = num_buckets
self.num_heads = num_heads
......@@ -220,7 +250,7 @@ class T5RelativeEmbedding(nn.Module):
self.max_dist = max_dist
# layers
self.embedding = nn.Embedding(num_buckets, num_heads)
self.embedding = nn.Embedding(num_buckets, num_heads, dtype=dtype)
def forward(self, lq, lk):
device = self.embedding.weight.device
......@@ -252,7 +282,7 @@ class T5RelativeEmbedding(nn.Module):
class T5Encoder(nn.Module):
def __init__(self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1, cpu_offload=False, quantized=False, quant_scheme=None):
def __init__(self, dtype, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1, cpu_offload=False, quantized=False, quant_scheme=None):
super(T5Encoder, self).__init__()
self.cpu_offload = cpu_offload
......@@ -266,11 +296,17 @@ class T5Encoder(nn.Module):
self.quant_scheme = quant_scheme
# layers
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) if shared_pos else None
self.token_embedding = vocab.to(dtype) if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim, dtype=dtype)
self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True, dtype=dtype) if shared_pos else None
self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList([T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout, quantized, quant_scheme) for _ in range(num_layers)])
self.norm = T5LayerNorm(dim)
self.blocks = nn.ModuleList([T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout, quantized, quant_scheme, dtype) for _ in range(num_layers)])
if cpu_offload:
for block in self.blocks:
block.cpu_offload = cpu_offload
block.attn.cpu_offload = cpu_offload
block.ffn.cpu_offload = cpu_offload
self.norm = T5LayerNorm(dim, dtype=dtype)
# initialize weights
# self.apply(init_weights)
......@@ -281,23 +317,32 @@ class T5Encoder(nn.Module):
x = self.token_embedding(ids)
if self.cpu_offload:
self.token_embedding = self.token_embedding.cpu()
optimize_memory_usage()
x = self.dropout(x)
if self.cpu_offload and self.pos_embedding is not None:
self.pos_embedding = self.pos_embedding.cuda()
e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
if self.cpu_offload and self.pos_embedding is not None:
self.pos_embedding = self.pos_embedding.cpu()
for block in self.blocks:
optimize_memory_usage()
for i, block in enumerate(self.blocks):
if self.cpu_offload:
block = block.cuda()
x = block(x, mask, pos_bias=e)
if self.cpu_offload:
block = block.cpu()
del block
optimize_memory_usage()
if self.cpu_offload:
self.norm = self.norm.cuda()
x = self.norm(x)
if self.cpu_offload:
self.norm = self.norm.cpu()
optimize_memory_usage()
x = self.dropout(x)
return x.to(torch.bfloat16)
......@@ -443,10 +488,10 @@ def _t5(
# init model
with torch.device(device):
model = model_cls(**kwargs)
model = model_cls(dtype=dtype, **kwargs)
# set device
model = model.to(dtype=dtype, device=device)
model = model.to(device=device)
return model
......@@ -511,9 +556,10 @@ class T5EncoderModel:
.requires_grad_(False)
)
logger.info(f"Loading weights from {self.checkpoint_path}")
logger.info(f"Start Loading weights from {self.checkpoint_path}")
model.load_state_dict(torch.load(self.checkpoint_path, map_location="cpu", weights_only=True))
logger.info(f"End Loading weights from {self.checkpoint_path}")
self.model = model
if shard_fn is not None:
self.model = shard_fn(self.model, sync_module_states=False)
......@@ -528,6 +574,10 @@ class T5EncoderModel:
def to_cuda(self):
self.model = self.model.to("cuda")
def optimize_memory(self):
"""优化内存使用"""
optimize_memory_usage()
def infer(self, texts):
if self.cpu_offload and self.offload_granularity == "model":
self.to_cuda()
......@@ -536,10 +586,17 @@ class T5EncoderModel:
ids = ids.cuda()
mask = mask.cuda()
seq_lens = mask.gt(0).sum(dim=1).long()
context = self.model(ids, mask)
with torch.no_grad():
context = self.model(ids, mask)
if self.cpu_offload and self.offload_granularity == "model":
self.to_cpu()
optimize_memory_usage()
del ids, mask
if self.cpu_offload:
optimize_memory_usage()
return [u[:v] for u, v in zip(context, seq_lens)]
......
......@@ -50,7 +50,7 @@ class LayerNorm(nn.LayerNorm):
class SelfAttention(nn.Module):
def __init__(self, dim, num_heads, causal=False, attn_dropout=0.0, proj_dropout=0.0, quantized=False, quant_scheme=None):
def __init__(self, dim, num_heads, causal=False, attn_dropout=0.0, proj_dropout=0.0, quantized=False, quant_scheme=None, dtype=None):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
......@@ -69,8 +69,8 @@ class SelfAttention(nn.Module):
else:
linear_cls = nn.Linear
self.to_qkv = linear_cls(dim, dim * 3)
self.proj = linear_cls(dim, dim)
self.to_qkv = linear_cls(dim, dim * 3, dtype=dtype)
self.proj = linear_cls(dim, dim, dtype=dtype)
def forward(self, x):
"""
......@@ -108,7 +108,21 @@ class SwiGLU(nn.Module):
class AttentionBlock(nn.Module):
def __init__(self, dim, mlp_ratio, num_heads, post_norm=False, causal=False, activation="quick_gelu", attn_dropout=0.0, proj_dropout=0.0, norm_eps=1e-5, quantized=False, quant_scheme=None):
def __init__(
self,
dim,
mlp_ratio,
num_heads,
post_norm=False,
causal=False,
activation="quick_gelu",
attn_dropout=0.0,
proj_dropout=0.0,
norm_eps=1e-5,
quantized=False,
quant_scheme=None,
dtype=torch.float16,
):
assert activation in ["quick_gelu", "gelu", "swi_glu"]
super().__init__()
self.dim = dim
......@@ -127,13 +141,18 @@ class AttentionBlock(nn.Module):
else:
linear_cls = nn.Linear
self.norm1 = LayerNorm(dim, eps=norm_eps)
self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, proj_dropout, quantized, quant_scheme)
self.norm2 = LayerNorm(dim, eps=norm_eps)
self.norm1 = LayerNorm(dim, eps=norm_eps, dtype=dtype)
self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, proj_dropout, quantized, quant_scheme, dtype)
self.norm2 = LayerNorm(dim, eps=norm_eps, dtype=dtype)
if activation == "swi_glu":
self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
self.mlp = SwiGLU(dim, int(dim * mlp_ratio), dtype=dtype)
else:
self.mlp = nn.Sequential(linear_cls(dim, int(dim * mlp_ratio)), QuickGELU() if activation == "quick_gelu" else nn.GELU(), linear_cls(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
self.mlp = nn.Sequential(
linear_cls(dim, int(dim * mlp_ratio), dtype=dtype),
QuickGELU() if activation == "quick_gelu" else nn.GELU(),
linear_cls(int(dim * mlp_ratio), dim, dtype=dtype),
nn.Dropout(proj_dropout),
)
def forward(self, x):
if self.post_norm:
......@@ -146,7 +165,7 @@ class AttentionBlock(nn.Module):
class AttentionPool(nn.Module):
def __init__(self, dim, mlp_ratio, num_heads, activation="gelu", proj_dropout=0.0, norm_eps=1e-5):
def __init__(self, dim, mlp_ratio, num_heads, activation="gelu", proj_dropout=0.0, norm_eps=1e-5, dtype=torch.float16):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
......@@ -159,11 +178,13 @@ class AttentionPool(nn.Module):
# layers
gain = 1.0 / math.sqrt(dim)
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
self.to_q = nn.Linear(dim, dim)
self.to_kv = nn.Linear(dim, dim * 2)
self.proj = nn.Linear(dim, dim)
self.norm = LayerNorm(dim, eps=norm_eps)
self.mlp = nn.Sequential(nn.Linear(dim, int(dim * mlp_ratio)), QuickGELU() if activation == "quick_gelu" else nn.GELU(), nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
self.to_q = nn.Linear(dim, dim, dtype=dtype)
self.to_kv = nn.Linear(dim, dim * 2, dtype=dtype)
self.proj = nn.Linear(dim, dim, dtype=dtype)
self.norm = LayerNorm(dim, eps=norm_eps, dtype=dtype)
self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio), dtype=dtype), QuickGELU() if activation == "quick_gelu" else nn.GELU(), nn.Linear(int(dim * mlp_ratio), dim, dtype=dtype), nn.Dropout(proj_dropout)
)
def forward(self, x):
"""
......@@ -191,6 +212,7 @@ class AttentionPool(nn.Module):
class VisionTransformer(nn.Module):
def __init__(
self,
dtype=torch.float16,
image_size=224,
patch_size=16,
dim=768,
......@@ -228,26 +250,26 @@ class VisionTransformer(nn.Module):
# embeddings
gain = 1.0 / math.sqrt(dim)
self.patch_embedding = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size, bias=not pre_norm)
self.patch_embedding = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size, bias=not pre_norm, dtype=dtype)
if pool_type in ("token", "token_fc"):
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
self.pos_embedding = nn.Parameter(gain * torch.randn(1, self.num_patches + (1 if pool_type in ("token", "token_fc") else 0), dim))
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim, dtype=dtype))
self.pos_embedding = nn.Parameter(gain * torch.randn(1, self.num_patches + (1 if pool_type in ("token", "token_fc") else 0), dim, dtype=dtype))
self.dropout = nn.Dropout(embedding_dropout)
# transformer
self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
self.pre_norm = LayerNorm(dim, eps=norm_eps, dtype=dtype) if pre_norm else None
self.transformer = nn.Sequential(
*[AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False, activation, attn_dropout, proj_dropout, norm_eps, quantized, quant_scheme) for _ in range(num_layers)]
*[AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False, activation, attn_dropout, proj_dropout, norm_eps, quantized, quant_scheme, dtype) for _ in range(num_layers)]
)
self.post_norm = LayerNorm(dim, eps=norm_eps)
self.post_norm = LayerNorm(dim, eps=norm_eps, dtype=dtype)
# head
if pool_type == "token":
self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
self.head = nn.Parameter(gain * torch.randn(dim, out_dim, dtype=dtype))
elif pool_type == "token_fc":
self.head = nn.Linear(dim, out_dim)
self.head = nn.Linear(dim, out_dim, dtype=dtype)
elif pool_type == "attn_pool":
self.head = AttentionPool(dim, mlp_ratio, num_heads, activation, proj_dropout, norm_eps)
self.head = AttentionPool(dim, mlp_ratio, num_heads, activation, proj_dropout, norm_eps, dtype=dtype)
def forward(self, x, interpolation=False, use_31_block=False):
b = x.size(0)
......@@ -276,6 +298,7 @@ class VisionTransformer(nn.Module):
class XLMRobertaCLIP(nn.Module):
def __init__(
self,
dtype=torch.float16,
embed_dim=1024,
image_size=224,
patch_size=14,
......@@ -317,6 +340,7 @@ class XLMRobertaCLIP(nn.Module):
# models
self.visual = VisionTransformer(
dtype=dtype,
image_size=image_size,
patch_size=patch_size,
dim=vision_dim,
......@@ -341,12 +365,11 @@ class XLMRobertaCLIP(nn.Module):
def _clip(pretrained=False, pretrained_name=None, model_cls=XLMRobertaCLIP, return_transforms=False, return_tokenizer=False, tokenizer_padding="eos", dtype=torch.float32, device="cpu", **kwargs):
# init a model on device
with torch.device(device):
model = model_cls(**kwargs)
model = model_cls(dtype=dtype, **kwargs)
# set device
model = model.to(dtype=dtype, device=device)
output = (model,)
model = model.to(device=device)
output = (model,)
# init transforms
if return_transforms:
# mean and std
......@@ -395,23 +418,23 @@ class CLIPModel:
else:
self.checkpoint_path = checkpoint_path
logger.info(f"Loading weights from {self.checkpoint_path}")
# init model
self.model, self.transforms = clip_xlm_roberta_vit_h_14(
pretrained=False, return_transforms=True, return_tokenizer=False, dtype=dtype, device=device, quantized=self.quantized, quant_scheme=quant_scheme
)
self.model = self.model.eval().requires_grad_(False)
weight_dict = torch.load(self.checkpoint_path, map_location="cpu", weights_only=True)
keys = list(weight_dict.keys())
for key in keys:
if "textual" in key:
weight_dict.pop(key)
logger.info(f"Start Loading weights from {self.checkpoint_path}")
self.model.load_state_dict(weight_dict)
logger.info(f"End Loading weights from {self.checkpoint_path}")
def visual(self, videos, args):
if args.cpu_offload:
if hasattr(args, "cpu_offload") and args.cpu_offload:
self.to_cuda()
# preprocess
size = (self.model.image_size,) * 2
......@@ -422,7 +445,7 @@ class CLIPModel:
with torch.amp.autocast("cuda", dtype=self.dtype):
out = self.model.visual(videos, use_31_block=True)
if args.cpu_offload:
if hasattr(args, "cpu_offload") and args.cpu_offload:
self.to_cpu()
return out
......
import flash_attn
try:
import flash_attn
except ModuleNotFoundError:
flash_attn = None
import math
import torch
import torch.nn as nn
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment