Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
cc2a283a
Commit
cc2a283a
authored
Jul 21, 2025
by
gushiqiao
Committed by
GitHub
Jul 21, 2025
Browse files
Add torchao kernel
Dev quant
parents
29a90944
adf8df9d
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
140 additions
and
34 deletions
+140
-34
configs/quantization/wan_i2v_torchao.json
configs/quantization/wan_i2v_torchao.json
+21
-0
docs/EN/source/deploy_guides/deploy_local_windows.md
docs/EN/source/deploy_guides/deploy_local_windows.md
+0
-16
lightx2v/common/ops/mm/mm_weight.py
lightx2v/common/ops/mm/mm_weight.py
+35
-0
lightx2v/models/input_encoders/hf/q_linear.py
lightx2v/models/input_encoders/hf/q_linear.py
+54
-3
lightx2v/models/input_encoders/hf/t5/model.py
lightx2v/models/input_encoders/hf/t5/model.py
+9
-5
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
+9
-5
lightx2v/models/networks/wan/distill_model.py
lightx2v/models/networks/wan/distill_model.py
+4
-0
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+8
-5
No files found.
configs/quantization/wan_i2v_torchao.json
0 → 100755
View file @
cc2a283a
{
"infer_steps"
:
4
,
"target_video_length"
:
81
,
"target_height"
:
480
,
"target_width"
:
832
,
"self_attn_1_type"
:
"sage_attn2"
,
"cross_attn_1_type"
:
"sage_attn2"
,
"cross_attn_2_type"
:
"sage_attn2"
,
"seed"
:
42
,
"sample_guide_scale"
:
5
,
"sample_shift"
:
5
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
,
"mm_config"
:
{
"mm_type"
:
"W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao"
},
"t5_quantized"
:
true
,
"t5_quant_scheme"
:
"int8-torchao"
,
"clip_quantized"
:
true
,
"clip_quant_scheme"
:
"int8-torchao"
}
docs/EN/source/deploy_guides/deploy_local_windows.md
View file @
cc2a283a
...
@@ -19,22 +19,6 @@ This document provides detailed instructions for deploying LightX2V locally on W
...
@@ -19,22 +19,6 @@ This document provides detailed instructions for deploying LightX2V locally on W
-
**CUDA**
: 12.4 or higher version
-
**CUDA**
: 12.4 or higher version
-
**Dependencies**
: Refer to LightX2V project's requirements_win.txt
-
**Dependencies**
: Refer to LightX2V project's requirements_win.txt
### Installation Steps
1.
**Clone Project**
```
cmd
git clone https://github.com/ModelTC/LightX2V.git
cd LightX2V
```
2.
**Install Dependencies**
```
cmd
pip install -r requirements_win.txt
```
3.
**Download Models**
Refer to
[
Model Download Guide
](
../getting_started/quickstart.md
)
to download required models
## 🎯 Usage Methods
## 🎯 Usage Methods
### Method 1: Using Batch File Inference
### Method 1: Using Batch File Inference
...
...
lightx2v/common/ops/mm/mm_weight.py
View file @
cc2a283a
...
@@ -25,6 +25,11 @@ try:
...
@@ -25,6 +25,11 @@ try:
except
ImportError
:
except
ImportError
:
deep_gemm
=
None
deep_gemm
=
None
try
:
from
torchao.quantization.utils
import
quant_int8_per_token_matmul
,
quantize_activation_per_token_absmax
except
ModuleNotFoundError
:
quant_int8_per_token_matmul
,
quantize_activation_per_token_absmax
=
None
,
None
class
MMWeightTemplate
(
metaclass
=
ABCMeta
):
class
MMWeightTemplate
(
metaclass
=
ABCMeta
):
def
__init__
(
self
,
weight_name
,
bias_name
,
lazy_load
=
False
,
lazy_load_file
=
None
):
def
__init__
(
self
,
weight_name
,
bias_name
,
lazy_load
=
False
,
lazy_load_file
=
None
):
...
@@ -232,6 +237,9 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -232,6 +237,9 @@ class MMWeightQuantTemplate(MMWeightTemplate):
# =========================
# =========================
# act quant kernels
# act quant kernels
# =========================
# =========================
def
act_quant_int8_perchannel_sym_torchao
(
self
,
x
):
input_tensor_quant
,
input_tensor_scale
=
quantize_activation_per_token_absmax
(
x
)
return
input_tensor_quant
,
input_tensor_scale
def
act_quant_fp8_perchannel_sym_vllm
(
self
,
x
):
def
act_quant_fp8_perchannel_sym_vllm
(
self
,
x
):
input_tensor_quant
,
input_tensor_scale
=
ops
.
scaled_fp8_quant
(
x
,
None
,
scale_ub
=
None
,
use_per_token_if_dynamic
=
True
)
input_tensor_quant
,
input_tensor_scale
=
ops
.
scaled_fp8_quant
(
x
,
None
,
scale_ub
=
None
,
use_per_token_if_dynamic
=
True
)
...
@@ -624,6 +632,33 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
...
@@ -624,6 +632,33 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
return
output_tensor
return
output_tensor
@
MM_WEIGHT_REGISTER
(
"W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao"
)
class
MMWeightWint8channelAint8channeldynamicSglActVllm
(
MMWeightQuantTemplate
):
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao
Quant MM:
Weight: int8 perchannel sym
Act: int8 perchannel dynamic sym
Kernel: Torchao
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
lazy_load
=
False
,
lazy_load_file
=
None
):
super
().
__init__
(
weight_name
,
bias_name
,
lazy_load
,
lazy_load_file
)
self
.
load_func
=
self
.
load_int8_perchannel_sym
self
.
weight_need_transpose
=
True
self
.
act_quant_func
=
self
.
act_quant_int8_perchannel_sym_torchao
def
apply
(
self
,
input_tensor
):
input_tensor
=
input_tensor
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
output_tensor
=
quant_int8_per_token_matmul
(
input_tensor_quant
,
input_tensor_scale
,
self
.
weight
,
self
.
weight_scale
.
t
().
float
(),
output_dtype
=
torch
.
bfloat16
)
if
self
.
bias
is
not
None
:
output_tensor
=
output_tensor
+
self
.
bias
return
output_tensor
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
weight_dict
=
{
weight_dict
=
{
"xx.weight"
:
torch
.
randn
(
8192
,
4096
).
to
(
torch
.
float8_e4m3fn
),
"xx.weight"
:
torch
.
randn
(
8192
,
4096
).
to
(
torch
.
float8_e4m3fn
),
...
...
lightx2v/models/input_encoders/hf/q_linear.py
View file @
cc2a283a
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm
import
_custom_ops
as
ops
try
:
from
vllm
import
_custom_ops
as
ops
except
ModuleNotFoundError
:
ops
=
None
class
QuantLinearInt8
(
nn
.
Module
):
try
:
from
torchao.quantization.utils
import
quant_int8_per_token_matmul
,
quantize_activation_per_token_absmax
except
ModuleNotFoundError
:
quant_int8_per_token_matmul
,
quantize_activation_per_token_absmax
=
None
,
None
class
VllmQuantLinearInt8
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
out_features
,
bias
=
True
,
dtype
=
torch
.
bfloat16
):
def
__init__
(
self
,
in_features
,
out_features
,
bias
=
True
,
dtype
=
torch
.
bfloat16
):
super
().
__init__
()
super
().
__init__
()
self
.
in_features
=
in_features
self
.
in_features
=
in_features
...
@@ -54,7 +63,7 @@ class QuantLinearInt8(nn.Module):
...
@@ -54,7 +63,7 @@ class QuantLinearInt8(nn.Module):
return
self
return
self
class
QuantLinearFp8
(
nn
.
Module
):
class
Vllm
QuantLinearFp8
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
out_features
,
bias
=
True
,
dtype
=
torch
.
bfloat16
):
def
__init__
(
self
,
in_features
,
out_features
,
bias
=
True
,
dtype
=
torch
.
bfloat16
):
super
().
__init__
()
super
().
__init__
()
self
.
in_features
=
in_features
self
.
in_features
=
in_features
...
@@ -101,3 +110,45 @@ class QuantLinearFp8(nn.Module):
...
@@ -101,3 +110,45 @@ class QuantLinearFp8(nn.Module):
self
.
weight_scale
=
maybe_cast
(
self
.
weight_scale
)
self
.
weight_scale
=
maybe_cast
(
self
.
weight_scale
)
self
.
bias
=
maybe_cast
(
self
.
bias
)
self
.
bias
=
maybe_cast
(
self
.
bias
)
return
self
return
self
class
TorchaoQuantLinearInt8
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
out_features
,
bias
=
True
,
dtype
=
torch
.
bfloat16
):
super
().
__init__
()
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
register_buffer
(
"weight"
,
torch
.
empty
((
out_features
,
in_features
),
dtype
=
torch
.
int8
))
self
.
register_buffer
(
"weight_scale"
,
torch
.
empty
((
out_features
,
1
),
dtype
=
torch
.
float32
))
if
bias
:
self
.
register_buffer
(
"bias"
,
torch
.
empty
(
out_features
,
dtype
=
dtype
))
else
:
self
.
register_buffer
(
"bias"
,
None
)
def
act_quant_func
(
self
,
x
):
input_tensor_quant
,
input_tensor_scale
=
quantize_activation_per_token_absmax
(
x
)
return
input_tensor_quant
,
input_tensor_scale
def
forward
(
self
,
input_tensor
):
input_tensor
=
input_tensor
.
squeeze
(
0
)
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
output_tensor
=
quant_int8_per_token_matmul
(
input_tensor_quant
,
input_tensor_scale
,
self
.
weight
.
t
(),
self
.
weight_scale
.
t
().
float
(),
output_dtype
=
torch
.
bfloat16
)
if
self
.
bias
is
not
None
:
output_tensor
=
output_tensor
+
self
.
bias
return
output_tensor
.
unsqueeze
(
0
)
def
_apply
(
self
,
fn
):
for
module
in
self
.
children
():
module
.
_apply
(
fn
)
def
maybe_cast
(
t
):
if
t
is
not
None
and
t
.
device
!=
fn
(
t
).
device
:
return
fn
(
t
)
return
t
self
.
weight
=
maybe_cast
(
self
.
weight
)
self
.
weight_scale
=
maybe_cast
(
self
.
weight_scale
)
self
.
bias
=
maybe_cast
(
self
.
bias
)
return
self
lightx2v/models/input_encoders/hf/t5/model.py
View file @
cc2a283a
...
@@ -9,7 +9,7 @@ import torch.nn.functional as F
...
@@ -9,7 +9,7 @@ import torch.nn.functional as F
from
.tokenizer
import
HuggingfaceTokenizer
from
.tokenizer
import
HuggingfaceTokenizer
from
loguru
import
logger
from
loguru
import
logger
from
lightx2v.models.input_encoders.hf.q_linear
import
QuantLinearInt8
,
QuantLinearFp8
from
lightx2v.models.input_encoders.hf.q_linear
import
Vllm
QuantLinearInt8
,
Vllm
QuantLinearFp8
,
TorchaoQuantLinearInt8
__all__
=
[
__all__
=
[
...
@@ -83,9 +83,11 @@ class T5Attention(nn.Module):
...
@@ -83,9 +83,11 @@ class T5Attention(nn.Module):
if
quantized
:
if
quantized
:
if
quant_scheme
==
"int8"
:
if
quant_scheme
==
"int8"
:
linear_cls
=
QuantLinearInt8
linear_cls
=
Vllm
QuantLinearInt8
elif
quant_scheme
==
"fp8"
:
elif
quant_scheme
==
"fp8"
:
linear_cls
=
QuantLinearFp8
linear_cls
=
VllmQuantLinearFp8
elif
quant_scheme
==
"int8-torchao"
:
linear_cls
=
TorchaoQuantLinearInt8
else
:
else
:
linear_cls
=
nn
.
Linear
linear_cls
=
nn
.
Linear
...
@@ -144,9 +146,11 @@ class T5FeedForward(nn.Module):
...
@@ -144,9 +146,11 @@ class T5FeedForward(nn.Module):
if
quantized
:
if
quantized
:
if
quant_scheme
==
"int8"
:
if
quant_scheme
==
"int8"
:
linear_cls
=
QuantLinearInt8
linear_cls
=
Vllm
QuantLinearInt8
elif
quant_scheme
==
"fp8"
:
elif
quant_scheme
==
"fp8"
:
linear_cls
=
QuantLinearFp8
linear_cls
=
VllmQuantLinearFp8
elif
quant_scheme
==
"int8-torchao"
:
linear_cls
=
TorchaoQuantLinearInt8
else
:
else
:
linear_cls
=
nn
.
Linear
linear_cls
=
nn
.
Linear
# layers
# layers
...
...
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
View file @
cc2a283a
...
@@ -10,7 +10,7 @@ import torchvision.transforms as T
...
@@ -10,7 +10,7 @@ import torchvision.transforms as T
from
lightx2v.attentions
import
attention
from
lightx2v.attentions
import
attention
from
loguru
import
logger
from
loguru
import
logger
from
lightx2v.models.input_encoders.hf.q_linear
import
QuantLinearInt8
,
QuantLinearFp8
from
lightx2v.models.input_encoders.hf.q_linear
import
Vllm
QuantLinearInt8
,
Vllm
QuantLinearFp8
,
TorchaoQuantLinearInt8
from
einops
import
rearrange
from
einops
import
rearrange
from
torch
import
Tensor
from
torch
import
Tensor
from
transformers
import
CLIPVisionModel
from
transformers
import
CLIPVisionModel
...
@@ -63,9 +63,11 @@ class SelfAttention(nn.Module):
...
@@ -63,9 +63,11 @@ class SelfAttention(nn.Module):
# layers
# layers
if
quantized
:
if
quantized
:
if
quant_scheme
==
"int8"
:
if
quant_scheme
==
"int8"
:
linear_cls
=
QuantLinearInt8
linear_cls
=
Vllm
QuantLinearInt8
elif
quant_scheme
==
"fp8"
:
elif
quant_scheme
==
"fp8"
:
linear_cls
=
QuantLinearFp8
linear_cls
=
VllmQuantLinearFp8
elif
quant_scheme
==
"int8-torchao"
:
linear_cls
=
TorchaoQuantLinearInt8
else
:
else
:
linear_cls
=
nn
.
Linear
linear_cls
=
nn
.
Linear
...
@@ -135,9 +137,11 @@ class AttentionBlock(nn.Module):
...
@@ -135,9 +137,11 @@ class AttentionBlock(nn.Module):
# layers
# layers
if
quantized
:
if
quantized
:
if
quant_scheme
==
"int8"
:
if
quant_scheme
==
"int8"
:
linear_cls
=
QuantLinearInt8
linear_cls
=
Vllm
QuantLinearInt8
elif
quant_scheme
==
"fp8"
:
elif
quant_scheme
==
"fp8"
:
linear_cls
=
QuantLinearFp8
linear_cls
=
VllmQuantLinearFp8
elif
quant_scheme
==
"int8-torchao"
:
linear_cls
=
TorchaoQuantLinearInt8
else
:
else
:
linear_cls
=
nn
.
Linear
linear_cls
=
nn
.
Linear
...
...
lightx2v/models/networks/wan/distill_model.py
View file @
cc2a283a
...
@@ -11,6 +11,7 @@ from lightx2v.models.networks.wan.weights.transformer_weights import (
...
@@ -11,6 +11,7 @@ from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights
,
WanTransformerWeights
,
)
)
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
loguru
import
logger
class
WanDistillModel
(
WanModel
):
class
WanDistillModel
(
WanModel
):
...
@@ -31,8 +32,11 @@ class WanDistillModel(WanModel):
...
@@ -31,8 +32,11 @@ class WanDistillModel(WanModel):
return
weight_dict
return
weight_dict
ckpt_path
=
os
.
path
.
join
(
self
.
model_path
,
f
"
{
ckpt_folder
}
/distill_model.pt"
)
ckpt_path
=
os
.
path
.
join
(
self
.
model_path
,
f
"
{
ckpt_folder
}
/distill_model.pt"
)
if
os
.
path
.
exists
(
ckpt_path
):
if
os
.
path
.
exists
(
ckpt_path
):
logger
.
info
(
f
"Loading weights from
{
ckpt_path
}
"
)
weight_dict
=
torch
.
load
(
ckpt_path
,
map_location
=
"cpu"
,
weights_only
=
True
)
weight_dict
=
torch
.
load
(
ckpt_path
,
map_location
=
"cpu"
,
weights_only
=
True
)
print
(
weight_dict
.
keys
())
weight_dict
=
{
weight_dict
=
{
key
:
(
weight_dict
[
key
].
to
(
torch
.
bfloat16
)
if
use_bf16
or
all
(
s
not
in
key
for
s
in
skip_bf16
)
else
weight_dict
[
key
]).
pin_memory
().
to
(
self
.
device
)
for
key
in
weight_dict
.
keys
()
key
:
(
weight_dict
[
key
].
to
(
torch
.
bfloat16
)
if
use_bf16
or
all
(
s
not
in
key
for
s
in
skip_bf16
)
else
weight_dict
[
key
]).
pin_memory
().
to
(
self
.
device
)
for
key
in
weight_dict
.
keys
()
}
}
...
...
lightx2v/models/runners/wan/wan_runner.py
View file @
cc2a283a
...
@@ -57,17 +57,19 @@ class WanRunner(DefaultRunner):
...
@@ -57,17 +57,19 @@ class WanRunner(DefaultRunner):
if
clip_quantized
:
if
clip_quantized
:
clip_quant_scheme
=
self
.
config
.
get
(
"clip_quant_scheme"
,
None
)
clip_quant_scheme
=
self
.
config
.
get
(
"clip_quant_scheme"
,
None
)
assert
clip_quant_scheme
is
not
None
assert
clip_quant_scheme
is
not
None
tmp_clip_quant_scheme
=
clip_quant_scheme
.
split
(
"-"
)[
0
]
clip_quantized_ckpt
=
self
.
config
.
get
(
clip_quantized_ckpt
=
self
.
config
.
get
(
"clip_quantized_ckpt"
,
"clip_quantized_ckpt"
,
os
.
path
.
join
(
os
.
path
.
join
(
os
.
path
.
join
(
self
.
config
.
model_path
,
clip_quant_scheme
),
os
.
path
.
join
(
self
.
config
.
model_path
,
tmp_
clip_quant_scheme
),
f
"clip-
{
clip_quant_scheme
}
.pth"
,
f
"clip-
{
tmp_
clip_quant_scheme
}
.pth"
,
),
),
)
)
else
:
else
:
clip_quantized_ckpt
=
None
clip_quantized_ckpt
=
None
clip_quant_scheme
=
None
clip_quant_scheme
=
None
print
(
clip_quant_scheme
)
image_encoder
=
CLIPModel
(
image_encoder
=
CLIPModel
(
dtype
=
torch
.
float16
,
dtype
=
torch
.
float16
,
device
=
self
.
init_device
,
device
=
self
.
init_device
,
...
@@ -93,18 +95,19 @@ class WanRunner(DefaultRunner):
...
@@ -93,18 +95,19 @@ class WanRunner(DefaultRunner):
t5_quantized
=
self
.
config
.
get
(
"t5_quantized"
,
False
)
t5_quantized
=
self
.
config
.
get
(
"t5_quantized"
,
False
)
if
t5_quantized
:
if
t5_quantized
:
t5_quant_scheme
=
self
.
config
.
get
(
"t5_quant_scheme"
,
None
)
t5_quant_scheme
=
self
.
config
.
get
(
"t5_quant_scheme"
,
None
)
tmp_t5_quant_scheme
=
t5_quant_scheme
.
split
(
"-"
)[
0
]
assert
t5_quant_scheme
is
not
None
assert
t5_quant_scheme
is
not
None
t5_quantized_ckpt
=
self
.
config
.
get
(
t5_quantized_ckpt
=
self
.
config
.
get
(
"t5_quantized_ckpt"
,
"t5_quantized_ckpt"
,
os
.
path
.
join
(
os
.
path
.
join
(
os
.
path
.
join
(
self
.
config
.
model_path
,
t5_quant_scheme
),
os
.
path
.
join
(
self
.
config
.
model_path
,
tmp_
t5_quant_scheme
),
f
"models_t5_umt5-xxl-enc-
{
t5_quant_scheme
}
.pth"
,
f
"models_t5_umt5-xxl-enc-
{
tmp_
t5_quant_scheme
}
.pth"
,
),
),
)
)
else
:
else
:
t5_quant_scheme
=
None
t5_quant_scheme
=
None
t5_quantized_ckpt
=
None
t5_quantized_ckpt
=
None
print
(
t5_quant_scheme
)
text_encoder
=
T5EncoderModel
(
text_encoder
=
T5EncoderModel
(
text_len
=
self
.
config
[
"text_len"
],
text_len
=
self
.
config
[
"text_len"
],
dtype
=
torch
.
bfloat16
,
dtype
=
torch
.
bfloat16
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment