Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
37a27712
Unverified
Commit
37a27712
authored
May 01, 2025
by
Muyang Li
Committed by
GitHub
May 01, 2025
Browse files
Merge pull request #340 from mit-han-lab/dev
feat: support PuLID, Double FBCache and TeaCache; better linter
parents
c1d6fc84
760ab022
Changes
192
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2209 additions
and
481 deletions
+2209
-481
nunchaku/caching/utils.py
nunchaku/caching/utils.py
+400
-93
nunchaku/csrc/flux.h
nunchaku/csrc/flux.h
+82
-52
nunchaku/csrc/gemm.h
nunchaku/csrc/gemm.h
+13
-12
nunchaku/csrc/gemm88.h
nunchaku/csrc/gemm88.h
+5
-4
nunchaku/csrc/module.h
nunchaku/csrc/module.h
+5
-5
nunchaku/csrc/ops.h
nunchaku/csrc/ops.h
+119
-162
nunchaku/csrc/pybind.cpp
nunchaku/csrc/pybind.cpp
+56
-66
nunchaku/csrc/sana.h
nunchaku/csrc/sana.h
+54
-58
nunchaku/csrc/utils.h
nunchaku/csrc/utils.h
+27
-27
nunchaku/lora/flux/__init__.py
nunchaku/lora/flux/__init__.py
+2
-0
nunchaku/lora/flux/nunchaku_converter.py
nunchaku/lora/flux/nunchaku_converter.py
+1
-1
nunchaku/lora/flux/packer.py
nunchaku/lora/flux/packer.py
+1
-1
nunchaku/models/__init__.py
nunchaku/models/__init__.py
+2
-0
nunchaku/models/pulid/encoders_transformer.py
nunchaku/models/pulid/encoders_transformer.py
+210
-0
nunchaku/models/pulid/eva_clip/__init__.py
nunchaku/models/pulid/eva_clip/__init__.py
+4
-0
nunchaku/models/pulid/eva_clip/constants.py
nunchaku/models/pulid/eva_clip/constants.py
+2
-0
nunchaku/models/pulid/eva_clip/eva_vit_model.py
nunchaku/models/pulid/eva_clip/eva_vit_model.py
+622
-0
nunchaku/models/pulid/eva_clip/factory.py
nunchaku/models/pulid/eva_clip/factory.py
+406
-0
nunchaku/models/pulid/eva_clip/hf_configs.py
nunchaku/models/pulid/eva_clip/hf_configs.py
+57
-0
nunchaku/models/pulid/eva_clip/hf_model.py
nunchaku/models/pulid/eva_clip/hf_model.py
+141
-0
No files found.
nunchaku/caching/utils.py
View file @
37a27712
...
...
@@ -3,11 +3,16 @@
import
contextlib
import
dataclasses
from
collections
import
defaultdict
from
typing
import
DefaultDict
,
Dict
,
Optional
from
typing
import
DefaultDict
,
Dict
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
nunchaku.models.transformers.utils
import
pad_tensor
num_transformer_blocks
=
19
# FIXME
num_single_transformer_blocks
=
38
# FIXME
@
dataclasses
.
dataclass
class
CacheContext
:
...
...
@@ -75,38 +80,123 @@ def cache_context(cache_context):
def
are_two_tensors_similar
(
t1
,
t2
,
*
,
threshold
,
parallelized
=
False
):
mean_diff
=
(
t1
-
t2
).
abs
().
mean
()
mean_t1
=
t1
.
abs
().
mean
()
diff
=
mean_diff
/
mean_t1
return
diff
.
item
()
<
threshold
diff
=
(
mean_diff
/
mean_t1
).
item
()
return
diff
<
threshold
,
diff
@
torch
.
compiler
.
disable
def
apply_prev_hidden_states_residual
(
hidden_states
:
torch
.
Tensor
,
encoder_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
hidden_states_residual
=
get_buffer
(
"hidden_states_residual"
)
assert
hidden_states_residual
is
not
None
,
"hidden_states_residual must be set before"
hidden_states
=
hidden_states_residual
+
hidden_states
hidden_states
=
hidden_states
.
contiguous
()
if
encoder_hidden_states
is
not
None
:
encoder_hidden_states_residual
=
get_buffer
(
"encoder_hidden_states_residual"
)
assert
encoder_hidden_states_residual
is
not
None
,
"encoder_hidden_states_residual must be set before"
encoder_hidden_states
=
encoder_hidden_states_residual
+
encoder_hidden_states
encoder_hidden_states
=
encoder_hidden_states
.
contiguous
()
hidden_states
:
torch
.
Tensor
,
encoder_hidden_states
:
torch
.
Tensor
=
None
,
mode
:
str
=
"multi"
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
mode
==
"multi"
:
hidden_states_residual
=
get_buffer
(
"multi_hidden_states_residual"
)
assert
hidden_states_residual
is
not
None
,
"multi_hidden_states_residual must be set before"
hidden_states
=
hidden_states
+
hidden_states_residual
hidden_states
=
hidden_states
.
contiguous
()
if
encoder_hidden_states
is
not
None
:
enc_hidden_res
=
get_buffer
(
"multi_encoder_hidden_states_residual"
)
msg
=
"multi_encoder_hidden_states_residual must be set before"
assert
enc_hidden_res
is
not
None
,
msg
encoder_hidden_states
=
encoder_hidden_states
+
enc_hidden_res
encoder_hidden_states
=
encoder_hidden_states
.
contiguous
()
return
hidden_states
,
encoder_hidden_states
elif
mode
==
"single"
:
single_residual
=
get_buffer
(
"single_hidden_states_residual"
)
msg
=
"single_hidden_states_residual must be set before"
assert
single_residual
is
not
None
,
msg
hidden_states
=
hidden_states
+
single_residual
hidden_states
=
hidden_states
.
contiguous
()
return
hidden_states
return
hidden_states
,
encoder_hidden_states
else
:
raise
ValueError
(
f
"Unknown mode
{
mode
}
; expected 'multi' or 'single'"
)
@
torch
.
compiler
.
disable
def
get_can_use_cache
(
first_hidden_states_residual
,
threshold
,
parallelized
=
False
):
prev_first_hidden_states_residual
=
get_buffer
(
"first_hidden_states_residual"
)
can_use_cache
=
prev_first_hidden_states_residual
is
not
None
and
are_two_tensors_similar
(
prev_first_hidden_states_residual
,
def
get_can_use_cache
(
first_hidden_states_residual
:
torch
.
Tensor
,
threshold
:
float
,
parallelized
:
bool
=
False
,
mode
:
str
=
"multi"
):
if
mode
==
"multi"
:
buffer_name
=
"first_multi_hidden_states_residual"
elif
mode
==
"single"
:
buffer_name
=
"first_single_hidden_states_residual"
else
:
raise
ValueError
(
f
"Unknown mode
{
mode
}
; expected 'multi' or 'single'"
)
prev_res
=
get_buffer
(
buffer_name
)
if
prev_res
is
None
:
return
False
,
threshold
is_similar
,
diff
=
are_two_tensors_similar
(
prev_res
,
first_hidden_states_residual
,
threshold
=
threshold
,
parallelized
=
parallelized
,
)
return
can_use_cache
return
is_similar
,
diff
def
check_and_apply_cache
(
*
,
first_residual
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
encoder_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
,
threshold
:
float
,
parallelized
:
bool
,
mode
:
str
,
verbose
:
bool
,
call_remaining_fn
,
remaining_kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
float
]:
can_use_cache
,
diff
=
get_can_use_cache
(
first_residual
,
threshold
=
threshold
,
parallelized
=
parallelized
,
mode
=
mode
,
)
torch
.
_dynamo
.
graph_break
()
if
can_use_cache
:
if
verbose
:
print
(
f
"[
{
mode
.
upper
()
}
] Cache hit! diff=
{
diff
:.
4
f
}
, "
f
"new threshold=
{
threshold
:.
4
f
}
"
)
out
=
apply_prev_hidden_states_residual
(
hidden_states
,
encoder_hidden_states
,
mode
=
mode
)
updated_h
,
updated_enc
=
out
if
isinstance
(
out
,
tuple
)
else
(
out
,
None
)
return
updated_h
,
updated_enc
,
threshold
old_threshold
=
threshold
if
verbose
:
print
(
f
"[
{
mode
.
upper
()
}
] Cache miss. diff=
{
diff
:.
4
f
}
, "
f
"was=
{
old_threshold
:.
4
f
}
=> now=
{
threshold
:.
4
f
}
"
)
if
mode
==
"multi"
:
set_buffer
(
"first_multi_hidden_states_residual"
,
first_residual
)
else
:
set_buffer
(
"first_single_hidden_states_residual"
,
first_residual
)
result
=
call_remaining_fn
(
hidden_states
=
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
**
remaining_kwargs
)
if
mode
==
"multi"
:
updated_h
,
updated_enc
,
hs_res
,
enc_res
=
result
set_buffer
(
"multi_hidden_states_residual"
,
hs_res
)
set_buffer
(
"multi_encoder_hidden_states_residual"
,
enc_res
)
return
updated_h
,
updated_enc
,
threshold
elif
mode
==
"single"
:
updated_cat_states
,
cat_res
=
result
set_buffer
(
"single_hidden_states_residual"
,
cat_res
)
return
updated_cat_states
,
None
,
threshold
raise
ValueError
(
f
"Unknown mode
{
mode
}
"
)
class
SanaCachedTransformerBlocks
(
nn
.
Module
):
...
...
@@ -230,109 +320,326 @@ class FluxCachedTransformerBlocks(nn.Module):
def
__init__
(
self
,
*
,
transformer
=
None
,
residual_diff_threshold
,
return_hidden_states_first
=
True
,
return_hidden_states_only
=
False
,
transformer
:
nn
.
Module
=
None
,
use_double_fb_cache
:
bool
=
True
,
residual_diff_threshold_multi
:
float
,
residual_diff_threshold_single
:
float
,
return_hidden_states_first
:
bool
=
True
,
return_hidden_states_only
:
bool
=
False
,
verbose
:
bool
=
False
,
):
super
().
__init__
()
self
.
transformer
=
transformer
self
.
transformer_blocks
=
transformer
.
transformer_blocks
self
.
single_transformer_blocks
=
transformer
.
single_transformer_blocks
self
.
residual_diff_threshold
=
residual_diff_threshold
self
.
use_double_fb_cache
=
use_double_fb_cache
self
.
residual_diff_threshold_multi
=
residual_diff_threshold_multi
self
.
residual_diff_threshold_single
=
residual_diff_threshold_single
self
.
return_hidden_states_first
=
return_hidden_states_first
self
.
return_hidden_states_only
=
return_hidden_states_only
self
.
verbose
=
verbose
def
update_residual_diff_threshold
(
self
,
residual_diff_threshold
=
0.12
):
self
.
residual_diff_threshold
=
residual_diff_threshold
self
.
m
=
self
.
transformer_blocks
[
0
].
m
self
.
dtype
=
torch
.
bfloat16
if
self
.
m
.
isBF16
()
else
torch
.
float16
self
.
device
=
transformer
.
device
@
staticmethod
def
pack_rotemb
(
rotemb
:
torch
.
Tensor
)
->
torch
.
Tensor
:
assert
rotemb
.
dtype
==
torch
.
float32
B
=
rotemb
.
shape
[
0
]
M
=
rotemb
.
shape
[
1
]
D
=
rotemb
.
shape
[
2
]
*
2
msg_shape
=
"rotemb shape must be (B, M, D//2, 1, 2)"
assert
rotemb
.
shape
==
(
B
,
M
,
D
//
2
,
1
,
2
),
msg_shape
assert
M
%
16
==
0
assert
D
%
8
==
0
rotemb
=
rotemb
.
reshape
(
B
,
M
//
16
,
16
,
D
//
8
,
8
)
rotemb
=
rotemb
.
permute
(
0
,
1
,
3
,
2
,
4
)
# 16*8 pack, FP32 accumulator (C) format
# https://docs.nvidia.com/cuda/parallel-thread-execution/#mma-16816-c
rotemb
=
rotemb
.
reshape
(
*
rotemb
.
shape
[
0
:
3
],
2
,
8
,
4
,
2
)
rotemb
=
rotemb
.
permute
(
0
,
1
,
2
,
4
,
5
,
3
,
6
)
rotemb
=
rotemb
.
contiguous
()
rotemb
=
rotemb
.
view
(
B
,
M
,
D
)
return
rotemb
def
update_residual_diff_threshold
(
self
,
use_double_fb_cache
=
True
,
residual_diff_threshold_multi
=
0.12
,
residual_diff_threshold_single
=
0.09
):
self
.
use_double_fb_cache
=
use_double_fb_cache
self
.
residual_diff_threshold_multi
=
residual_diff_threshold_multi
self
.
residual_diff_threshold_single
=
residual_diff_threshold_single
def
forward
(
self
,
hidden_states
,
encoder_hidden_states
,
*
args
,
**
kwargs
):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
temb
:
torch
.
Tensor
,
encoder_hidden_states
:
torch
.
Tensor
,
image_rotary_emb
:
torch
.
Tensor
,
joint_attention_kwargs
=
None
,
controlnet_block_samples
=
None
,
controlnet_single_block_samples
=
None
,
skip_first_layer
=
False
,
):
batch_size
=
hidden_states
.
shape
[
0
]
if
self
.
residual_diff_threshold
<=
0.0
or
batch_size
>
1
:
if
batch_size
>
1
:
print
(
"Batch size > 1 currently not supported"
)
txt_tokens
=
encoder_hidden_states
.
shape
[
1
]
img_tokens
=
hidden_states
.
shape
[
1
]
first_transformer_block
=
self
.
transformer_blocks
[
0
]
encoder_hidden_states
,
hidden_states
=
first_transformer_block
(
hidden_states
=
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
*
args
,
**
kwargs
original_dtype
=
hidden_states
.
dtype
original_device
=
hidden_states
.
device
hidden_states
=
hidden_states
.
to
(
self
.
dtype
).
to
(
self
.
device
)
encoder_hidden_states
=
encoder_hidden_states
.
to
(
self
.
dtype
).
to
(
self
.
device
)
temb
=
temb
.
to
(
self
.
dtype
).
to
(
self
.
device
)
image_rotary_emb
=
image_rotary_emb
.
to
(
self
.
device
)
if
controlnet_block_samples
is
not
None
:
controlnet_block_samples
=
(
torch
.
stack
(
controlnet_block_samples
).
to
(
self
.
device
)
if
len
(
controlnet_block_samples
)
>
0
else
None
)
if
controlnet_single_block_samples
is
not
None
and
len
(
controlnet_single_block_samples
)
>
0
:
controlnet_single_block_samples
=
(
torch
.
stack
(
controlnet_single_block_samples
).
to
(
self
.
device
)
if
len
(
controlnet_single_block_samples
)
>
0
else
None
)
return
(
hidden_states
if
self
.
return_hidden_states_only
else
(
(
hidden_states
,
encoder_hidden_states
)
if
self
.
return_hidden_states_first
else
(
encoder_hidden_states
,
hidden_states
)
)
assert
image_rotary_emb
.
ndim
==
6
assert
image_rotary_emb
.
shape
[
0
]
==
1
assert
image_rotary_emb
.
shape
[
1
]
==
1
# [1, tokens, head_dim/2, 1, 2] (sincos)
total_tokens
=
txt_tokens
+
img_tokens
assert
image_rotary_emb
.
shape
[
2
]
==
1
*
total_tokens
image_rotary_emb
=
image_rotary_emb
.
reshape
([
1
,
txt_tokens
+
img_tokens
,
*
image_rotary_emb
.
shape
[
3
:]])
rotary_emb_txt
=
image_rotary_emb
[:,
:
txt_tokens
,
...]
rotary_emb_img
=
image_rotary_emb
[:,
txt_tokens
:,
...]
rotary_emb_single
=
image_rotary_emb
rotary_emb_txt
=
self
.
pack_rotemb
(
pad_tensor
(
rotary_emb_txt
,
256
,
1
))
rotary_emb_img
=
self
.
pack_rotemb
(
pad_tensor
(
rotary_emb_img
,
256
,
1
))
rotary_emb_single
=
self
.
pack_rotemb
(
pad_tensor
(
rotary_emb_single
,
256
,
1
))
if
(
self
.
residual_diff_threshold_multi
<
0.0
)
or
(
batch_size
>
1
):
if
batch_size
>
1
and
self
.
verbose
:
print
(
"Batch size > 1 currently not supported"
)
hidden_states
=
self
.
m
.
forward
(
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_txt
,
rotary_emb_single
,
controlnet_block_samples
,
controlnet_single_block_samples
,
skip_first_layer
,
)
original_hidden_states
=
hidden_states
first_transformer_block
=
self
.
transformer_blocks
[
0
]
encoder_hidden_states
,
hidden_states
=
first_transformer_block
.
forward_layer_at
(
0
,
hidden_states
,
encoder_hidden_states
,
*
args
,
**
kwargs
)
hidden_states
=
hidden_states
.
to
(
original_dtype
).
to
(
original_device
)
first
_hidden_states
_residual
=
hidden_states
-
original_hidden_states
del
original_hidden_states
encoder
_hidden_states
=
hidden_states
[:,
:
txt_tokens
,
...]
hidden_states
=
hidden_states
[:,
txt_tokens
:,
...]
can_use_cache
=
get_can_use_cache
(
first_hidden_states_residual
,
threshold
=
self
.
residual_diff_threshold
,
parallelized
=
self
.
transformer
is
not
None
and
getattr
(
self
.
transformer
,
"_is_parallelized"
,
False
),
if
self
.
return_hidden_states_only
:
return
hidden_states
if
self
.
return_hidden_states_first
:
return
hidden_states
,
encoder_hidden_states
return
encoder_hidden_states
,
hidden_states
remaining_kwargs
=
{
"temb"
:
temb
,
"rotary_emb_img"
:
rotary_emb_img
,
"rotary_emb_txt"
:
rotary_emb_txt
,
"rotary_emb_single"
:
rotary_emb_single
,
"controlnet_block_samples"
:
controlnet_block_samples
,
"controlnet_single_block_samples"
:
controlnet_single_block_samples
,
"txt_tokens"
:
txt_tokens
,
}
original_hidden_states
=
hidden_states
first_hidden_states
,
first_encoder_hidden_states
=
self
.
m
.
forward_layer
(
0
,
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_txt
,
controlnet_block_samples
,
controlnet_single_block_samples
,
)
hidden_states
=
first_hidden_states
encoder_hidden_states
=
first_encoder_hidden_states
first_hidden_states_residual_multi
=
hidden_states
-
original_hidden_states
del
original_hidden_states
torch
.
_dynamo
.
graph_break
()
if
can_use_cache
:
del
first_hidden_states_residual
if
self
.
verbose
:
print
(
"Cache hit!!!"
)
hidden_states
,
encoder_hidden_states
=
apply_prev_hidden_states_residual
(
hidden_states
,
encoder_hidden_states
)
if
self
.
use_double_fb_cache
:
call_remaining_fn
=
self
.
call_remaining_multi_transformer_blocks
else
:
if
self
.
verbose
:
print
(
"Cache miss!!!"
)
set_buffer
(
"first_hidden_states_residual"
,
first_hidden_states_residual
)
del
first_hidden_states_residual
(
hidden_states
,
encoder_hidden_states
,
hidden_states_residual
,
encoder_hidden_states_residual
,
)
=
self
.
call_remaining_transformer_blocks
(
hidden_states
,
encoder_hidden_states
,
*
args
,
**
kwargs
)
set_buffer
(
"hidden_states_residual"
,
hidden_states_residual
)
set_buffer
(
"encoder_hidden_states_residual"
,
encoder_hidden_states_residual
)
torch
.
_dynamo
.
graph_break
()
call_remaining_fn
=
self
.
call_remaining_FBCache_transformer_blocks
return
(
hidden_states
if
self
.
return_hidden_states_only
else
(
(
hidden_states
,
encoder_hidden_states
)
if
self
.
return_hidden_states_first
else
(
encoder_hidden_states
,
hidden_states
)
)
torch
.
_dynamo
.
graph_break
()
updated_h
,
updated_enc
,
threshold
=
check_and_apply_cache
(
first_residual
=
first_hidden_states_residual_multi
,
hidden_states
=
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
threshold
=
self
.
residual_diff_threshold_multi
,
parallelized
=
(
self
.
transformer
is
not
None
and
getattr
(
self
.
transformer
,
"_is_parallelized"
,
False
)),
mode
=
"multi"
,
verbose
=
self
.
verbose
,
call_remaining_fn
=
call_remaining_fn
,
remaining_kwargs
=
remaining_kwargs
,
)
self
.
residual_diff_threshold_multi
=
threshold
if
not
self
.
use_double_fb_cache
:
if
self
.
return_hidden_states_only
:
return
updated_h
if
self
.
return_hidden_states_first
:
return
updated_h
,
updated_enc
return
updated_enc
,
updated_h
# DoubleFBCache
cat_hidden_states
=
torch
.
cat
([
updated_enc
,
updated_h
],
dim
=
1
)
original_cat
=
cat_hidden_states
cat_hidden_states
=
self
.
m
.
forward_single_layer
(
0
,
cat_hidden_states
,
temb
,
rotary_emb_single
)
first_hidden_states_residual_single
=
cat_hidden_states
-
original_cat
del
original_cat
call_remaining_fn_single
=
self
.
call_remaining_single_transformer_blocks
updated_cat
,
_
,
threshold
=
check_and_apply_cache
(
first_residual
=
first_hidden_states_residual_single
,
hidden_states
=
cat_hidden_states
,
encoder_hidden_states
=
None
,
threshold
=
self
.
residual_diff_threshold_single
,
parallelized
=
(
self
.
transformer
is
not
None
and
getattr
(
self
.
transformer
,
"_is_parallelized"
,
False
)),
mode
=
"single"
,
verbose
=
self
.
verbose
,
call_remaining_fn
=
call_remaining_fn_single
,
remaining_kwargs
=
remaining_kwargs
,
)
self
.
residual_diff_threshold_single
=
threshold
def
call_remaining_transformer_blocks
(
self
,
hidden_states
,
encoder_hidden_states
,
*
args
,
**
kwargs
):
first_transformer_block
=
self
.
transformer_blocks
[
0
]
# torch._dynamo.graph_break()
final_enc
=
updated_cat
[:,
:
txt_tokens
,
...]
final_h
=
updated_cat
[:,
txt_tokens
:,
...]
final_h
=
final_h
.
to
(
original_dtype
).
to
(
original_device
)
final_enc
=
final_enc
.
to
(
original_dtype
).
to
(
original_device
)
if
self
.
return_hidden_states_only
:
return
final_h
if
self
.
return_hidden_states_first
:
return
final_h
,
final_enc
return
final_enc
,
final_h
def
call_remaining_FBCache_transformer_blocks
(
self
,
hidden_states
:
torch
.
Tensor
,
temb
:
torch
.
Tensor
,
encoder_hidden_states
:
torch
.
Tensor
,
rotary_emb_img
:
torch
.
Tensor
,
rotary_emb_txt
:
torch
.
Tensor
,
rotary_emb_single
:
torch
.
Tensor
,
controlnet_block_samples
=
None
,
controlnet_single_block_samples
=
None
,
skip_first_layer
=
True
,
txt_tokens
=
None
,
):
original_dtype
=
hidden_states
.
dtype
original_device
=
hidden_states
.
device
original_hidden_states
=
hidden_states
original_encoder_hidden_states
=
encoder_hidden_states
encoder_hidden_states
,
hidden_states
=
first_transformer_block
.
forward
(
hidden_states
=
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
skip_first_layer
=
True
,
*
args
,
**
kwargs
,
hidden_states
=
self
.
m
.
forward
(
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_txt
,
rotary_emb_single
,
controlnet_block_samples
,
controlnet_single_block_samples
,
skip_first_layer
,
)
hidden_states
=
hidden_states
.
to
(
original_dtype
).
to
(
original_device
)
encoder_hidden_states
=
hidden_states
[:,
:
txt_tokens
,
...]
hidden_states
=
hidden_states
[:,
txt_tokens
:,
...]
hidden_states
=
hidden_states
.
contiguous
()
encoder_hidden_states
=
encoder_hidden_states
.
contiguous
()
hidden_states_residual
=
hidden_states
-
original_hidden_states
encoder_hidden_states_residual
=
encoder_hidden_states
-
original_encoder_hidden_states
enc_residual
=
encoder_hidden_states
-
original_encoder_hidden_states
return
hidden_states
,
encoder_hidden_states
,
hidden_states_residual
,
enc_residual
return
hidden_states
,
encoder_hidden_states
,
hidden_states_residual
,
encoder_hidden_states_residual
def
call_remaining_multi_transformer_blocks
(
self
,
hidden_states
:
torch
.
Tensor
,
temb
:
torch
.
Tensor
,
encoder_hidden_states
:
torch
.
Tensor
,
rotary_emb_img
:
torch
.
Tensor
,
rotary_emb_txt
:
torch
.
Tensor
,
rotary_emb_single
:
torch
.
Tensor
,
controlnet_block_samples
=
None
,
controlnet_single_block_samples
=
None
,
skip_first_layer
=
False
,
txt_tokens
=
None
,
):
start_idx
=
1
original_hidden_states
=
hidden_states
.
clone
()
original_encoder_hidden_states
=
encoder_hidden_states
.
clone
()
for
idx
in
range
(
start_idx
,
num_transformer_blocks
):
hidden_states
,
encoder_hidden_states
=
self
.
m
.
forward_layer
(
idx
,
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_txt
,
controlnet_block_samples
,
controlnet_single_block_samples
,
)
hidden_states
=
hidden_states
.
contiguous
()
encoder_hidden_states
=
encoder_hidden_states
.
contiguous
()
hs_res
=
hidden_states
-
original_hidden_states
enc_res
=
encoder_hidden_states
-
original_encoder_hidden_states
return
hidden_states
,
encoder_hidden_states
,
hs_res
,
enc_res
def
call_remaining_single_transformer_blocks
(
self
,
hidden_states
:
torch
.
Tensor
,
temb
:
torch
.
Tensor
,
encoder_hidden_states
:
torch
.
Tensor
,
rotary_emb_img
:
torch
.
Tensor
,
rotary_emb_txt
:
torch
.
Tensor
,
rotary_emb_single
:
torch
.
Tensor
,
controlnet_block_samples
=
None
,
controlnet_single_block_samples
=
None
,
skip_first_layer
=
False
,
txt_tokens
=
None
,
):
start_idx
=
1
original_hidden_states
=
hidden_states
.
clone
()
for
idx
in
range
(
start_idx
,
num_single_transformer_blocks
):
hidden_states
=
self
.
m
.
forward_single_layer
(
idx
,
hidden_states
,
temb
,
rotary_emb_single
,
)
hidden_states
=
hidden_states
.
contiguous
()
hs_res
=
hidden_states
-
original_hidden_states
return
hidden_states
,
hs_res
nunchaku/csrc/flux.h
View file @
37a27712
...
...
@@ -20,36 +20,59 @@ public:
ModuleWrapper
::
init
(
deviceId
);
CUDADeviceContext
ctx
(
this
->
deviceId
);
net
=
std
::
make_unique
<
FluxModel
>
(
use_fp4
,
offload
,
bf16
?
Tensor
::
BF16
:
Tensor
::
FP16
,
Device
::
cuda
((
int
)
deviceId
));
net
=
std
::
make_unique
<
FluxModel
>
(
use_fp4
,
offload
,
bf16
?
Tensor
::
BF16
:
Tensor
::
FP16
,
Device
::
cuda
((
int
)
deviceId
));
}
bool
isBF16
()
{
checkModel
();
return
net
->
dtype
==
Tensor
::
BF16
;
}
pybind11
::
function
residual_callback
;
void
set_residual_callback
(
pybind11
::
function
callback
)
{
pybind11
::
gil_scoped_acquire
gil
;
if
(
!
callback
||
callback
.
is_none
())
{
residual_callback
=
pybind11
::
function
();
if
(
net
)
{
net
->
set_residual_callback
(
nullptr
);
}
return
;
}
residual_callback
=
std
::
move
(
callback
);
if
(
net
)
{
pybind11
::
object
cb
=
residual_callback
;
net
->
set_residual_callback
([
cb
](
const
Tensor
&
x
)
->
Tensor
{
pybind11
::
gil_scoped_acquire
gil
;
torch
::
Tensor
torch_x
=
to_torch
(
x
);
pybind11
::
object
result
=
cb
(
torch_x
);
torch
::
Tensor
torch_y
=
result
.
cast
<
torch
::
Tensor
>
();
Tensor
y
=
from_torch
(
torch_y
);
return
y
;
});
}
else
{
}
}
torch
::
Tensor
forward
(
torch
::
Tensor
hidden_states
,
torch
::
Tensor
encoder_hidden_states
,
torch
::
Tensor
temb
,
torch
::
Tensor
rotary_emb_img
,
torch
::
Tensor
rotary_emb_context
,
torch
::
Tensor
rotary_emb_single
,
std
::
optional
<
torch
::
Tensor
>
controlnet_block_samples
=
std
::
nullopt
,
std
::
optional
<
torch
::
Tensor
>
controlnet_single_block_samples
=
std
::
nullopt
,
bool
skip_first_layer
=
false
)
{
torch
::
Tensor
forward
(
torch
::
Tensor
hidden_states
,
torch
::
Tensor
encoder_hidden_states
,
torch
::
Tensor
temb
,
torch
::
Tensor
rotary_emb_img
,
torch
::
Tensor
rotary_emb_context
,
torch
::
Tensor
rotary_emb_single
,
std
::
optional
<
torch
::
Tensor
>
controlnet_block_samples
=
std
::
nullopt
,
std
::
optional
<
torch
::
Tensor
>
controlnet_single_block_samples
=
std
::
nullopt
,
bool
skip_first_layer
=
false
)
{
checkModel
();
CUDADeviceContext
ctx
(
deviceId
);
spdlog
::
debug
(
"QuantizedFluxModel forward"
);
hidden_states
=
hidden_states
.
contiguous
();
hidden_states
=
hidden_states
.
contiguous
();
encoder_hidden_states
=
encoder_hidden_states
.
contiguous
();
temb
=
temb
.
contiguous
();
rotary_emb_img
=
rotary_emb_img
.
contiguous
();
rotary_emb_context
=
rotary_emb_context
.
contiguous
();
rotary_emb_single
=
rotary_emb_single
.
contiguous
();
temb
=
temb
.
contiguous
();
rotary_emb_img
=
rotary_emb_img
.
contiguous
();
rotary_emb_context
=
rotary_emb_context
.
contiguous
();
rotary_emb_single
=
rotary_emb_single
.
contiguous
();
Tensor
result
=
net
->
forward
(
from_torch
(
hidden_states
),
...
...
@@ -59,9 +82,10 @@ public:
from_torch
(
rotary_emb_context
),
from_torch
(
rotary_emb_single
),
controlnet_block_samples
.
has_value
()
?
from_torch
(
controlnet_block_samples
.
value
().
contiguous
())
:
Tensor
{},
controlnet_single_block_samples
.
has_value
()
?
from_torch
(
controlnet_single_block_samples
.
value
().
contiguous
())
:
Tensor
{},
skip_first_layer
);
controlnet_single_block_samples
.
has_value
()
?
from_torch
(
controlnet_single_block_samples
.
value
().
contiguous
())
:
Tensor
{},
skip_first_layer
);
torch
::
Tensor
output
=
to_torch
(
result
);
Tensor
::
synchronizeDevice
();
...
...
@@ -69,25 +93,24 @@ public:
return
output
;
}
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
forward_layer
(
int64_t
idx
,
torch
::
Tensor
hidden_states
,
torch
::
Tensor
encoder_hidden_states
,
torch
::
Tensor
temb
,
torch
::
Tensor
rotary_emb_img
,
torch
::
Tensor
rotary_emb_context
,
std
::
optional
<
torch
::
Tensor
>
controlnet_block_samples
=
std
::
nullopt
,
std
::
optional
<
torch
::
Tensor
>
controlnet_single_block_samples
=
std
::
nullopt
)
{
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
forward_layer
(
int64_t
idx
,
torch
::
Tensor
hidden_states
,
torch
::
Tensor
encoder_hidden_states
,
torch
::
Tensor
temb
,
torch
::
Tensor
rotary_emb_img
,
torch
::
Tensor
rotary_emb_context
,
std
::
optional
<
torch
::
Tensor
>
controlnet_block_samples
=
std
::
nullopt
,
std
::
optional
<
torch
::
Tensor
>
controlnet_single_block_samples
=
std
::
nullopt
)
{
CUDADeviceContext
ctx
(
deviceId
);
spdlog
::
debug
(
"QuantizedFluxModel forward_layer {}"
,
idx
);
hidden_states
=
hidden_states
.
contiguous
();
hidden_states
=
hidden_states
.
contiguous
();
encoder_hidden_states
=
encoder_hidden_states
.
contiguous
();
temb
=
temb
.
contiguous
();
rotary_emb_img
=
rotary_emb_img
.
contiguous
();
rotary_emb_context
=
rotary_emb_context
.
contiguous
();
temb
=
temb
.
contiguous
();
rotary_emb_img
=
rotary_emb_img
.
contiguous
();
rotary_emb_context
=
rotary_emb_context
.
contiguous
();
auto
&&
[
hidden_states_
,
encoder_hidden_states_
]
=
net
->
forward_layer
(
idx
,
...
...
@@ -97,35 +120,31 @@ public:
from_torch
(
rotary_emb_img
),
from_torch
(
rotary_emb_context
),
controlnet_block_samples
.
has_value
()
?
from_torch
(
controlnet_block_samples
.
value
().
contiguous
())
:
Tensor
{},
controlnet_single_block_samples
.
has_value
()
?
from_torch
(
controlnet_single_block_samples
.
value
().
contiguous
())
:
Tensor
{}
);
controlnet_single_block_samples
.
has_value
()
?
from_torch
(
controlnet_single_block_samples
.
value
().
contiguous
())
:
Tensor
{});
hidden_states
=
to_torch
(
hidden_states_
);
hidden_states
=
to_torch
(
hidden_states_
);
encoder_hidden_states
=
to_torch
(
encoder_hidden_states_
);
Tensor
::
synchronizeDevice
();
return
{
hidden_states
,
encoder_hidden_states
};
return
{
hidden_states
,
encoder_hidden_states
};
}
torch
::
Tensor
forward_single_layer
(
int64_t
idx
,
torch
::
Tensor
hidden_states
,
torch
::
Tensor
temb
,
torch
::
Tensor
rotary_emb_single
)
{
torch
::
Tensor
forward_single_layer
(
int64_t
idx
,
torch
::
Tensor
hidden_states
,
torch
::
Tensor
temb
,
torch
::
Tensor
rotary_emb_single
)
{
CUDADeviceContext
ctx
(
deviceId
);
spdlog
::
debug
(
"QuantizedFluxModel forward_single_layer {}"
,
idx
);
hidden_states
=
hidden_states
.
contiguous
();
temb
=
temb
.
contiguous
();
hidden_states
=
hidden_states
.
contiguous
();
temb
=
temb
.
contiguous
();
rotary_emb_single
=
rotary_emb_single
.
contiguous
();
Tensor
result
=
net
->
single_transformer_blocks
.
at
(
idx
)
->
forward
(
from_torch
(
hidden_states
),
from_torch
(
temb
),
from_torch
(
rotary_emb_single
)
);
from_torch
(
hidden_states
),
from_torch
(
temb
),
from_torch
(
rotary_emb_single
));
hidden_states
=
to_torch
(
result
);
Tensor
::
synchronizeDevice
();
...
...
@@ -133,6 +152,18 @@ public:
return
hidden_states
;
}
// expose the norm1 forward method of the transformer blocks
// this is used by TeaCache to get the norm1 output
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>
norm_one_forward
(
int64_t
idx
,
torch
::
Tensor
hidden_states
,
torch
::
Tensor
temb
)
{
AdaLayerNormZero
::
Output
result
=
net
->
transformer_blocks
.
at
(
idx
)
->
norm1
.
forward
(
from_torch
(
hidden_states
),
from_torch
(
temb
));
return
{
to_torch
(
result
.
x
),
to_torch
(
result
.
gate_msa
),
to_torch
(
result
.
shift_mlp
),
to_torch
(
result
.
scale_mlp
),
to_torch
(
result
.
gate_mlp
)};
}
// must be called after loading lora
// skip specific ranks in W4A4 layers
...
...
@@ -174,5 +205,4 @@ public:
throw
std
::
invalid_argument
(
spdlog
::
fmt_lib
::
format
(
"Invalid attention implementation {}"
,
name
));
}
}
};
\ No newline at end of file
};
nunchaku/csrc/gemm.h
View file @
37a27712
...
...
@@ -16,7 +16,12 @@ public:
checkCUDA
(
cudaDeviceGetLimit
(
&
val
,
cudaLimitStackSize
));
spdlog
::
debug
(
"Stack={}"
,
val
);
net
=
std
::
make_unique
<
GEMM_W4A4
>
((
int
)
in_features
,
(
int
)
out_features
,
bias
,
use_fp4
,
bf16
?
Tensor
::
BF16
:
Tensor
::
FP16
,
Device
::
cuda
((
int
)
deviceId
));
net
=
std
::
make_unique
<
GEMM_W4A4
>
((
int
)
in_features
,
(
int
)
out_features
,
bias
,
use_fp4
,
bf16
?
Tensor
::
BF16
:
Tensor
::
FP16
,
Device
::
cuda
((
int
)
deviceId
));
}
torch
::
Tensor
forward
(
torch
::
Tensor
x
)
{
...
...
@@ -53,11 +58,11 @@ public:
// activation: row major, [M / BLOCK_M, K / WARP_K, NUM_WARPS, WARP_M_TILES, WARP_SIZE] of packed_act_t (uint4)
constexpr
int
BLOCK_M
=
256
;
constexpr
int
WARP_K
=
64
;
constexpr
int
NUM_WARPS
=
8
;
constexpr
int
BLOCK_M
=
256
;
constexpr
int
WARP_K
=
64
;
constexpr
int
NUM_WARPS
=
8
;
constexpr
int
WARP_M_TILES
=
2
;
constexpr
int
WARP_SIZE
=
32
;
constexpr
int
WARP_SIZE
=
32
;
std
::
stringstream
ss
;
for
(
int
bm
=
0
;
bm
<
M
/
BLOCK_M
;
bm
++
)
{
...
...
@@ -95,13 +100,10 @@ public:
x
=
x
.
contiguous
();
auto
qout
=
net
->
quantize
(
from_torch
(
x
),
fuse_glu
);
auto
qout
=
net
->
quantize
(
from_torch
(
x
),
fuse_glu
);
Tensor
act
=
qout
.
act
.
copy
(
Device
::
cpu
());
Tensor
ascales
=
qout
.
ascales
.
copy
(
Device
::
cpu
());
Tensor
act
=
qout
.
act
.
copy
(
Device
::
cpu
());
Tensor
ascales
=
qout
.
ascales
.
copy
(
Device
::
cpu
());
Tensor
lora_act
=
qout
.
lora_act
.
copy
(
Device
::
cpu
());
Tensor
::
synchronizeDevice
();
...
...
@@ -109,5 +111,4 @@ public:
spdlog
::
debug
(
"act = {}"
,
dumpTensorINT4
(
act
));
spdlog
::
debug
(
"ascales = {}"
,
dumpTensorBF16
(
ascales
));
}
};
nunchaku/csrc/gemm88.h
View file @
37a27712
...
...
@@ -10,13 +10,14 @@ class QuantizedGEMM88 : public ModuleWrapper<GEMM_W8A8> {
public:
void
init
(
int64_t
in_features
,
int64_t
out_features
,
bool
bias
,
bool
bf16
,
int8_t
deviceId
)
{
spdlog
::
info
(
"Initializing QuantizedGEMM88"
);
size_t
val
=
0
;
checkCUDA
(
cudaDeviceSetLimit
(
cudaLimitStackSize
,
8192
));
checkCUDA
(
cudaDeviceGetLimit
(
&
val
,
cudaLimitStackSize
));
spdlog
::
debug
(
"Stack={}"
,
val
);
net
=
std
::
make_unique
<
GEMM_W8A8
>
((
int
)
in_features
,
(
int
)
out_features
,
bias
,
bf16
?
Tensor
::
BF16
:
Tensor
::
FP16
,
Device
::
cuda
((
int
)
deviceId
));
net
=
std
::
make_unique
<
GEMM_W8A8
>
(
(
int
)
in_features
,
(
int
)
out_features
,
bias
,
bf16
?
Tensor
::
BF16
:
Tensor
::
FP16
,
Device
::
cuda
((
int
)
deviceId
));
}
torch
::
Tensor
forward
(
torch
::
Tensor
x
)
{
...
...
@@ -27,10 +28,10 @@ public:
x
=
x
.
contiguous
();
Tensor
result
=
net
->
forward
(
from_torch
(
x
));
torch
::
Tensor
output
=
to_torch
(
result
);
Tensor
::
synchronizeDevice
();
return
output
;
}
};
\ No newline at end of file
};
nunchaku/csrc/module.h
View file @
37a27712
...
...
@@ -18,7 +18,7 @@ public:
debugContext
.
reset
();
net
.
reset
();
Tensor
::
synchronizeDevice
();
nunchaku
::
utils
::
trim_memory
();
Tensor
::
synchronizeDevice
();
}
...
...
@@ -28,7 +28,7 @@ public:
CUDADeviceContext
ctx
(
this
->
deviceId
);
spdlog
::
info
(
"{} weights from {}"
,
partial
?
"Loading partial"
:
"Loading"
,
path
);
std
::
shared_ptr
<
SafeTensors
>
provider
=
std
::
make_shared
<
SafeTensors
>
(
path
);
net
->
loadParams
(
*
provider
,
partial
);
Tensor
::
synchronizeDevice
();
...
...
@@ -41,7 +41,7 @@ public:
CUDADeviceContext
ctx
(
this
->
deviceId
);
spdlog
::
info
(
"{} weights from pytorch"
,
partial
?
"Loading partial"
:
"Loading"
);
std
::
shared_ptr
<
TensorsProviderTorch
>
provider
=
std
::
make_shared
<
TensorsProviderTorch
>
(
std
::
move
(
dict
));
net
->
loadParams
(
*
provider
,
partial
);
Tensor
::
synchronizeDevice
();
...
...
@@ -66,7 +66,7 @@ public:
result
[
key
]
=
to_torch
(
value
);
}
}
return
result
;
}
...
...
@@ -82,4 +82,4 @@ protected:
std
::
unique_ptr
<
DebugContext
>
debugContext
;
int
deviceId
=
-
1
;
};
\ No newline at end of file
};
nunchaku/csrc/ops.h
View file @
37a27712
...
...
@@ -7,175 +7,132 @@
namespace
nunchaku
::
ops
{
void
gemm_w4a4
(
std
::
optional
<
torch
::
Tensor
>
act
,
// packed act [M, K / 2]
std
::
optional
<
torch
::
Tensor
>
wgt
,
// packed act [N, K / 2]
std
::
optional
<
torch
::
Tensor
>
out
,
// linear [M, N]
std
::
optional
<
torch
::
Tensor
>
qout
,
// packed act [M, N / 2]
std
::
optional
<
torch
::
Tensor
>
ascales
,
// packed as [K / 64, M]
std
::
optional
<
torch
::
Tensor
>
wscales
,
// packed ws [K / 64, N]
std
::
optional
<
torch
::
Tensor
>
oscales
,
// packed as [N / 64, M]
std
::
optional
<
torch
::
Tensor
>
poolout
,
// linear [M / PoolSize, N]
std
::
optional
<
torch
::
Tensor
>
lora_act_in
,
// packed lora_act [M, R]
std
::
optional
<
torch
::
Tensor
>
lora_up
,
// packed lora_wgt [N, R]
std
::
optional
<
torch
::
Tensor
>
lora_down
,
// packed lora_wgt [N, R]
std
::
optional
<
torch
::
Tensor
>
lora_act_out
,
// packed lora_act [M, R]
std
::
optional
<
torch
::
Tensor
>
norm_q
,
// linear [HEAD_DIM]
std
::
optional
<
torch
::
Tensor
>
norm_k
,
// linear [HEAD_DIM]
std
::
optional
<
torch
::
Tensor
>
rotary_emb
,
// linear [M, HEAD_DIM / 2, 2, 2]
std
::
optional
<
torch
::
Tensor
>
bias
,
// packed ws [N]
std
::
optional
<
torch
::
Tensor
>
smooth_factor
,
// packed ws [N], for quantization of the next layer
std
::
optional
<
torch
::
Tensor
>
out_vk
,
// linear [B, num_heads, head_dim + 1, head_dim]
std
::
optional
<
torch
::
Tensor
>
out_linearattn
,
// linear [B, (M), N / 3]
bool
act_unsigned
,
std
::
vector
<
float
>
lora_scales
,
bool
fuse_silu
,
bool
fp4
,
float
alpha
,
std
::
optional
<
torch
::
Tensor
>
wcscales
,
std
::
optional
<
torch
::
Tensor
>
out_q
,
// packed attention [B, H, M, D]
std
::
optional
<
torch
::
Tensor
>
out_k
,
// packed attention [B, H, M, D]
std
::
optional
<
torch
::
Tensor
>
out_v
,
// packed attention [B, H, M, D]
int
attn_tokens
)
{
spdlog
::
trace
(
"running gemm_w4a4: "
);
void
gemm_w4a4
(
std
::
optional
<
torch
::
Tensor
>
act
,
// packed act [M, K / 2]
std
::
optional
<
torch
::
Tensor
>
wgt
,
// packed act [N, K / 2]
std
::
optional
<
torch
::
Tensor
>
out
,
// linear [M, N]
std
::
optional
<
torch
::
Tensor
>
qout
,
// packed act [M, N / 2]
std
::
optional
<
torch
::
Tensor
>
ascales
,
// packed as [K / 64, M]
std
::
optional
<
torch
::
Tensor
>
wscales
,
// packed ws [K / 64, N]
std
::
optional
<
torch
::
Tensor
>
oscales
,
// packed as [N / 64, M]
std
::
optional
<
torch
::
Tensor
>
poolout
,
// linear [M / PoolSize, N]
std
::
optional
<
torch
::
Tensor
>
lora_act_in
,
// packed lora_act [M, R]
std
::
optional
<
torch
::
Tensor
>
lora_up
,
// packed lora_wgt [N, R]
std
::
optional
<
torch
::
Tensor
>
lora_down
,
// packed lora_wgt [N, R]
std
::
optional
<
torch
::
Tensor
>
lora_act_out
,
// packed lora_act [M, R]
std
::
optional
<
torch
::
Tensor
>
norm_q
,
// linear [HEAD_DIM]
std
::
optional
<
torch
::
Tensor
>
norm_k
,
// linear [HEAD_DIM]
std
::
optional
<
torch
::
Tensor
>
rotary_emb
,
// linear [M, HEAD_DIM / 2, 2, 2]
std
::
optional
<
torch
::
Tensor
>
bias
,
// packed ws [N]
std
::
optional
<
torch
::
Tensor
>
smooth_factor
,
// packed ws [N], for quantization of the next layer
std
::
optional
<
torch
::
Tensor
>
out_vk
,
// linear [B, num_heads, head_dim + 1, head_dim]
std
::
optional
<
torch
::
Tensor
>
out_linearattn
,
// linear [B, (M), N / 3]
bool
act_unsigned
,
std
::
vector
<
float
>
lora_scales
,
bool
fuse_silu
,
bool
fp4
,
float
alpha
,
std
::
optional
<
torch
::
Tensor
>
wcscales
,
std
::
optional
<
torch
::
Tensor
>
out_q
,
// packed attention [B, H, M, D]
std
::
optional
<
torch
::
Tensor
>
out_k
,
// packed attention [B, H, M, D]
std
::
optional
<
torch
::
Tensor
>
out_v
,
// packed attention [B, H, M, D]
int
attn_tokens
)
{
spdlog
::
trace
(
"running gemm_w4a4: "
);
auto
getTensor
=
[](
std
::
optional
<
torch
::
Tensor
>
&
t
)
{
Tensor
ret
=
t
.
has_value
()
?
from_torch
(
t
.
value
())
:
Tensor
{};
if
(
ret
.
valid
())
{
spdlog
::
trace
(
" {}"
,
ret
.
shape
.
str
());
}
else
{
spdlog
::
trace
(
" <invalid>"
);
}
return
ret
;
};
nunchaku
::
kernels
::
gemm_w4a4
(
getTensor
(
act
),
getTensor
(
wgt
),
getTensor
(
out
),
getTensor
(
qout
),
getTensor
(
ascales
),
getTensor
(
wscales
),
getTensor
(
oscales
),
getTensor
(
poolout
),
getTensor
(
lora_act_in
),
getTensor
(
lora_up
),
getTensor
(
lora_down
),
getTensor
(
lora_act_out
),
getTensor
(
norm_q
),
getTensor
(
norm_k
),
getTensor
(
rotary_emb
),
getTensor
(
bias
),
getTensor
(
smooth_factor
),
getTensor
(
out_vk
),
getTensor
(
out_linearattn
),
act_unsigned
,
lora_scales
,
fuse_silu
,
fp4
,
alpha
,
getTensor
(
wcscales
),
getTensor
(
out_q
),
getTensor
(
out_k
),
getTensor
(
out_v
),
attn_tokens
);
// Tensor::synchronizeDevice();
}
auto
getTensor
=
[](
std
::
optional
<
torch
::
Tensor
>
&
t
)
{
Tensor
ret
=
t
.
has_value
()
?
from_torch
(
t
.
value
())
:
Tensor
{};
if
(
ret
.
valid
())
{
spdlog
::
trace
(
" {}"
,
ret
.
shape
.
str
());
}
else
{
spdlog
::
trace
(
" <invalid>"
);
}
return
ret
;
};
nunchaku
::
kernels
::
gemm_w4a4
(
getTensor
(
act
),
getTensor
(
wgt
),
getTensor
(
out
),
getTensor
(
qout
),
getTensor
(
ascales
),
getTensor
(
wscales
),
getTensor
(
oscales
),
getTensor
(
poolout
),
getTensor
(
lora_act_in
),
getTensor
(
lora_up
),
getTensor
(
lora_down
),
getTensor
(
lora_act_out
),
getTensor
(
norm_q
),
getTensor
(
norm_k
),
getTensor
(
rotary_emb
),
getTensor
(
bias
),
getTensor
(
smooth_factor
),
getTensor
(
out_vk
),
getTensor
(
out_linearattn
),
act_unsigned
,
lora_scales
,
fuse_silu
,
fp4
,
alpha
,
getTensor
(
wcscales
),
getTensor
(
out_q
),
getTensor
(
out_k
),
getTensor
(
out_v
),
attn_tokens
);
// Tensor::synchronizeDevice();
}
void
attention_fp16
(
torch
::
Tensor
q
,
// packed [Batch, Head, TokensQ, HEAD_DIM]
torch
::
Tensor
k
,
// packed [Batch, Head, TokensKV, HEAD_DIM]
torch
::
Tensor
v
,
// packed [Batch, Head, TokensKV, HEAD_DIM]
torch
::
Tensor
o
,
// linear [Batch, TokensQ, Head * HEAD_DIM]
float
scale
)
{
nunchaku
::
kernels
::
attention_fp16
(
from_torch
(
q
),
from_torch
(
k
),
from_torch
(
v
),
from_torch
(
o
),
scale
);
}
void
attention_fp16
(
torch
::
Tensor
q
,
// packed [Batch, Head, TokensQ, HEAD_DIM]
torch
::
Tensor
k
,
// packed [Batch, Head, TokensKV, HEAD_DIM]
torch
::
Tensor
v
,
// packed [Batch, Head, TokensKV, HEAD_DIM]
torch
::
Tensor
o
,
// linear [Batch, TokensQ, Head * HEAD_DIM]
float
scale
)
{
nunchaku
::
kernels
::
attention_fp16
(
from_torch
(
q
),
from_torch
(
k
),
from_torch
(
v
),
from_torch
(
o
),
scale
);
}
torch
::
Tensor
gemv_awq
(
torch
::
Tensor
_in_feats
,
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
int64_t
group_size
)
{
Tensor
result
=
::
gemv_awq
(
from_torch
(
_in_feats
.
contiguous
()),
from_torch
(
_kernel
.
contiguous
()),
from_torch
(
_scaling_factors
.
contiguous
()),
from_torch
(
_zeros
.
contiguous
()),
(
int
)
m
,
(
int
)
n
,
(
int
)
k
,
(
int
)
group_size
);
torch
::
Tensor
gemv_awq
(
torch
::
Tensor
_in_feats
,
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
int64_t
group_size
)
{
Tensor
result
=
::
gemv_awq
(
from_torch
(
_in_feats
.
contiguous
()),
from_torch
(
_kernel
.
contiguous
()),
from_torch
(
_scaling_factors
.
contiguous
()),
from_torch
(
_zeros
.
contiguous
()),
(
int
)
m
,
(
int
)
n
,
(
int
)
k
,
(
int
)
group_size
);
torch
::
Tensor
output
=
to_torch
(
result
);
// Tensor::synchronizeDevice();
torch
::
Tensor
output
=
to_torch
(
result
);
// Tensor::synchronizeDevice();
return
output
;
}
return
output
;
}
torch
::
Tensor
gemm_awq
(
torch
::
Tensor
_in_feats
,
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
)
{
Tensor
result
=
::
awq_gemm_forward_cuda
(
from_torch
(
_in_feats
.
contiguous
()),
from_torch
(
_kernel
.
contiguous
()),
from_torch
(
_scaling_factors
.
contiguous
()),
from_torch
(
_zeros
.
contiguous
())
);
torch
::
Tensor
gemm_awq
(
torch
::
Tensor
_in_feats
,
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
)
{
Tensor
result
=
::
awq_gemm_forward_cuda
(
from_torch
(
_in_feats
.
contiguous
()),
from_torch
(
_kernel
.
contiguous
()),
from_torch
(
_scaling_factors
.
contiguous
()),
from_torch
(
_zeros
.
contiguous
()));
// TODO: allocate output in torch and use from_torch instead (to_torch needs an extra copy)
torch
::
Tensor
output
=
to_torch
(
result
);
// Tensor::synchronizeDevice();
// TODO: allocate output in torch and use from_torch instead (to_torch needs an extra copy)
torch
::
Tensor
output
=
to_torch
(
result
);
// Tensor::synchronizeDevice();
return
output
;
}
return
output
;
}
void
test_rmsnorm_rope
(
torch
::
Tensor
input
,
torch
::
Tensor
output
,
torch
::
Tensor
norm_q
,
torch
::
Tensor
norm_k
,
torch
::
Tensor
rotary_emb
)
{
nunchaku
::
kernels
::
test_rmsnorm_rope
(
from_torch
(
input
),
from_torch
(
output
),
from_torch
(
norm_q
),
from_torch
(
norm_k
),
from_torch
(
rotary_emb
)
);
}
void
test_rmsnorm_rope
(
torch
::
Tensor
input
,
torch
::
Tensor
output
,
torch
::
Tensor
norm_q
,
torch
::
Tensor
norm_k
,
torch
::
Tensor
rotary_emb
)
{
nunchaku
::
kernels
::
test_rmsnorm_rope
(
from_torch
(
input
),
from_torch
(
output
),
from_torch
(
norm_q
),
from_torch
(
norm_k
),
from_torch
(
rotary_emb
));
}
void
test_pack_qkv
(
torch
::
Tensor
input
,
torch
::
Tensor
out_q
,
torch
::
Tensor
out_k
,
torch
::
Tensor
out_v
,
int
numTokens
)
{
nunchaku
::
kernels
::
test_pack_qkv
(
from_torch
(
input
),
from_torch
(
out_q
),
from_torch
(
out_k
),
from_torch
(
out_v
),
numTokens
);
}
};
\ No newline at end of file
void
test_pack_qkv
(
torch
::
Tensor
input
,
torch
::
Tensor
out_q
,
torch
::
Tensor
out_k
,
torch
::
Tensor
out_v
,
int
numTokens
)
{
nunchaku
::
kernels
::
test_pack_qkv
(
from_torch
(
input
),
from_torch
(
out_q
),
from_torch
(
out_k
),
from_torch
(
out_v
),
numTokens
);
}
};
// namespace nunchaku::ops
nunchaku/csrc/pybind.cpp
View file @
37a27712
...
...
@@ -5,80 +5,75 @@
#include "ops.h"
#include "utils.h"
#include <torch/extension.h>
#include "interop/torch.h"
#include <pybind11/pybind11.h>
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
py
::
class_
<
QuantizedFluxModel
>
(
m
,
"QuantizedFluxModel"
)
.
def
(
py
::
init
<>
())
.
def
(
"init"
,
&
QuantizedFluxModel
::
init
,
py
::
arg
(
"use_fp4"
),
py
::
arg
(
"offload"
),
py
::
arg
(
"bf16"
),
py
::
arg
(
"deviceId"
)
)
.
def
(
"init"
,
&
QuantizedFluxModel
::
init
,
py
::
arg
(
"use_fp4"
),
py
::
arg
(
"offload"
),
py
::
arg
(
"bf16"
),
py
::
arg
(
"deviceId"
))
.
def
(
"set_residual_callback"
,
[](
QuantizedFluxModel
&
self
,
pybind11
::
object
call_back
)
{
if
(
call_back
.
is_none
())
{
self
.
set_residual_callback
(
pybind11
::
function
());
}
else
{
self
.
set_residual_callback
(
call_back
);
}
})
.
def
(
"reset"
,
&
QuantizedFluxModel
::
reset
)
.
def
(
"load"
,
&
QuantizedFluxModel
::
load
,
py
::
arg
(
"path"
),
py
::
arg
(
"partial"
)
=
false
)
.
def
(
"loadDict"
,
&
QuantizedFluxModel
::
loadDict
,
py
::
arg
(
"dict"
),
py
::
arg
(
"partial"
)
=
false
)
.
def
(
"forward"
,
&
QuantizedFluxModel
::
forward
,
py
::
arg
(
"hidden_states"
),
py
::
arg
(
"encoder_hidden_states"
),
py
::
arg
(
"temb"
),
py
::
arg
(
"rotary_emb_img"
),
py
::
arg
(
"rotary_emb_context"
),
py
::
arg
(
"rotary_emb_single"
),
py
::
arg
(
"controlnet_block_samples"
)
=
py
::
none
(),
py
::
arg
(
"controlnet_single_block_samples"
)
=
py
::
none
(),
py
::
arg
(
"skip_first_layer"
)
=
false
)
.
def
(
"forward_layer"
,
&
QuantizedFluxModel
::
forward_layer
,
py
::
arg
(
"idx"
),
py
::
arg
(
"hidden_states"
),
py
::
arg
(
"encoder_hidden_states"
),
py
::
arg
(
"temb"
),
py
::
arg
(
"rotary_emb_img"
),
py
::
arg
(
"rotary_emb_context"
),
py
::
arg
(
"controlnet_block_samples"
)
=
py
::
none
(),
py
::
arg
(
"controlnet_single_block_samples"
)
=
py
::
none
()
)
.
def
(
"load"
,
&
QuantizedFluxModel
::
load
,
py
::
arg
(
"path"
),
py
::
arg
(
"partial"
)
=
false
)
.
def
(
"loadDict"
,
&
QuantizedFluxModel
::
loadDict
,
py
::
arg
(
"dict"
),
py
::
arg
(
"partial"
)
=
false
)
.
def
(
"forward"
,
&
QuantizedFluxModel
::
forward
,
py
::
arg
(
"hidden_states"
),
py
::
arg
(
"encoder_hidden_states"
),
py
::
arg
(
"temb"
),
py
::
arg
(
"rotary_emb_img"
),
py
::
arg
(
"rotary_emb_context"
),
py
::
arg
(
"rotary_emb_single"
),
py
::
arg
(
"controlnet_block_samples"
)
=
py
::
none
(),
py
::
arg
(
"controlnet_single_block_samples"
)
=
py
::
none
(),
py
::
arg
(
"skip_first_layer"
)
=
false
)
.
def
(
"forward_layer"
,
&
QuantizedFluxModel
::
forward_layer
,
py
::
arg
(
"idx"
),
py
::
arg
(
"hidden_states"
),
py
::
arg
(
"encoder_hidden_states"
),
py
::
arg
(
"temb"
),
py
::
arg
(
"rotary_emb_img"
),
py
::
arg
(
"rotary_emb_context"
),
py
::
arg
(
"controlnet_block_samples"
)
=
py
::
none
(),
py
::
arg
(
"controlnet_single_block_samples"
)
=
py
::
none
())
.
def
(
"forward_single_layer"
,
&
QuantizedFluxModel
::
forward_single_layer
)
.
def
(
"norm_one_forward"
,
&
QuantizedFluxModel
::
norm_one_forward
)
.
def
(
"startDebug"
,
&
QuantizedFluxModel
::
startDebug
)
.
def
(
"stopDebug"
,
&
QuantizedFluxModel
::
stopDebug
)
.
def
(
"getDebugResults"
,
&
QuantizedFluxModel
::
getDebugResults
)
.
def
(
"setLoraScale"
,
&
QuantizedFluxModel
::
setLoraScale
)
.
def
(
"setAttentionImpl"
,
&
QuantizedFluxModel
::
setAttentionImpl
)
.
def
(
"isBF16"
,
&
QuantizedFluxModel
::
isBF16
)
;
.
def
(
"isBF16"
,
&
QuantizedFluxModel
::
isBF16
);
py
::
class_
<
QuantizedSanaModel
>
(
m
,
"QuantizedSanaModel"
)
.
def
(
py
::
init
<>
())
.
def
(
"init"
,
&
QuantizedSanaModel
::
init
,
py
::
arg
(
"config"
)
,
py
::
arg
(
"
pag_layers
"
),
py
::
arg
(
"
use_fp4
"
),
py
::
arg
(
"
bf16
"
),
py
::
arg
(
"
deviceId
"
)
)
.
def
(
"init"
,
&
QuantizedSanaModel
::
init
,
py
::
arg
(
"
config
"
),
py
::
arg
(
"
pag_layers
"
),
py
::
arg
(
"
use_fp4
"
),
py
::
arg
(
"
bf16
"
)
,
py
::
arg
(
"deviceId"
)
)
.
def
(
"reset"
,
&
QuantizedSanaModel
::
reset
)
.
def
(
"load"
,
&
QuantizedSanaModel
::
load
,
py
::
arg
(
"path"
),
py
::
arg
(
"partial"
)
=
false
)
.
def
(
"loadDict"
,
&
QuantizedSanaModel
::
loadDict
,
py
::
arg
(
"dict"
),
py
::
arg
(
"partial"
)
=
false
)
.
def
(
"load"
,
&
QuantizedSanaModel
::
load
,
py
::
arg
(
"path"
),
py
::
arg
(
"partial"
)
=
false
)
.
def
(
"loadDict"
,
&
QuantizedSanaModel
::
loadDict
,
py
::
arg
(
"dict"
),
py
::
arg
(
"partial"
)
=
false
)
.
def
(
"forward"
,
&
QuantizedSanaModel
::
forward
)
.
def
(
"forward_layer"
,
&
QuantizedSanaModel
::
forward_layer
)
.
def
(
"startDebug"
,
&
QuantizedSanaModel
::
startDebug
)
.
def
(
"stopDebug"
,
&
QuantizedSanaModel
::
stopDebug
)
.
def
(
"getDebugResults"
,
&
QuantizedSanaModel
::
getDebugResults
)
;
.
def
(
"getDebugResults"
,
&
QuantizedSanaModel
::
getDebugResults
);
py
::
class_
<
QuantizedGEMM
>
(
m
,
"QuantizedGEMM"
)
.
def
(
py
::
init
<>
())
.
def
(
"init"
,
&
QuantizedGEMM
::
init
)
...
...
@@ -88,8 +83,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.
def
(
"quantize"
,
&
QuantizedGEMM
::
quantize
)
.
def
(
"startDebug"
,
&
QuantizedGEMM
::
startDebug
)
.
def
(
"stopDebug"
,
&
QuantizedGEMM
::
stopDebug
)
.
def
(
"getDebugResults"
,
&
QuantizedGEMM
::
getDebugResults
)
;
.
def
(
"getDebugResults"
,
&
QuantizedGEMM
::
getDebugResults
)
;
py
::
class_
<
Tensor
>
(
m
,
"Tensor"
)
;
py
::
class_
<
QuantizedGEMM88
>
(
m
,
"QuantizedGEMM88"
)
.
def
(
py
::
init
<>
())
.
def
(
"init"
,
&
QuantizedGEMM88
::
init
)
...
...
@@ -98,8 +93,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.
def
(
"forward"
,
&
QuantizedGEMM88
::
forward
)
.
def
(
"startDebug"
,
&
QuantizedGEMM88
::
startDebug
)
.
def
(
"stopDebug"
,
&
QuantizedGEMM88
::
stopDebug
)
.
def
(
"getDebugResults"
,
&
QuantizedGEMM88
::
getDebugResults
)
;
.
def
(
"getDebugResults"
,
&
QuantizedGEMM88
::
getDebugResults
);
m
.
def_submodule
(
"ops"
)
.
def
(
"gemm_w4a4"
,
nunchaku
::
ops
::
gemm_w4a4
)
...
...
@@ -108,16 +102,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.
def
(
"gemv_awq"
,
nunchaku
::
ops
::
gemv_awq
)
.
def
(
"test_rmsnorm_rope"
,
nunchaku
::
ops
::
test_rmsnorm_rope
)
.
def
(
"test_pack_qkv"
,
nunchaku
::
ops
::
test_pack_qkv
)
;
.
def
(
"test_pack_qkv"
,
nunchaku
::
ops
::
test_pack_qkv
);
m
.
def_submodule
(
"utils"
)
.
def
(
"set_log_level"
,
[](
const
std
::
string
&
level
)
{
spdlog
::
set_level
(
spdlog
::
level
::
from_str
(
level
));
})
.
def
(
"set_log_level"
,
[](
const
std
::
string
&
level
)
{
spdlog
::
set_level
(
spdlog
::
level
::
from_str
(
level
));
})
.
def
(
"set_cuda_stack_limit"
,
nunchaku
::
utils
::
set_cuda_stack_limit
)
.
def
(
"disable_memory_auto_release"
,
nunchaku
::
utils
::
disable_memory_auto_release
)
.
def
(
"trim_memory"
,
nunchaku
::
utils
::
trim_memory
)
.
def
(
"set_faster_i2f_mode"
,
nunchaku
::
utils
::
set_faster_i2f_mode
)
;
.
def
(
"set_faster_i2f_mode"
,
nunchaku
::
utils
::
set_faster_i2f_mode
);
}
nunchaku/csrc/sana.h
View file @
37a27712
...
...
@@ -11,13 +11,13 @@ public:
void
init
(
pybind11
::
dict
config
,
std
::
vector
<
int
>
pag_layers
,
bool
use_fp4
,
bool
bf16
,
int8_t
deviceId
)
{
spdlog
::
info
(
"Initializing QuantizedSanaModel on device {}"
,
deviceId
);
SanaConfig
cfg
{
.
num_layers
=
config
[
"num_layers"
].
cast
<
int
>
(),
.
num_attention_heads
=
config
[
"num_attention_heads"
].
cast
<
int
>
(),
.
attention_head_dim
=
config
[
"attention_head_dim"
].
cast
<
int
>
(),
.
num_layers
=
config
[
"num_layers"
].
cast
<
int
>
(),
.
num_attention_heads
=
config
[
"num_attention_heads"
].
cast
<
int
>
(),
.
attention_head_dim
=
config
[
"attention_head_dim"
].
cast
<
int
>
(),
.
num_cross_attention_heads
=
config
[
"num_cross_attention_heads"
].
cast
<
int
>
(),
.
expand_ratio
=
config
[
"mlp_ratio"
].
cast
<
double
>
(),
.
pag_layers
=
pag_layers
,
.
use_fp4
=
use_fp4
,
.
expand_ratio
=
config
[
"mlp_ratio"
].
cast
<
double
>
(),
.
pag_layers
=
pag_layers
,
.
use_fp4
=
use_fp4
,
};
ModuleWrapper
::
init
(
deviceId
);
...
...
@@ -25,39 +25,37 @@ public:
net
=
std
::
make_unique
<
SanaModel
>
(
cfg
,
bf16
?
Tensor
::
BF16
:
Tensor
::
FP16
,
Device
::
cuda
((
int
)
deviceId
));
}
torch
::
Tensor
forward
(
torch
::
Tensor
hidden_states
,
torch
::
Tensor
encoder_hidden_states
,
torch
::
Tensor
timestep
,
torch
::
Tensor
cu_seqlens_img
,
torch
::
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
,
bool
skip_first_layer
=
false
)
{
torch
::
Tensor
forward
(
torch
::
Tensor
hidden_states
,
torch
::
Tensor
encoder_hidden_states
,
torch
::
Tensor
timestep
,
torch
::
Tensor
cu_seqlens_img
,
torch
::
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
,
bool
skip_first_layer
=
false
)
{
checkModel
();
CUDADeviceContext
ctx
(
deviceId
);
spdlog
::
debug
(
"QuantizedSanaModel forward"
);
hidden_states
=
hidden_states
.
contiguous
();
hidden_states
=
hidden_states
.
contiguous
();
encoder_hidden_states
=
encoder_hidden_states
.
contiguous
();
timestep
=
timestep
.
contiguous
();
cu_seqlens_img
=
cu_seqlens_img
.
contiguous
();
cu_seqlens_txt
=
cu_seqlens_txt
.
contiguous
();
timestep
=
timestep
.
contiguous
();
cu_seqlens_img
=
cu_seqlens_img
.
contiguous
();
cu_seqlens_txt
=
cu_seqlens_txt
.
contiguous
();
Tensor
result
=
net
->
forward
(
from_torch
(
hidden_states
),
from_torch
(
encoder_hidden_sta
te
s
),
from_torch
(
timestep
),
from_torch
(
cu_seqlens_
img
),
from_torch
(
cu_seqlens_txt
)
,
H
,
W
,
pag
,
cfg
,
skip_first_layer
);
Tensor
result
=
net
->
forward
(
from_torch
(
hidden_states
),
from_torch
(
encoder_
hidden_states
),
from_torch
(
times
te
p
),
from_torch
(
cu_seqlens_img
),
from_torch
(
cu_seqlens_
txt
),
H
,
W
,
pag
,
cfg
,
skip_first_layer
);
torch
::
Tensor
output
=
to_torch
(
result
);
// Tensor::synchronizeDevice();
...
...
@@ -65,42 +63,40 @@ public:
return
output
;
}
torch
::
Tensor
forward_layer
(
int64_t
idx
,
torch
::
Tensor
hidden_states
,
torch
::
Tensor
encoder_hidden_states
,
torch
::
Tensor
timestep
,
torch
::
Tensor
cu_seqlens_img
,
torch
::
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
)
{
torch
::
Tensor
forward_layer
(
int64_t
idx
,
torch
::
Tensor
hidden_states
,
torch
::
Tensor
encoder_hidden_states
,
torch
::
Tensor
timestep
,
torch
::
Tensor
cu_seqlens_img
,
torch
::
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
)
{
checkModel
();
CUDADeviceContext
ctx
(
deviceId
);
spdlog
::
debug
(
"QuantizedSanaModel forward_layer {}"
,
idx
);
hidden_states
=
hidden_states
.
contiguous
();
hidden_states
=
hidden_states
.
contiguous
();
encoder_hidden_states
=
encoder_hidden_states
.
contiguous
();
timestep
=
timestep
.
contiguous
();
cu_seqlens_img
=
cu_seqlens_img
.
contiguous
();
cu_seqlens_txt
=
cu_seqlens_txt
.
contiguous
();
timestep
=
timestep
.
contiguous
();
cu_seqlens_img
=
cu_seqlens_img
.
contiguous
();
cu_seqlens_txt
=
cu_seqlens_txt
.
contiguous
();
Tensor
result
=
net
->
transformer_blocks
.
at
(
idx
)
->
forward
(
from_torch
(
hidden_states
),
from_torch
(
encoder_hidden_states
),
from_torch
(
timestep
),
from_torch
(
cu_seqlens_
img
),
from_torch
(
cu_seqlens_txt
)
,
H
,
W
,
pag
,
cfg
);
Tensor
result
=
net
->
transformer_blocks
.
at
(
idx
)
->
forward
(
from_torch
(
hidden_states
),
from_torch
(
encoder_
hidden_states
),
from_torch
(
timestep
),
from_torch
(
cu_seqlens_img
),
from_torch
(
cu_seqlens_
txt
),
H
,
W
,
pag
,
cfg
);
torch
::
Tensor
output
=
to_torch
(
result
);
// Tensor::synchronizeDevice();
return
output
;
}
};
\ No newline at end of file
};
nunchaku/csrc/utils.h
View file @
37a27712
...
...
@@ -6,34 +6,34 @@
namespace
nunchaku
::
utils
{
void
set_cuda_stack_limit
(
int64_t
newval
)
{
size_t
val
=
0
;
checkCUDA
(
cudaDeviceSetLimit
(
cudaLimitStackSize
,
(
size_t
)
newval
));
checkCUDA
(
cudaDeviceGetLimit
(
&
val
,
cudaLimitStackSize
));
spdlog
::
debug
(
"Stack={}"
,
val
);
}
void
set_cuda_stack_limit
(
int64_t
newval
)
{
size_t
val
=
0
;
checkCUDA
(
cudaDeviceSetLimit
(
cudaLimitStackSize
,
(
size_t
)
newval
));
checkCUDA
(
cudaDeviceGetLimit
(
&
val
,
cudaLimitStackSize
));
spdlog
::
debug
(
"Stack={}"
,
val
);
}
void
disable_memory_auto_release
()
{
int
device
;
checkCUDA
(
cudaGetDevice
(
&
device
));
cudaMemPool_t
mempool
;
checkCUDA
(
cudaDeviceGetDefaultMemPool
(
&
mempool
,
device
));
uint64_t
threshold
=
UINT64_MAX
;
checkCUDA
(
cudaMemPoolSetAttribute
(
mempool
,
cudaMemPoolAttrReleaseThreshold
,
&
threshold
));
}
void
disable_memory_auto_release
()
{
int
device
;
checkCUDA
(
cudaGetDevice
(
&
device
));
cudaMemPool_t
mempool
;
checkCUDA
(
cudaDeviceGetDefaultMemPool
(
&
mempool
,
device
));
uint64_t
threshold
=
UINT64_MAX
;
checkCUDA
(
cudaMemPoolSetAttribute
(
mempool
,
cudaMemPoolAttrReleaseThreshold
,
&
threshold
));
}
void
trim_memory
()
{
int
device
;
checkCUDA
(
cudaGetDevice
(
&
device
));
cudaMemPool_t
mempool
;
checkCUDA
(
cudaDeviceGetDefaultMemPool
(
&
mempool
,
device
));
size_t
bytesToKeep
=
0
;
checkCUDA
(
cudaMemPoolTrimTo
(
mempool
,
bytesToKeep
));
}
void
trim_memory
()
{
int
device
;
checkCUDA
(
cudaGetDevice
(
&
device
));
cudaMemPool_t
mempool
;
checkCUDA
(
cudaDeviceGetDefaultMemPool
(
&
mempool
,
device
));
size_t
bytesToKeep
=
0
;
checkCUDA
(
cudaMemPoolTrimTo
(
mempool
,
bytesToKeep
));
}
void
set_faster_i2f_mode
(
std
::
string
mode
)
{
spdlog
::
info
(
"Set fasteri2f mode to {}"
,
mode
);
kernels
::
set_faster_i2f_mode
(
mode
);
}
void
set_faster_i2f_mode
(
std
::
string
mode
)
{
spdlog
::
info
(
"Set fasteri2f mode to {}"
,
mode
);
kernels
::
set_faster_i2f_mode
(
mode
);
}
};
\ No newline at end of file
};
// namespace nunchaku::utils
nunchaku/lora/flux/__init__.py
View file @
37a27712
from
.diffusers_converter
import
to_diffusers
from
.nunchaku_converter
import
convert_to_nunchaku_flux_lowrank_dict
,
to_nunchaku
from
.utils
import
is_nunchaku_format
__all__
=
[
"to_diffusers"
,
"to_nunchaku"
,
"convert_to_nunchaku_flux_lowrank_dict"
,
"is_nunchaku_format"
]
nunchaku/lora/flux/nunchaku_converter.py
View file @
37a27712
...
...
@@ -7,10 +7,10 @@ import torch
from
safetensors.torch
import
save_file
from
tqdm
import
tqdm
from
...utils
import
filter_state_dict
,
load_state_dict_in_safetensors
from
.diffusers_converter
import
to_diffusers
from
.packer
import
NunchakuWeightPacker
from
.utils
import
is_nunchaku_format
,
pad
from
...utils
import
filter_state_dict
,
load_state_dict_in_safetensors
logger
=
logging
.
getLogger
(
__name__
)
...
...
nunchaku/lora/flux/packer.py
View file @
37a27712
# Copy the packer from https://github.com/mit-han-lab/deepcompressor/
import
torch
from
.utils
import
pad
from
...utils
import
ceil_divide
from
.utils
import
pad
class
MmaWeightPackerBase
:
...
...
nunchaku/models/__init__.py
View file @
37a27712
from
.text_encoders.t5_encoder
import
NunchakuT5EncoderModel
from
.transformers
import
NunchakuFluxTransformer2dModel
,
NunchakuSanaTransformer2DModel
__all__
=
[
"NunchakuFluxTransformer2dModel"
,
"NunchakuSanaTransformer2DModel"
,
"NunchakuT5EncoderModel"
]
nunchaku/models/pulid/encoders_transformer.py
0 → 100644
View file @
37a27712
# Adapted from https://github.com/ToTheBeginning/PuLID
import
math
import
torch
from
torch
import
nn
# FFN
def
FeedForward
(
dim
,
mult
=
4
):
inner_dim
=
int
(
dim
*
mult
)
return
nn
.
Sequential
(
nn
.
LayerNorm
(
dim
),
nn
.
Linear
(
dim
,
inner_dim
,
bias
=
False
),
nn
.
GELU
(),
nn
.
Linear
(
inner_dim
,
dim
,
bias
=
False
),
)
def
reshape_tensor
(
x
,
heads
):
bs
,
length
,
width
=
x
.
shape
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
x
=
x
.
view
(
bs
,
length
,
heads
,
-
1
)
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
x
=
x
.
transpose
(
1
,
2
)
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
x
=
x
.
reshape
(
bs
,
heads
,
length
,
-
1
)
return
x
class
PerceiverAttentionCA
(
nn
.
Module
):
def
__init__
(
self
,
*
,
dim
=
3072
,
dim_head
=
128
,
heads
=
16
,
kv_dim
=
2048
):
super
().
__init__
()
self
.
scale
=
dim_head
**-
0.5
self
.
dim_head
=
dim_head
self
.
heads
=
heads
inner_dim
=
dim_head
*
heads
self
.
norm1
=
nn
.
LayerNorm
(
dim
if
kv_dim
is
None
else
kv_dim
)
self
.
norm2
=
nn
.
LayerNorm
(
dim
)
self
.
to_q
=
nn
.
Linear
(
dim
,
inner_dim
,
bias
=
False
)
self
.
to_kv
=
nn
.
Linear
(
dim
if
kv_dim
is
None
else
kv_dim
,
inner_dim
*
2
,
bias
=
False
)
self
.
to_out
=
nn
.
Linear
(
inner_dim
,
dim
,
bias
=
False
)
def
forward
(
self
,
x
,
latents
):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latent (torch.Tensor): latent features
shape (b, n2, D)
"""
x
=
self
.
norm1
(
x
)
latents
=
self
.
norm2
(
latents
)
b
,
seq_len
,
_
=
latents
.
shape
q
=
self
.
to_q
(
latents
)
k
,
v
=
self
.
to_kv
(
x
).
chunk
(
2
,
dim
=-
1
)
q
=
reshape_tensor
(
q
,
self
.
heads
)
k
=
reshape_tensor
(
k
,
self
.
heads
)
v
=
reshape_tensor
(
v
,
self
.
heads
)
# attention
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
self
.
dim_head
))
weight
=
(
q
*
scale
)
@
(
k
*
scale
).
transpose
(
-
2
,
-
1
)
# More stable with f16 than dividing afterwards
weight
=
torch
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
out
=
weight
@
v
out
=
out
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
b
,
seq_len
,
-
1
)
return
self
.
to_out
(
out
)
class
PerceiverAttention
(
nn
.
Module
):
def
__init__
(
self
,
*
,
dim
,
dim_head
=
64
,
heads
=
8
,
kv_dim
=
None
):
super
().
__init__
()
self
.
scale
=
dim_head
**-
0.5
self
.
dim_head
=
dim_head
self
.
heads
=
heads
inner_dim
=
dim_head
*
heads
self
.
norm1
=
nn
.
LayerNorm
(
dim
if
kv_dim
is
None
else
kv_dim
)
self
.
norm2
=
nn
.
LayerNorm
(
dim
)
self
.
to_q
=
nn
.
Linear
(
dim
,
inner_dim
,
bias
=
False
)
self
.
to_kv
=
nn
.
Linear
(
dim
if
kv_dim
is
None
else
kv_dim
,
inner_dim
*
2
,
bias
=
False
)
self
.
to_out
=
nn
.
Linear
(
inner_dim
,
dim
,
bias
=
False
)
def
forward
(
self
,
x
,
latents
):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latent (torch.Tensor): latent features
shape (b, n2, D)
"""
x
=
self
.
norm1
(
x
)
latents
=
self
.
norm2
(
latents
)
b
,
seq_len
,
_
=
latents
.
shape
q
=
self
.
to_q
(
latents
)
kv_input
=
torch
.
cat
((
x
,
latents
),
dim
=-
2
)
k
,
v
=
self
.
to_kv
(
kv_input
).
chunk
(
2
,
dim
=-
1
)
q
=
reshape_tensor
(
q
,
self
.
heads
)
k
=
reshape_tensor
(
k
,
self
.
heads
)
v
=
reshape_tensor
(
v
,
self
.
heads
)
# attention
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
self
.
dim_head
))
weight
=
(
q
*
scale
)
@
(
k
*
scale
).
transpose
(
-
2
,
-
1
)
# More stable with f16 than dividing afterwards
weight
=
torch
.
softmax
(
weight
.
float
(),
dim
=-
1
).
type
(
weight
.
dtype
)
out
=
weight
@
v
out
=
out
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
b
,
seq_len
,
-
1
)
return
self
.
to_out
(
out
)
class
IDFormer
(
nn
.
Module
):
"""
- perceiver resampler like arch (compared with previous MLP-like arch)
- we concat id embedding (generated by arcface) and query tokens as latents
- latents will attend each other and interact with vit features through cross-attention
- vit features are multi-scaled and inserted into IDFormer in order, currently, each scale corresponds to two
IDFormer layers
"""
def
__init__
(
self
,
dim
=
1024
,
depth
=
10
,
dim_head
=
64
,
heads
=
16
,
num_id_token
=
5
,
num_queries
=
32
,
output_dim
=
2048
,
ff_mult
=
4
,
):
super
().
__init__
()
self
.
num_id_token
=
num_id_token
self
.
dim
=
dim
self
.
num_queries
=
num_queries
assert
depth
%
5
==
0
self
.
depth
=
depth
//
5
scale
=
dim
**-
0.5
self
.
latents
=
nn
.
Parameter
(
torch
.
randn
(
1
,
num_queries
,
dim
)
*
scale
)
self
.
proj_out
=
nn
.
Parameter
(
scale
*
torch
.
randn
(
dim
,
output_dim
))
self
.
layers
=
nn
.
ModuleList
([])
for
_
in
range
(
depth
):
self
.
layers
.
append
(
nn
.
ModuleList
(
[
PerceiverAttention
(
dim
=
dim
,
dim_head
=
dim_head
,
heads
=
heads
),
FeedForward
(
dim
=
dim
,
mult
=
ff_mult
),
]
)
)
for
i
in
range
(
5
):
setattr
(
self
,
f
"mapping_
{
i
}
"
,
nn
.
Sequential
(
nn
.
Linear
(
1024
,
1024
),
nn
.
LayerNorm
(
1024
),
nn
.
LeakyReLU
(),
nn
.
Linear
(
1024
,
1024
),
nn
.
LayerNorm
(
1024
),
nn
.
LeakyReLU
(),
nn
.
Linear
(
1024
,
dim
),
),
)
self
.
id_embedding_mapping
=
nn
.
Sequential
(
nn
.
Linear
(
1280
,
1024
),
nn
.
LayerNorm
(
1024
),
nn
.
LeakyReLU
(),
nn
.
Linear
(
1024
,
1024
),
nn
.
LayerNorm
(
1024
),
nn
.
LeakyReLU
(),
nn
.
Linear
(
1024
,
dim
*
num_id_token
),
)
def
forward
(
self
,
x
,
y
):
latents
=
self
.
latents
.
repeat
(
x
.
size
(
0
),
1
,
1
)
num_duotu
=
x
.
shape
[
1
]
if
x
.
ndim
==
3
else
1
x
=
self
.
id_embedding_mapping
(
x
)
x
=
x
.
reshape
(
-
1
,
self
.
num_id_token
*
num_duotu
,
self
.
dim
)
latents
=
torch
.
cat
((
latents
,
x
),
dim
=
1
)
for
i
in
range
(
5
):
vit_feature
=
getattr
(
self
,
f
"mapping_
{
i
}
"
)(
y
[
i
])
ctx_feature
=
torch
.
cat
((
x
,
vit_feature
),
dim
=
1
)
for
attn
,
ff
in
self
.
layers
[
i
*
self
.
depth
:
(
i
+
1
)
*
self
.
depth
]:
latents
=
attn
(
ctx_feature
,
latents
)
+
latents
latents
=
ff
(
latents
)
+
latents
latents
=
latents
[:,
:
self
.
num_queries
]
latents
=
latents
@
self
.
proj_out
return
latents
nunchaku/models/pulid/eva_clip/__init__.py
0 → 100644
View file @
37a27712
from
.constants
import
OPENAI_DATASET_MEAN
,
OPENAI_DATASET_STD
from
.factory
import
create_model_and_transforms
__all__
=
[
"create_model_and_transforms"
,
"OPENAI_DATASET_MEAN"
,
"OPENAI_DATASET_STD"
]
nunchaku/models/pulid/eva_clip/constants.py
0 → 100644
View file @
37a27712
OPENAI_DATASET_MEAN
=
(
0.48145466
,
0.4578275
,
0.40821073
)
OPENAI_DATASET_STD
=
(
0.26862954
,
0.26130258
,
0.27577711
)
nunchaku/models/pulid/eva_clip/eva_vit_model.py
0 → 100644
View file @
37a27712
# --------------------------------------------------------
# Adapted from https://github.com/microsoft/unilm/tree/master/beit
# --------------------------------------------------------
import
math
import
os
from
functools
import
partial
import
torch
import
torch.nn
as
nn
from
torch.nn
import
functional
as
F
try
:
from
timm.models.layers
import
drop_path
,
to_2tuple
,
trunc_normal_
except
ImportError
:
from
timm.layers
import
drop_path
,
to_2tuple
,
trunc_normal_
from
.rope
import
VisionRotaryEmbeddingFast
from
.transformer
import
PatchDropout
if
os
.
getenv
(
"ENV_TYPE"
)
==
"deepspeed"
:
try
:
from
deepspeed.runtime.activation_checkpointing.checkpointing
import
checkpoint
except
ImportError
:
from
torch.utils.checkpoint
import
checkpoint
else
:
from
torch.utils.checkpoint
import
checkpoint
try
:
import
xformers
# noqa: F401
import
xformers.ops
as
xops
XFORMERS_IS_AVAILBLE
=
True
except
ImportError
:
XFORMERS_IS_AVAILBLE
=
False
class
DropPath
(
nn
.
Module
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def
__init__
(
self
,
drop_prob
=
None
):
super
(
DropPath
,
self
).
__init__
()
self
.
drop_prob
=
drop_prob
def
forward
(
self
,
x
):
return
drop_path
(
x
,
self
.
drop_prob
,
self
.
training
)
class
Mlp
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
GELU
,
norm_layer
=
nn
.
LayerNorm
,
drop
=
0.0
,
subln
=
False
,
):
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
)
self
.
act
=
act_layer
()
self
.
ffn_ln
=
norm_layer
(
hidden_features
)
if
subln
else
nn
.
Identity
()
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
)
self
.
drop
=
nn
.
Dropout
(
drop
)
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
x
=
self
.
act
(
x
)
# x = self.drop(x)
# commit this for the orignal BERT implement
x
=
self
.
ffn_ln
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
drop
(
x
)
return
x
class
SwiGLU
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
SiLU
,
drop
=
0.0
,
norm_layer
=
nn
.
LayerNorm
,
subln
=
False
,
):
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
w1
=
nn
.
Linear
(
in_features
,
hidden_features
)
self
.
w2
=
nn
.
Linear
(
in_features
,
hidden_features
)
self
.
act
=
act_layer
()
self
.
ffn_ln
=
norm_layer
(
hidden_features
)
if
subln
else
nn
.
Identity
()
self
.
w3
=
nn
.
Linear
(
hidden_features
,
out_features
)
self
.
drop
=
nn
.
Dropout
(
drop
)
def
forward
(
self
,
x
):
x1
=
self
.
w1
(
x
)
x2
=
self
.
w2
(
x
)
hidden
=
self
.
act
(
x1
)
*
x2
x
=
self
.
ffn_ln
(
hidden
)
x
=
self
.
w3
(
x
)
x
=
self
.
drop
(
x
)
return
x
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
=
8
,
qkv_bias
=
False
,
qk_scale
=
None
,
attn_drop
=
0.0
,
proj_drop
=
0.0
,
window_size
=
None
,
attn_head_dim
=
None
,
xattn
=
False
,
rope
=
None
,
subln
=
False
,
norm_layer
=
nn
.
LayerNorm
,
):
super
().
__init__
()
self
.
num_heads
=
num_heads
head_dim
=
dim
//
num_heads
if
attn_head_dim
is
not
None
:
head_dim
=
attn_head_dim
all_head_dim
=
head_dim
*
self
.
num_heads
self
.
scale
=
qk_scale
or
head_dim
**-
0.5
self
.
subln
=
subln
if
self
.
subln
:
self
.
q_proj
=
nn
.
Linear
(
dim
,
all_head_dim
,
bias
=
False
)
self
.
k_proj
=
nn
.
Linear
(
dim
,
all_head_dim
,
bias
=
False
)
self
.
v_proj
=
nn
.
Linear
(
dim
,
all_head_dim
,
bias
=
False
)
else
:
self
.
qkv
=
nn
.
Linear
(
dim
,
all_head_dim
*
3
,
bias
=
False
)
if
qkv_bias
:
self
.
q_bias
=
nn
.
Parameter
(
torch
.
zeros
(
all_head_dim
))
self
.
v_bias
=
nn
.
Parameter
(
torch
.
zeros
(
all_head_dim
))
else
:
self
.
q_bias
=
None
self
.
v_bias
=
None
if
window_size
:
self
.
window_size
=
window_size
self
.
num_relative_distance
=
(
2
*
window_size
[
0
]
-
1
)
*
(
2
*
window_size
[
1
]
-
1
)
+
3
self
.
relative_position_bias_table
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
num_relative_distance
,
num_heads
)
)
# 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h
=
torch
.
arange
(
window_size
[
0
])
coords_w
=
torch
.
arange
(
window_size
[
1
])
coords
=
torch
.
stack
(
torch
.
meshgrid
([
coords_h
,
coords_w
]))
# 2, Wh, Ww
coords_flatten
=
torch
.
flatten
(
coords
,
1
)
# 2, Wh*Ww
relative_coords
=
coords_flatten
[:,
:,
None
]
-
coords_flatten
[:,
None
,
:]
# 2, Wh*Ww, Wh*Ww
relative_coords
=
relative_coords
.
permute
(
1
,
2
,
0
).
contiguous
()
# Wh*Ww, Wh*Ww, 2
relative_coords
[:,
:,
0
]
+=
window_size
[
0
]
-
1
# shift to start from 0
relative_coords
[:,
:,
1
]
+=
window_size
[
1
]
-
1
relative_coords
[:,
:,
0
]
*=
2
*
window_size
[
1
]
-
1
relative_position_index
=
torch
.
zeros
(
size
=
(
window_size
[
0
]
*
window_size
[
1
]
+
1
,)
*
2
,
dtype
=
relative_coords
.
dtype
)
relative_position_index
[
1
:,
1
:]
=
relative_coords
.
sum
(
-
1
)
# Wh*Ww, Wh*Ww
relative_position_index
[
0
,
0
:]
=
self
.
num_relative_distance
-
3
relative_position_index
[
0
:,
0
]
=
self
.
num_relative_distance
-
2
relative_position_index
[
0
,
0
]
=
self
.
num_relative_distance
-
1
self
.
register_buffer
(
"relative_position_index"
,
relative_position_index
)
else
:
self
.
window_size
=
None
self
.
relative_position_bias_table
=
None
self
.
relative_position_index
=
None
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
inner_attn_ln
=
norm_layer
(
all_head_dim
)
if
subln
else
nn
.
Identity
()
# self.proj = nn.Linear(all_head_dim, all_head_dim)
self
.
proj
=
nn
.
Linear
(
all_head_dim
,
dim
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
self
.
xattn
=
xattn
self
.
xattn_drop
=
attn_drop
self
.
rope
=
rope
def
forward
(
self
,
x
,
rel_pos_bias
=
None
,
attn_mask
=
None
):
B
,
N
,
C
=
x
.
shape
if
self
.
subln
:
q
=
F
.
linear
(
input
=
x
,
weight
=
self
.
q_proj
.
weight
,
bias
=
self
.
q_bias
)
k
=
F
.
linear
(
input
=
x
,
weight
=
self
.
k_proj
.
weight
,
bias
=
None
)
v
=
F
.
linear
(
input
=
x
,
weight
=
self
.
v_proj
.
weight
,
bias
=
self
.
v_bias
)
q
=
q
.
reshape
(
B
,
N
,
self
.
num_heads
,
-
1
).
permute
(
0
,
2
,
1
,
3
)
# B, num_heads, N, C
k
=
k
.
reshape
(
B
,
N
,
self
.
num_heads
,
-
1
).
permute
(
0
,
2
,
1
,
3
)
v
=
v
.
reshape
(
B
,
N
,
self
.
num_heads
,
-
1
).
permute
(
0
,
2
,
1
,
3
)
else
:
qkv_bias
=
None
if
self
.
q_bias
is
not
None
:
qkv_bias
=
torch
.
cat
((
self
.
q_bias
,
torch
.
zeros_like
(
self
.
v_bias
,
requires_grad
=
False
),
self
.
v_bias
))
qkv
=
F
.
linear
(
input
=
x
,
weight
=
self
.
qkv
.
weight
,
bias
=
qkv_bias
)
qkv
=
qkv
.
reshape
(
B
,
N
,
3
,
self
.
num_heads
,
-
1
).
permute
(
2
,
0
,
3
,
1
,
4
)
# 3, B, num_heads, N, C
q
,
k
,
v
=
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
if
self
.
rope
:
# slightly fast impl
q_t
=
q
[:,
:,
1
:,
:]
ro_q_t
=
self
.
rope
(
q_t
)
q
=
torch
.
cat
((
q
[:,
:,
:
1
,
:],
ro_q_t
),
-
2
).
type_as
(
v
)
k_t
=
k
[:,
:,
1
:,
:]
ro_k_t
=
self
.
rope
(
k_t
)
k
=
torch
.
cat
((
k
[:,
:,
:
1
,
:],
ro_k_t
),
-
2
).
type_as
(
v
)
if
self
.
xattn
:
q
=
q
.
permute
(
0
,
2
,
1
,
3
)
# B, num_heads, N, C -> B, N, num_heads, C
k
=
k
.
permute
(
0
,
2
,
1
,
3
)
v
=
v
.
permute
(
0
,
2
,
1
,
3
)
x
=
xops
.
memory_efficient_attention
(
q
,
k
,
v
,
p
=
self
.
xattn_drop
,
scale
=
self
.
scale
,
)
x
=
x
.
reshape
(
B
,
N
,
-
1
)
x
=
self
.
inner_attn_ln
(
x
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
else
:
q
=
q
*
self
.
scale
attn
=
q
@
k
.
transpose
(
-
2
,
-
1
)
if
self
.
relative_position_bias_table
is
not
None
:
relative_position_bias
=
self
.
relative_position_bias_table
[
self
.
relative_position_index
.
view
(
-
1
)].
view
(
self
.
window_size
[
0
]
*
self
.
window_size
[
1
]
+
1
,
self
.
window_size
[
0
]
*
self
.
window_size
[
1
]
+
1
,
-
1
)
# Wh*Ww,Wh*Ww,nH
relative_position_bias
=
relative_position_bias
.
permute
(
2
,
0
,
1
).
contiguous
()
# nH, Wh*Ww, Wh*Ww
attn
=
attn
+
relative_position_bias
.
unsqueeze
(
0
).
type_as
(
attn
)
if
rel_pos_bias
is
not
None
:
attn
=
attn
+
rel_pos_bias
.
type_as
(
attn
)
if
attn_mask
is
not
None
:
attn_mask
=
attn_mask
.
bool
()
attn
=
attn
.
masked_fill
(
~
attn_mask
[:,
None
,
None
,
:],
float
(
"-inf"
))
attn
=
attn
.
softmax
(
dim
=-
1
)
attn
=
self
.
attn_drop
(
attn
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
).
reshape
(
B
,
N
,
-
1
)
x
=
self
.
inner_attn_ln
(
x
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
class
Block
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
,
mlp_ratio
=
4.0
,
qkv_bias
=
False
,
qk_scale
=
None
,
drop
=
0.0
,
attn_drop
=
0.0
,
drop_path
=
0.0
,
init_values
=
None
,
act_layer
=
nn
.
GELU
,
norm_layer
=
nn
.
LayerNorm
,
window_size
=
None
,
attn_head_dim
=
None
,
xattn
=
False
,
rope
=
None
,
postnorm
=
False
,
subln
=
False
,
naiveswiglu
=
False
,
):
super
().
__init__
()
self
.
norm1
=
norm_layer
(
dim
)
self
.
attn
=
Attention
(
dim
,
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
attn_drop
=
attn_drop
,
proj_drop
=
drop
,
window_size
=
window_size
,
attn_head_dim
=
attn_head_dim
,
xattn
=
xattn
,
rope
=
rope
,
subln
=
subln
,
norm_layer
=
norm_layer
,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.0
else
nn
.
Identity
()
self
.
norm2
=
norm_layer
(
dim
)
mlp_hidden_dim
=
int
(
dim
*
mlp_ratio
)
if
naiveswiglu
:
self
.
mlp
=
SwiGLU
(
in_features
=
dim
,
hidden_features
=
mlp_hidden_dim
,
subln
=
subln
,
norm_layer
=
norm_layer
,
)
else
:
self
.
mlp
=
Mlp
(
in_features
=
dim
,
hidden_features
=
mlp_hidden_dim
,
act_layer
=
act_layer
,
subln
=
subln
,
drop
=
drop
)
if
init_values
is
not
None
and
init_values
>
0
:
self
.
gamma_1
=
nn
.
Parameter
(
init_values
*
torch
.
ones
((
dim
)),
requires_grad
=
True
)
self
.
gamma_2
=
nn
.
Parameter
(
init_values
*
torch
.
ones
((
dim
)),
requires_grad
=
True
)
else
:
self
.
gamma_1
,
self
.
gamma_2
=
None
,
None
self
.
postnorm
=
postnorm
def
forward
(
self
,
x
,
rel_pos_bias
=
None
,
attn_mask
=
None
):
if
self
.
gamma_1
is
None
:
if
self
.
postnorm
:
x
=
x
+
self
.
drop_path
(
self
.
norm1
(
self
.
attn
(
x
,
rel_pos_bias
=
rel_pos_bias
,
attn_mask
=
attn_mask
)))
x
=
x
+
self
.
drop_path
(
self
.
norm2
(
self
.
mlp
(
x
)))
else
:
x
=
x
+
self
.
drop_path
(
self
.
attn
(
self
.
norm1
(
x
),
rel_pos_bias
=
rel_pos_bias
,
attn_mask
=
attn_mask
))
x
=
x
+
self
.
drop_path
(
self
.
mlp
(
self
.
norm2
(
x
)))
else
:
if
self
.
postnorm
:
x
=
x
+
self
.
drop_path
(
self
.
gamma_1
*
self
.
norm1
(
self
.
attn
(
x
,
rel_pos_bias
=
rel_pos_bias
,
attn_mask
=
attn_mask
))
)
x
=
x
+
self
.
drop_path
(
self
.
gamma_2
*
self
.
norm2
(
self
.
mlp
(
x
)))
else
:
x
=
x
+
self
.
drop_path
(
self
.
gamma_1
*
self
.
attn
(
self
.
norm1
(
x
),
rel_pos_bias
=
rel_pos_bias
,
attn_mask
=
attn_mask
)
)
x
=
x
+
self
.
drop_path
(
self
.
gamma_2
*
self
.
mlp
(
self
.
norm2
(
x
)))
return
x
class
PatchEmbed
(
nn
.
Module
):
"""Image to Patch Embedding"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
in_chans
=
3
,
embed_dim
=
768
):
super
().
__init__
()
img_size
=
to_2tuple
(
img_size
)
patch_size
=
to_2tuple
(
patch_size
)
num_patches
=
(
img_size
[
1
]
//
patch_size
[
1
])
*
(
img_size
[
0
]
//
patch_size
[
0
])
self
.
patch_shape
=
(
img_size
[
0
]
//
patch_size
[
0
],
img_size
[
1
]
//
patch_size
[
1
])
self
.
img_size
=
img_size
self
.
patch_size
=
patch_size
self
.
num_patches
=
num_patches
self
.
proj
=
nn
.
Conv2d
(
in_chans
,
embed_dim
,
kernel_size
=
patch_size
,
stride
=
patch_size
)
def
forward
(
self
,
x
,
**
kwargs
):
B
,
C
,
H
,
W
=
x
.
shape
# FIXME look at relaxing size constraints
assert
(
H
==
self
.
img_size
[
0
]
and
W
==
self
.
img_size
[
1
]
),
f
"Input image size (
{
H
}
*
{
W
}
) doesn't match model (
{
self
.
img_size
[
0
]
}
*
{
self
.
img_size
[
1
]
}
)."
x
=
self
.
proj
(
x
).
flatten
(
2
).
transpose
(
1
,
2
)
return
x
class
RelativePositionBias
(
nn
.
Module
):
def
__init__
(
self
,
window_size
,
num_heads
):
super
().
__init__
()
self
.
window_size
=
window_size
self
.
num_relative_distance
=
(
2
*
window_size
[
0
]
-
1
)
*
(
2
*
window_size
[
1
]
-
1
)
+
3
self
.
relative_position_bias_table
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
num_relative_distance
,
num_heads
)
)
# 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h
=
torch
.
arange
(
window_size
[
0
])
coords_w
=
torch
.
arange
(
window_size
[
1
])
coords
=
torch
.
stack
(
torch
.
meshgrid
([
coords_h
,
coords_w
]))
# 2, Wh, Ww
coords_flatten
=
torch
.
flatten
(
coords
,
1
)
# 2, Wh*Ww
relative_coords
=
coords_flatten
[:,
:,
None
]
-
coords_flatten
[:,
None
,
:]
# 2, Wh*Ww, Wh*Ww
relative_coords
=
relative_coords
.
permute
(
1
,
2
,
0
).
contiguous
()
# Wh*Ww, Wh*Ww, 2
relative_coords
[:,
:,
0
]
+=
window_size
[
0
]
-
1
# shift to start from 0
relative_coords
[:,
:,
1
]
+=
window_size
[
1
]
-
1
relative_coords
[:,
:,
0
]
*=
2
*
window_size
[
1
]
-
1
relative_position_index
=
torch
.
zeros
(
size
=
(
window_size
[
0
]
*
window_size
[
1
]
+
1
,)
*
2
,
dtype
=
relative_coords
.
dtype
)
relative_position_index
[
1
:,
1
:]
=
relative_coords
.
sum
(
-
1
)
# Wh*Ww, Wh*Ww
relative_position_index
[
0
,
0
:]
=
self
.
num_relative_distance
-
3
relative_position_index
[
0
:,
0
]
=
self
.
num_relative_distance
-
2
relative_position_index
[
0
,
0
]
=
self
.
num_relative_distance
-
1
self
.
register_buffer
(
"relative_position_index"
,
relative_position_index
)
def
forward
(
self
):
relative_position_bias
=
self
.
relative_position_bias_table
[
self
.
relative_position_index
.
view
(
-
1
)].
view
(
self
.
window_size
[
0
]
*
self
.
window_size
[
1
]
+
1
,
self
.
window_size
[
0
]
*
self
.
window_size
[
1
]
+
1
,
-
1
)
# Wh*Ww,Wh*Ww,nH
return
relative_position_bias
.
permute
(
2
,
0
,
1
).
contiguous
()
# nH, Wh*Ww, Wh*Ww
class
EVAVisionTransformer
(
nn
.
Module
):
"""Vision Transformer with support for patch or hybrid CNN input stage"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
in_chans
=
3
,
num_classes
=
1000
,
embed_dim
=
768
,
depth
=
12
,
num_heads
=
12
,
mlp_ratio
=
4.0
,
qkv_bias
=
False
,
qk_scale
=
None
,
drop_rate
=
0.0
,
attn_drop_rate
=
0.0
,
drop_path_rate
=
0.0
,
norm_layer
=
nn
.
LayerNorm
,
init_values
=
None
,
patch_dropout
=
0.0
,
use_abs_pos_emb
=
True
,
use_rel_pos_bias
=
False
,
use_shared_rel_pos_bias
=
False
,
rope
=
False
,
use_mean_pooling
=
True
,
init_scale
=
0.001
,
grad_checkpointing
=
False
,
xattn
=
False
,
postnorm
=
False
,
pt_hw_seq_len
=
16
,
intp_freq
=
False
,
naiveswiglu
=
False
,
subln
=
False
,
):
super
().
__init__
()
if
not
XFORMERS_IS_AVAILBLE
:
xattn
=
False
self
.
image_size
=
img_size
self
.
num_classes
=
num_classes
self
.
num_features
=
self
.
embed_dim
=
embed_dim
# num_features for consistency with other models
self
.
patch_embed
=
PatchEmbed
(
img_size
=
img_size
,
patch_size
=
patch_size
,
in_chans
=
in_chans
,
embed_dim
=
embed_dim
)
num_patches
=
self
.
patch_embed
.
num_patches
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
embed_dim
))
# self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
if
use_abs_pos_emb
:
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
num_patches
+
1
,
embed_dim
))
else
:
self
.
pos_embed
=
None
self
.
pos_drop
=
nn
.
Dropout
(
p
=
drop_rate
)
if
use_shared_rel_pos_bias
:
self
.
rel_pos_bias
=
RelativePositionBias
(
window_size
=
self
.
patch_embed
.
patch_shape
,
num_heads
=
num_heads
)
else
:
self
.
rel_pos_bias
=
None
if
rope
:
half_head_dim
=
embed_dim
//
num_heads
//
2
hw_seq_len
=
img_size
//
patch_size
self
.
rope
=
VisionRotaryEmbeddingFast
(
dim
=
half_head_dim
,
pt_seq_len
=
pt_hw_seq_len
,
ft_seq_len
=
hw_seq_len
if
intp_freq
else
None
,
# patch_dropout=patch_dropout
)
else
:
self
.
rope
=
None
self
.
naiveswiglu
=
naiveswiglu
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
depth
)]
# stochastic depth decay rule
self
.
use_rel_pos_bias
=
use_rel_pos_bias
self
.
blocks
=
nn
.
ModuleList
(
[
Block
(
dim
=
embed_dim
,
num_heads
=
num_heads
,
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
i
],
norm_layer
=
norm_layer
,
init_values
=
init_values
,
window_size
=
self
.
patch_embed
.
patch_shape
if
use_rel_pos_bias
else
None
,
xattn
=
xattn
,
rope
=
self
.
rope
,
postnorm
=
postnorm
,
subln
=
subln
,
naiveswiglu
=
naiveswiglu
,
)
for
i
in
range
(
depth
)
]
)
self
.
norm
=
nn
.
Identity
()
if
use_mean_pooling
else
norm_layer
(
embed_dim
)
self
.
fc_norm
=
norm_layer
(
embed_dim
)
if
use_mean_pooling
else
None
self
.
head
=
nn
.
Linear
(
embed_dim
,
num_classes
)
if
num_classes
>
0
else
nn
.
Identity
()
if
self
.
pos_embed
is
not
None
:
trunc_normal_
(
self
.
pos_embed
,
std
=
0.02
)
trunc_normal_
(
self
.
cls_token
,
std
=
0.02
)
# trunc_normal_(self.mask_token, std=.02)
self
.
apply
(
self
.
_init_weights
)
self
.
fix_init_weight
()
if
isinstance
(
self
.
head
,
nn
.
Linear
):
trunc_normal_
(
self
.
head
.
weight
,
std
=
0.02
)
self
.
head
.
weight
.
data
.
mul_
(
init_scale
)
self
.
head
.
bias
.
data
.
mul_
(
init_scale
)
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
self
.
patch_dropout
=
PatchDropout
(
patch_dropout
)
if
patch_dropout
>
0.0
else
nn
.
Identity
()
self
.
grad_checkpointing
=
grad_checkpointing
def
fix_init_weight
(
self
):
def
rescale
(
param
,
layer_id
):
param
.
div_
(
math
.
sqrt
(
2.0
*
layer_id
))
for
layer_id
,
layer
in
enumerate
(
self
.
blocks
):
rescale
(
layer
.
attn
.
proj
.
weight
.
data
,
layer_id
+
1
)
if
self
.
naiveswiglu
:
rescale
(
layer
.
mlp
.
w3
.
weight
.
data
,
layer_id
+
1
)
else
:
rescale
(
layer
.
mlp
.
fc2
.
weight
.
data
,
layer_id
+
1
)
def
get_cast_dtype
(
self
)
->
torch
.
dtype
:
return
self
.
blocks
[
0
].
mlp
.
fc2
.
weight
.
dtype
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
0.02
)
if
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
@
torch
.
jit
.
ignore
def
set_grad_checkpointing
(
self
,
enable
=
True
):
self
.
grad_checkpointing
=
enable
def
forward_features
(
self
,
x
,
return_all_features
=
False
,
return_hidden
=
False
,
shuffle
=
False
):
x
=
self
.
patch_embed
(
x
)
batch_size
,
seq_len
,
_
=
x
.
size
()
if
shuffle
:
idx
=
torch
.
randperm
(
x
.
shape
[
1
])
+
1
zero
=
torch
.
LongTensor
(
[
0
,
]
)
idx
=
torch
.
cat
([
zero
,
idx
])
pos_embed
=
self
.
pos_embed
[:,
idx
]
cls_tokens
=
self
.
cls_token
.
expand
(
batch_size
,
-
1
,
-
1
)
# stole cls_tokens impl from Phil Wang, thanks
x
=
torch
.
cat
((
cls_tokens
,
x
),
dim
=
1
)
if
shuffle
:
x
=
x
+
pos_embed
elif
self
.
pos_embed
is
not
None
:
x
=
x
+
self
.
pos_embed
x
=
self
.
pos_drop
(
x
)
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
if
os
.
getenv
(
"RoPE"
)
==
"1"
:
if
self
.
training
and
not
isinstance
(
self
.
patch_dropout
,
nn
.
Identity
):
x
,
patch_indices_keep
=
self
.
patch_dropout
(
x
)
self
.
rope
.
forward
=
partial
(
self
.
rope
.
forward
,
patch_indices_keep
=
patch_indices_keep
)
else
:
self
.
rope
.
forward
=
partial
(
self
.
rope
.
forward
,
patch_indices_keep
=
None
)
x
=
self
.
patch_dropout
(
x
)
else
:
x
=
self
.
patch_dropout
(
x
)
rel_pos_bias
=
self
.
rel_pos_bias
()
if
self
.
rel_pos_bias
is
not
None
else
None
hidden_states
=
[]
for
idx
,
blk
in
enumerate
(
self
.
blocks
):
if
(
0
<
idx
<=
20
)
and
(
idx
%
4
==
0
)
and
return_hidden
:
hidden_states
.
append
(
x
)
if
self
.
grad_checkpointing
:
x
=
checkpoint
(
blk
,
x
,
(
rel_pos_bias
,))
else
:
x
=
blk
(
x
,
rel_pos_bias
=
rel_pos_bias
)
if
not
return_all_features
:
x
=
self
.
norm
(
x
)
if
self
.
fc_norm
is
not
None
:
return
self
.
fc_norm
(
x
.
mean
(
1
)),
hidden_states
else
:
return
x
[:,
0
],
hidden_states
return
x
def
forward
(
self
,
x
,
return_all_features
=
False
,
return_hidden
=
False
,
shuffle
=
False
):
if
return_all_features
:
return
self
.
forward_features
(
x
,
return_all_features
,
return_hidden
,
shuffle
)
x
,
hidden_states
=
self
.
forward_features
(
x
,
return_all_features
,
return_hidden
,
shuffle
)
x
=
self
.
head
(
x
)
if
return_hidden
:
return
x
,
hidden_states
return
x
nunchaku/models/pulid/eva_clip/factory.py
0 → 100644
View file @
37a27712
import
json
import
logging
import
os
import
re
from
copy
import
deepcopy
from
pathlib
import
Path
from
typing
import
Optional
,
Tuple
,
Union
import
torch
from
.constants
import
OPENAI_DATASET_MEAN
,
OPENAI_DATASET_STD
from
.model
import
CLIP
,
CustomCLIP
,
convert_to_custom_text_state_dict
,
get_cast_dtype
from
.pretrained
import
download_pretrained
,
get_pretrained_cfg
,
list_pretrained_tags_by_model
from
.transform
import
image_transform
from
.utils
import
resize_clip_pos_embed
,
resize_eva_pos_embed
,
resize_evaclip_pos_embed
,
resize_visual_pos_embed
_MODEL_CONFIG_PATHS
=
[
Path
(
__file__
).
parent
/
"model_configs/"
]
_MODEL_CONFIGS
=
{}
# directory (model_name: config) of model architecture configs
def
_natural_key
(
string_
):
return
[
int
(
s
)
if
s
.
isdigit
()
else
s
for
s
in
re
.
split
(
r
"(\d+)"
,
string_
.
lower
())]
def
_rescan_model_configs
():
global
_MODEL_CONFIGS
config_ext
=
(
".json"
,)
config_files
=
[]
for
config_path
in
_MODEL_CONFIG_PATHS
:
if
config_path
.
is_file
()
and
config_path
.
suffix
in
config_ext
:
config_files
.
append
(
config_path
)
elif
config_path
.
is_dir
():
for
ext
in
config_ext
:
config_files
.
extend
(
config_path
.
glob
(
f
"*
{
ext
}
"
))
for
cf
in
config_files
:
with
open
(
cf
,
"r"
,
encoding
=
"utf8"
)
as
f
:
model_cfg
=
json
.
load
(
f
)
if
all
(
a
in
model_cfg
for
a
in
(
"embed_dim"
,
"vision_cfg"
,
"text_cfg"
)):
_MODEL_CONFIGS
[
cf
.
stem
]
=
model_cfg
_MODEL_CONFIGS
=
dict
(
sorted
(
_MODEL_CONFIGS
.
items
(),
key
=
lambda
x
:
_natural_key
(
x
[
0
])))
_rescan_model_configs
()
# initial populate of model config registry
def
list_models
():
"""enumerate available model architectures based on config files"""
return
list
(
_MODEL_CONFIGS
.
keys
())
def
get_model_config
(
model_name
):
if
model_name
in
_MODEL_CONFIGS
:
return
deepcopy
(
_MODEL_CONFIGS
[
model_name
])
else
:
return
None
# loading openai CLIP weights when is_openai=True for training
def
load_state_dict
(
checkpoint_path
:
str
,
map_location
:
str
=
"cpu"
,
model_key
:
str
=
"model|module|state_dict"
,
is_openai
:
bool
=
False
,
skip_list
:
list
=
[],
):
if
is_openai
:
model
=
torch
.
jit
.
load
(
checkpoint_path
,
map_location
=
"cpu"
).
eval
()
state_dict
=
model
.
state_dict
()
for
key
in
[
"input_resolution"
,
"context_length"
,
"vocab_size"
]:
state_dict
.
pop
(
key
,
None
)
else
:
checkpoint
=
torch
.
load
(
checkpoint_path
,
map_location
=
map_location
)
for
mk
in
model_key
.
split
(
"|"
):
if
isinstance
(
checkpoint
,
dict
)
and
mk
in
checkpoint
:
state_dict
=
checkpoint
[
mk
]
break
else
:
state_dict
=
checkpoint
if
next
(
iter
(
state_dict
.
items
()))[
0
].
startswith
(
"module"
):
state_dict
=
{
k
[
7
:]:
v
for
k
,
v
in
state_dict
.
items
()}
for
k
in
skip_list
:
if
k
in
list
(
state_dict
.
keys
()):
logging
.
info
(
f
"Removing key
{
k
}
from pretrained checkpoint"
)
del
state_dict
[
k
]
if
os
.
getenv
(
"RoPE"
)
==
"1"
:
for
k
in
list
(
state_dict
.
keys
()):
if
"freqs_cos"
in
k
or
"freqs_sin"
in
k
:
del
state_dict
[
k
]
return
state_dict
def
load_checkpoint
(
model
,
checkpoint_path
,
model_key
=
"model|module|state_dict"
,
strict
=
True
):
state_dict
=
load_state_dict
(
checkpoint_path
,
model_key
=
model_key
,
is_openai
=
False
)
# detect old format and make compatible with new format
if
"positional_embedding"
in
state_dict
and
not
hasattr
(
model
,
"positional_embedding"
):
state_dict
=
convert_to_custom_text_state_dict
(
state_dict
)
if
"text.logit_scale"
in
state_dict
and
hasattr
(
model
,
"logit_scale"
):
state_dict
[
"logit_scale"
]
=
state_dict
[
"text.logit_scale"
]
del
state_dict
[
"text.logit_scale"
]
# resize_clip_pos_embed for CLIP and open CLIP
if
"visual.positional_embedding"
in
state_dict
:
resize_clip_pos_embed
(
state_dict
,
model
)
# specified to eva_vit_model
elif
"visual.pos_embed"
in
state_dict
:
resize_evaclip_pos_embed
(
state_dict
,
model
)
# resize_clip_pos_embed(state_dict, model)
incompatible_keys
=
model
.
load_state_dict
(
state_dict
,
strict
=
strict
)
# logging.info(f"incompatible_keys.missing_keys: {incompatible_keys.missing_keys}")
return
incompatible_keys
def
load_clip_visual_state_dict
(
checkpoint_path
:
str
,
map_location
:
str
=
"cpu"
,
is_openai
:
bool
=
False
,
skip_list
:
list
=
[]
):
state_dict
=
load_state_dict
(
checkpoint_path
,
map_location
=
map_location
,
is_openai
=
is_openai
,
skip_list
=
skip_list
)
for
k
in
list
(
state_dict
.
keys
()):
if
not
k
.
startswith
(
"visual."
):
del
state_dict
[
k
]
for
k
in
list
(
state_dict
.
keys
()):
if
k
.
startswith
(
"visual."
):
new_k
=
k
[
7
:]
state_dict
[
new_k
]
=
state_dict
[
k
]
del
state_dict
[
k
]
return
state_dict
def
load_clip_text_state_dict
(
checkpoint_path
:
str
,
map_location
:
str
=
"cpu"
,
is_openai
:
bool
=
False
,
skip_list
:
list
=
[]
):
state_dict
=
load_state_dict
(
checkpoint_path
,
map_location
=
map_location
,
is_openai
=
is_openai
,
skip_list
=
skip_list
)
for
k
in
list
(
state_dict
.
keys
()):
if
k
.
startswith
(
"visual."
):
del
state_dict
[
k
]
return
state_dict
def
get_pretrained_tag
(
pretrained_model
):
pretrained_model
=
pretrained_model
.
lower
()
if
"laion"
in
pretrained_model
or
"open_clip"
in
pretrained_model
:
return
"open_clip"
elif
"openai"
in
pretrained_model
:
return
"clip"
elif
"eva"
in
pretrained_model
and
"clip"
in
pretrained_model
:
return
"eva_clip"
else
:
return
"other"
def
load_pretrained_checkpoint
(
model
,
visual_checkpoint_path
,
text_checkpoint_path
,
strict
=
True
,
visual_model
=
None
,
text_model
=
None
,
model_key
=
"model|module|state_dict"
,
skip_list
=
[],
):
visual_tag
=
get_pretrained_tag
(
visual_model
)
text_tag
=
get_pretrained_tag
(
text_model
)
logging
.
info
(
f
"num of model state_dict keys:
{
len
(
model
.
state_dict
().
keys
())
}
"
)
visual_incompatible_keys
,
text_incompatible_keys
=
None
,
None
if
visual_checkpoint_path
:
if
visual_tag
==
"eva_clip"
or
visual_tag
==
"open_clip"
:
visual_state_dict
=
load_clip_visual_state_dict
(
visual_checkpoint_path
,
is_openai
=
False
,
skip_list
=
skip_list
)
elif
visual_tag
==
"clip"
:
visual_state_dict
=
load_clip_visual_state_dict
(
visual_checkpoint_path
,
is_openai
=
True
,
skip_list
=
skip_list
)
else
:
visual_state_dict
=
load_state_dict
(
visual_checkpoint_path
,
model_key
=
model_key
,
is_openai
=
False
,
skip_list
=
skip_list
)
# resize_clip_pos_embed for CLIP and open CLIP
if
"positional_embedding"
in
visual_state_dict
:
resize_visual_pos_embed
(
visual_state_dict
,
model
)
# specified to EVA model
elif
"pos_embed"
in
visual_state_dict
:
resize_eva_pos_embed
(
visual_state_dict
,
model
)
visual_incompatible_keys
=
model
.
visual
.
load_state_dict
(
visual_state_dict
,
strict
=
strict
)
logging
.
info
(
f
"num of loaded visual_state_dict keys:
{
len
(
visual_state_dict
.
keys
())
}
"
)
logging
.
info
(
f
"visual_incompatible_keys.missing_keys:
{
visual_incompatible_keys
.
missing_keys
}
"
)
if
text_checkpoint_path
:
if
text_tag
==
"eva_clip"
or
text_tag
==
"open_clip"
:
text_state_dict
=
load_clip_text_state_dict
(
text_checkpoint_path
,
is_openai
=
False
,
skip_list
=
skip_list
)
elif
text_tag
==
"clip"
:
text_state_dict
=
load_clip_text_state_dict
(
text_checkpoint_path
,
is_openai
=
True
,
skip_list
=
skip_list
)
else
:
text_state_dict
=
load_state_dict
(
visual_checkpoint_path
,
model_key
=
model_key
,
is_openai
=
False
,
skip_list
=
skip_list
)
text_incompatible_keys
=
model
.
text
.
load_state_dict
(
text_state_dict
,
strict
=
strict
)
logging
.
info
(
f
"num of loaded text_state_dict keys:
{
len
(
text_state_dict
.
keys
())
}
"
)
logging
.
info
(
f
"text_incompatible_keys.missing_keys:
{
text_incompatible_keys
.
missing_keys
}
"
)
return
visual_incompatible_keys
,
text_incompatible_keys
def
create_model
(
model_name
:
str
,
pretrained
:
Optional
[
str
]
=
None
,
precision
:
str
=
"fp32"
,
device
:
Union
[
str
,
torch
.
device
]
=
"cpu"
,
jit
:
bool
=
False
,
force_quick_gelu
:
bool
=
False
,
force_custom_clip
:
bool
=
False
,
force_patch_dropout
:
Optional
[
float
]
=
None
,
pretrained_image
:
str
=
""
,
pretrained_text
:
str
=
""
,
pretrained_hf
:
bool
=
True
,
pretrained_visual_model
:
str
=
None
,
pretrained_text_model
:
str
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
skip_list
:
list
=
[],
):
model_name
=
model_name
.
replace
(
"/"
,
"-"
)
# for callers using old naming with / in ViT names
if
isinstance
(
device
,
str
):
device
=
torch
.
device
(
device
)
if
pretrained
and
pretrained
.
lower
()
==
"openai"
:
pass
else
:
model_cfg
=
get_model_config
(
model_name
)
if
model_cfg
is
not
None
:
logging
.
info
(
f
"Loaded
{
model_name
}
model config."
)
else
:
logging
.
error
(
f
"Model config for
{
model_name
}
not found; available models
{
list_models
()
}
."
)
raise
RuntimeError
(
f
"Model config for
{
model_name
}
not found."
)
if
"rope"
in
model_cfg
.
get
(
"vision_cfg"
,
{}):
if
model_cfg
[
"vision_cfg"
][
"rope"
]:
os
.
environ
[
"RoPE"
]
=
"1"
else
:
os
.
environ
[
"RoPE"
]
=
"0"
if
force_quick_gelu
:
# override for use of QuickGELU on non-OpenAI transformer models
model_cfg
[
"quick_gelu"
]
=
True
if
force_patch_dropout
is
not
None
:
# override the default patch dropout value
model_cfg
[
"vision_cfg"
][
"patch_dropout"
]
=
force_patch_dropout
cast_dtype
=
get_cast_dtype
(
precision
)
custom_clip
=
(
model_cfg
.
pop
(
"custom_text"
,
False
)
or
force_custom_clip
or
(
"hf_model_name"
in
model_cfg
[
"text_cfg"
])
)
if
custom_clip
:
if
"hf_model_name"
in
model_cfg
.
get
(
"text_cfg"
,
{}):
model_cfg
[
"text_cfg"
][
"hf_model_pretrained"
]
=
pretrained_hf
model
=
CustomCLIP
(
**
model_cfg
,
cast_dtype
=
cast_dtype
)
else
:
model
=
CLIP
(
**
model_cfg
,
cast_dtype
=
cast_dtype
)
pretrained_cfg
=
{}
if
pretrained
:
checkpoint_path
=
""
pretrained_cfg
=
get_pretrained_cfg
(
model_name
,
pretrained
)
if
pretrained_cfg
:
checkpoint_path
=
download_pretrained
(
pretrained_cfg
,
cache_dir
=
cache_dir
)
elif
os
.
path
.
exists
(
pretrained
):
checkpoint_path
=
pretrained
if
checkpoint_path
:
logging
.
info
(
f
"Loading pretrained
{
model_name
}
weights (
{
pretrained
}
)."
)
load_checkpoint
(
model
,
checkpoint_path
,
model_key
=
"model|module|state_dict"
,
strict
=
False
)
else
:
error_str
=
(
f
"Pretrained weights (
{
pretrained
}
) not found for model
{
model_name
}
."
f
"Available pretrained tags (
{
list_pretrained_tags_by_model
(
model_name
)
}
."
)
logging
.
warning
(
error_str
)
raise
RuntimeError
(
error_str
)
else
:
visual_checkpoint_path
=
""
text_checkpoint_path
=
""
if
pretrained_image
:
pretrained_visual_model
=
pretrained_visual_model
.
replace
(
"/"
,
"-"
)
# for callers using old naming with / in ViT names
pretrained_image_cfg
=
get_pretrained_cfg
(
pretrained_visual_model
,
pretrained_image
)
if
"timm_model_name"
in
model_cfg
.
get
(
"vision_cfg"
,
{}):
# pretrained weight loading for timm models set via vision_cfg
model_cfg
[
"vision_cfg"
][
"timm_model_pretrained"
]
=
True
elif
pretrained_image_cfg
:
visual_checkpoint_path
=
download_pretrained
(
pretrained_image_cfg
,
cache_dir
=
cache_dir
)
elif
os
.
path
.
exists
(
pretrained_image
):
visual_checkpoint_path
=
pretrained_image
else
:
logging
.
warning
(
f
"Pretrained weights (
{
visual_checkpoint_path
}
) not found for model
{
model_name
}
.visual."
)
raise
RuntimeError
(
f
"Pretrained weights (
{
visual_checkpoint_path
}
) not found for model
{
model_name
}
.visual."
)
if
pretrained_text
:
pretrained_text_model
=
pretrained_text_model
.
replace
(
"/"
,
"-"
)
# for callers using old naming with / in ViT names
pretrained_text_cfg
=
get_pretrained_cfg
(
pretrained_text_model
,
pretrained_text
)
if
pretrained_image_cfg
:
text_checkpoint_path
=
download_pretrained
(
pretrained_text_cfg
,
cache_dir
=
cache_dir
)
elif
os
.
path
.
exists
(
pretrained_text
):
text_checkpoint_path
=
pretrained_text
else
:
logging
.
warning
(
f
"Pretrained weights (
{
text_checkpoint_path
}
) not found for model
{
model_name
}
.text."
)
raise
RuntimeError
(
f
"Pretrained weights (
{
text_checkpoint_path
}
) not found for model
{
model_name
}
.text."
)
if
visual_checkpoint_path
:
logging
.
info
(
f
"Loading pretrained
{
model_name
}
.visual weights (
{
visual_checkpoint_path
}
)."
)
if
text_checkpoint_path
:
logging
.
info
(
f
"Loading pretrained
{
model_name
}
.text weights (
{
text_checkpoint_path
}
)."
)
if
visual_checkpoint_path
or
text_checkpoint_path
:
load_pretrained_checkpoint
(
model
,
visual_checkpoint_path
,
text_checkpoint_path
,
strict
=
False
,
visual_model
=
pretrained_visual_model
,
text_model
=
pretrained_text_model
,
model_key
=
"model|module|state_dict"
,
skip_list
=
skip_list
,
)
if
"fp16"
in
precision
or
"bf16"
in
precision
:
logging
.
info
(
f
"convert precision to
{
precision
}
"
)
model
=
model
.
to
(
torch
.
bfloat16
)
if
"bf16"
in
precision
else
model
.
to
(
torch
.
float16
)
model
.
to
(
device
=
device
)
# set image / mean metadata from pretrained_cfg if available, or use default
model
.
visual
.
image_mean
=
pretrained_cfg
.
get
(
"mean"
,
None
)
or
OPENAI_DATASET_MEAN
model
.
visual
.
image_std
=
pretrained_cfg
.
get
(
"std"
,
None
)
or
OPENAI_DATASET_STD
if
jit
:
model
=
torch
.
jit
.
script
(
model
)
return
model
def
create_model_and_transforms
(
model_name
:
str
,
pretrained
:
Optional
[
str
]
=
None
,
precision
:
str
=
"fp32"
,
device
:
Union
[
str
,
torch
.
device
]
=
"cpu"
,
jit
:
bool
=
False
,
force_quick_gelu
:
bool
=
False
,
force_custom_clip
:
bool
=
False
,
force_patch_dropout
:
Optional
[
float
]
=
None
,
pretrained_image
:
str
=
""
,
pretrained_text
:
str
=
""
,
pretrained_hf
:
bool
=
True
,
pretrained_visual_model
:
str
=
None
,
pretrained_text_model
:
str
=
None
,
image_mean
:
Optional
[
Tuple
[
float
,
...]]
=
None
,
image_std
:
Optional
[
Tuple
[
float
,
...]]
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
skip_list
:
list
=
[],
):
model
=
create_model
(
model_name
,
pretrained
,
precision
=
precision
,
device
=
device
,
jit
=
jit
,
force_quick_gelu
=
force_quick_gelu
,
force_custom_clip
=
force_custom_clip
,
force_patch_dropout
=
force_patch_dropout
,
pretrained_image
=
pretrained_image
,
pretrained_text
=
pretrained_text
,
pretrained_hf
=
pretrained_hf
,
pretrained_visual_model
=
pretrained_visual_model
,
pretrained_text_model
=
pretrained_text_model
,
cache_dir
=
cache_dir
,
skip_list
=
skip_list
,
)
image_mean
=
image_mean
or
getattr
(
model
.
visual
,
"image_mean"
,
None
)
image_std
=
image_std
or
getattr
(
model
.
visual
,
"image_std"
,
None
)
preprocess_train
=
image_transform
(
model
.
visual
.
image_size
,
is_train
=
True
,
mean
=
image_mean
,
std
=
image_std
)
preprocess_val
=
image_transform
(
model
.
visual
.
image_size
,
is_train
=
False
,
mean
=
image_mean
,
std
=
image_std
)
return
model
,
preprocess_train
,
preprocess_val
nunchaku/models/pulid/eva_clip/hf_configs.py
0 → 100644
View file @
37a27712
# HF architecture dict:
arch_dict
=
{
# https://huggingface.co/docs/transformers/model_doc/roberta#roberta
"roberta"
:
{
"config_names"
:
{
"context_length"
:
"max_position_embeddings"
,
"vocab_size"
:
"vocab_size"
,
"width"
:
"hidden_size"
,
"heads"
:
"num_attention_heads"
,
"layers"
:
"num_hidden_layers"
,
"layer_attr"
:
"layer"
,
"token_embeddings_attr"
:
"embeddings"
,
},
"pooler"
:
"mean_pooler"
,
},
# https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
"xlm-roberta"
:
{
"config_names"
:
{
"context_length"
:
"max_position_embeddings"
,
"vocab_size"
:
"vocab_size"
,
"width"
:
"hidden_size"
,
"heads"
:
"num_attention_heads"
,
"layers"
:
"num_hidden_layers"
,
"layer_attr"
:
"layer"
,
"token_embeddings_attr"
:
"embeddings"
,
},
"pooler"
:
"mean_pooler"
,
},
# https://huggingface.co/docs/transformers/model_doc/mt5#mt5
"mt5"
:
{
"config_names"
:
{
# unlimited seqlen
# https://github.com/google-research/text-to-text-transfer-transformer/issues/273
# https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
"context_length"
:
""
,
"vocab_size"
:
"vocab_size"
,
"width"
:
"d_model"
,
"heads"
:
"num_heads"
,
"layers"
:
"num_layers"
,
"layer_attr"
:
"block"
,
"token_embeddings_attr"
:
"embed_tokens"
,
},
"pooler"
:
"mean_pooler"
,
},
"bert"
:
{
"config_names"
:
{
"context_length"
:
"max_position_embeddings"
,
"vocab_size"
:
"vocab_size"
,
"width"
:
"hidden_size"
,
"heads"
:
"num_attention_heads"
,
"layers"
:
"num_hidden_layers"
,
"layer_attr"
:
"layer"
,
"token_embeddings_attr"
:
"embeddings"
,
},
"pooler"
:
"mean_pooler"
,
},
}
nunchaku/models/pulid/eva_clip/hf_model.py
0 → 100644
View file @
37a27712
"""huggingface model adapter
Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
"""
import
re
import
torch
import
torch.nn
as
nn
from
torch
import
TensorType
try
:
import
transformers
from
transformers
import
AutoConfig
,
AutoModel
,
AutoModelForMaskedLM
,
AutoTokenizer
,
PretrainedConfig
except
ImportError
:
transformers
=
None
class
PretrainedConfig
:
pass
from
.hf_configs
import
arch_dict
# utils
def
_camel2snake
(
s
):
return
re
.
sub
(
r
"(?<!^)(?=[A-Z])"
,
"_"
,
s
).
lower
()
# TODO: ?last - for gpt-like models
_POOLERS
=
{}
class
HFTextEncoder
(
nn
.
Module
):
"""HuggingFace model adapter"""
def
__init__
(
self
,
model_name_or_path
:
str
,
output_dim
:
int
,
tokenizer_name
:
str
=
None
,
config
:
PretrainedConfig
=
None
,
pooler_type
:
str
=
None
,
proj
:
str
=
None
,
pretrained
:
bool
=
True
,
masked_language_modeling
:
bool
=
False
,
):
super
().
__init__
()
self
.
output_dim
=
output_dim
# TODO: find better way to get this information
uses_transformer_pooler
=
pooler_type
==
"cls_pooler"
if
transformers
is
None
:
raise
RuntimeError
(
"Please `pip install transformers` to use pre-trained HuggingFace models"
)
if
config
is
None
:
self
.
config
=
AutoConfig
.
from_pretrained
(
model_name_or_path
)
if
masked_language_modeling
:
create_func
,
model_args
=
(
(
AutoModelForMaskedLM
.
from_pretrained
,
model_name_or_path
)
if
pretrained
else
(
AutoModelForMaskedLM
.
from_config
,
self
.
config
)
)
else
:
create_func
,
model_args
=
(
(
AutoModel
.
from_pretrained
,
model_name_or_path
)
if
pretrained
else
(
AutoModel
.
from_config
,
self
.
config
)
)
# TODO: do all model configs have this attribute? PretrainedConfig does so yes??
if
hasattr
(
self
.
config
,
"is_encoder_decoder"
)
and
self
.
config
.
is_encoder_decoder
:
self
.
transformer
=
create_func
(
model_args
)
self
.
transformer
=
self
.
transformer
.
encoder
else
:
self
.
transformer
=
create_func
(
model_args
,
add_pooling_layer
=
uses_transformer_pooler
)
else
:
self
.
config
=
config
if
masked_language_modeling
:
self
.
transformer
=
AutoModelForMaskedLM
.
from_config
(
config
)
else
:
self
.
transformer
=
AutoModel
.
from_config
(
config
)
if
pooler_type
is
None
:
# get default arch pooler
self
.
pooler
=
_POOLERS
[(
arch_dict
[
self
.
config
.
model_type
][
"pooler"
])]()
else
:
self
.
pooler
=
_POOLERS
[
pooler_type
]()
d_model
=
getattr
(
self
.
config
,
arch_dict
[
self
.
config
.
model_type
][
"config_names"
][
"width"
])
if
(
d_model
==
output_dim
)
and
(
proj
is
None
):
# do we always need a proj?
self
.
proj
=
nn
.
Identity
()
elif
proj
==
"linear"
:
self
.
proj
=
nn
.
Linear
(
d_model
,
output_dim
,
bias
=
False
)
elif
proj
==
"mlp"
:
hidden_size
=
(
d_model
+
output_dim
)
//
2
self
.
proj
=
nn
.
Sequential
(
nn
.
Linear
(
d_model
,
hidden_size
,
bias
=
False
),
nn
.
GELU
(),
nn
.
Linear
(
hidden_size
,
output_dim
,
bias
=
False
),
)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer_name
)
def
mask
(
self
,
input_ids
,
vocab_size
,
device
,
targets
=
None
,
masked_indices
=
None
,
probability_matrix
=
None
):
if
masked_indices
is
None
:
masked_indices
=
torch
.
bernoulli
(
probability_matrix
).
bool
()
masked_indices
[
input_ids
==
self
.
tokenizer
.
pad_token_id
]
=
False
masked_indices
[
input_ids
==
self
.
tokenizer
.
cls_token_id
]
=
False
if
targets
is
not
None
:
targets
[
~
masked_indices
]
=
-
100
# We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced
=
torch
.
bernoulli
(
torch
.
full
(
input_ids
.
shape
,
0.8
)).
bool
()
&
masked_indices
input_ids
[
indices_replaced
]
=
self
.
tokenizer
.
mask_token_id
# 10% of the time, we replace masked input tokens with random word
indices_random
=
torch
.
bernoulli
(
torch
.
full
(
input_ids
.
shape
,
0.5
)).
bool
()
&
masked_indices
&
~
indices_replaced
random_words
=
torch
.
randint
(
vocab_size
,
input_ids
.
shape
,
dtype
=
torch
.
long
).
to
(
device
)
input_ids
[
indices_random
]
=
random_words
[
indices_random
]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
if
targets
is
not
None
:
return
input_ids
,
targets
else
:
return
input_ids
def
forward
(
self
,
x
:
TensorType
)
->
TensorType
:
attn_mask
=
(
x
!=
self
.
config
.
pad_token_id
).
long
()
out
=
self
.
transformer
(
input_ids
=
x
,
attention_mask
=
attn_mask
)
pooled_out
=
self
.
pooler
(
out
,
attn_mask
)
return
self
.
proj
(
pooled_out
)
@
torch
.
jit
.
ignore
def
set_grad_checkpointing
(
self
,
enable
=
True
):
self
.
transformer
.
gradient_checkpointing_enable
()
def
init_parameters
(
self
):
pass
Prev
1
2
3
4
5
6
7
8
…
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment