Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
aece8f87
Commit
aece8f87
authored
Mar 29, 2025
by
Muyang Li
Committed by
Zhekai Zhang
Apr 01, 2025
Browse files
Fix the docker image; remove comfyui folder;
parent
cdf5a19b
Changes
30
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
134 additions
and
12 deletions
+134
-12
nunchaku/__version__.py
nunchaku/__version__.py
+1
-1
nunchaku/lora/flux/compose.py
nunchaku/lora/flux/compose.py
+0
-2
nunchaku/lora/flux/diffusers_converter.py
nunchaku/lora/flux/diffusers_converter.py
+1
-1
nunchaku/models/transformers/transformer_flux.py
nunchaku/models/transformers/transformer_flux.py
+41
-1
nunchaku/test.py
nunchaku/test.py
+2
-2
scripts/build_all_linux_wheels.sh
scripts/build_all_linux_wheels.sh
+14
-0
scripts/build_docker.sh
scripts/build_docker.sh
+40
-0
scripts/build_linux_wheel.sh
scripts/build_linux_wheel.sh
+14
-1
scripts/build_windows_wheels.ps1
scripts/build_windows_wheels.ps1
+19
-3
src/kernels/zgemm/gemm_w4a4_launch.cuh
src/kernels/zgemm/gemm_w4a4_launch.cuh
+2
-1
No files found.
nunchaku/__version__.py
View file @
aece8f87
__version__
=
"0.
1.5
"
__version__
=
"0.
2.0dev0
"
nunchaku/lora/flux/compose.py
View file @
aece8f87
...
@@ -28,8 +28,6 @@ def compose_lora(
...
@@ -28,8 +28,6 @@ def compose_lora(
composed
[
k
]
=
previous_tensor
+
v
*
strength
composed
[
k
]
=
previous_tensor
+
v
*
strength
else
:
else
:
assert
v
.
ndim
==
2
assert
v
.
ndim
==
2
if
"lora_A"
in
k
:
v
=
v
*
strength
if
".to_q."
in
k
or
".add_q_proj."
in
k
:
# qkv must all exist
if
".to_q."
in
k
or
".add_q_proj."
in
k
:
# qkv must all exist
if
"lora_B"
in
k
:
if
"lora_B"
in
k
:
continue
continue
...
...
nunchaku/lora/flux/diffusers_converter.py
View file @
aece8f87
...
@@ -13,7 +13,7 @@ def to_diffusers(input_lora: str | dict[str, torch.Tensor], output_path: str | N
...
@@ -13,7 +13,7 @@ def to_diffusers(input_lora: str | dict[str, torch.Tensor], output_path: str | N
if
isinstance
(
input_lora
,
str
):
if
isinstance
(
input_lora
,
str
):
tensors
=
load_state_dict_in_safetensors
(
input_lora
,
device
=
"cpu"
)
tensors
=
load_state_dict_in_safetensors
(
input_lora
,
device
=
"cpu"
)
else
:
else
:
tensors
=
input_lora
tensors
=
{
k
:
v
for
k
,
v
in
input_lora
.
items
()}
new_tensors
,
alphas
=
FluxLoraLoaderMixin
.
lora_state_dict
(
tensors
,
return_alphas
=
True
)
new_tensors
,
alphas
=
FluxLoraLoaderMixin
.
lora_state_dict
(
tensors
,
return_alphas
=
True
)
if
alphas
is
not
None
and
len
(
alphas
)
>
0
:
if
alphas
is
not
None
and
len
(
alphas
)
>
0
:
...
...
nunchaku/models/transformers/transformer_flux.py
View file @
aece8f87
...
@@ -7,7 +7,7 @@ from diffusers import FluxTransformer2DModel
...
@@ -7,7 +7,7 @@ from diffusers import FluxTransformer2DModel
from
diffusers.configuration_utils
import
register_to_config
from
diffusers.configuration_utils
import
register_to_config
from
huggingface_hub
import
utils
from
huggingface_hub
import
utils
from
packaging.version
import
Version
from
packaging.version
import
Version
from
safetensors.torch
import
load_file
from
safetensors.torch
import
load_file
,
save_file
from
torch
import
nn
from
torch
import
nn
from
.utils
import
get_precision
,
NunchakuModelLoaderMixin
,
pad_tensor
from
.utils
import
get_precision
,
NunchakuModelLoaderMixin
,
pad_tensor
...
@@ -233,6 +233,11 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
...
@@ -233,6 +233,11 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
self
.
_unquantized_part_loras
:
dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
_unquantized_part_loras
:
dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
_quantized_part_sd
:
dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
_quantized_part_sd
:
dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
_quantized_part_vectors
:
dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
_quantized_part_vectors
:
dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
_original_in_channels
=
in_channels
# Comfyui LoRA related
self
.
comfy_lora_meta_list
=
[]
self
.
comfy_lora_sd_list
=
[]
@
classmethod
@
classmethod
@
utils
.
validate_hf_hub_args
@
utils
.
validate_hf_hub_args
...
@@ -433,3 +438,38 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
...
@@ -433,3 +438,38 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
if
len
(
self
.
_quantized_part_vectors
)
>
0
:
if
len
(
self
.
_quantized_part_vectors
)
>
0
:
vector_dict
=
fuse_vectors
(
self
.
_quantized_part_vectors
,
self
.
_quantized_part_sd
,
strength
)
vector_dict
=
fuse_vectors
(
self
.
_quantized_part_vectors
,
self
.
_quantized_part_sd
,
strength
)
block
.
m
.
loadDict
(
vector_dict
,
True
)
block
.
m
.
loadDict
(
vector_dict
,
True
)
def
reset_x_embedder
(
self
):
# if change the model in channels, we need to update the x_embedder
if
self
.
_original_in_channels
!=
self
.
config
.
in_channels
:
assert
self
.
_original_in_channels
<
self
.
config
.
in_channels
old_module
=
self
.
x_embedder
new_module
=
nn
.
Linear
(
in_features
=
self
.
_original_in_channels
,
out_features
=
old_module
.
out_features
,
bias
=
old_module
.
bias
is
not
None
,
device
=
old_module
.
weight
.
device
,
dtype
=
old_module
.
weight
.
dtype
,
)
new_module
.
weight
.
data
.
copy_
(
old_module
.
weight
.
data
[:
new_module
.
out_features
,
:
new_module
.
in_features
])
self
.
_unquantized_part_sd
[
"x_embedder.weight"
]
=
new_module
.
weight
.
data
.
clone
()
if
new_module
.
bias
is
not
None
:
new_module
.
bias
.
data
.
zero_
()
new_module
.
bias
.
data
.
copy_
(
old_module
.
bias
.
data
[:
new_module
.
out_features
])
self
.
_unquantized_part_sd
[
"x_embedder.bias"
]
=
new_module
.
bias
.
data
.
clone
()
self
.
x_embedder
=
new_module
setattr
(
self
.
config
,
"in_channels"
,
self
.
_original_in_channels
)
def
reset_lora
(
self
):
unquantized_part_loras
=
{}
if
len
(
self
.
_unquantized_part_loras
)
>
0
or
len
(
unquantized_part_loras
)
>
0
:
self
.
_unquantized_part_loras
=
unquantized_part_loras
self
.
_update_unquantized_part_lora_params
(
1
)
state_dict
=
{
k
:
v
for
k
,
v
in
self
.
_quantized_part_sd
.
items
()
if
"lora"
in
k
}
quantized_part_vectors
=
{}
if
len
(
self
.
_quantized_part_vectors
)
>
0
or
len
(
quantized_part_vectors
)
>
0
:
self
.
_quantized_part_vectors
=
quantized_part_vectors
updated_vectors
=
fuse_vectors
(
quantized_part_vectors
,
self
.
_quantized_part_sd
,
1
)
state_dict
.
update
(
updated_vectors
)
self
.
transformer_blocks
[
0
].
m
.
loadDict
(
state_dict
,
True
)
self
.
reset_x_embedder
()
nunchaku/test.py
View file @
aece8f87
...
@@ -9,12 +9,12 @@ if __name__ == "__main__":
...
@@ -9,12 +9,12 @@ if __name__ == "__main__":
precision
=
"fp4"
if
sm
==
"120"
else
"int4"
precision
=
"fp4"
if
sm
==
"120"
else
"int4"
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
f
"mit-han-lab/svdq-
{
precision
}
-flux.1-schnell"
,
offload
=
True
,
precision
=
precision
f
"mit-han-lab/svdq-
{
precision
}
-flux.1-schnell"
,
offload
=
True
)
)
pipeline
=
FluxPipeline
.
from_pretrained
(
pipeline
=
FluxPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-schnell"
,
transformer
=
transformer
,
torch_dtype
=
torch
.
bfloat16
"black-forest-labs/FLUX.1-schnell"
,
transformer
=
transformer
,
torch_dtype
=
torch
.
bfloat16
)
)
pipeline
.
enable_sequential_cpu_offload
()
image
=
pipeline
(
image
=
pipeline
(
"A cat holding a sign that says hello world"
,
width
=
1024
,
height
=
1024
,
num_inference_steps
=
4
,
guidance_scale
=
0
"A cat holding a sign that says hello world"
,
width
=
1024
,
height
=
1024
,
num_inference_steps
=
4
,
guidance_scale
=
0
).
images
[
0
]
).
images
[
0
]
...
...
scripts/build_all_linux_wheels.sh
0 → 100644
View file @
aece8f87
#!/bin/bash
# Define the versions for Python, Torch, and CUDA
python_versions
=(
"3.10"
"3.11"
"3.12"
"3.13"
)
torch_versions
=(
"2.5"
"2.6"
)
cuda_versions
=(
"12.4"
)
# Loop through all combinations of Python, Torch, and CUDA versions
for
python_version
in
"
${
python_versions
[@]
}
"
;
do
for
torch_version
in
"
${
torch_versions
[@]
}
"
;
do
for
cuda_version
in
"
${
cuda_versions
[@]
}
"
;
do
bash scripts/build_linux_wheel.sh
"
$python_version
"
"
$torch_version
"
"
$cuda_version
"
done
done
done
\ No newline at end of file
scripts/build_docker.sh
0 → 100644
View file @
aece8f87
#!/bin/bash
PYTHON_VERSION
=
$1
TORCH_VERSION
=
$2
CUDA_VERSION
=
$3
NUNCHAKU_VERSION
=
$4
# Check if TORCH_VERSION is 2.5 or 2.6 and set the corresponding versions for TORCHVISION and TORCHAUDIO
if
[
"
$TORCH_VERSION
"
==
"2.5"
]
;
then
TORCHVISION_VERSION
=
"0.20"
TORCHAUDIO_VERSION
=
"2.5"
echo
"TORCH_VERSION is 2.5, setting TORCHVISION_VERSION to
$TORCHVISION_VERSION
and TORCHAUDIO_VERSION to
$TORCHAUDIO_VERSION
"
elif
[
"
$TORCH_VERSION
"
==
"2.6"
]
;
then
TORCHVISION_VERSION
=
"0.21"
TORCHAUDIO_VERSION
=
"2.6"
echo
"TORCH_VERSION is 2.6, setting TORCHVISION_VERSION to
$TORCHVISION_VERSION
and TORCHAUDIO_VERSION to
$TORCHAUDIO_VERSION
"
else
echo
"TORCH_VERSION is not 2.5 or 2.6. Exit."
exit
2
fi
if
[
"
$CUDA_VERSION
"
==
"12.8"
]
;
then
CUDA_IMAGE
=
"12.8.1-devel-ubuntu24.04"
echo
"CUDA_VERSION is 12.8, setting CUDA_IMAGE to
$CUDA_IMAGE
"
elif
[
"
$CUDA_VERSION
"
==
"12.4"
]
;
then
CUDA_IMAGE
=
"12.4.1-devel-ubuntu22.04"
echo
"CUDA_VERSION is 12.4, setting CUDA_IMAGE to
$CUDA_IMAGE
"
else
echo
"CUDA_VERSION is not 12.8 or 12.4. Exit."
exit
2
fi
docker build
--no-cache
\
--build-arg
PYTHON_VERSION
=
${
PYTHON_VERSION
}
\
--build-arg
CUDA_SHORT_VERSION
=
${
CUDA_VERSION
//.
}
\
--build-arg
CUDA_IMAGE
=
${
CUDA_IMAGE
}
\
--build-arg
TORCH_VERSION
=
${
TORCH_VERSION
}
\
--build-arg
TORCHVISION_VERSION
=
${
TORCHVISION_VERSION
}
\
--build-arg
TORCHAUDIO_VERSION
=
${
TORCHAUDIO_VERSION
}
\
-t
nunchaku:
${
NUNCHAKU_VERSION
}
-py
${
PYTHON_VERSION
}
-torch
${
TORCH_VERSION
}
-cuda
${
CUDA_VERSION
}
.
scripts/build_linux_wheel
s
.sh
→
scripts/build_linux_wheel.sh
View file @
aece8f87
...
@@ -7,6 +7,19 @@ CUDA_VERSION=$3
...
@@ -7,6 +7,19 @@ CUDA_VERSION=$3
MAX_JOBS
=
${
4
:-}
# optional
MAX_JOBS
=
${
4
:-}
# optional
PYTHON_ROOT_PATH
=
/opt/python/cp
${
PYTHON_VERSION
//.
}
-cp
${
PYTHON_VERSION
//.
}
PYTHON_ROOT_PATH
=
/opt/python/cp
${
PYTHON_VERSION
//.
}
-cp
${
PYTHON_VERSION
//.
}
# Check if TORCH_VERSION is 2.5 or 2.6 and set the corresponding versions for TORCHVISION and TORCHAUDIO
if
[
"
$TORCH_VERSION
"
==
"2.5"
]
;
then
TORCHVISION_VERSION
=
"0.20"
TORCHAUDIO_VERSION
=
"2.5"
echo
"TORCH_VERSION is 2.5, setting TORCHVISION_VERSION to
$TORCHVISION_VERSION
and TORCHAUDIO_VERSION to
$TORCHAUDIO_VERSION
"
elif
[
"
$TORCH_VERSION
"
==
"2.6"
]
;
then
TORCHVISION_VERSION
=
"0.21"
TORCHAUDIO_VERSION
=
"2.6"
echo
"TORCH_VERSION is 2.6, setting TORCHVISION_VERSION to
$TORCHVISION_VERSION
and TORCHAUDIO_VERSION to
$TORCHAUDIO_VERSION
"
else
echo
"TORCH_VERSION is not 2.5 or 2.6, no changes to versions."
fi
docker run
--rm
\
docker run
--rm
\
-v
"
$(
pwd
)
"
:/nunchaku
\
-v
"
$(
pwd
)
"
:/nunchaku
\
pytorch/manylinux-builder:cuda
${
CUDA_VERSION
}
\
pytorch/manylinux-builder:cuda
${
CUDA_VERSION
}
\
...
@@ -16,7 +29,7 @@ docker run --rm \
...
@@ -16,7 +29,7 @@ docker run --rm \
yum install -y devtoolset-11 &&
\
yum install -y devtoolset-11 &&
\
source scl_source enable devtoolset-11 &&
\
source scl_source enable devtoolset-11 &&
\
gcc --version && g++ --version &&
\
gcc --version && g++ --version &&
\
${
PYTHON_ROOT_PATH
}
/bin/pip install --no-cache-dir torch==
${
TORCH_VERSION
}
numpy
--index-url https://download.pytorch.org/whl/cu
${
CUDA_VERSION
//.
}
&&
\
${
PYTHON_ROOT_PATH
}
/bin/pip install --no-cache-dir torch==
${
TORCH_VERSION
}
torchvision==
${
TORCHVISION_VERSION
}
torchaudio==
${
TORCHAUDIO_VERSION
}
--index-url https://download.pytorch.org/whl/cu
${
CUDA_VERSION
//.
}
&&
\
${
PYTHON_ROOT_PATH
}
/bin/pip install build ninja wheel setuptools &&
\
${
PYTHON_ROOT_PATH
}
/bin/pip install build ninja wheel setuptools &&
\
export NUNCHAKU_INSTALL_MODE=ALL &&
\
export NUNCHAKU_INSTALL_MODE=ALL &&
\
export NUNCHAKU_BUILD_WHEELS=1 &&
\
export NUNCHAKU_BUILD_WHEELS=1 &&
\
...
...
scripts/build_windows_wheels.ps1
View file @
aece8f87
...
@@ -5,8 +5,23 @@ param (
...
@@ -5,8 +5,23 @@ param (
[
string
]
$MAX_JOBS
=
""
[
string
]
$MAX_JOBS
=
""
)
)
# Check if TORCH_VERSION is 2.5 or 2.6 and set the corresponding versions for TORCHVISION and TORCHAUDIO
if
(
$TORCH_VERSION
-eq
"2.5"
)
{
$TORCHVISION_VERSION
=
"0.20"
$TORCHAUDIO_VERSION
=
"2.5"
Write-Output
"TORCH_VERSION is 2.5, setting TORCHVISION_VERSION to
$TORCHVISION_VERSION
and TORCHAUDIO_VERSION to
$TORCHAUDIO_VERSION
"
}
elseif
(
$TORCH_VERSION
-eq
"2.6"
)
{
$TORCHVISION_VERSION
=
"0.21"
$TORCHAUDIO_VERSION
=
"2.6"
Write-Output
"TORCH_VERSION is 2.6, setting TORCHVISION_VERSION to
$TORCHVISION_VERSION
and TORCHAUDIO_VERSION to
$TORCHAUDIO_VERSION
"
}
else
{
Write-Output
"TORCH_VERSION is not 2.5 or 2.6, no changes to versions."
}
# Conda 环境名称
# Conda 环境名称
$ENV_NAME
=
"build_env_
$PYTHON_VERSION
"
$ENV_NAME
=
"build_env_
$PYTHON_VERSION
_$TORCH_VERSION
"
# 创建 Conda 环境
# 创建 Conda 环境
conda
create
-y
-n
$ENV_NAME
python
=
$PYTHON_VERSION
conda
create
-y
-n
$ENV_NAME
python
=
$PYTHON_VERSION
...
@@ -14,7 +29,7 @@ conda activate $ENV_NAME
...
@@ -14,7 +29,7 @@ conda activate $ENV_NAME
# 安装依赖
# 安装依赖
conda
install
-y
ninja
setuptools
wheel
pip
conda
install
-y
ninja
setuptools
wheel
pip
pip
install
--no-cache-dir
torch
==
$TORCH_VERSION
numpy
--index-url
"https://download.pytorch.org/whl/cu
$(
$CUDA_VERSION
.
Substring
(
0
,
2
)
)
/"
pip
install
--no-cache-dir
torch
==
$TORCH_VERSION
torchvision
==
$TORCHVISION_VERSION
torchaudio
==
$TORCHAUDIO_VERSION
--index-url
"https://download.pytorch.org/whl/cu
$(
$CUDA_VERSION
.
Substring
(
0
,
2
)
)
/"
# 设置环境变量
# 设置环境变量
$
env
:
NUNCHAKU_INSTALL_MODE
=
"ALL"
$
env
:
NUNCHAKU_INSTALL_MODE
=
"ALL"
...
@@ -22,11 +37,12 @@ $env:NUNCHAKU_BUILD_WHEELS="1"
...
@@ -22,11 +37,12 @@ $env:NUNCHAKU_BUILD_WHEELS="1"
$
env
:
MAX_JOBS
=
$MAX_JOBS
$
env
:
MAX_JOBS
=
$MAX_JOBS
# 进入当前脚本所在目录并构建 wheels
# 进入当前脚本所在目录并构建 wheels
Set-Location
-Path
$PSScriptRoot
Set-Location
-Path
"
$PSScriptRoot
\.."
if
(
Test-Path
"build"
)
{
Remove-Item
-Recurse
-Force
"build"
}
if
(
Test-Path
"build"
)
{
Remove-Item
-Recurse
-Force
"build"
}
python
-m
build
--wheel
--no-isolation
python
-m
build
--wheel
--no-isolation
# 退出 Conda 环境
# 退出 Conda 环境
conda
deactivate
conda
deactivate
conda
remove
-y
-n
$ENV_NAME
--all
Write-Output
"Build complete!"
Write-Output
"Build complete!"
src/kernels/zgemm/gemm_w4a4_launch.cuh
View file @
aece8f87
...
@@ -6,7 +6,8 @@ template<typename Config, bool USE_FP4>
...
@@ -6,7 +6,8 @@ template<typename Config, bool USE_FP4>
class
GEMM_W4A4_Launch
{
class
GEMM_W4A4_Launch
{
using
GEMM
=
GEMM_W4A4
<
Config
>
;
using
GEMM
=
GEMM_W4A4
<
Config
>
;
// using LoraRanks = std::integer_sequence<int, 0, 32>;
// using LoraRanks = std::integer_sequence<int, 0, 32>;
using
LoraRanks
=
std
::
integer_sequence
<
int
,
0
,
32
,
48
,
64
,
80
,
96
,
112
,
128
,
160
,
176
,
224
>
;
// using LoraRanks = std::integer_sequence<int, 0, 32, 48, 64, 80, 96, 112, 128, 160, 176, 224>;
using
LoraRanks
=
std
::
integer_sequence
<
int
,
0
,
32
,
48
,
64
,
80
,
96
,
112
,
128
,
144
,
160
,
176
,
192
,
208
,
224
>
;
// using LoraRanks = std::integer_sequence<int,
// using LoraRanks = std::integer_sequence<int,
// 0, 32, 48, 64, 80, 96, 112, 128, 144, 160,
// 0, 32, 48, 64, 80, 96, 112, 128, 144, 160,
// 176, 192, 208, 224, 240, 256, 272, 288, 304, 320,
// 176, 192, 208, 224, 240, 256, 272, 288, 304, 320,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment