Commit 3e4fe79b authored by GoatWu's avatar GoatWu
Browse files

Merge branch 'main' of github.com:ModelTC/lightx2v into main

parents 8ddd33a5 d013cac7
import unittest
import torch
from lightx2v_kernel.gemm import cutlass_scaled_mxfp6_mxfp8_mm
from lightx2v_kernel.gemm import scaled_fp6_quant, scaled_fp8_quant
from lightx2v_kernel.gemm import scaled_mxfp6_quant, scaled_mxfp8_quant
from torch.nn.functional import linear
from lightx2v_kernel.utils import error, benchmark
......@@ -22,10 +22,10 @@ class TestQuantBF162MXFP6(unittest.TestCase):
for n in self.channels:
with self.subTest(shape=[m, k, n]):
activation = torch.randn(m, k, dtype=self.dtype, device=self.device)
activation_quant_pred, activation_scale_pred = scaled_fp8_quant(activation)
activation_quant_pred, activation_scale_pred = scaled_mxfp8_quant(activation)
weight = torch.randn(n, k, dtype=self.dtype, device=self.device)
weight_quant_pred, weight_scale_pred = scaled_fp6_quant(weight)
weight_quant_pred, weight_scale_pred = scaled_mxfp6_quant(weight)
bias = torch.rand(1, n, dtype=self.dtype, device=self.device) * 10
......@@ -44,7 +44,7 @@ class TestQuantBF162MXFP6(unittest.TestCase):
input = torch.randn(m, k, dtype=self.dtype, device=self.device)
shape = [m, k]
tflops = 2 * (m * k / 1024**4)
benchmark(scaled_fp6_quant, shape, tflops, 100, input)
benchmark(scaled_mxfp6_quant, shape, tflops, 100, input)
if __name__ == "__main__":
......
import torch
from lightx2v_kernel.gemm import scaled_fp6_quant
from lightx2v_kernel.gemm import scaled_mxfp6_quant
def quantize_fp6(x):
return scaled_fp6_quant(x)
return scaled_mxfp6_quant(x)
def test_memory_bandwidth(func, x, num_warmup=10, num_runs=100):
......
import torch
from lightx2v_kernel.gemm import scaled_fp8_quant, cutlass_scaled_mxfp8_mm
from lightx2v_kernel.gemm import scaled_mxfp8_quant, cutlass_scaled_mxfp8_mm
import time
......@@ -17,7 +17,7 @@ class MMWeightMxfp8:
@torch.no_grad()
def load_fp8_weight(self, weight, bias):
self.weight, self.weight_scale = scaled_fp8_quant(weight)
self.weight, self.weight_scale = scaled_mxfp8_quant(weight)
self.bias = bias
def set_alpha(self):
......@@ -25,7 +25,7 @@ class MMWeightMxfp8:
@torch.no_grad()
def act_quant_fp8(self, x):
return scaled_fp8_quant(x)
return scaled_mxfp8_quant(x)
def test_speed(m, k, n):
......
import unittest
import torch
from lightx2v_kernel.gemm import cutlass_scaled_mxfp8_mm
from lightx2v_kernel.gemm import scaled_fp8_quant
from lightx2v_kernel.gemm import scaled_mxfp8_quant
from torch.nn.functional import linear
from lightx2v_kernel.utils import error, benchmark
......@@ -22,10 +22,10 @@ class TestQuantBF162MXFP8(unittest.TestCase):
for n in self.channels:
with self.subTest(shape=[m, k, n]):
activation = torch.randn(m, k, dtype=self.dtype, device=self.device)
activation_quant_pred, activation_scale_pred = scaled_fp8_quant(activation)
activation_quant_pred, activation_scale_pred = scaled_mxfp8_quant(activation)
weight = torch.randn(n, k, dtype=self.dtype, device=self.device)
weight_quant_pred, weight_scale_pred = scaled_fp8_quant(weight)
weight_quant_pred, weight_scale_pred = scaled_mxfp8_quant(weight)
bias = torch.rand(1, n, dtype=self.dtype, device=self.device) * 10
......@@ -44,7 +44,7 @@ class TestQuantBF162MXFP8(unittest.TestCase):
input = torch.randn(m, k, dtype=self.dtype, device=self.device)
shape = [m, k]
tflops = 2 * (m * k / 1024**4)
benchmark(scaled_fp8_quant, shape, tflops, 100, input)
benchmark(scaled_mxfp8_quant, shape, tflops, 100, input)
if __name__ == "__main__":
......
import torch
from lightx2v_kernel.gemm import scaled_fp8_quant
from lightx2v_kernel.gemm import scaled_mxfp8_quant
def quantize_fp8(x):
return scaled_fp8_quant(x)
return scaled_mxfp8_quant(x)
def test_memory_bandwidth(func, x, num_warmup=10, num_runs=100):
......
import torch
from lightx2v_kernel.gemm import scaled_fp4_quant, cutlass_scaled_fp4_mm
from lightx2v_kernel.gemm import scaled_nvfp4_quant, cutlass_scaled_nvfp4_mm
FLOAT4_E2M1_MAX = 6.0
......@@ -110,8 +110,8 @@ def test_nvfp4_gemm(
print(f"b_global_scale : {b_global_scale}, {b_global_scale.shape}")
alpha = 1.0 / (a_global_scale * b_global_scale)
a_fp4, a_scale_interleaved = scaled_fp4_quant(a_dtype, a_global_scale)
b_fp4, b_scale_interleaved = scaled_fp4_quant(b_dtype, b_global_scale)
a_fp4, a_scale_interleaved = scaled_nvfp4_quant(a_dtype, a_global_scale)
b_fp4, b_scale_interleaved = scaled_nvfp4_quant(b_dtype, b_global_scale)
expected_out = get_ref_results(
a_fp4,
......@@ -130,7 +130,7 @@ def test_nvfp4_gemm(
print(f"alpha {alpha}, {alpha.shape}, {alpha.dtype}")
out = cutlass_scaled_fp4_mm(a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, bias)
out = cutlass_scaled_nvfp4_mm(a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, bias)
print(f"out : {out}, {out.shape}, {out.dtype}")
print(f"expected_out : {expected_out}, {expected_out.shape}, {expected_out.dtype}")
......
import torch
from lightx2v_kernel.gemm import scaled_fp4_quant, cutlass_scaled_fp4_mm
from lightx2v_kernel.gemm import scaled_nvfp4_quant, cutlass_scaled_nvfp4_mm
import time
......@@ -14,13 +14,13 @@ class MMWeightFp4:
@torch.no_grad()
def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = cutlass_scaled_fp4_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, alpha=self.alpha, bias=self.bias)
output_tensor = cutlass_scaled_nvfp4_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, alpha=self.alpha, bias=self.bias)
return output_tensor
@torch.no_grad()
def load_fp4_weight(self, weight, bias):
self.weight_global_scale = (2688.0 / torch.max(torch.abs(weight))).to(torch.float32)
self.weight, self.weight_scale = scaled_fp4_quant(weight, self.weight_global_scale)
self.weight, self.weight_scale = scaled_nvfp4_quant(weight, self.weight_global_scale)
self.bias = bias
def calibrate_x_absmax(self):
......@@ -30,7 +30,7 @@ class MMWeightFp4:
@torch.no_grad()
def act_quant_fp4(self, x):
return scaled_fp4_quant(x, self.input_global_scale)
return scaled_nvfp4_quant(x, self.input_global_scale)
def test_speed(m, k, n):
......
import torch
from lightx2v_kernel.gemm import scaled_fp4_quant, cutlass_scaled_fp4_mm
import time
from test_bench2 import MMWeightFp4
......
import torch
from lightx2v_kernel.gemm import cutlass_scaled_fp4_mm
from lightx2v_kernel.gemm import cutlass_scaled_nvfp4_mm
"""
......@@ -16,7 +16,7 @@ bias = None
def test_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha, bias):
output_tensor = cutlass_scaled_fp4_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha=alpha, bias=bias)
output_tensor = cutlass_scaled_nvfp4_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha=alpha, bias=bias)
return output_tensor
......
import torch
from lightx2v_kernel.gemm import scaled_fp4_quant
from lightx2v_kernel.gemm import scaled_nvfp4_quant
input_global_scale = torch.tensor(808.0, dtype=torch.float32).cuda()
def quantize_fp4(x):
return scaled_fp4_quant(x, input_global_scale)
return scaled_nvfp4_quant(x, input_global_scale)
def test_memory_bandwidth(func, x, num_warmup=10, num_runs=100):
......
#!/bin/bash
# set path and first
lightx2v_path=/path/to/lightx2v
model_path=/path/to/lightx2v/Wan2.1-I2V-14B-480P-Lightx2v
# model_path=/path/to/lightx2v/Wan2.1-I2V-14B-720P-Lightx2v
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python -m lightx2v.infer \
--model_cls wan2.1 \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/bench/lightx2v_5.json \
--prompt "A close-up cinematic view of a person cooking in a warm,sunlit kitchen, using a wooden spatula to stir-fry a colorful mix of freshvegetables—carrots, broccoli, and bell peppers—in a black frying pan on amodern induction stove. The scene captures the glistening texture of thevegetables, steam gently rising, and subtle reflections on the stove surface.In the background, soft-focus jars, fruits, and a window with natural daylightcreate a cozy atmosphere. The hand motions are smooth and rhythmic, with a realisticsense of motion blur and lighting." \
--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_2.jpg \
--save_video_path ${lightx2v_path}/save_results/lightx2v_5.mp4
#!/bin/bash
# set path and first
lightx2v_path=/path/to/lightx2v
model_path=/path/to/lightx2v/Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-Lightx2v
# model_path=/path/to/lightx2v/Wan2.1-I2V-14B-720P-StepDistill-CfgDistill-Lightx2v
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python -m lightx2v.infer \
--model_cls wan2.1_distill \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/bench/lightx2v_5_distill.json \
--prompt "A close-up cinematic view of a person cooking in a warm,sunlit kitchen, using a wooden spatula to stir-fry a colorful mix of freshvegetables—carrots, broccoli, and bell peppers—in a black frying pan on amodern induction stove. The scene captures the glistening texture of thevegetables, steam gently rising, and subtle reflections on the stove surface.In the background, soft-focus jars, fruits, and a window with natural daylightcreate a cozy atmosphere. The hand motions are smooth and rhythmic, with a realisticsense of motion blur and lighting." \
--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_2.jpg \
--save_video_path ${lightx2v_path}/save_results/lightx2v_5_distill.mp4
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -32,7 +32,7 @@ python -m lightx2v.infer \
--model_cls wan2.1 \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/changing_resolution/wan_t2v.json \
--config_json ${lightx2v_path}/configs/changing_resolution/wan_t2v_U.json \
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v_changing_resolution.mp4
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment