Commit d66b98de authored by gushiqiao's avatar gushiqiao
Browse files

Enable 720p model inference on low-spec GPUs/CPUs and accelerate T5/CLIP...

Enable 720p model inference on low-spec GPUs/CPUs and accelerate T5/CLIP quantized models with vLLM operators
parent d5b0e0a2
-----BEGIN CERTIFICATE-----
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
-----END CERTIFICATE-----
......@@ -7,30 +7,23 @@
"seed": 42,
"sample_guide_scale": 5,
"sample_shift": 5,
"enable_cfg": false,
"enable_cfg": true,
"cpu_offload": true,
"offload_granularity": "phase",
"t5_offload_granularity": "block",
"dit_quantized_ckpt": "/path/to/Wan2.1-I2V-480P-cfg-blocks-fp8/",
"dit_quantized_ckpt": "/path/to/dit_quant_model",
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm",
"weight_auto_quant": false
},
"t5_quantized": true,
"t5_quantized_ckpt": "/path/to/models_t5_umt5-xxl-enc-int8.pth",
"t5_quant_scheme": "int8",
"t5_quantized_ckpt": "/path/to/models_t5_umt5-xxl-enc-fp8.pth",
"t5_quant_scheme": "fp8",
"clip_quantized": true,
"clip_quantized_ckpt": "/path/to/clip-int8.pth",
"clip_quant_scheme": "int8",
"clip_quantized_ckpt": "/path/to/clip-fp8.pth",
"clip_quant_scheme": "fp8",
"use_tiling_vae": true,
"tiny_vae": true,
"tiny_vae_path": "/path/to/taew2_1.pth",
"lazy_load": true,
"feature_caching": "Tea",
"coefficients": [
[2.57151496e05, -3.54229917e04, 1.40286849e03, -1.35890334e01, 1.32517977e-01],
[-3.02331670e02, 2.23948934e02, -5.25463970e01, 5.87348440e00, -2.01973289e-01]
],
"use_ret_steps": true,
"teacache_thresh": 0.26
"lazy_load": true
}
......@@ -11,28 +11,21 @@
"cpu_offload": true,
"offload_granularity": "phase",
"t5_offload_granularity": "block",
"dit_quantized_ckpt": "/path/to/Wan2.1-I2V-480P-cfg-blocks-fp8/",
"dit_quantized_ckpt": "/path/to/dit_quant_model",
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm",
"weight_auto_quant": false
},
"t5_quantized": true,
"t5_quantized_ckpt": "/path/to/models_t5_umt5-xxl-enc-int8.pth",
"t5_quant_scheme": "int8",
"t5_quantized_ckpt": "/path/to/models_t5_umt5-xxl-enc-fp8.pth",
"t5_quant_scheme": "fp8",
"clip_quantized": true,
"clip_quantized_ckpt": "/path/to/clip-int8.pth",
"clip_quant_scheme": "int8",
"clip_quantized_ckpt": "/path/to/clip-fp8.pth",
"clip_quant_scheme": "fp8",
"use_tiling_vae": true,
"tiny_vae": true,
"tiny_vae_path": "/path/to/taew2_1.pth",
"lazy_load": true,
"feature_caching": "Tea",
"coefficients": [
[8.10705460e03, 2.13393892e03, -3.72934672e02, 1.66203073e01, -4.17769401e-02],
[-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
],
"use_ret_steps": true,
"teacache_thresh": 0.26,
"rotary_chunk": true,
"clean_cuda_cache": true
}
......@@ -49,6 +49,8 @@ def run_inference(
if torch_compile:
os.environ["ENABLE_GRAPH_MODE"] = "true"
os.environ["DTYPE"] = "BF16"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:true"
config = {
"infer_steps": infer_steps,
"target_video_length": num_frames,
......@@ -87,11 +89,11 @@ def run_inference(
"use_ret_steps": True,
"teacache_thresh": 0.26,
"t5_quantized": True,
"t5_quantized_ckpt": os.path.join(model_path, "models_t5_umt5-xxl-enc-int8.pth"),
"t5_quant_scheme": "int8",
"t5_quantized_ckpt": os.path.join(model_path, "models_t5_umt5-xxl-enc-fp8.pth"),
"t5_quant_scheme": "fp8",
"clip_quantized": True,
"clip_quantized_ckpt": os.path.join(model_path, "clip-int8.pth"),
"clip_quant_scheme": "int8",
"clip_quantized_ckpt": os.path.join(model_path, "clip-fp8.pth"),
"clip_quant_scheme": "fp8",
"use_tiling_vae": True,
"tiny_vae": use_tiny_vae,
"tiny_vae_path": tiny_vae_path if use_tiny_vae else None,
......@@ -348,4 +350,4 @@ with gr.Blocks(
)
if __name__ == "__main__":
demo.launch(share=False, server_port=7860, server_name="0.0.0.0")
demo.launch(share=True, server_port=7862, server_name="0.0.0.0")
......@@ -3,36 +3,37 @@ import torch.nn as nn
from abc import ABCMeta, abstractmethod
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
import torch.nn.functional as F
from loguru import logger
try:
from spas_sage_attn.autotune import SparseAttentionMeansim
except ImportError:
print("SparseAttentionMeansim not found, please install sparge first")
logger.info("SparseAttentionMeansim not found, please install sparge first")
SparseAttentionMeansim = None
try:
from flash_attn.flash_attn_interface import flash_attn_varlen_func
except ImportError:
print("flash_attn_varlen_func not found, please install flash_attn2 first")
logger.info("flash_attn_varlen_func not found, please install flash_attn2 first")
flash_attn_varlen_func = None
try:
from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3
except ImportError:
print("flash_attn_varlen_func_v3 not found, please install flash_attn3 first")
logger.info("flash_attn_varlen_func_v3 not found, please install flash_attn3 first")
flash_attn_varlen_func_v3 = None
if torch.cuda.get_device_capability(0) == (8, 9):
if torch.cuda.get_device_capability(0)[0] <= 8 and torch.cuda.get_device_capability(0)[1] <= 9:
try:
from sageattention import sageattn_qk_int8_pv_fp16_triton as sageattn
except ImportError:
print("sageattn not found, please install sageattention first")
logger.info("sageattn not found, please install sageattention first")
sageattn = None
else:
try:
from sageattention import sageattn
except ImportError:
print("sageattn not found, please install sageattention first")
logger.info("sageattn not found, please install sageattention first")
sageattn = None
......
......@@ -26,15 +26,57 @@ class QuantLinearInt8(nn.Module):
input_tensor_quant, input_tensor_scale, _ = ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True)
return input_tensor_quant, input_tensor_scale
def forward(self, x):
input_tensor_quant, input_tensor_scale = self.act_quant_func(x)
output_tensor = Q8F.linear.q8_linear(
def forward(self, input_tensor):
input_tensor = input_tensor.squeeze(0)
shape = (input_tensor.shape[0], self.weight.shape[0])
dtype = input_tensor.dtype
device = input_tensor.device
output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
torch.ops._C.cutlass_scaled_mm(
output_tensor,
input_tensor_quant,
self.weight.t(),
input_tensor_scale,
self.weight_scale.float(),
self.bias,
)
return output_tensor.unsqueeze(0)
class QuantLinearFp8(nn.Module):
def __init__(self, in_features, out_features, bias=True):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.float8_e4m3fn))
self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float32))
if bias:
self.register_buffer("bias", torch.empty(out_features, dtype=torch.float32))
else:
self.register_buffer("bias", None)
def act_quant_func(self, x):
input_tensor_quant, input_tensor_scale = ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True)
return input_tensor_quant, input_tensor_scale
def forward(self, input_tensor):
input_tensor = input_tensor.squeeze(0)
self.weight = self.weight.to(torch.float8_e4m3fn)
shape = (input_tensor.shape[0], self.weight.shape[0])
dtype = input_tensor.dtype
device = input_tensor.device
output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
torch.ops._C.cutlass_scaled_mm(
output_tensor,
input_tensor_quant,
self.weight,
self.bias.float() if self.bias is not None else None,
self.weight.t(),
input_tensor_scale,
self.weight_scale.float(),
fuse_gelu=False,
out_dtype=torch.bfloat16,
self.bias,
)
return output_tensor
return output_tensor.unsqueeze(0)
......@@ -9,7 +9,7 @@ import torch.nn.functional as F
from .tokenizer import HuggingfaceTokenizer
from loguru import logger
from lightx2v.models.input_encoders.hf.q_linear import QuantLinearInt8
from lightx2v.models.input_encoders.hf.q_linear import QuantLinearInt8, QuantLinearFp8
__all__ = [
......@@ -76,6 +76,8 @@ class T5Attention(nn.Module):
if quantized:
if quant_scheme == "int8":
linear_cls = QuantLinearInt8
elif quant_scheme == "fp8":
linear_cls = QuantLinearFp8
else:
linear_cls = nn.Linear
......@@ -131,6 +133,8 @@ class T5FeedForward(nn.Module):
if quantized:
if quant_scheme == "int8":
linear_cls = QuantLinearInt8
elif quant_scheme == "fp8":
linear_cls = QuantLinearFp8
else:
linear_cls = nn.Linear
# layers
......
......@@ -10,7 +10,7 @@ import torchvision.transforms as T
from lightx2v.attentions import attention
from loguru import logger
from lightx2v.models.input_encoders.hf.q_linear import QuantLinearInt8
from lightx2v.models.input_encoders.hf.q_linear import QuantLinearInt8, QuantLinearFp8
__all__ = [
......@@ -61,6 +61,8 @@ class SelfAttention(nn.Module):
if quantized:
if quant_scheme == "int8":
linear_cls = QuantLinearInt8
elif quant_scheme == "fp8":
linear_cls = QuantLinearFp8
else:
linear_cls = nn.Linear
......@@ -117,6 +119,8 @@ class AttentionBlock(nn.Module):
if quantized:
if quant_scheme == "int8":
linear_cls = QuantLinearInt8
elif quant_scheme == "fp8":
linear_cls = QuantLinearFp8
else:
linear_cls = nn.Linear
......
......@@ -93,13 +93,18 @@ class WanPreInfer:
if self.task == "i2v":
context_clip = weights.proj_0.apply(clip_fea)
if self.clean_cuda_cache:
del clip_fea
torch.cuda.empty_cache()
context_clip = weights.proj_1.apply(context_clip)
context_clip = torch.nn.functional.gelu(context_clip, approximate="none")
if self.clean_cuda_cache:
torch.cuda.empty_cache()
context_clip = weights.proj_3.apply(context_clip)
context_clip = weights.proj_4.apply(context_clip)
context = torch.concat([context_clip, context], dim=0)
if self.clean_cuda_cache:
del context_clip, clip_fea
del context_clip
torch.cuda.empty_cache()
return (
embed,
......
......@@ -274,6 +274,10 @@ class WanTransformerInfer(BaseTransformerInfer):
cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(q, k_lens=seq_lens)
if self.clean_cuda_cache:
del freqs_i, norm1_out, norm1_weight, norm1_bias
torch.cuda.empty_cache()
if not self.parallel_attention:
attn_out = weights.self_attn_1.apply(
q=q,
......@@ -298,7 +302,7 @@ class WanTransformerInfer(BaseTransformerInfer):
y = weights.self_attn_o.apply(attn_out)
if self.clean_cuda_cache:
del q, k, v, attn_out, freqs_i, norm1_out, norm1_weight, norm1_bias
del q, k, v, attn_out
torch.cuda.empty_cache()
return y
......
#!/bin/bash
# set path and first
lightx2v_path=/data/video_gen/lightx2v_latest/lightx2v
model_path=/data/video_gen/x2v_models/Wan2.1-I2V-14B-720P
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=true
export DTYPE=BF16 # remove this can get high quality video
python -m lightx2v.infer \
--model_cls wan2.1 \
--task i2v \
--model_path $model_path \
--config_json /data/video_gen/lightx2v_latest/lightx2v/configs/offload/disk/wan_i2v_phase_lazy_load_480p.json \
--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \
--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v.mp4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment