Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
e3597f7e
Commit
e3597f7e
authored
Nov 07, 2024
by
Muyang Li
Browse files
remove dev scripts
parent
8431762a
Changes
9
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
0 additions
and
1618 deletions
+0
-1618
dev-scripts/convert_nf4_flux.py
dev-scripts/convert_nf4_flux.py
+0
-191
dev-scripts/dump_flux.py
dev-scripts/dump_flux.py
+0
-552
dev-scripts/dump_flux_lora.py
dev-scripts/dump_flux_lora.py
+0
-311
dev-scripts/eval_perf.sh
dev-scripts/eval_perf.sh
+0
-46
dev-scripts/fakequant.py
dev-scripts/fakequant.py
+0
-135
dev-scripts/merge_non_lora_weight.py
dev-scripts/merge_non_lora_weight.py
+0
-27
dev-scripts/qmodule.py
dev-scripts/qmodule.py
+0
-177
dev-scripts/run_flux.py
dev-scripts/run_flux.py
+0
-129
dev-scripts/run_flux_generate.py
dev-scripts/run_flux_generate.py
+0
-50
No files found.
dev-scripts/convert_nf4_flux.py
deleted
100644 → 0
View file @
8431762a
"""
Utilities adapted from
* https://github.com/huggingface/transformers/blob/main/src/transformers/quantizers/quantizer_bnb_4bit.py
* https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/bitsandbytes.py
"""
import
torch
import
bitsandbytes
as
bnb
from
transformers.quantizers.quantizers_utils
import
get_module_from_name
import
torch.nn
as
nn
from
accelerate
import
init_empty_weights
def
_replace_with_bnb_linear
(
model
,
method
=
"nf4"
,
has_been_replaced
=
False
,
):
"""
Private method that wraps the recursion for module replacement.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
"""
for
name
,
module
in
model
.
named_children
():
if
isinstance
(
module
,
nn
.
Linear
):
with
init_empty_weights
():
in_features
=
module
.
in_features
out_features
=
module
.
out_features
if
method
==
"llm_int8"
:
model
.
_modules
[
name
]
=
bnb
.
nn
.
Linear8bitLt
(
in_features
,
out_features
,
module
.
bias
is
not
None
,
has_fp16_weights
=
False
,
threshold
=
6.0
,
)
has_been_replaced
=
True
else
:
model
.
_modules
[
name
]
=
bnb
.
nn
.
Linear4bit
(
in_features
,
out_features
,
module
.
bias
is
not
None
,
compute_dtype
=
torch
.
bfloat16
,
compress_statistics
=
False
,
quant_type
=
"nf4"
,
)
has_been_replaced
=
True
# Store the module class in case we need to transpose the weight later
model
.
_modules
[
name
].
source_cls
=
type
(
module
)
# Force requires grad to False to avoid unexpected errors
model
.
_modules
[
name
].
requires_grad_
(
False
)
if
len
(
list
(
module
.
children
()))
>
0
:
_
,
has_been_replaced
=
_replace_with_bnb_linear
(
module
,
has_been_replaced
=
has_been_replaced
,
)
# Remove the last key for recursion
return
model
,
has_been_replaced
def
check_quantized_param
(
model
,
param_name
:
str
,
)
->
bool
:
module
,
tensor_name
=
get_module_from_name
(
model
,
param_name
)
if
isinstance
(
module
.
_parameters
.
get
(
tensor_name
,
None
),
bnb
.
nn
.
Params4bit
):
# Add here check for loaded components' dtypes once serialization is implemented
return
True
elif
isinstance
(
module
,
bnb
.
nn
.
Linear4bit
)
and
tensor_name
==
"bias"
:
# bias could be loaded by regular set_module_tensor_to_device() from accelerate,
# but it would wrongly use uninitialized weight there.
return
True
else
:
return
False
def
create_quantized_param
(
model
,
param_value
:
"torch.Tensor"
,
param_name
:
str
,
target_device
:
"torch.device"
,
state_dict
=
None
,
unexpected_keys
=
None
,
pre_quantized
=
False
):
module
,
tensor_name
=
get_module_from_name
(
model
,
param_name
)
if
tensor_name
not
in
module
.
_parameters
:
raise
ValueError
(
f
"
{
module
}
does not have a parameter or a buffer named
{
tensor_name
}
."
)
old_value
=
getattr
(
module
,
tensor_name
)
if
tensor_name
==
"bias"
:
if
param_value
is
None
:
new_value
=
old_value
.
to
(
target_device
)
else
:
new_value
=
param_value
.
to
(
target_device
)
new_value
=
torch
.
nn
.
Parameter
(
new_value
,
requires_grad
=
old_value
.
requires_grad
)
module
.
_parameters
[
tensor_name
]
=
new_value
return
if
not
isinstance
(
module
.
_parameters
[
tensor_name
],
bnb
.
nn
.
Params4bit
):
raise
ValueError
(
"this function only loads `Linear4bit components`"
)
if
(
old_value
.
device
==
torch
.
device
(
"meta"
)
and
target_device
not
in
[
"meta"
,
torch
.
device
(
"meta"
)]
and
param_value
is
None
):
raise
ValueError
(
f
"
{
tensor_name
}
is on the meta device, we need a `value` to put in on
{
target_device
}
."
)
if
pre_quantized
:
if
(
param_name
+
".quant_state.bitsandbytes__fp4"
not
in
state_dict
)
and
(
param_name
+
".quant_state.bitsandbytes__nf4"
not
in
state_dict
):
raise
ValueError
(
f
"Supplied state dict for
{
param_name
}
does not contain `bitsandbytes__*` and possibly other `quantized_stats` components."
)
quantized_stats
=
{}
for
k
,
v
in
state_dict
.
items
():
# `startswith` to counter for edge cases where `param_name`
# substring can be present in multiple places in the `state_dict`
if
param_name
+
"."
in
k
and
k
.
startswith
(
param_name
):
quantized_stats
[
k
]
=
v
if
unexpected_keys
is
not
None
and
k
in
unexpected_keys
:
unexpected_keys
.
remove
(
k
)
new_value
=
bnb
.
nn
.
Params4bit
.
from_prequantized
(
data
=
param_value
,
quantized_stats
=
quantized_stats
,
requires_grad
=
False
,
device
=
target_device
,
)
else
:
new_value
=
param_value
.
to
(
"cpu"
)
kwargs
=
old_value
.
__dict__
new_value
=
bnb
.
nn
.
Params4bit
(
new_value
,
requires_grad
=
False
,
**
kwargs
).
to
(
target_device
)
print
(
f
"
{
param_name
}
: new_value.quant_type=
{
new_value
.
quant_type
}
quant_state=
{
new_value
.
quant_state
}
storage=
{
new_value
.
quant_storage
}
blocksize=
{
new_value
.
blocksize
}
"
)
state
=
new_value
.
quant_state
print
(
f
" -- state.code=
{
state
.
code
}
dtype=
{
state
.
dtype
}
blocksize=
{
state
.
blocksize
}
"
)
module
.
_parameters
[
tensor_name
]
=
new_value
# generate.py
# from huggingface_hub import hf_hub_download
# from accelerate.utils import set_module_tensor_to_device, compute_module_sizes
# from accelerate import init_empty_weights
# from diffusers.loaders.single_file_utils import convert_flux_transformer_checkpoint_to_diffusers
# from convert_nf4_flux import _replace_with_bnb_linear, create_quantized_param, check_quantized_param
# from diffusers import FluxTransformer2DModel, FluxPipeline
# import safetensors.torch
# import gc
# import torch
# dtype = torch.bfloat16
# ckpt_path = hf_hub_download("black-forest-labs/flux.1-dev", filename="flux1-dev.safetensors")
# original_state_dict = safetensors.torch.load_file(ckpt_path)
# converted_state_dict = convert_flux_transformer_checkpoint_to_diffusers(original_state_dict)
# del original_state_dict
# gc.collect()
# with init_empty_weights():
# config = FluxTransformer2DModel.load_config("black-forest-labs/flux.1-dev", subfolder="transformer")
# model = FluxTransformer2DModel.from_config(config).to(dtype)
# _replace_with_bnb_linear(model, "nf4")
# for param_name, param in converted_state_dict.items():
# param = param.to(dtype)
# if not check_quantized_param(model, param_name):
# set_module_tensor_to_device(model, param_name, device=0, value=param)
# else:
# create_quantized_param(model, param, param_name, target_device=0)
# del converted_state_dict
# gc.collect()
# print(compute_module_sizes(model)[""] / 1024 / 1204)
# pipe = FluxPipeline.from_pretrained("black-forest-labs/flux.1-dev", transformer=model, torch_dtype=dtype)
# pipe.enable_model_cpu_offload()
# prompt = "A mystic cat with a sign that says hello world!"
# image = pipe(prompt, guidance_scale=3.5, num_inference_steps=50, generator=torch.manual_seed(0)).images[0]
# image.save("flux-nf4-dev.png")
# model.push_to_hub("sayakpaul/flux.1-dev-nf4")
\ No newline at end of file
dev-scripts/dump_flux.py
deleted
100644 → 0
View file @
8431762a
This diff is collapsed.
Click to expand it.
dev-scripts/dump_flux_lora.py
deleted
100644 → 0
View file @
8431762a
import
torch
import
safetensors
import
torch.nn.functional
as
F
from
dump_flux
import
DeepCompressorModel
,
TensorDict
,
pack_wscales
,
pack_lora
,
merge_dict
,
unsmooth
from
typing
import
Optional
Lora
=
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
def
load_svdq_lora
(
path
:
str
,
lora_path
:
str
)
->
DeepCompressorModel
:
result
=
DeepCompressorModel
(
model
=
torch
.
load
(
f
"
{
path
}
/model.pt"
,
map_location
=
"cpu"
),
smooth
=
torch
.
load
(
f
"
{
path
}
/smooth.pt"
,
map_location
=
"cpu"
),
branch
=
torch
.
load
(
f
"
{
path
}
/branch.pt"
,
map_location
=
"cpu"
),
lora
=
{}
)
with
safetensors
.
safe_open
(
lora_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
f
:
for
k
in
f
.
keys
():
prefix
=
"transformer."
if
k
.
startswith
(
prefix
):
result
.
lora
[
k
.
removeprefix
(
prefix
)]
=
f
.
get_tensor
(
k
)
dtype
=
next
(
iter
(
result
.
branch
.
values
()))[
"a.weight"
].
dtype
for
k
,
v
in
result
.
lora
.
items
():
if
v
.
dtype
!=
dtype
:
print
(
f
"Convert lora weight
{
k
}
from
{
v
.
dtype
}
to
{
dtype
}
"
)
result
.
lora
[
k
]
=
v
.
to
(
dtype
)
# for k, v in result.lora.items():
# v.fill_(0)
return
result
# q/k/v [3072, ...] -> qkv [3072 * 3, ...]
def
extend_qkv
(
input
:
torch
.
Tensor
,
id
:
int
)
->
torch
.
Tensor
:
oc
,
ic
=
input
.
shape
tmp
=
torch
.
zeros
([
oc
*
3
,
ic
],
dtype
=
input
.
dtype
,
device
=
input
.
device
)
tmp
[
id
*
oc
:(
id
+
1
)
*
oc
,
...]
=
input
return
tmp
def
merge_lora
(
inputs
:
list
[
Lora
])
->
Optional
[
Lora
]:
if
len
(
inputs
)
==
0
:
return
None
lora_downs
=
[
x
[
0
]
for
x
in
inputs
]
lora_ups
=
[
x
[
1
]
for
x
in
inputs
]
lora_down
=
torch
.
cat
(
lora_downs
,
dim
=
0
)
lora_up
=
torch
.
cat
(
lora_ups
,
dim
=
1
)
return
(
lora_down
,
lora_up
)
def
merge_lora_qkv
(
inputs
:
list
[
Lora
])
->
list
[
Lora
]:
if
len
(
inputs
)
==
0
:
return
[]
for
x
in
inputs
:
if
not
x
[
0
].
equal
(
inputs
[
0
][
0
]):
return
inputs
lora_down
=
inputs
[
0
][
0
]
lora_ups
=
[
x
[
1
]
for
x
in
inputs
]
lora_up
=
torch
.
sum
(
torch
.
stack
(
lora_ups
),
dim
=
0
).
to
(
lora_down
.
dtype
)
return
[(
lora_down
,
lora_up
)]
def
dump_lora
(
lora_down
:
Optional
[
torch
.
Tensor
],
lora_up
:
Optional
[
torch
.
Tensor
])
->
TensorDict
:
if
lora_down
is
None
:
return
{}
rank
,
ic
=
lora_down
.
shape
oc
=
lora_up
.
shape
[
0
]
assert
lora_up
.
shape
==
(
oc
,
rank
)
if
rank
%
16
!=
0
:
rank_pad
=
(
rank
+
16
-
1
)
//
16
*
16
tmp_down
=
torch
.
zeros
([
rank_pad
,
ic
],
dtype
=
lora_down
.
dtype
,
device
=
lora_down
.
device
)
tmp_up
=
torch
.
zeros
([
oc
,
rank_pad
],
dtype
=
lora_down
.
dtype
,
device
=
lora_down
.
device
)
tmp_down
[:
rank
,
...]
=
lora_down
tmp_up
[...,
:
rank
]
=
lora_up
lora_down
=
tmp_down
lora_up
=
tmp_up
print
(
f
"Pad lora rank from
{
rank
}
to
{
rank_pad
}
"
)
lora_down
=
pack_lora
(
lora_down
.
transpose
(
0
,
1
),
is_lora_down
=
True
)
lora_up
=
pack_lora
(
lora_up
,
is_lora_down
=
False
)
tensors
=
{}
tensors
[
"lora_down"
]
=
lora_down
tensors
[
"lora_up"
]
=
lora_up
return
tensors
def
get_original_lora
(
qmodel
:
DeepCompressorModel
,
key_branch
:
str
,
key_smooth
:
Optional
[
str
])
->
Lora
:
dtype
=
qmodel
.
branch
[
key_branch
][
"a.weight"
].
dtype
smooth
=
qmodel
.
smooth
[
key_smooth
].
to
(
dtype
).
float
()
if
key_smooth
else
None
return
(
unsmooth
(
qmodel
.
branch
[
key_branch
][
"a.weight"
],
smooth
),
qmodel
.
branch
[
key_branch
][
"b.weight"
]
)
def
dump_linear_lora
(
qmodel
:
DeepCompressorModel
,
key_lora
:
str
,
key_branch
:
str
,
key_smooth
:
str
,
shift_bias
:
bool
=
False
,
key_bias
:
Optional
[
str
]
=
None
,
range_ic
:
slice
=
slice
(
None
,
None
,
None
))
->
TensorDict
:
lora_original
=
get_original_lora
(
qmodel
,
key_branch
,
key_smooth
)
if
f
"
{
key_lora
}
.lora_A.weight"
in
qmodel
.
lora
:
# lora_down = qmodel.lora[f"{key}.lora_A.weight"][..., range_ic]
# lora_up = qmodel.lora[f"{key}.lora_B.weight"]
lora_new
=
(
qmodel
.
lora
[
f
"
{
key_lora
}
.lora_A.weight"
][...,
range_ic
],
qmodel
.
lora
[
f
"
{
key_lora
}
.lora_B.weight"
]
)
lora_down
,
lora_up
=
merge_lora
([
lora_original
,
lora_new
])
rank
,
ic
=
lora_down
.
shape
oc
=
lora_up
.
shape
[
0
]
assert
lora_up
.
shape
==
(
oc
,
rank
)
print
(
f
"linear at
{
key_lora
}
has rank
{
rank
}
"
)
tensors
=
dump_lora
(
lora_down
,
lora_up
)
if
shift_bias
and
False
:
# no longer need shift bias
if
key_bias
is
None
:
key_bias
=
f
"
{
key_branch
}
.bias"
if
key_bias
in
qmodel
.
model
:
bias
=
qmodel
.
model
[
key_bias
]
print
(
f
"linear at
{
key_lora
}
apply shift_bias from original bias at
{
key_bias
}
"
)
else
:
bias
=
torch
.
zeros
([
oc
],
dtype
=
lora_up
.
dtype
,
device
=
lora_up
.
device
)
print
(
f
"linear at
{
key_lora
}
apply shift_bias from empty original bias"
)
shift
=
torch
.
empty
([
ic
],
dtype
=
lora_down
.
dtype
,
device
=
lora_down
.
device
)
shift
=
shift
.
fill_
(
0.171875
)
delta
=
F
.
linear
(
F
.
linear
(
shift
,
lora_new
[
0
]),
lora_new
[
1
])
print
(
f
"shift_bias delta =
{
delta
}
"
)
bias
-=
delta
tensors
[
"bias"
]
=
pack_wscales
(
bias
[...,
None
])[
0
]
return
tensors
else
:
print
(
f
"linear at
{
key_lora
}
use original lora"
)
return
dump_lora
(
*
lora_original
)
def
dump_qkv_proj_svdq_lora
(
qmodel
:
DeepCompressorModel
,
key_qkv
:
tuple
[
str
,
str
,
str
],
key_smooth
:
str
,
key_smooth_out
:
str
)
->
TensorDict
:
dtype
=
qmodel
.
branch
[
key_smooth
][
"a.weight"
].
dtype
smooth_out
=
qmodel
.
smooth
[
key_smooth_out
].
to
(
dtype
).
float
()
lora_original
=
get_original_lora
(
qmodel
,
key_smooth
,
key_smooth
)
loras
=
[]
for
i
in
range
(
3
):
key
=
key_qkv
[
i
]
if
f
"
{
key
}
.lora_A.weight"
in
qmodel
.
lora
:
lora_down
=
qmodel
.
lora
[
f
"
{
key
}
.lora_A.weight"
]
lora_up
=
qmodel
.
lora
[
f
"
{
key
}
.lora_B.weight"
]
if
i
==
2
:
lora_up
=
(
lora_up
/
smooth_out
[...,
None
]).
to
(
lora_up
.
dtype
)
loras
.
append
((
lora_down
,
extend_qkv
(
lora_up
,
i
)))
# print(loras)
lora_down
,
lora_up
=
merge_lora
([
lora_original
,
*
merge_lora_qkv
(
loras
)])
print
(
f
"qkv_proj at
{
key_smooth
}
has rank
{
lora_down
.
shape
[
0
]
}
"
)
return
dump_lora
(
lora_down
,
lora_up
)
def
dump_transformer_svdq_lora
(
qmodel
:
DeepCompressorModel
,
layer_id
:
int
)
->
TensorDict
:
tensors
=
{}
def
reorder_adanorm_linear
(
weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
oc
,
ic
=
weight
.
shape
assert
oc
%
6
==
0
return
weight
.
reshape
(
6
,
oc
//
6
,
ic
).
transpose
(
0
,
1
).
reshape
(
oc
,
ic
).
contiguous
()
def
linear
(
key
:
str
,
**
kwargs
):
key_lora
=
key
key_branch
=
kwargs
.
pop
(
"key_branch"
,
key_lora
)
key_smooth
=
kwargs
.
pop
(
"key_smooth"
,
key_branch
)
return
dump_linear_lora
(
qmodel
,
key_lora
,
key_branch
,
key_smooth
,
**
kwargs
)
prefix
=
f
"transformer_blocks.
{
layer_id
}
"
if
f
"
{
prefix
}
.norm1.linear.lora_A.weight"
in
qmodel
.
lora
:
lora_down
=
qmodel
.
lora
[
f
"
{
prefix
}
.norm1.linear.lora_A.weight"
]
lora_up
=
qmodel
.
lora
[
f
"
{
prefix
}
.norm1.linear.lora_B.weight"
]
tensors
[
f
"norm1.linear.lora_down"
]
=
lora_down
tensors
[
f
"norm1.linear.lora_up"
]
=
reorder_adanorm_linear
(
lora_up
)
if
f
"
{
prefix
}
.norm1_context.linear.lora_A.weight"
in
qmodel
.
lora
:
lora_down
=
qmodel
.
lora
[
f
"
{
prefix
}
.norm1_context.linear.lora_A.weight"
]
lora_up
=
qmodel
.
lora
[
f
"
{
prefix
}
.norm1_context.linear.lora_B.weight"
]
tensors
[
f
"norm1_context.linear.lora_down"
]
=
lora_down
tensors
[
f
"norm1_context.linear.lora_up"
]
=
reorder_adanorm_linear
(
lora_up
)
merge_dict
(
tensors
,
dump_qkv_proj_svdq_lora
(
qmodel
,
(
f
"
{
prefix
}
.attn.to_q"
,
f
"
{
prefix
}
.attn.to_k"
,
f
"
{
prefix
}
.attn.to_v"
),
f
"
{
prefix
}
.attn.to_q"
,
f
"
{
prefix
}
.attn.to_out.0"
),
"qkv_proj."
)
merge_dict
(
tensors
,
dump_qkv_proj_svdq_lora
(
qmodel
,
(
f
"
{
prefix
}
.attn.add_q_proj"
,
f
"
{
prefix
}
.attn.add_k_proj"
,
f
"
{
prefix
}
.attn.add_v_proj"
),
f
"
{
prefix
}
.attn.add_k_proj"
,
f
"
{
prefix
}
.attn.to_out.0"
),
"qkv_proj_context."
)
merge_dict
(
tensors
,
linear
(
f
"
{
prefix
}
.attn.to_out.0"
,
key_smooth
=
None
),
"out_proj."
)
merge_dict
(
tensors
,
linear
(
f
"
{
prefix
}
.attn.to_add_out"
,
key_smooth
=
None
),
"out_proj_context."
)
merge_dict
(
tensors
,
linear
(
f
"
{
prefix
}
.ff.net.0.proj"
),
"mlp_fc1."
)
merge_dict
(
tensors
,
linear
(
f
"
{
prefix
}
.ff.net.2"
,
key_branch
=
f
"
{
prefix
}
.ff.net.2.linear"
,
shift_bias
=
True
),
"mlp_fc2."
)
merge_dict
(
tensors
,
linear
(
f
"
{
prefix
}
.ff_context.net.0.proj"
),
"mlp_context_fc1."
)
merge_dict
(
tensors
,
linear
(
f
"
{
prefix
}
.ff_context.net.2"
,
key_branch
=
f
"
{
prefix
}
.ff_context.net.2.linear"
,
shift_bias
=
True
),
"mlp_context_fc2."
)
return
tensors
def
dump_single_transformer_svdq_lora
(
qmodel
:
DeepCompressorModel
,
layer_id
:
int
)
->
TensorDict
:
tensors
=
{}
def
reorder_adanorm_linear
(
weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
oc
,
ic
=
weight
.
shape
assert
oc
%
3
==
0
return
weight
.
reshape
(
3
,
oc
//
3
,
ic
).
transpose
(
0
,
1
).
reshape
(
oc
,
ic
).
contiguous
()
def
linear
(
key
:
str
,
**
kwargs
):
key_lora
=
key
key_branch
=
kwargs
.
pop
(
"key_branch"
,
key_lora
)
key_smooth
=
kwargs
.
pop
(
"key_smooth"
,
key_branch
)
return
dump_linear_lora
(
qmodel
,
key_lora
,
key_branch
,
key_smooth
,
**
kwargs
)
prefix
=
f
"single_transformer_blocks.
{
layer_id
}
"
if
f
"
{
prefix
}
.norm.linear.lora_A.weight"
in
qmodel
.
lora
:
lora_down
=
qmodel
.
lora
[
f
"
{
prefix
}
.norm.linear.lora_A.weight"
]
lora_up
=
qmodel
.
lora
[
f
"
{
prefix
}
.norm.linear.lora_B.weight"
]
tensors
[
f
"norm.linear.lora_down"
]
=
lora_down
tensors
[
f
"norm.linear.lora_up"
]
=
reorder_adanorm_linear
(
lora_up
)
merge_dict
(
tensors
,
dump_qkv_proj_svdq_lora
(
qmodel
,
(
f
"
{
prefix
}
.attn.to_q"
,
f
"
{
prefix
}
.attn.to_k"
,
f
"
{
prefix
}
.attn.to_v"
),
f
"
{
prefix
}
.attn.to_q"
,
f
"
{
prefix
}
.proj_out.linears.0"
),
"qkv_proj."
)
merge_dict
(
tensors
,
linear
(
f
"
{
prefix
}
.proj_mlp"
,
key_smooth
=
f
"
{
prefix
}
.attn.to_q"
),
"mlp_fc1."
)
# TODO
out_dim
=
3072
merge_dict
(
tensors
,
linear
(
f
"
{
prefix
}
.proj_out"
,
key_branch
=
f
"
{
prefix
}
.proj_out.linears.0"
,
key_smooth
=
None
,
range_ic
=
slice
(
0
,
out_dim
)),
"out_proj."
)
merge_dict
(
tensors
,
linear
(
f
"
{
prefix
}
.proj_out"
,
key_branch
=
f
"
{
prefix
}
.proj_out.linears.1.linear"
,
shift_bias
=
True
,
range_ic
=
slice
(
out_dim
,
None
)),
"mlp_fc2."
)
return
tensors
@
torch
.
inference_mode
()
def
dump_flux_svdq_lora
(
qmodel
:
DeepCompressorModel
,
**
kwargs
)
->
TensorDict
:
tensors
=
{}
for
i
in
range
(
19
):
merge_dict
(
tensors
,
dump_transformer_svdq_lora
(
qmodel
,
i
,
**
kwargs
),
f
"transformer_blocks.
{
i
}
."
)
for
i
in
range
(
38
):
merge_dict
(
tensors
,
dump_single_transformer_svdq_lora
(
qmodel
,
i
,
**
kwargs
),
f
"single_transformer_blocks.
{
i
}
."
)
return
tensors
if
__name__
==
"__main__"
:
lora_name
=
"realism"
if
lora_name
==
"sketch"
:
qmodel
=
load_svdq_lora
(
"model-dev"
,
"../third_party/FLUX.1-dev-LoRA-Collections/sketch.safetensors"
)
elif
lora_name
==
"realism"
:
qmodel
=
load_svdq_lora
(
"model-dev"
,
"../third_party/FLUX.1-dev-LoRA-Collections/realism.safetensors"
)
elif
lora_name
==
"anime"
:
qmodel
=
load_svdq_lora
(
"model-dev"
,
"../third_party/sonny-anime-fixed/araminta_k_sonnyanime_fluxd_fixed.safetensors"
)
elif
lora_name
==
"ghibsky"
:
qmodel
=
load_svdq_lora
(
"model-dev"
,
"../third_party/flux-ghibsky-illustration/lora.safetensors"
)
elif
lora_name
==
"yarn"
:
qmodel
=
load_svdq_lora
(
"model-dev"
,
"../third_party/yarn_art_Flux_LoRA/pytorch_lora_weights.safetensors"
)
elif
lora_name
==
"sketch2image"
:
qmodel
=
load_svdq_lora
(
"model-dev"
,
"sketch2image.safetensors"
)
else
:
raise
NotImplementedError
tensors
=
dump_flux_svdq_lora
(
qmodel
)
for
k
,
v
in
tensors
.
items
():
assert
not
v
.
isnan
().
any
()
assert
not
v
.
isinf
().
any
()
safetensors
.
torch
.
save_file
(
tensors
,
f
"/tmp/flux-lora-
{
lora_name
}
-bf16.safetensors"
)
dev-scripts/eval_perf.sh
deleted
100755 → 0
View file @
8431762a
#!/bin/bash
rundir
=
$(
date
+
"run-
$(
hostname
-s
)
-%Y%m%d-%H%M%S"
)
mkdir
-p
$rundir
function
run
()
{
echo
config
=
$config
echo
args
=
$@
python3 run_flux.py
--steps
4
"
$@
"
>
>(
tee
$rundir
/stdout-s4-
$config
.log
)
2>
>(
tee
$rundir
/stderr-s4-
$config
.log
)
python3 run_flux.py
--steps
25
"
$@
"
>
>(
tee
$rundir
/stdout-s25-
$config
.log
)
2>
>(
tee
$rundir
/stderr-s25-
$config
.log
)
python3 run_flux.py
--steps
50
"
$@
"
>
>(
tee
$rundir
/stdout-s50-
$config
.log
)
2>
>(
tee
$rundir
/stderr-s50-
$config
.log
)
if
[
$?
-eq
0
]
;
then
nsys profile
--cuda-memory-usage
true
-o
$rundir
/report-
$config
.nsys-rep python3 run_flux.py
--steps
4
"
$@
"
fi
}
config
=
bf16-compile
run
--config
bf16
--compile
config
=
bf16-t5-compile
run
--config
bf16-t5
--compile
config
=
int8dq-compile
run
--config
bf16
--torchao
--compile
config
=
int8dq-t5-compile
run
--config
bf16-t5
--torchao
--compile
config
=
int8dq-nocompile
run
--config
bf16
--torchao
config
=
int8dq-t5-nocompile
run
--config
bf16-t5
--torchao
for
cfg
in
svdq svdq-t5 w4a4 w4a4-t5 bf16 bf16-t5 nf4 nf4-t5
;
do
config
=
$cfg
run
--config
$cfg
config
=
$cfg
-ol1
run
--config
$cfg
--offload
1
config
=
$cfg
-ol2
run
--config
$cfg
--offload
2
done
\ No newline at end of file
dev-scripts/fakequant.py
deleted
100644 → 0
View file @
8431762a
import
torch
from
torch.nn
import
functional
as
F
from
dump_flux
import
group_scale
def
compare
(
ref
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
refname
:
str
,
vname
:
str
,
list_diff
:
bool
=
False
):
print
(
f
"== COMPARE v=
{
vname
}
vs ref=
{
refname
}
"
)
diff
=
v
-
ref
print
(
f
" - diff =
{
diff
}
"
)
if
list_diff
:
print
(
f
" - diffs at
{
diff
.
nonzero
()
}
"
)
mse
=
diff
.
square
().
mean
()
print
(
f
" - mse =
{
mse
}
"
)
nmse
=
mse
/
ref
.
square
().
mean
()
print
(
f
" - nmse =
{
nmse
}
"
)
print
(
f
" - mean(v/ref)=
{
v
.
mean
()
}
/
{
ref
.
mean
()
}
"
)
print
(
f
" - var(v/ref)=
{
v
.
var
()
}
/
{
ref
.
var
()
}
"
)
print
(
f
"== "
)
print
()
def
print_debug_results
(
debug_results
:
dict
[
str
,
torch
.
Tensor
],
is_ref
:
bool
=
False
):
ref
=
'REF'
if
is_ref
else
''
for
k
,
v
in
debug_results
.
items
():
has_nan
=
v
.
isnan
().
any
()
has_inf
=
v
.
isinf
().
any
()
if
v
.
dtype
.
is_floating_point
:
print
(
f
"
{
ref
}
{
k
}
:
{
v
.
shape
}
(
{
v
.
dtype
}
) has_nan=
{
has_nan
}
has_inf=
{
has_inf
}
max=
{
v
.
max
()
}
min=
{
v
.
min
()
}
mean=
{
v
.
mean
()
}
var=
{
v
.
var
()
}
"
)
else
:
print
(
f
"
{
ref
}
{
k
}
:
{
v
.
shape
}
(
{
v
.
dtype
}
)"
)
if
has_nan
:
cnt
=
v
.
isnan
().
count_nonzero
()
print
(
f
"
{
ref
}
--
{
cnt
}
(
{
cnt
/
v
.
numel
()
*
100
}
%) nans at
{
v
.
isnan
().
nonzero
()[
0
:
10
]
}
"
)
if
has_inf
:
cnt
=
v
.
isinf
().
count_nonzero
()
print
(
f
"
{
ref
}
--
{
cnt
}
(
{
cnt
/
v
.
numel
()
*
100
}
%) infs at
{
v
.
isinf
().
nonzero
()[
0
:
10
]
}
"
)
print
(
f
"
{
ref
}
--
{
v
}
"
)
print
()
def
fakequant
(
act
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
group_size
:
int
=
64
,
force_cuda
:
bool
=
False
,
):
oc
,
ic
=
weight
.
shape
batch_size
=
act
.
shape
[
0
]
assert
act
.
shape
[
1
]
==
ic
# [oc, ic // group_size]
wscales
=
group_scale
(
weight
,
num_bits
=
4
,
group_size
=
group_size
)
qweight
=
weight
.
reshape
(
oc
,
ic
//
group_size
,
group_size
).
to
(
dtype
=
torch
.
float32
)
/
wscales
[...,
None
]
qweight
=
qweight
.
round
().
clamp
(
-
8
,
7
)
qweight_i
=
qweight
.
int
()
qweight
=
qweight
*
wscales
[...,
None
]
qweight
=
qweight
.
to
(
weight
.
dtype
)
qweight
=
qweight
.
reshape
(
oc
,
ic
)
# print(f"qweight = {qweight}")
print_debug_results
({
"qweight"
:
qweight
})
# [batch_size, ic // group_size]
ascales
=
group_scale
(
act
,
num_bits
=
4
,
group_size
=
group_size
).
to
(
dtype
=
weight
.
dtype
)
qact
=
act
.
reshape
(
batch_size
,
ic
//
group_size
,
group_size
).
to
(
dtype
=
torch
.
float32
)
/
ascales
[...,
None
]
qact
=
qact
.
round
().
clamp
(
-
8
,
7
)
qact_i
=
qact
.
int
()
print_debug_results
({
"qact_i"
:
qact_i
})
qact
=
qact
*
ascales
[...,
None
]
qact
=
qact
.
to
(
act
.
dtype
)
qact
=
qact
.
reshape
(
batch_size
,
ic
)
# print(f"qact = {qact}")
print_debug_results
({
"qact"
:
qact
})
outref_q
=
F
.
linear
(
qact
.
to
(
qweight
.
dtype
),
qweight
,
bias
)
# print(f"outref_q = {outref_q}")
print_debug_results
({
"outref_q"
:
outref_q
})
###
if
force_cuda
:
qweight_i
=
qweight_i
.
to
(
"cuda"
)
qact_i
=
qact_i
.
to
(
"cuda"
)
wscales
=
wscales
.
to
(
"cuda"
)
ascales
=
ascales
.
to
(
"cuda"
)
bias
=
bias
.
to
(
"cuda"
)
qweight
=
qweight_i
qact
=
qact_i
qweight
=
qweight
.
reshape
(
oc
,
ic
//
group_size
,
group_size
).
transpose
(
0
,
1
).
transpose
(
1
,
2
)
qact
=
qact
.
reshape
(
batch_size
,
ic
//
group_size
,
group_size
).
transpose
(
0
,
1
)
# [ic // group_size, batch_size, oc]
psum
=
torch
.
bmm
(
qact
.
float
(),
qweight
.
float
())
print
(
f
"psum_i (
{
psum
.
shape
}
) =
{
psum
}
"
)
# print(psum[:, 0, 23])
# print(f"ascales = {ascales}")
print_debug_results
({
"ascales"
:
ascales
})
print
(
f
"ascales[0:16] =
{
ascales
[
0
:
16
,
0
]
}
"
)
ws1
=
wscales
.
transpose
(
0
,
1
).
reshape
(
ic
//
group_size
,
1
,
oc
).
repeat
(
1
,
batch_size
,
1
)
as1
=
ascales
.
transpose
(
0
,
1
).
reshape
(
ic
//
group_size
,
batch_size
,
1
).
repeat
(
1
,
1
,
oc
)
scales
=
ws1
*
as1
print
(
f
"scales =
{
scales
}
"
)
# print(scales[:, 0, 23])
psum
=
psum
.
to
(
dtype
=
act
.
dtype
).
float
()
psum
=
psum
*
scales
print
(
f
"psum (
{
psum
.
shape
}
) =
{
psum
}
"
)
# print(psum[:, 0, 23])
# outref_q2 = psum.sum(dim=0) # .to(layer.weight.dtype)
outref_q2
=
torch
.
zeros_like
(
psum
[
0
])
for
i
in
range
(
psum
.
shape
[
0
]):
outref_q2
=
(
outref_q2
+
psum
[
i
]).
to
(
act
.
dtype
)
outref_q2
+=
bias
[
None
,
...]
# print(f"outref_q2 = {outref_q2}")
print_debug_results
({
"outref_q2"
:
outref_q2
})
# print(outref_q2[0, 23])
if
force_cuda
:
outref_q2
=
outref_q2
.
to
(
act
.
device
)
return
outref_q
,
outref_q2
dev-scripts/merge_non_lora_weight.py
deleted
100644 → 0
View file @
8431762a
from
safetensors.torch
import
safe_open
,
save_file
def
main
():
input_path1
=
"app/i2i/pretrained/converted/sketch.safetensors"
input_path2
=
"app/i2i/pretrained/original/flux-lora-sketch2image-bf16.safetensors"
sd1
=
{}
with
safe_open
(
input_path1
,
framework
=
"pt"
)
as
f
:
for
k
in
f
.
keys
():
sd1
[
k
]
=
f
.
get_tensor
(
k
)
sd2
=
{}
with
safe_open
(
input_path2
,
framework
=
"pt"
)
as
f
:
for
k
in
f
.
keys
():
sd2
[
k
]
=
f
.
get_tensor
(
k
)
for
k
in
sd1
.
keys
():
if
"lora"
not
in
k
:
print
(
k
)
sd2
[
k
.
replace
(
"transformer."
,
""
)]
=
sd1
[
k
]
save_file
(
sd2
,
"svdq-flux.1-pix2pix-turbo-sketch2image.safetensors"
)
if
__name__
==
"__main__"
:
main
()
dev-scripts/qmodule.py
deleted
100644 → 0
View file @
8431762a
import
math
import
torch
import
torch.nn
as
nn
def
make_divisible
(
c
,
divisor
):
return
(
c
+
divisor
-
1
)
//
divisor
def
calculate_zeros_width
(
in_features
,
group_size
=
128
,
pack_num
=
8
):
if
group_size
>=
128
:
size_multiplier
=
1
elif
group_size
==
64
:
size_multiplier
=
2
elif
group_size
==
32
:
size_multiplier
=
4
else
:
raise
NotImplementedError
base_width
=
make_divisible
(
in_features
//
group_size
,
pack_num
)
base_width
=
make_divisible
(
base_width
,
size_multiplier
)
*
size_multiplier
return
base_width
def
pack_intweight
(
unpacked_qweight
,
interleave
,
kstride
):
# unpacked_qweight: [N, K]
N
=
unpacked_qweight
.
shape
[
0
]
K
=
unpacked_qweight
.
shape
[
1
]
Packed_Kernel
=
unpacked_qweight
.
cpu
().
numpy
().
reshape
(
N
,
K
//
32
,
32
)
# np.arange(32).reshape(4, 4, 2).transpose(1, 0, 2) => [0, 1, 8, 9, 16, 17, 24, 25, ...]
Packed_Kernel
=
Packed_Kernel
.
reshape
(
N
,
K
//
32
,
4
,
4
,
2
).
transpose
(
0
,
1
,
3
,
2
,
4
)
Packed_Kernel
=
Packed_Kernel
.
reshape
(
N
,
K
//
32
,
32
)
# reorder each 8 weights for fast dequantization
# [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7]
Packed_Kernel
=
Packed_Kernel
.
reshape
(
N
,
K
//
32
,
4
,
8
)
Packed_Kernel
=
Packed_Kernel
.
reshape
(
N
,
K
//
32
,
4
,
4
,
2
).
transpose
(
0
,
1
,
2
,
4
,
3
)
Packed_Kernel
=
Packed_Kernel
.
reshape
(
N
,
K
)
# interleaving every four rows
Packed_Kernel
=
Packed_Kernel
.
reshape
(
N
//
interleave
,
interleave
,
K
//
kstride
,
kstride
)
# N // 4, K // 64, 4, 64
Packed_Kernel
=
Packed_Kernel
.
transpose
(
0
,
2
,
1
,
3
)
Packed_Kernel
=
Packed_Kernel
.
reshape
(
N
//
interleave
,
K
//
kstride
,
kstride
,
interleave
)
# Packing -> (N // 4, K // 64, 64)
Packed_Kernel
=
(
Packed_Kernel
[...,
0
]
|
(
Packed_Kernel
[...,
1
]
<<
4
)
|
(
Packed_Kernel
[...,
2
]
<<
8
)
|
(
Packed_Kernel
[...,
3
]
<<
12
)
)
# reshape to (N // 4, K), FP16 format
Packed_Kernel
=
Packed_Kernel
.
reshape
(
N
//
interleave
,
K
)
qweight
=
(
torch
.
tensor
(
Packed_Kernel
.
astype
(
"int16"
))
.
to
(
unpacked_qweight
.
device
)
.
contiguous
()
)
return
qweight
def
pseudo_quantize_tensor
(
w
,
n_bit
=
8
,
zero_point
=
True
,
q_group_size
=-
1
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
org_w_shape
=
w
.
shape
if
q_group_size
>
0
:
assert
org_w_shape
[
-
1
]
%
q_group_size
==
0
w
=
w
.
reshape
(
-
1
,
q_group_size
)
assert
w
.
dim
()
==
2
if
zero_point
:
max_val
=
w
.
amax
(
dim
=
1
,
keepdim
=
True
)
min_val
=
w
.
amin
(
dim
=
1
,
keepdim
=
True
)
max_int
=
2
**
n_bit
-
1
min_int
=
0
scales
=
(
max_val
-
min_val
).
clamp
(
min
=
1e-5
)
/
max_int
zeros
=
(
-
torch
.
round
(
min_val
/
scales
)).
clamp_
(
min_int
,
max_int
)
else
:
# we actually never used this
# assert min_val is None
max_val
=
w
.
abs
().
amax
(
dim
=
1
,
keepdim
=
True
)
max_val
=
max_val
.
clamp
(
min
=
1e-5
)
max_int
=
2
**
(
n_bit
-
1
)
-
1
min_int
=
-
max_int
scales
=
max_val
/
max_int
zeros
=
torch
.
full_like
(
scales
,
fill_value
=-
min_int
)
assert
torch
.
isnan
(
scales
).
sum
()
==
0
assert
torch
.
isnan
(
w
).
sum
()
==
0
w
=
w
.
reshape
(
org_w_shape
)
return
scales
.
view
(
w
.
shape
[
0
],
-
1
),
zeros
.
view
(
w
.
shape
[
0
],
-
1
)
def
dump_linear_awq
(
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
w_bit
:
int
,
group_size
:
int
,
zero_point
:
bool
=
True
)
->
dict
[
str
,
torch
.
Tensor
]:
scales
,
zeros
=
pseudo_quantize_tensor
(
weight
,
w_bit
,
zero_point
,
group_size
)
print
(
scales
.
shape
)
print
(
zeros
.
shape
)
tensors
=
{}
dtype
=
weight
.
dtype
oc
,
ic
=
weight
.
shape
# need scales and zeros info for real quantization
assert
scales
is
not
None
and
zeros
is
not
None
scale_zeros
=
zeros
*
scales
pack_num
=
32
//
w_bit
qscales
=
torch
.
zeros
(
(
scales
.
shape
[
0
],
calculate_zeros_width
(
ic
,
group_size
)
*
pack_num
,
),
dtype
=
dtype
,
device
=
scales
.
device
,
)
qscales
[:,
:
scales
.
shape
[
1
]]
=
scales
# awq_linear.scales = scales.clone().half()
tensors
[
"wscales"
]
=
qscales
.
transpose
(
1
,
0
).
contiguous
()
if
bias
is
not
None
:
tensors
[
"bias"
]
=
bias
.
clone
()
if
False
:
intweight
=
[]
for
idx
in
range
(
ic
):
intweight
.
append
(
torch
.
round
(
(
weight
.
data
[:,
idx
]
+
scale_zeros
[:,
idx
//
group_size
])
/
qscales
[:,
idx
//
group_size
]
).
clamp
(
0
,
15
if
zero_point
else
14
).
to
(
torch
.
int
)[:,
None
]
)
print
(
intweight
[
0
].
shape
)
intweight
=
torch
.
cat
(
intweight
,
dim
=
1
)
print
(
intweight
.
shape
)
intweight_ref
=
intweight
# intweight = intweight.t().contiguous()
assert
ic
%
group_size
==
0
intweight
=
weight
.
reshape
(
oc
,
ic
//
group_size
,
group_size
)
# print(f"{intweight.shape} {scale_zeros[..., None].shape} {qscales[..., None].shape}")
intweight
=
(
intweight
+
scale_zeros
[...,
None
])
/
qscales
[...,
None
]
intweight
=
intweight
.
round_
()
intweight
=
intweight
.
clamp_
(
0
,
15
if
zero_point
else
14
)
intweight
=
intweight
.
to
(
dtype
=
torch
.
int32
)
intweight
=
intweight
.
reshape
(
oc
,
ic
)
if
False
:
print
(
intweight_ref
-
intweight
)
assert
not
(
intweight_ref
-
intweight
!=
0
).
any
()
tensors
[
"qweight"
]
=
pack_intweight
(
intweight
.
contiguous
(),
interleave
=
4
,
kstride
=
64
)
zeros
=
zeros
.
to
(
dtype
=
torch
.
int32
)
scaled_zeros
=
torch
.
zeros_like
(
qscales
)
# scaled_zeros[:, :scales.shape[1]] = -(qscales[:, :scales.shape[1]] * (zeros.to(torch.float32) - 8.0)).to(torch.float16)
scaled_zeros
[:,
:
scales
.
shape
[
1
]]
=
-
(
qscales
[:,
:
scales
.
shape
[
1
]]
*
(
zeros
.
to
(
torch
.
float32
))
).
to
(
dtype
)
tensors
[
"wzeros"
]
=
scaled_zeros
.
transpose
(
1
,
0
).
contiguous
()
return
tensors
dev-scripts/run_flux.py
deleted
100644 → 0
View file @
8431762a
import
time
import
argparse
import
torch
from
diffusers
import
FluxPipeline
,
FluxTransformer2DModel
import
nunchaku.pipelines.flux
def
get_pipe
(
config
:
str
,
dev
:
bool
)
->
FluxPipeline
:
version
=
"dev"
if
dev
else
"schnell"
dtype
=
torch
.
bfloat16
qencoder_path
=
"/NFS/raid0/user/zhangzk/models/flux-t5-tinychat-v2.pt"
if
config
.
startswith
(
"svdq"
):
pipe
=
nunchaku
.
pipelines
.
flux
.
from_pretrained
(
f
"black-forest-labs/FLUX.1-
{
version
}
"
,
torch_dtype
=
dtype
,
qmodel_path
=
f
"/NFS/raid0/user/zhangzk/models/flux
{
'-dev'
if
dev
else
''
}
-svdq-19-38-divsmooth-shift-ada-bf16.safetensors"
,
qencoder_path
=
qencoder_path
if
config
==
"svdq-t5"
else
None
)
elif
config
.
startswith
(
"w4a4"
):
pipe
=
nunchaku
.
pipelines
.
flux
.
from_pretrained
(
f
"black-forest-labs/FLUX.1-
{
version
}
"
,
torch_dtype
=
dtype
,
qmodel_path
=
f
"/NFS/raid0/user/zhangzk/models/flux
{
'-dev'
if
dev
else
''
}
-divsmooth-shift-ada-bf16.safetensors"
,
qencoder_path
=
qencoder_path
if
config
==
"w4a4-t5"
else
None
)
elif
config
.
startswith
(
"bf16"
):
pipe
=
FluxPipeline
.
from_pretrained
(
f
"black-forest-labs/FLUX.1-
{
version
}
"
,
torch_dtype
=
dtype
,
)
if
config
==
"bf16-t5"
:
nunchaku
.
pipelines
.
flux
.
quantize_t5
(
pipe
,
qencoder_path
)
elif
config
.
startswith
(
"nf4"
):
from
accelerate.utils
import
set_module_tensor_to_device
,
compute_module_sizes
from
accelerate
import
init_empty_weights
from
convert_nf4_flux
import
_replace_with_bnb_linear
,
create_quantized_param
,
check_quantized_param
converted_state_dict
=
torch
.
load
(
f
"/NFS/raid0/user/zhangzk/models/flux1-
{
version
}
-nf4.pt"
)
with
init_empty_weights
():
config
=
FluxTransformer2DModel
.
load_config
(
f
"black-forest-labs/flux.1-
{
version
}
"
,
subfolder
=
"transformer"
)
model
=
FluxTransformer2DModel
.
from_config
(
config
).
to
(
dtype
)
_replace_with_bnb_linear
(
model
,
"nf4"
)
for
param_name
,
param
in
converted_state_dict
.
items
():
param
=
param
.
to
(
dtype
)
print
(
f
"
{
param_name
}
:
{
param
.
shape
}
check_quantized_param=
{
check_quantized_param
(
model
,
param_name
)
}
"
)
if
not
check_quantized_param
(
model
,
param_name
):
set_module_tensor_to_device
(
model
,
param_name
,
device
=
0
,
value
=
param
)
else
:
create_quantized_param
(
model
,
param
,
param_name
,
target_device
=
0
)
pipe
=
FluxPipeline
.
from_pretrained
(
f
"black-forest-labs/flux.1-
{
version
}
"
,
transformer
=
model
,
torch_dtype
=
dtype
)
if
config
==
"nf4-t5"
:
nunchaku
.
pipelines
.
flux
.
quantize_t5
(
pipe
,
qencoder_path
)
else
:
raise
NotImplementedError
return
pipe
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--config"
,
type
=
str
,
default
=
"svdq"
,
choices
=
[
"svdq"
,
"svdq-t5"
,
"w4a4"
,
"w4a4-t5"
,
"bf16"
,
"bf16-t5"
,
"nf4"
,
"nf4-t5"
])
parser
.
add_argument
(
"--offload"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--steps"
,
type
=
int
,
default
=
50
)
parser
.
add_argument
(
"--dev"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--torchao"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--compile"
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
print
(
f
"Use config
{
args
.
config
}
"
)
if
args
.
offload
>
0
:
print
(
f
"Use offloading level
{
args
.
offload
}
"
)
pipe
=
get_pipe
(
args
.
config
,
args
.
dev
)
print
(
pipe
)
if
args
.
torchao
:
from
torchao.quantization
import
quantize_
,
int8_dynamic_activation_int8_weight
# pipe.transformer = autoquant(pipe.transformer, error_on_unseen=False)
quantize_
(
pipe
.
transformer
,
int8_dynamic_activation_int8_weight
())
if
args
.
offload
==
2
:
pipe
.
enable_sequential_cpu_offload
()
elif
args
.
offload
==
1
:
pipe
.
enable_model_cpu_offload
()
elif
args
.
offload
==
0
:
pipe
.
to
(
"cuda:0"
)
else
:
raise
NotImplementedError
# assert isinstance(pipe, FluxPipeline)
if
args
.
compile
:
pipe
.
transformer
.
to
(
memory_format
=
torch
.
channels_last
)
pipe
.
transformer
=
torch
.
compile
(
pipe
.
transformer
,
mode
=
"max-autotune"
,
fullgraph
=
True
)
prompt
=
"A cat holding a sign that says hello world"
print
(
f
"Using prompt '
{
prompt
}
'"
)
print
(
f
"Run
{
args
.
steps
}
steps"
)
latencies
=
[]
for
i
in
range
(
5
):
start_time
=
time
.
time
()
out
=
pipe
(
prompt
=
prompt
,
guidance_scale
=
0
,
num_inference_steps
=
args
.
steps
,
generator
=
torch
.
Generator
(
device
=
"cpu"
).
manual_seed
(
233
),
).
images
[
0
]
end_time
=
time
.
time
()
latencies
.
append
(
end_time
-
start_time
)
torch
.
cuda
.
empty_cache
()
latencies
=
sorted
(
latencies
)
latencies
=
latencies
[
1
:
-
1
]
out
.
save
(
"output.png"
)
print
(
f
"Elapsed:
{
sum
(
latencies
)
/
len
(
latencies
)
}
seconds"
)
print
(
f
"Torch max_memory_allocated=
{
torch
.
cuda
.
max_memory_allocated
()
}
"
)
\ No newline at end of file
dev-scripts/run_flux_generate.py
deleted
100644 → 0
View file @
8431762a
import
time
import
torch
import
diffusers
from
diffusers
import
FluxPipeline
import
nunchaku.pipelines.flux
if
__name__
==
"__main__"
:
QUANT
=
False
SEED
=
1
DEV
=
True
LORA_NAME
=
"anime"
pipe
=
nunchaku
.
pipelines
.
flux
.
from_pretrained
(
f
"black-forest-labs/FLUX.1-
{
'dev'
if
DEV
else
'schnell'
}
"
,
torch_dtype
=
torch
.
bfloat16
,
qmodel_path
=
f
"/NFS/raid0/user/zhangzk/models/flux
{
'-dev'
if
DEV
else
''
}
-svdq-19-38-divsmooth-shift-ada-bf16.safetensors"
,
qencoder_path
=
"/NFS/raid0/user/zhangzk/models/flux-t5-tinychat-v2.pt"
if
QUANT
else
None
,
)
if
LORA_NAME
:
pipe
.
transformer
.
nunchaku_update_params
(
f
"/tmp/flux-lora-
{
LORA_NAME
}
-bf16.safetensors"
)
pipe
.
transformer
.
nunchaku_set_lora_scale
(
0.4
)
print
(
"Moving model to CUDA"
)
pipe
.
to
(
"cuda:0"
)
print
(
"Done"
)
# prompt = "A cat holding a sign that says hello world"
# prompt = "A cyberpunk cat holding a huge neon sign that says \"SVDQuant is lite and fast\""
prompt
=
"girl, neck tuft, white hair ,sheep horns, blue eyes, nm22 style"
# prompt = "GHIBSKY style, the most beautiful place in the universe"
# prompt = "the joker, yarn art style"
print
(
f
"Using prompt '
{
prompt
}
'"
)
latencies
=
[]
diffusers
.
training_utils
.
set_seed
(
SEED
)
start_time
=
time
.
time
()
out
=
pipe
(
prompt
=
prompt
,
guidance_scale
=
3.5
if
DEV
else
0
,
num_inference_steps
=
50
if
DEV
else
4
,
generator
=
torch
.
Generator
(
device
=
"cpu"
).
manual_seed
(
SEED
),
).
images
[
0
]
end_time
=
time
.
time
()
latencies
.
append
(
end_time
-
start_time
)
out
.
save
(
f
"output
{
'-dev'
if
DEV
else
''
}
-
{
SEED
}
-
{
'quant'
if
QUANT
else
'noquant'
}
.png"
)
print
(
f
"Elapsed:
{
sum
(
latencies
)
/
len
(
latencies
)
}
seconds"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment