Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
51650a36
Commit
51650a36
authored
Feb 14, 2025
by
muyangli
Browse files
[major] lora conversion script released; upgrade the model; release flux.1-redux model
parent
26e1d00f
Changes
25
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
452 additions
and
9 deletions
+452
-9
app/sana/t2i/assets/frame2.css
app/sana/t2i/assets/frame2.css
+4
-1
examples/flux.1-fill-dev.py
examples/flux.1-fill-dev.py
+1
-1
examples/flux.1-schnell.py
examples/flux.1-schnell.py
+1
-6
nunchaku/__version__.py
nunchaku/__version__.py
+1
-1
nunchaku/convert_lora.py
nunchaku/convert_lora.py
+445
-0
No files found.
app/sana/t2i/assets/frame2.css
View file @
51650a36
.gradio-container
{
max-width
:
1200px
!important
}
.gradio-container
{
max-width
:
1200px
!important
;
margin
:
auto
;
/* Centers the element horizontally */
}
examples/flux.1-fill-dev.py
View file @
51650a36
...
@@ -12,7 +12,7 @@ pipe = FluxFillPipeline.from_pretrained(
...
@@ -12,7 +12,7 @@ pipe = FluxFillPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Fill-dev"
,
transformer
=
transformer
,
torch_dtype
=
torch
.
bfloat16
"black-forest-labs/FLUX.1-Fill-dev"
,
transformer
=
transformer
,
torch_dtype
=
torch
.
bfloat16
).
to
(
"cuda"
)
).
to
(
"cuda"
)
image
=
pipe
(
image
=
pipe
(
prompt
=
"A wooden basket of
several individual cartons of blueberries
."
,
prompt
=
"A wooden basket of
a cat
."
,
image
=
image
,
image
=
image
,
mask_image
=
mask
,
mask_image
=
mask
,
height
=
1024
,
height
=
1024
,
...
...
examples/flux.1-schnell.py
View file @
51650a36
...
@@ -8,11 +8,6 @@ pipeline = FluxPipeline.from_pretrained(
...
@@ -8,11 +8,6 @@ pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell"
,
transformer
=
transformer
,
torch_dtype
=
torch
.
bfloat16
"black-forest-labs/FLUX.1-schnell"
,
transformer
=
transformer
,
torch_dtype
=
torch
.
bfloat16
).
to
(
"cuda"
)
).
to
(
"cuda"
)
image
=
pipeline
(
image
=
pipeline
(
"A cat holding a sign that says hello world"
,
"A cat holding a sign that says hello world"
,
width
=
1024
,
height
=
1024
,
num_inference_steps
=
4
,
guidance_scale
=
0
width
=
1024
,
height
=
1024
,
num_inference_steps
=
4
,
guidance_scale
=
0
,
generator
=
torch
.
Generator
().
manual_seed
(
2333
),
).
images
[
0
]
).
images
[
0
]
image
.
save
(
"flux.1-schnell-int4.png"
)
image
.
save
(
"flux.1-schnell-int4.png"
)
nunchaku/__version__.py
View file @
51650a36
__version__
=
"0.0.2beta
3
"
__version__
=
"0.0.2beta
4
"
nunchaku/convert_lora.py
0 → 100644
View file @
51650a36
"""Convert LoRA weights to Nunchaku format."""
import
argparse
import
os
import
typing
as
tp
import
safetensors
import
safetensors.torch
import
torch
import
tqdm
# region utilities
def
ceil_divide
(
x
:
int
,
divisor
:
int
)
->
int
:
"""Ceiling division.
Args:
x (`int`):
dividend.
divisor (`int`):
divisor.
Returns:
`int`:
ceiling division result.
"""
return
(
x
+
divisor
-
1
)
//
divisor
def
pad
(
tensor
:
tp
.
Optional
[
torch
.
Tensor
],
divisor
:
int
|
tp
.
Sequence
[
int
],
dim
:
int
|
tp
.
Sequence
[
int
],
fill_value
:
float
|
int
=
0
,
)
->
torch
.
Tensor
|
None
:
if
isinstance
(
divisor
,
int
):
if
divisor
<=
1
:
return
tensor
elif
all
(
d
<=
1
for
d
in
divisor
):
return
tensor
if
tensor
is
None
:
return
None
shape
=
list
(
tensor
.
shape
)
if
isinstance
(
dim
,
int
):
assert
isinstance
(
divisor
,
int
)
shape
[
dim
]
=
ceil_divide
(
shape
[
dim
],
divisor
)
*
divisor
else
:
if
isinstance
(
divisor
,
int
):
divisor
=
[
divisor
]
*
len
(
dim
)
for
d
,
div
in
zip
(
dim
,
divisor
,
strict
=
True
):
shape
[
d
]
=
ceil_divide
(
shape
[
d
],
div
)
*
div
result
=
torch
.
full
(
shape
,
fill_value
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
result
[[
slice
(
0
,
extent
)
for
extent
in
tensor
.
shape
]]
=
tensor
return
result
def
update_state_dict
(
lhs
:
dict
[
str
,
torch
.
Tensor
],
rhs
:
dict
[
str
,
torch
.
Tensor
],
prefix
:
str
=
""
)
->
dict
[
str
,
torch
.
Tensor
]:
for
rkey
,
value
in
rhs
.
items
():
lkey
=
f
"
{
prefix
}
.
{
rkey
}
"
if
prefix
else
rkey
assert
lkey
not
in
lhs
,
f
"Key
{
lkey
}
already exists in the state dict."
lhs
[
lkey
]
=
value
return
lhs
def
load_state_dict_in_safetensors
(
path
:
str
,
device
:
str
|
torch
.
device
=
"cpu"
,
filter_prefix
:
str
=
""
)
->
dict
[
str
,
torch
.
Tensor
]:
"""Load state dict in SafeTensors.
Args:
path (`str`):
file path.
device (`str` | `torch.device`, optional, defaults to `"cpu"`):
device.
filter_prefix (`str`, optional, defaults to `""`):
filter prefix.
Returns:
`dict`:
loaded SafeTensors.
"""
state_dict
=
{}
with
safetensors
.
safe_open
(
path
,
framework
=
"pt"
,
device
=
device
)
as
f
:
for
k
in
f
.
keys
():
if
filter_prefix
and
not
k
.
startswith
(
filter_prefix
):
continue
state_dict
[
k
.
removeprefix
(
filter_prefix
)]
=
f
.
get_tensor
(
k
)
return
state_dict
# endregion
def
pack_lowrank_weight
(
weight
:
torch
.
Tensor
,
down
:
bool
)
->
torch
.
Tensor
:
"""Pack Low-Rank Weight.
Args:
weight (`torch.Tensor`):
low-rank weight tensor.
down (`bool`):
whether the weight is for down projection in low-rank branch.
"""
assert
weight
.
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
),
f
"Unsupported weight dtype
{
weight
.
dtype
}
."
lane_n
,
lane_k
=
1
,
2
# lane_n is always 1, lane_k is 32 bits // 16 bits = 2
n_pack_size
,
k_pack_size
=
2
,
2
num_n_lanes
,
num_k_lanes
=
8
,
4
frag_n
=
n_pack_size
*
num_n_lanes
*
lane_n
frag_k
=
k_pack_size
*
num_k_lanes
*
lane_k
weight
=
pad
(
weight
,
divisor
=
(
frag_n
,
frag_k
),
dim
=
(
0
,
1
))
if
down
:
r
,
c
=
weight
.
shape
r_frags
,
c_frags
=
r
//
frag_n
,
c
//
frag_k
weight
=
weight
.
view
(
r_frags
,
frag_n
,
c_frags
,
frag_k
).
permute
(
2
,
0
,
1
,
3
)
else
:
c
,
r
=
weight
.
shape
c_frags
,
r_frags
=
c
//
frag_n
,
r
//
frag_k
weight
=
weight
.
view
(
c_frags
,
frag_n
,
r_frags
,
frag_k
).
permute
(
0
,
2
,
1
,
3
)
weight
=
weight
.
reshape
(
c_frags
,
r_frags
,
n_pack_size
,
num_n_lanes
,
k_pack_size
,
num_k_lanes
,
lane_k
)
weight
=
weight
.
permute
(
0
,
1
,
3
,
5
,
2
,
4
,
6
).
contiguous
()
return
weight
.
view
(
c
,
r
)
def
unpack_lowrank_weight
(
weight
:
torch
.
Tensor
,
down
:
bool
)
->
torch
.
Tensor
:
"""Unpack Low-Rank Weight.
Args:
weight (`torch.Tensor`):
low-rank weight tensor.
down (`bool`):
whether the weight is for down projection in low-rank branch.
"""
c
,
r
=
weight
.
shape
assert
weight
.
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
),
f
"Unsupported weight dtype
{
weight
.
dtype
}
."
lane_n
,
lane_k
=
1
,
2
# lane_n is always 1, lane_k is 32 bits // 16 bits = 2
n_pack_size
,
k_pack_size
=
2
,
2
num_n_lanes
,
num_k_lanes
=
8
,
4
frag_n
=
n_pack_size
*
num_n_lanes
*
lane_n
frag_k
=
k_pack_size
*
num_k_lanes
*
lane_k
if
down
:
r_frags
,
c_frags
=
r
//
frag_n
,
c
//
frag_k
else
:
c_frags
,
r_frags
=
c
//
frag_n
,
r
//
frag_k
weight
=
weight
.
view
(
c_frags
,
r_frags
,
num_n_lanes
,
num_k_lanes
,
n_pack_size
,
k_pack_size
,
lane_k
)
weight
=
weight
.
permute
(
0
,
1
,
4
,
2
,
5
,
3
,
6
).
contiguous
()
weight
=
weight
.
view
(
c_frags
,
r_frags
,
frag_n
,
frag_k
)
if
down
:
weight
=
weight
.
permute
(
1
,
2
,
0
,
3
).
contiguous
().
view
(
r
,
c
)
else
:
weight
=
weight
.
permute
(
0
,
2
,
1
,
3
).
contiguous
().
view
(
c
,
r
)
return
weight
def
reorder_adanorm_lora_up
(
lora_up
:
torch
.
Tensor
,
splits
:
int
)
->
torch
.
Tensor
:
c
,
r
=
lora_up
.
shape
assert
c
%
splits
==
0
return
lora_up
.
view
(
splits
,
c
//
splits
,
r
).
transpose
(
0
,
1
).
reshape
(
c
,
r
).
contiguous
()
def
convert_to_nunchaku_transformer_block_lowrank_dict
(
# noqa: C901
orig_state_dict
:
dict
[
str
,
torch
.
Tensor
],
extra_lora_dict
:
dict
[
str
,
torch
.
Tensor
],
converted_block_name
:
str
,
candidate_block_name
:
str
,
local_name_map
:
dict
[
str
,
str
|
list
[
str
]],
convert_map
:
dict
[
str
,
str
],
default_dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
)
->
dict
[
str
,
torch
.
Tensor
]:
print
(
f
"Converting LoRA branch for block
{
candidate_block_name
}
..."
)
converted
:
dict
[
str
,
torch
.
Tensor
]
=
{}
for
converted_local_name
,
candidate_local_names
in
tqdm
.
tqdm
(
local_name_map
.
items
(),
desc
=
f
"Converting
{
candidate_block_name
}
"
,
dynamic_ncols
=
True
):
if
isinstance
(
candidate_local_names
,
str
):
candidate_local_names
=
[
candidate_local_names
]
# region original LoRA
orig_lora
=
(
orig_state_dict
.
get
(
f
"
{
converted_block_name
}
.
{
converted_local_name
}
.lora_down"
,
None
),
orig_state_dict
.
get
(
f
"
{
converted_block_name
}
.
{
converted_local_name
}
.lora_up"
,
None
),
)
if
orig_lora
[
0
]
is
None
or
orig_lora
[
1
]
is
None
:
assert
orig_lora
[
0
]
is
None
and
orig_lora
[
1
]
is
None
orig_lora
=
None
else
:
assert
orig_lora
[
0
]
is
not
None
and
orig_lora
[
1
]
is
not
None
orig_lora
=
(
unpack_lowrank_weight
(
orig_lora
[
0
],
down
=
True
),
unpack_lowrank_weight
(
orig_lora
[
1
],
down
=
False
),
)
print
(
f
" - Found
{
converted_block_name
}
LoRA of
{
converted_local_name
}
(rank:
{
orig_lora
[
0
].
shape
[
0
]
}
)"
)
# endregion
# region extra LoRA
extra_lora
=
[
(
extra_lora_dict
.
get
(
f
"
{
candidate_block_name
}
.
{
candidate_local_name
}
.lora_A.weight"
,
None
),
extra_lora_dict
.
get
(
f
"
{
candidate_block_name
}
.
{
candidate_local_name
}
.lora_B.weight"
,
None
),
)
for
candidate_local_name
in
candidate_local_names
]
if
any
(
lora
[
0
]
is
not
None
or
lora
[
1
]
is
not
None
for
lora
in
extra_lora
):
# merge extra LoRAs into one LoRA
if
len
(
extra_lora
)
>
1
:
first_lora
=
None
for
lora
in
extra_lora
:
if
lora
[
0
]
is
not
None
:
assert
lora
[
1
]
is
not
None
first_lora
=
lora
break
assert
first_lora
is
not
None
for
lora_index
in
range
(
len
(
extra_lora
)):
if
extra_lora
[
lora_index
][
0
]
is
None
:
assert
extra_lora
[
lora_index
][
1
]
is
None
extra_lora
[
lora_index
]
=
(
first_lora
[
0
].
clone
(),
torch
.
zeros_like
(
first_lora
[
1
]))
if
all
(
lora
[
0
].
equal
(
extra_lora
[
0
][
0
])
for
lora
in
extra_lora
):
# if all extra LoRAs have the same lora_down, use it
extra_lora_down
=
extra_lora
[
0
][
0
]
extra_lora_up
=
torch
.
cat
([
lora
[
1
]
for
lora
in
extra_lora
],
dim
=
0
)
else
:
extra_lora_down
=
torch
.
cat
([
lora
[
0
]
for
lora
in
extra_lora
],
dim
=
0
)
extra_lora_up_c
=
sum
(
lora
[
1
].
shape
[
0
]
for
lora
in
extra_lora
)
extra_lora_up_r
=
sum
(
lora
[
1
].
shape
[
1
]
for
lora
in
extra_lora
)
assert
extra_lora_up_r
==
extra_lora_down
.
shape
[
0
]
extra_lora_up
=
torch
.
zeros
((
extra_lora_up_c
,
extra_lora_up_r
),
dtype
=
extra_lora_down
.
dtype
)
c
,
r
=
0
,
0
for
lora
in
extra_lora
:
c_next
,
r_next
=
c
+
lora
[
1
].
shape
[
0
],
r
+
lora
[
1
].
shape
[
1
]
extra_lora_up
[
c
:
c_next
,
r
:
r_next
]
=
lora
[
1
]
c
,
r
=
c_next
,
r_next
else
:
extra_lora_down
,
extra_lora_up
=
extra_lora
[
0
]
extra_lora
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
=
(
extra_lora_down
,
extra_lora_up
)
print
(
f
" - Found
{
candidate_block_name
}
LoRA of
{
candidate_local_names
}
(rank:
{
extra_lora
[
0
].
shape
[
0
]
}
)"
)
else
:
extra_lora
=
None
# endregion
# region merge LoRA
if
orig_lora
is
None
:
if
extra_lora
is
None
:
lora
=
None
else
:
print
(
" - Using extra LoRA"
)
lora
=
(
extra_lora
[
0
].
to
(
default_dtype
),
extra_lora
[
1
].
to
(
default_dtype
))
elif
extra_lora
is
None
:
print
(
" - Using original LoRA"
)
lora
=
orig_lora
else
:
lora
=
(
torch
.
cat
([
orig_lora
[
0
],
extra_lora
[
0
].
to
(
orig_lora
[
0
].
dtype
)],
dim
=
0
),
torch
.
cat
([
orig_lora
[
1
],
extra_lora
[
1
].
to
(
orig_lora
[
1
].
dtype
)],
dim
=
1
),
)
print
(
f
" - Merging original and extra LoRA (rank:
{
lora
[
0
].
shape
[
0
]
}
)"
)
# endregion
if
lora
is
not
None
:
if
convert_map
[
converted_local_name
]
==
"adanorm_single"
:
update_state_dict
(
converted
,
{
"lora_down"
:
lora
[
0
],
"lora_up"
:
reorder_adanorm_lora_up
(
lora
[
1
],
splits
=
3
),
},
prefix
=
converted_local_name
,
)
elif
convert_map
[
converted_local_name
]
==
"adanorm_zero"
:
update_state_dict
(
converted
,
{
"lora_down"
:
lora
[
0
],
"lora_up"
:
reorder_adanorm_lora_up
(
lora
[
1
],
splits
=
6
),
},
prefix
=
converted_local_name
,
)
elif
convert_map
[
converted_local_name
]
==
"linear"
:
update_state_dict
(
converted
,
{
"lora_down"
:
pack_lowrank_weight
(
lora
[
0
],
down
=
True
),
"lora_up"
:
pack_lowrank_weight
(
lora
[
1
],
down
=
False
),
},
prefix
=
converted_local_name
,
)
return
converted
def
convert_to_nunchaku_flux_single_transformer_block_lowrank_dict
(
orig_state_dict
:
dict
[
str
,
torch
.
Tensor
],
extra_lora_dict
:
dict
[
str
,
torch
.
Tensor
],
converted_block_name
:
str
,
candidate_block_name
:
str
,
default_dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
)
->
dict
[
str
,
torch
.
Tensor
]:
if
f
"
{
candidate_block_name
}
.proj_out.lora_A.weight"
in
extra_lora_dict
:
assert
f
"
{
converted_block_name
}
.out_proj.qweight"
in
orig_state_dict
assert
f
"
{
converted_block_name
}
.mlp_fc2.qweight"
in
orig_state_dict
n1
=
orig_state_dict
[
f
"
{
converted_block_name
}
.out_proj.qweight"
].
shape
[
1
]
*
2
n2
=
orig_state_dict
[
f
"
{
converted_block_name
}
.mlp_fc2.qweight"
].
shape
[
1
]
*
2
lora_down
=
extra_lora_dict
[
f
"
{
candidate_block_name
}
.proj_out.lora_A.weight"
]
lora_up
=
extra_lora_dict
[
f
"
{
candidate_block_name
}
.proj_out.lora_B.weight"
]
assert
lora_down
.
shape
[
1
]
==
n1
+
n2
extra_lora_dict
[
f
"
{
candidate_block_name
}
.proj_out.linears.0.lora_A.weight"
]
=
lora_down
[:,
:
n1
].
clone
()
extra_lora_dict
[
f
"
{
candidate_block_name
}
.proj_out.linears.0.lora_B.weight"
]
=
lora_up
.
clone
()
extra_lora_dict
[
f
"
{
candidate_block_name
}
.proj_out.linears.1.lora_A.weight"
]
=
lora_down
[:,
n1
:].
clone
()
extra_lora_dict
[
f
"
{
candidate_block_name
}
.proj_out.linears.1.lora_B.weight"
]
=
lora_up
.
clone
()
extra_lora_dict
.
pop
(
f
"
{
candidate_block_name
}
.proj_out.lora_A.weight"
)
extra_lora_dict
.
pop
(
f
"
{
candidate_block_name
}
.proj_out.lora_B.weight"
)
return
convert_to_nunchaku_transformer_block_lowrank_dict
(
orig_state_dict
=
orig_state_dict
,
extra_lora_dict
=
extra_lora_dict
,
converted_block_name
=
converted_block_name
,
candidate_block_name
=
candidate_block_name
,
local_name_map
=
{
"norm.linear"
:
"norm.linear"
,
"qkv_proj"
:
[
"attn.to_q"
,
"attn.to_k"
,
"attn.to_v"
],
"norm_q"
:
"attn.norm_q"
,
"norm_k"
:
"attn.norm_k"
,
"out_proj"
:
"proj_out.linears.0"
,
"mlp_fc1"
:
"proj_mlp"
,
"mlp_fc2"
:
"proj_out.linears.1"
,
},
convert_map
=
{
"norm.linear"
:
"adanorm_single"
,
"qkv_proj"
:
"linear"
,
"out_proj"
:
"linear"
,
"mlp_fc1"
:
"linear"
,
"mlp_fc2"
:
"linear"
,
},
default_dtype
=
default_dtype
,
)
def
convert_to_nunchaku_flux_transformer_block_lowrank_dict
(
orig_state_dict
:
dict
[
str
,
torch
.
Tensor
],
extra_lora_dict
:
dict
[
str
,
torch
.
Tensor
],
converted_block_name
:
str
,
candidate_block_name
:
str
,
default_dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
)
->
dict
[
str
,
torch
.
Tensor
]:
return
convert_to_nunchaku_transformer_block_lowrank_dict
(
orig_state_dict
=
orig_state_dict
,
extra_lora_dict
=
extra_lora_dict
,
converted_block_name
=
converted_block_name
,
candidate_block_name
=
candidate_block_name
,
local_name_map
=
{
"norm1.linear"
:
"norm1.linear"
,
"norm1_context.linear"
:
"norm1_context.linear"
,
"qkv_proj"
:
[
"attn.to_q"
,
"attn.to_k"
,
"attn.to_v"
],
"qkv_proj_context"
:
[
"attn.add_q_proj"
,
"attn.add_k_proj"
,
"attn.add_v_proj"
],
"norm_q"
:
"attn.norm_q"
,
"norm_k"
:
"attn.norm_k"
,
"norm_added_q"
:
"attn.norm_added_q"
,
"norm_added_k"
:
"attn.norm_added_k"
,
"out_proj"
:
"attn.to_out.0"
,
"out_proj_context"
:
"attn.to_add_out"
,
"mlp_fc1"
:
"ff.net.0.proj"
,
"mlp_fc2"
:
"ff.net.2"
,
"mlp_context_fc1"
:
"ff_context.net.0.proj"
,
"mlp_context_fc2"
:
"ff_context.net.2"
,
},
convert_map
=
{
"norm1.linear"
:
"adanorm_zero"
,
"norm1_context.linear"
:
"adanorm_zero"
,
"qkv_proj"
:
"linear"
,
"qkv_proj_context"
:
"linear"
,
"out_proj"
:
"linear"
,
"out_proj_context"
:
"linear"
,
"mlp_fc1"
:
"linear"
,
"mlp_fc2"
:
"linear"
,
"mlp_context_fc1"
:
"linear"
,
"mlp_context_fc2"
:
"linear"
,
},
default_dtype
=
default_dtype
,
)
def
convert_to_nunchaku_flux_lowrank_dict
(
orig_state_dict
:
dict
[
str
,
torch
.
Tensor
],
extra_lora_dict
:
dict
[
str
,
torch
.
Tensor
],
default_dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
)
->
dict
[
str
,
torch
.
Tensor
]:
block_names
:
set
[
str
]
=
set
()
for
param_name
in
orig_state_dict
.
keys
():
if
param_name
.
startswith
((
"transformer_blocks."
,
"single_transformer_blocks."
)):
block_names
.
add
(
"."
.
join
(
param_name
.
split
(
"."
)[:
2
]))
block_names
=
sorted
(
block_names
,
key
=
lambda
x
:
(
x
.
split
(
"."
)[
0
],
int
(
x
.
split
(
"."
)[
-
1
])))
print
(
f
"Converting
{
len
(
block_names
)
}
transformer blocks..."
)
converted
:
dict
[
str
,
torch
.
Tensor
]
=
{}
for
block_name
in
block_names
:
if
block_name
.
startswith
(
"transformer_blocks"
):
convert_fn
=
convert_to_nunchaku_flux_transformer_block_lowrank_dict
else
:
convert_fn
=
convert_to_nunchaku_flux_single_transformer_block_lowrank_dict
update_state_dict
(
converted
,
convert_fn
(
orig_state_dict
=
orig_state_dict
,
extra_lora_dict
=
extra_lora_dict
,
converted_block_name
=
block_name
,
candidate_block_name
=
block_name
,
default_dtype
=
default_dtype
,
),
prefix
=
block_name
,
)
return
converted
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--quant-path"
,
type
=
str
,
required
=
True
,
help
=
"path to the quantized model safetensor file"
)
parser
.
add_argument
(
"--lora-path"
,
type
=
str
,
required
=
True
,
help
=
"path to LoRA weights safetensor file"
)
parser
.
add_argument
(
"--output-root"
,
type
=
str
,
default
=
""
,
help
=
"root to the output safetensor file"
)
parser
.
add_argument
(
"--lora-name"
,
type
=
str
,
default
=
None
,
help
=
"name of the LoRA weights"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
default
=
"bfloat16"
,
choices
=
[
"bfloat16"
,
"float16"
],
help
=
"data type of the converted weights"
,
)
args
=
parser
.
parse_args
()
if
not
args
.
output_root
:
# output to the parent directory of the quantized model safetensor file
args
.
output_root
=
os
.
path
.
dirname
(
args
.
quant_path
)
if
args
.
lora_name
is
None
:
assert
args
.
lora_path
is
not
None
,
"LoRA name or path must be provided"
lora_name
=
args
.
lora_path
.
rstrip
(
os
.
sep
).
split
(
os
.
sep
)[
-
1
].
replace
(
".safetensors"
,
""
)
print
(
f
"Lora name not provided, using
{
lora_name
}
as the LoRA name"
)
else
:
lora_name
=
args
.
lora_name
assert
lora_name
,
"LoRA name must be provided."
assert
args
.
quant_path
.
endswith
(
".safetensors"
),
"Quantized model must be a safetensor file"
assert
args
.
lora_path
.
endswith
(
".safetensors"
),
"LoRA weights must be a safetensor file"
orig_state_dict
=
load_state_dict_in_safetensors
(
args
.
quant_path
)
extra_lora_dict
=
load_state_dict_in_safetensors
(
args
.
lora_path
,
filter_prefix
=
"transformer."
)
converted
=
convert_to_nunchaku_flux_lowrank_dict
(
orig_state_dict
=
orig_state_dict
,
extra_lora_dict
=
extra_lora_dict
,
default_dtype
=
torch
.
bfloat16
if
args
.
dtype
==
"bfloat16"
else
torch
.
float16
,
)
os
.
makedirs
(
args
.
output_root
,
exist_ok
=
True
)
safetensors
.
torch
.
save_file
(
converted
,
os
.
path
.
join
(
args
.
output_root
,
f
"
{
lora_name
}
.safetensors"
))
print
(
f
"Saved LoRA weights to
{
args
.
output_root
}
."
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment