Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
57e50f8d
Unverified
Commit
57e50f8d
authored
May 01, 2025
by
Muyang Li
Committed by
GitHub
May 01, 2025
Browse files
style: upgrade the linter (#339)
* style: reformated codes * style: reformated codes
parent
b737368d
Changes
174
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
121 additions
and
114 deletions
+121
-114
nunchaku/csrc/sana.h
nunchaku/csrc/sana.h
+54
-58
nunchaku/csrc/utils.h
nunchaku/csrc/utils.h
+27
-27
nunchaku/lora/flux/__init__.py
nunchaku/lora/flux/__init__.py
+2
-0
nunchaku/lora/flux/nunchaku_converter.py
nunchaku/lora/flux/nunchaku_converter.py
+1
-1
nunchaku/lora/flux/packer.py
nunchaku/lora/flux/packer.py
+1
-1
nunchaku/models/__init__.py
nunchaku/models/__init__.py
+2
-0
nunchaku/models/pulid/eva_clip/__init__.py
nunchaku/models/pulid/eva_clip/__init__.py
+2
-0
nunchaku/models/pulid/eva_clip/eva_vit_model.py
nunchaku/models/pulid/eva_clip/eva_vit_model.py
+2
-2
nunchaku/models/pulid/eva_clip/factory.py
nunchaku/models/pulid/eva_clip/factory.py
+1
-1
nunchaku/models/pulid/eva_clip/hf_model.py
nunchaku/models/pulid/eva_clip/hf_model.py
+1
-1
nunchaku/models/pulid/eva_clip/model.py
nunchaku/models/pulid/eva_clip/model.py
+2
-2
nunchaku/models/pulid/eva_clip/model_configs/EVA02-CLIP-L-14-336.json
...els/pulid/eva_clip/model_configs/EVA02-CLIP-L-14-336.json
+1
-1
nunchaku/models/pulid/eva_clip/modified_resnet.py
nunchaku/models/pulid/eva_clip/modified_resnet.py
+0
-2
nunchaku/models/pulid/eva_clip/transformer.py
nunchaku/models/pulid/eva_clip/transformer.py
+2
-2
nunchaku/models/pulid/eva_clip/utils.py
nunchaku/models/pulid/eva_clip/utils.py
+0
-1
nunchaku/models/pulid/pulid_forward.py
nunchaku/models/pulid/pulid_forward.py
+3
-2
nunchaku/models/text_encoders/linear.py
nunchaku/models/text_encoders/linear.py
+1
-1
nunchaku/models/transformers/__init__.py
nunchaku/models/transformers/__init__.py
+2
-0
nunchaku/models/transformers/transformer_flux.py
nunchaku/models/transformers/transformer_flux.py
+12
-10
nunchaku/models/transformers/transformer_sana.py
nunchaku/models/transformers/transformer_sana.py
+5
-2
No files found.
nunchaku/csrc/sana.h
View file @
57e50f8d
...
...
@@ -11,13 +11,13 @@ public:
void
init
(
pybind11
::
dict
config
,
std
::
vector
<
int
>
pag_layers
,
bool
use_fp4
,
bool
bf16
,
int8_t
deviceId
)
{
spdlog
::
info
(
"Initializing QuantizedSanaModel on device {}"
,
deviceId
);
SanaConfig
cfg
{
.
num_layers
=
config
[
"num_layers"
].
cast
<
int
>
(),
.
num_attention_heads
=
config
[
"num_attention_heads"
].
cast
<
int
>
(),
.
attention_head_dim
=
config
[
"attention_head_dim"
].
cast
<
int
>
(),
.
num_layers
=
config
[
"num_layers"
].
cast
<
int
>
(),
.
num_attention_heads
=
config
[
"num_attention_heads"
].
cast
<
int
>
(),
.
attention_head_dim
=
config
[
"attention_head_dim"
].
cast
<
int
>
(),
.
num_cross_attention_heads
=
config
[
"num_cross_attention_heads"
].
cast
<
int
>
(),
.
expand_ratio
=
config
[
"mlp_ratio"
].
cast
<
double
>
(),
.
pag_layers
=
pag_layers
,
.
use_fp4
=
use_fp4
,
.
expand_ratio
=
config
[
"mlp_ratio"
].
cast
<
double
>
(),
.
pag_layers
=
pag_layers
,
.
use_fp4
=
use_fp4
,
};
ModuleWrapper
::
init
(
deviceId
);
...
...
@@ -25,39 +25,37 @@ public:
net
=
std
::
make_unique
<
SanaModel
>
(
cfg
,
bf16
?
Tensor
::
BF16
:
Tensor
::
FP16
,
Device
::
cuda
((
int
)
deviceId
));
}
torch
::
Tensor
forward
(
torch
::
Tensor
hidden_states
,
torch
::
Tensor
encoder_hidden_states
,
torch
::
Tensor
timestep
,
torch
::
Tensor
cu_seqlens_img
,
torch
::
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
,
bool
skip_first_layer
=
false
)
{
torch
::
Tensor
forward
(
torch
::
Tensor
hidden_states
,
torch
::
Tensor
encoder_hidden_states
,
torch
::
Tensor
timestep
,
torch
::
Tensor
cu_seqlens_img
,
torch
::
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
,
bool
skip_first_layer
=
false
)
{
checkModel
();
CUDADeviceContext
ctx
(
deviceId
);
spdlog
::
debug
(
"QuantizedSanaModel forward"
);
hidden_states
=
hidden_states
.
contiguous
();
hidden_states
=
hidden_states
.
contiguous
();
encoder_hidden_states
=
encoder_hidden_states
.
contiguous
();
timestep
=
timestep
.
contiguous
();
cu_seqlens_img
=
cu_seqlens_img
.
contiguous
();
cu_seqlens_txt
=
cu_seqlens_txt
.
contiguous
();
timestep
=
timestep
.
contiguous
();
cu_seqlens_img
=
cu_seqlens_img
.
contiguous
();
cu_seqlens_txt
=
cu_seqlens_txt
.
contiguous
();
Tensor
result
=
net
->
forward
(
from_torch
(
hidden_states
),
from_torch
(
encoder_hidden_sta
te
s
),
from_torch
(
timestep
),
from_torch
(
cu_seqlens_
img
),
from_torch
(
cu_seqlens_txt
)
,
H
,
W
,
pag
,
cfg
,
skip_first_layer
);
Tensor
result
=
net
->
forward
(
from_torch
(
hidden_states
),
from_torch
(
encoder_
hidden_states
),
from_torch
(
times
te
p
),
from_torch
(
cu_seqlens_img
),
from_torch
(
cu_seqlens_
txt
),
H
,
W
,
pag
,
cfg
,
skip_first_layer
);
torch
::
Tensor
output
=
to_torch
(
result
);
// Tensor::synchronizeDevice();
...
...
@@ -65,42 +63,40 @@ public:
return
output
;
}
torch
::
Tensor
forward_layer
(
int64_t
idx
,
torch
::
Tensor
hidden_states
,
torch
::
Tensor
encoder_hidden_states
,
torch
::
Tensor
timestep
,
torch
::
Tensor
cu_seqlens_img
,
torch
::
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
)
{
torch
::
Tensor
forward_layer
(
int64_t
idx
,
torch
::
Tensor
hidden_states
,
torch
::
Tensor
encoder_hidden_states
,
torch
::
Tensor
timestep
,
torch
::
Tensor
cu_seqlens_img
,
torch
::
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
)
{
checkModel
();
CUDADeviceContext
ctx
(
deviceId
);
spdlog
::
debug
(
"QuantizedSanaModel forward_layer {}"
,
idx
);
hidden_states
=
hidden_states
.
contiguous
();
hidden_states
=
hidden_states
.
contiguous
();
encoder_hidden_states
=
encoder_hidden_states
.
contiguous
();
timestep
=
timestep
.
contiguous
();
cu_seqlens_img
=
cu_seqlens_img
.
contiguous
();
cu_seqlens_txt
=
cu_seqlens_txt
.
contiguous
();
timestep
=
timestep
.
contiguous
();
cu_seqlens_img
=
cu_seqlens_img
.
contiguous
();
cu_seqlens_txt
=
cu_seqlens_txt
.
contiguous
();
Tensor
result
=
net
->
transformer_blocks
.
at
(
idx
)
->
forward
(
from_torch
(
hidden_states
),
from_torch
(
encoder_hidden_states
),
from_torch
(
timestep
),
from_torch
(
cu_seqlens_
img
),
from_torch
(
cu_seqlens_txt
)
,
H
,
W
,
pag
,
cfg
);
Tensor
result
=
net
->
transformer_blocks
.
at
(
idx
)
->
forward
(
from_torch
(
hidden_states
),
from_torch
(
encoder_
hidden_states
),
from_torch
(
timestep
),
from_torch
(
cu_seqlens_img
),
from_torch
(
cu_seqlens_
txt
),
H
,
W
,
pag
,
cfg
);
torch
::
Tensor
output
=
to_torch
(
result
);
// Tensor::synchronizeDevice();
return
output
;
}
};
\ No newline at end of file
};
nunchaku/csrc/utils.h
View file @
57e50f8d
...
...
@@ -6,34 +6,34 @@
namespace
nunchaku
::
utils
{
void
set_cuda_stack_limit
(
int64_t
newval
)
{
size_t
val
=
0
;
checkCUDA
(
cudaDeviceSetLimit
(
cudaLimitStackSize
,
(
size_t
)
newval
));
checkCUDA
(
cudaDeviceGetLimit
(
&
val
,
cudaLimitStackSize
));
spdlog
::
debug
(
"Stack={}"
,
val
);
}
void
set_cuda_stack_limit
(
int64_t
newval
)
{
size_t
val
=
0
;
checkCUDA
(
cudaDeviceSetLimit
(
cudaLimitStackSize
,
(
size_t
)
newval
));
checkCUDA
(
cudaDeviceGetLimit
(
&
val
,
cudaLimitStackSize
));
spdlog
::
debug
(
"Stack={}"
,
val
);
}
void
disable_memory_auto_release
()
{
int
device
;
checkCUDA
(
cudaGetDevice
(
&
device
));
cudaMemPool_t
mempool
;
checkCUDA
(
cudaDeviceGetDefaultMemPool
(
&
mempool
,
device
));
uint64_t
threshold
=
UINT64_MAX
;
checkCUDA
(
cudaMemPoolSetAttribute
(
mempool
,
cudaMemPoolAttrReleaseThreshold
,
&
threshold
));
}
void
disable_memory_auto_release
()
{
int
device
;
checkCUDA
(
cudaGetDevice
(
&
device
));
cudaMemPool_t
mempool
;
checkCUDA
(
cudaDeviceGetDefaultMemPool
(
&
mempool
,
device
));
uint64_t
threshold
=
UINT64_MAX
;
checkCUDA
(
cudaMemPoolSetAttribute
(
mempool
,
cudaMemPoolAttrReleaseThreshold
,
&
threshold
));
}
void
trim_memory
()
{
int
device
;
checkCUDA
(
cudaGetDevice
(
&
device
));
cudaMemPool_t
mempool
;
checkCUDA
(
cudaDeviceGetDefaultMemPool
(
&
mempool
,
device
));
size_t
bytesToKeep
=
0
;
checkCUDA
(
cudaMemPoolTrimTo
(
mempool
,
bytesToKeep
));
}
void
trim_memory
()
{
int
device
;
checkCUDA
(
cudaGetDevice
(
&
device
));
cudaMemPool_t
mempool
;
checkCUDA
(
cudaDeviceGetDefaultMemPool
(
&
mempool
,
device
));
size_t
bytesToKeep
=
0
;
checkCUDA
(
cudaMemPoolTrimTo
(
mempool
,
bytesToKeep
));
}
void
set_faster_i2f_mode
(
std
::
string
mode
)
{
spdlog
::
info
(
"Set fasteri2f mode to {}"
,
mode
);
kernels
::
set_faster_i2f_mode
(
mode
);
}
void
set_faster_i2f_mode
(
std
::
string
mode
)
{
spdlog
::
info
(
"Set fasteri2f mode to {}"
,
mode
);
kernels
::
set_faster_i2f_mode
(
mode
);
}
};
\ No newline at end of file
};
// namespace nunchaku::utils
nunchaku/lora/flux/__init__.py
View file @
57e50f8d
from
.diffusers_converter
import
to_diffusers
from
.nunchaku_converter
import
convert_to_nunchaku_flux_lowrank_dict
,
to_nunchaku
from
.utils
import
is_nunchaku_format
__all__
=
[
"to_diffusers"
,
"to_nunchaku"
,
"convert_to_nunchaku_flux_lowrank_dict"
,
"is_nunchaku_format"
]
nunchaku/lora/flux/nunchaku_converter.py
View file @
57e50f8d
...
...
@@ -7,10 +7,10 @@ import torch
from
safetensors.torch
import
save_file
from
tqdm
import
tqdm
from
...utils
import
filter_state_dict
,
load_state_dict_in_safetensors
from
.diffusers_converter
import
to_diffusers
from
.packer
import
NunchakuWeightPacker
from
.utils
import
is_nunchaku_format
,
pad
from
...utils
import
filter_state_dict
,
load_state_dict_in_safetensors
logger
=
logging
.
getLogger
(
__name__
)
...
...
nunchaku/lora/flux/packer.py
View file @
57e50f8d
# Copy the packer from https://github.com/mit-han-lab/deepcompressor/
import
torch
from
.utils
import
pad
from
...utils
import
ceil_divide
from
.utils
import
pad
class
MmaWeightPackerBase
:
...
...
nunchaku/models/__init__.py
View file @
57e50f8d
from
.text_encoders.t5_encoder
import
NunchakuT5EncoderModel
from
.transformers
import
NunchakuFluxTransformer2dModel
,
NunchakuSanaTransformer2DModel
__all__
=
[
"NunchakuFluxTransformer2dModel"
,
"NunchakuSanaTransformer2DModel"
,
"NunchakuT5EncoderModel"
]
nunchaku/models/pulid/eva_clip/__init__.py
View file @
57e50f8d
from
.constants
import
OPENAI_DATASET_MEAN
,
OPENAI_DATASET_STD
from
.factory
import
create_model_and_transforms
__all__
=
[
"create_model_and_transforms"
,
"OPENAI_DATASET_MEAN"
,
"OPENAI_DATASET_STD"
]
nunchaku/models/pulid/eva_clip/eva_vit_model.py
View file @
57e50f8d
...
...
@@ -14,8 +14,8 @@ try:
except
ImportError
:
from
timm.layers
import
drop_path
,
to_2tuple
,
trunc_normal_
from
.transformer
import
PatchDropout
from
.rope
import
VisionRotaryEmbeddingFast
from
.transformer
import
PatchDropout
if
os
.
getenv
(
"ENV_TYPE"
)
==
"deepspeed"
:
try
:
...
...
@@ -26,7 +26,7 @@ else:
from
torch.utils.checkpoint
import
checkpoint
try
:
import
xformers
import
xformers
# noqa: F401
import
xformers.ops
as
xops
XFORMERS_IS_AVAILBLE
=
True
...
...
nunchaku/models/pulid/eva_clip/factory.py
View file @
57e50f8d
...
...
@@ -9,7 +9,7 @@ from typing import Optional, Tuple, Union
import
torch
from
.constants
import
OPENAI_DATASET_MEAN
,
OPENAI_DATASET_STD
from
.model
import
CLIP
,
convert_to_custom_text_state_dict
,
CustomCLIP
,
get_cast_dtype
from
.model
import
CLIP
,
CustomCLIP
,
convert_to_custom_text_state_dict
,
get_cast_dtype
from
.pretrained
import
download_pretrained
,
get_pretrained_cfg
,
list_pretrained_tags_by_model
from
.transform
import
image_transform
from
.utils
import
resize_clip_pos_embed
,
resize_eva_pos_embed
,
resize_evaclip_pos_embed
,
resize_visual_pos_embed
...
...
nunchaku/models/pulid/eva_clip/hf_model.py
View file @
57e50f8d
...
...
@@ -11,7 +11,7 @@ from torch import TensorType
try
:
import
transformers
from
transformers
import
AutoModel
,
AutoModelForMaskedLM
,
AutoTokenizer
,
AutoConfig
,
PretrainedConfig
from
transformers
import
AutoConfig
,
AutoModel
,
AutoModelForMaskedLM
,
AutoTokenizer
,
PretrainedConfig
except
ImportError
:
transformers
=
None
...
...
nunchaku/models/pulid/eva_clip/model.py
View file @
57e50f8d
...
...
@@ -16,9 +16,9 @@ try:
from
.hf_model
import
HFTextEncoder
except
ImportError
:
HFTextEncoder
=
None
from
.modified_resnet
import
ModifiedResNet
from
.eva_vit_model
import
EVAVisionTransformer
from
.transformer
import
LayerNorm
,
QuickGELU
,
VisionTransformer
,
TextTransformer
from
.modified_resnet
import
ModifiedResNet
from
.transformer
import
LayerNorm
,
QuickGELU
,
TextTransformer
,
VisionTransformer
try
:
from
apex.normalization
import
FusedLayerNorm
...
...
nunchaku/models/pulid/eva_clip/model_configs/EVA02-CLIP-L-14-336.json
View file @
57e50f8d
...
...
@@ -26,4 +26,4 @@
"xattn"
:
false
,
"fusedLN"
:
true
}
}
\ No newline at end of file
}
nunchaku/models/pulid/eva_clip/modified_resnet.py
View file @
57e50f8d
...
...
@@ -4,8 +4,6 @@ import torch
from
torch
import
nn
from
torch.nn
import
functional
as
F
from
.utils
import
freeze_batch_norm_2d
class
Bottleneck
(
nn
.
Module
):
expansion
=
4
...
...
nunchaku/models/pulid/eva_clip/transformer.py
View file @
57e50f8d
...
...
@@ -2,7 +2,7 @@ import logging
import
math
import
os
from
collections
import
OrderedDict
from
typing
import
Callable
,
Optional
,
Sequence
from
typing
import
Callable
,
Optional
import
torch
from
torch
import
nn
...
...
@@ -11,7 +11,7 @@ from torch.nn import functional as F
try
:
from
timm.models.layers
import
trunc_normal_
except
ImportError
:
from
timm.layers
import
trunc_normal_
from
timm.layers
import
trunc_normal_
# noqa: F401
from
.utils
import
to_2tuple
...
...
nunchaku/models/pulid/eva_clip/utils.py
View file @
57e50f8d
...
...
@@ -3,7 +3,6 @@ import logging
import
math
from
itertools
import
repeat
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
as
nn
...
...
nunchaku/models/pulid/pulid_forward.py
View file @
57e50f8d
# Adapted from https://github.com/ToTheBeginning/PuLID
import
torch
import
logging
from
typing
import
Any
,
Dict
,
Optional
,
Union
import
torch
from
diffusers.models.modeling_outputs
import
Transformer2DModelOutput
import
logging
logger
=
logging
.
getLogger
(
__name__
)
...
...
nunchaku/models/text_encoders/linear.py
View file @
57e50f8d
...
...
@@ -4,8 +4,8 @@
import
torch
import
torch.nn
as
nn
from
.tinychat_utils
import
ceil_num_groups
,
convert_to_tinychat_w4x16y16_linear_weight
from
..._C.ops
import
gemm_awq
,
gemv_awq
from
.tinychat_utils
import
ceil_num_groups
,
convert_to_tinychat_w4x16y16_linear_weight
__all__
=
[
"W4Linear"
]
...
...
nunchaku/models/transformers/__init__.py
View file @
57e50f8d
from
.transformer_flux
import
NunchakuFluxTransformer2dModel
from
.transformer_sana
import
NunchakuSanaTransformer2DModel
__all__
=
[
"NunchakuFluxTransformer2dModel"
,
"NunchakuSanaTransformer2DModel"
]
nunchaku/models/transformers/transformer_flux.py
View file @
57e50f8d
...
...
@@ -12,11 +12,12 @@ from packaging.version import Version
from
safetensors.torch
import
load_file
from
torch
import
nn
from
.
utils
import
NunchakuModelLoaderMixin
,
pad_tensor
from
..._C
import
QuantizedFluxModel
,
utils
as
cutils
from
.
.._C
import
QuantizedFluxModel
from
..._C
import
utils
as
cutils
from
...lora.flux.nunchaku_converter
import
fuse_vectors
,
to_nunchaku
from
...lora.flux.utils
import
is_nunchaku_format
from
...utils
import
get_precision
,
load_state_dict_in_safetensors
from
.utils
import
NunchakuModelLoaderMixin
,
pad_tensor
SVD_RANK
=
32
...
...
@@ -77,7 +78,7 @@ class NunchakuFluxTransformerBlocks(nn.Module):
self
.
id_embeddings
=
id_embeddings
self
.
id_weight
=
id_weight
self
.
pulid_ca_idx
=
0
if
self
.
id_embeddings
is
not
None
:
if
self
.
id_embeddings
is
not
None
:
self
.
set_residual_callback
()
original_dtype
=
hidden_states
.
dtype
...
...
@@ -122,13 +123,12 @@ class NunchakuFluxTransformerBlocks(nn.Module):
rotary_emb_single
,
controlnet_block_samples
,
controlnet_single_block_samples
,
skip_first_layer
skip_first_layer
,
)
if
self
.
id_embeddings
is
not
None
:
if
self
.
id_embeddings
is
not
None
:
self
.
reset_residual_callback
()
hidden_states
=
hidden_states
.
to
(
original_dtype
).
to
(
original_device
)
encoder_hidden_states
=
hidden_states
[:,
:
txt_tokens
,
...]
...
...
@@ -191,20 +191,25 @@ class NunchakuFluxTransformerBlocks(nn.Module):
encoder_hidden_states
=
encoder_hidden_states
.
to
(
original_dtype
).
to
(
original_device
)
return
encoder_hidden_states
,
hidden_states
def
set_residual_callback
(
self
):
id_embeddings
=
self
.
id_embeddings
pulid_ca
=
self
.
pulid_ca
pulid_ca_idx
=
[
self
.
pulid_ca_idx
]
id_weight
=
self
.
id_weight
def
callback
(
hidden_states
):
ip
=
id_weight
*
pulid_ca
[
pulid_ca_idx
[
0
]](
id_embeddings
,
hidden_states
.
to
(
"cuda"
))
pulid_ca_idx
[
0
]
+=
1
return
ip
self
.
callback_holder
=
callback
self
.
m
.
set_residual_callback
(
callback
)
def
reset_residual_callback
(
self
):
self
.
callback_holder
=
None
self
.
m
.
set_residual_callback
(
None
)
def
__del__
(
self
):
self
.
m
.
reset
()
...
...
@@ -477,10 +482,7 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
if
len
(
self
.
_unquantized_part_loras
)
>
0
or
len
(
unquantized_part_loras
)
>
0
:
self
.
_unquantized_part_loras
=
unquantized_part_loras
self
.
_unquantized_part_sd
=
{
k
:
v
for
k
,
v
in
self
.
_unquantized_part_sd
.
items
()
if
"pulid_ca"
not
in
k
}
self
.
_unquantized_part_sd
=
{
k
:
v
for
k
,
v
in
self
.
_unquantized_part_sd
.
items
()
if
"pulid_ca"
not
in
k
}
self
.
_update_unquantized_part_lora_params
(
1
)
quantized_part_vectors
=
{}
...
...
nunchaku/models/transformers/transformer_sana.py
View file @
57e50f8d
...
...
@@ -8,9 +8,10 @@ from safetensors.torch import load_file
from
torch
import
nn
from
torch.nn
import
functional
as
F
from
.utils
import
NunchakuModelLoaderMixin
from
..._C
import
QuantizedSanaModel
from
..._C
import
utils
as
cutils
from
...utils
import
get_precision
from
.
.._C
import
QuantizedSanaModel
,
utils
as
cutils
from
.
utils
import
NunchakuModelLoaderMixin
SVD_RANK
=
32
...
...
@@ -130,9 +131,11 @@ class NunchakuSanaTransformerBlocks(nn.Module):
.
to
(
original_dtype
)
.
to
(
original_device
)
)
def
__del__
(
self
):
self
.
m
.
reset
()
class
NunchakuSanaTransformer2DModel
(
SanaTransformer2DModel
,
NunchakuModelLoaderMixin
):
@
classmethod
@
utils
.
validate_hf_hub_args
...
...
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment