Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
57e50f8d
Unverified
Commit
57e50f8d
authored
May 01, 2025
by
Muyang Li
Committed by
GitHub
May 01, 2025
Browse files
style: upgrade the linter (#339)
* style: reformated codes * style: reformated codes
parent
b737368d
Changes
173
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
310 additions
and
393 deletions
+310
-393
app/flux.1/t2i/utils.py
app/flux.1/t2i/utils.py
+1
-1
app/sana/t2i/README.md
app/sana/t2i/README.md
+1
-1
app/sana/t2i/assets/common.css
app/sana/t2i/assets/common.css
+1
-1
app/sana/t2i/assets/description.html
app/sana/t2i/assets/description.html
+2
-2
app/sana/t2i/generate.py
app/sana/t2i/generate.py
+0
-1
app/sana/t2i/latency.py
app/sana/t2i/latency.py
+0
-1
app/sana/t2i/run_gradio.py
app/sana/t2i/run_gradio.py
+6
-6
assets/nunchaku.svg
assets/nunchaku.svg
+1
-1
assets/svdquant.svg
assets/svdquant.svg
+1
-1
docs/contribution_guide.md
docs/contribution_guide.md
+1
-1
docs/setup_windows.md
docs/setup_windows.md
+32
-32
examples/flux.1-dev-controlnet-union-pro.py
examples/flux.1-dev-controlnet-union-pro.py
+0
-6
examples/flux.1-dev-double_cache.py
examples/flux.1-dev-double_cache.py
+3
-10
nunchaku/__init__.py
nunchaku/__init__.py
+2
-0
nunchaku/csrc/flux.h
nunchaku/csrc/flux.h
+64
-74
nunchaku/csrc/gemm.h
nunchaku/csrc/gemm.h
+13
-12
nunchaku/csrc/gemm88.h
nunchaku/csrc/gemm88.h
+5
-4
nunchaku/csrc/module.h
nunchaku/csrc/module.h
+5
-5
nunchaku/csrc/ops.h
nunchaku/csrc/ops.h
+119
-162
nunchaku/csrc/pybind.cpp
nunchaku/csrc/pybind.cpp
+53
-72
No files found.
app/flux.1/t2i/utils.py
View file @
57e50f8d
import
torch
import
torch
from
diffusers
import
FluxPipeline
from
diffusers
import
FluxPipeline
from
peft.tuners
import
lora
from
peft.tuners
import
lora
from
vars
import
LORA_PATHS
,
SVDQ_LORA_PATHS
from
nunchaku
import
NunchakuFluxTransformer2dModel
from
nunchaku
import
NunchakuFluxTransformer2dModel
from
vars
import
LORA_PATHS
,
SVDQ_LORA_PATHS
def
hash_str_to_int
(
s
:
str
)
->
int
:
def
hash_str_to_int
(
s
:
str
)
->
int
:
...
...
app/sana/t2i/README.md
View file @
57e50f8d
app/sana/t2i/assets/common.css
View file @
57e50f8d
app/sana/t2i/assets/description.html
View file @
57e50f8d
<div
style=
"display: flex; justify-content: center; align-items: center; text-align: center;"
>
<div
style=
"display: flex; justify-content: center; align-items: center; text-align: center;"
>
<div>
<div>
<h1>
<h1>
<img
src=
"https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/
logo
.svg"
<img
src=
"https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/
svdquant
.svg"
alt=
"logo"
alt=
"logo"
style=
"height: 40px; width: auto; display: block; margin: auto;"
/>
style=
"height: 40px; width: auto; display: block; margin: auto;"
/>
<a
href=
'https://nvlabs.github.io/Sana/'
target=
"_blank"
>
SANA-1.6B
</a>
Demo
<a
href=
'https://nvlabs.github.io/Sana/'
target=
"_blank"
>
SANA-1.6B
</a>
Demo
...
...
app/sana/t2i/generate.py
View file @
57e50f8d
...
@@ -2,7 +2,6 @@ import argparse
...
@@ -2,7 +2,6 @@ import argparse
import
os
import
os
import
torch
import
torch
from
utils
import
get_pipeline
from
utils
import
get_pipeline
...
...
app/sana/t2i/latency.py
View file @
57e50f8d
...
@@ -4,7 +4,6 @@ import time
...
@@ -4,7 +4,6 @@ import time
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
tqdm
import
trange
from
tqdm
import
trange
from
utils
import
get_pipeline
from
utils
import
get_pipeline
...
...
app/sana/t2i/run_gradio.py
View file @
57e50f8d
...
@@ -8,13 +8,13 @@ from datetime import datetime
...
@@ -8,13 +8,13 @@ from datetime import datetime
import
GPUtil
import
GPUtil
import
spaces
import
spaces
import
torch
import
torch
from
nunchaku.models.safety_checker
import
SafetyChecker
from
utils
import
get_pipeline
from
utils
import
get_pipeline
from
vars
import
EXAMPLES
,
MAX_SEED
from
vars
import
EXAMPLES
,
MAX_SEED
from
nunchaku.models.safety_checker
import
SafetyChecker
# import gradio last to avoid conflicts with other imports
# import gradio last to avoid conflicts with other imports
import
gradio
as
gr
import
gradio
as
gr
# noqa: isort: skip
def
get_args
()
->
argparse
.
Namespace
:
def
get_args
()
->
argparse
.
Namespace
:
...
@@ -73,7 +73,7 @@ def generate(
...
@@ -73,7 +73,7 @@ def generate(
prompt
=
"A peaceful world."
prompt
=
"A peaceful world."
images
,
latency_strs
=
[],
[]
images
,
latency_strs
=
[],
[]
for
i
,
pipeline
in
enumerate
(
pipelines
):
for
i
,
pipeline
in
enumerate
(
pipelines
):
progress
=
gr
.
Progress
(
track_tqdm
=
True
)
gr
.
Progress
(
track_tqdm
=
True
)
start_time
=
time
.
time
()
start_time
=
time
.
time
()
image
=
pipeline
(
image
=
pipeline
(
prompt
=
prompt
,
prompt
=
prompt
,
...
@@ -124,11 +124,11 @@ if len(gpus) > 0:
...
@@ -124,11 +124,11 @@ if len(gpus) > 0:
device_info
=
f
"Running on
{
gpu
.
name
}
with
{
memory
:.
0
f
}
GiB memory."
device_info
=
f
"Running on
{
gpu
.
name
}
with
{
memory
:.
0
f
}
GiB memory."
else
:
else
:
device_info
=
"Running on CPU 🥶 This demo does not work on CPU."
device_info
=
"Running on CPU 🥶 This demo does not work on CPU."
notice
=
f
'<strong>Notice:</strong> We will replace unsafe prompts with a default prompt: "A peaceful world."'
notice
=
'<strong>Notice:</strong> We will replace unsafe prompts with a default prompt: "A peaceful world."'
with
gr
.
Blocks
(
with
gr
.
Blocks
(
css_paths
=
[
f
"assets/frame
{
len
(
args
.
precisions
)
}
.css"
,
"assets/common.css"
],
css_paths
=
[
f
"assets/frame
{
len
(
args
.
precisions
)
}
.css"
,
"assets/common.css"
],
title
=
f
"SVDQuant SANA-1600M Demo"
,
title
=
"SVDQuant SANA-1600M Demo"
,
)
as
demo
:
)
as
demo
:
def
get_header_str
():
def
get_header_str
():
...
...
assets/nunchaku.svg
View file @
57e50f8d
assets/svdquant.svg
View file @
57e50f8d
docs/contribution_guide.md
View file @
57e50f8d
docs/setup_windows.md
View file @
57e50f8d
examples/flux.1-dev-controlnet-union-pro.py
View file @
57e50f8d
...
@@ -4,7 +4,6 @@ from diffusers.models import FluxMultiControlNetModel
...
@@ -4,7 +4,6 @@ from diffusers.models import FluxMultiControlNetModel
from
diffusers.utils
import
load_image
from
diffusers.utils
import
load_image
from
nunchaku
import
NunchakuFluxTransformer2dModel
from
nunchaku
import
NunchakuFluxTransformer2dModel
from
nunchaku.caching.diffusers_adapters.flux
import
apply_cache_on_pipe
from
nunchaku.utils
import
get_gpu_memory
,
get_precision
from
nunchaku.utils
import
get_gpu_memory
,
get_precision
base_model
=
"black-forest-labs/FLUX.1-dev"
base_model
=
"black-forest-labs/FLUX.1-dev"
...
@@ -29,11 +28,6 @@ if need_offload:
...
@@ -29,11 +28,6 @@ if need_offload:
else
:
else
:
pipeline
=
pipeline
.
to
(
"cuda"
)
pipeline
=
pipeline
.
to
(
"cuda"
)
# apply_cache_on_pipe(
# pipeline, residual_diff_threshold=0.1
# ) # Uncomment this line to enable first-block cache to speedup generation
prompt
=
"A anime style girl with messy beach waves."
prompt
=
"A anime style girl with messy beach waves."
control_image_depth
=
load_image
(
control_image_depth
=
load_image
(
"https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/depth.jpg"
"https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/depth.jpg"
...
...
examples/flux.1-dev-double_cache.py
View file @
57e50f8d
...
@@ -7,14 +7,10 @@ from nunchaku.utils import get_precision
...
@@ -7,14 +7,10 @@ from nunchaku.utils import get_precision
precision
=
get_precision
()
precision
=
get_precision
()
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
f
"mit-han-lab/svdq-
{
precision
}
-flux.1-dev"
)
f
"mit-han-lab/svdq-
{
precision
}
-flux.1-dev"
)
pipeline
=
FluxPipeline
.
from_pretrained
(
pipeline
=
FluxPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-dev"
,
"black-forest-labs/FLUX.1-dev"
,
transformer
=
transformer
,
torch_dtype
=
torch
.
bfloat16
transformer
=
transformer
,
torch_dtype
=
torch
.
bfloat16
).
to
(
"cuda"
)
).
to
(
"cuda"
)
apply_cache_on_pipe
(
apply_cache_on_pipe
(
...
@@ -24,9 +20,6 @@ apply_cache_on_pipe(
...
@@ -24,9 +20,6 @@ apply_cache_on_pipe(
residual_diff_threshold_single
=
0.12
,
residual_diff_threshold_single
=
0.12
,
)
)
image
=
pipeline
(
image
=
pipeline
([
"A cat holding a sign that says hello world"
],
num_inference_steps
=
50
).
images
[
0
]
[
"A cat holding a sign that says hello world"
],
num_inference_steps
=
50
).
images
[
0
]
image
.
save
(
f
"flux.1-dev-cache-
{
precision
}
.png"
)
image
.
save
(
f
"flux.1-dev-cache-
{
precision
}
.png"
)
nunchaku/__init__.py
View file @
57e50f8d
from
.models
import
NunchakuFluxTransformer2dModel
,
NunchakuSanaTransformer2DModel
,
NunchakuT5EncoderModel
from
.models
import
NunchakuFluxTransformer2dModel
,
NunchakuSanaTransformer2DModel
,
NunchakuT5EncoderModel
__all__
=
[
"NunchakuFluxTransformer2dModel"
,
"NunchakuSanaTransformer2DModel"
,
"NunchakuT5EncoderModel"
]
nunchaku/csrc/flux.h
View file @
57e50f8d
...
@@ -20,7 +20,8 @@ public:
...
@@ -20,7 +20,8 @@ public:
ModuleWrapper
::
init
(
deviceId
);
ModuleWrapper
::
init
(
deviceId
);
CUDADeviceContext
ctx
(
this
->
deviceId
);
CUDADeviceContext
ctx
(
this
->
deviceId
);
net
=
std
::
make_unique
<
FluxModel
>
(
use_fp4
,
offload
,
bf16
?
Tensor
::
BF16
:
Tensor
::
FP16
,
Device
::
cuda
((
int
)
deviceId
));
net
=
std
::
make_unique
<
FluxModel
>
(
use_fp4
,
offload
,
bf16
?
Tensor
::
BF16
:
Tensor
::
FP16
,
Device
::
cuda
((
int
)
deviceId
));
}
}
bool
isBF16
()
{
bool
isBF16
()
{
...
@@ -32,7 +33,7 @@ public:
...
@@ -32,7 +33,7 @@ public:
pybind11
::
gil_scoped_acquire
gil
;
pybind11
::
gil_scoped_acquire
gil
;
if
(
!
callback
||
callback
.
is_none
())
{
if
(
!
callback
||
callback
.
is_none
())
{
residual_callback
=
pybind11
::
function
();
residual_callback
=
pybind11
::
function
();
if
(
net
){
if
(
net
)
{
net
->
set_residual_callback
(
nullptr
);
net
->
set_residual_callback
(
nullptr
);
}
}
return
;
return
;
...
@@ -52,8 +53,7 @@ public:
...
@@ -52,8 +53,7 @@ public:
}
}
}
}
torch
::
Tensor
forward
(
torch
::
Tensor
forward
(
torch
::
Tensor
hidden_states
,
torch
::
Tensor
hidden_states
,
torch
::
Tensor
encoder_hidden_states
,
torch
::
Tensor
encoder_hidden_states
,
torch
::
Tensor
temb
,
torch
::
Tensor
temb
,
torch
::
Tensor
rotary_emb_img
,
torch
::
Tensor
rotary_emb_img
,
...
@@ -61,8 +61,7 @@ public:
...
@@ -61,8 +61,7 @@ public:
torch
::
Tensor
rotary_emb_single
,
torch
::
Tensor
rotary_emb_single
,
std
::
optional
<
torch
::
Tensor
>
controlnet_block_samples
=
std
::
nullopt
,
std
::
optional
<
torch
::
Tensor
>
controlnet_block_samples
=
std
::
nullopt
,
std
::
optional
<
torch
::
Tensor
>
controlnet_single_block_samples
=
std
::
nullopt
,
std
::
optional
<
torch
::
Tensor
>
controlnet_single_block_samples
=
std
::
nullopt
,
bool
skip_first_layer
=
false
)
bool
skip_first_layer
=
false
)
{
{
checkModel
();
checkModel
();
CUDADeviceContext
ctx
(
deviceId
);
CUDADeviceContext
ctx
(
deviceId
);
...
@@ -83,9 +82,10 @@ public:
...
@@ -83,9 +82,10 @@ public:
from_torch
(
rotary_emb_context
),
from_torch
(
rotary_emb_context
),
from_torch
(
rotary_emb_single
),
from_torch
(
rotary_emb_single
),
controlnet_block_samples
.
has_value
()
?
from_torch
(
controlnet_block_samples
.
value
().
contiguous
())
:
Tensor
{},
controlnet_block_samples
.
has_value
()
?
from_torch
(
controlnet_block_samples
.
value
().
contiguous
())
:
Tensor
{},
controlnet_single_block_samples
.
has_value
()
?
from_torch
(
controlnet_single_block_samples
.
value
().
contiguous
())
:
Tensor
{},
controlnet_single_block_samples
.
has_value
()
skip_first_layer
?
from_torch
(
controlnet_single_block_samples
.
value
().
contiguous
())
);
:
Tensor
{},
skip_first_layer
);
torch
::
Tensor
output
=
to_torch
(
result
);
torch
::
Tensor
output
=
to_torch
(
result
);
Tensor
::
synchronizeDevice
();
Tensor
::
synchronizeDevice
();
...
@@ -93,16 +93,15 @@ public:
...
@@ -93,16 +93,15 @@ public:
return
output
;
return
output
;
}
}
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
forward_layer
(
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
int64_t
idx
,
forward_layer
(
int64_t
idx
,
torch
::
Tensor
hidden_states
,
torch
::
Tensor
hidden_states
,
torch
::
Tensor
encoder_hidden_states
,
torch
::
Tensor
encoder_hidden_states
,
torch
::
Tensor
temb
,
torch
::
Tensor
temb
,
torch
::
Tensor
rotary_emb_img
,
torch
::
Tensor
rotary_emb_img
,
torch
::
Tensor
rotary_emb_context
,
torch
::
Tensor
rotary_emb_context
,
std
::
optional
<
torch
::
Tensor
>
controlnet_block_samples
=
std
::
nullopt
,
std
::
optional
<
torch
::
Tensor
>
controlnet_block_samples
=
std
::
nullopt
,
std
::
optional
<
torch
::
Tensor
>
controlnet_single_block_samples
=
std
::
nullopt
)
std
::
optional
<
torch
::
Tensor
>
controlnet_single_block_samples
=
std
::
nullopt
)
{
{
CUDADeviceContext
ctx
(
deviceId
);
CUDADeviceContext
ctx
(
deviceId
);
spdlog
::
debug
(
"QuantizedFluxModel forward_layer {}"
,
idx
);
spdlog
::
debug
(
"QuantizedFluxModel forward_layer {}"
,
idx
);
...
@@ -121,22 +120,21 @@ public:
...
@@ -121,22 +120,21 @@ public:
from_torch
(
rotary_emb_img
),
from_torch
(
rotary_emb_img
),
from_torch
(
rotary_emb_context
),
from_torch
(
rotary_emb_context
),
controlnet_block_samples
.
has_value
()
?
from_torch
(
controlnet_block_samples
.
value
().
contiguous
())
:
Tensor
{},
controlnet_block_samples
.
has_value
()
?
from_torch
(
controlnet_block_samples
.
value
().
contiguous
())
:
Tensor
{},
controlnet_single_block_samples
.
has_value
()
?
from_torch
(
controlnet_single_block_samples
.
value
().
contiguous
())
:
Tensor
{}
controlnet_single_block_samples
.
has_value
()
);
?
from_torch
(
controlnet_single_block_samples
.
value
().
contiguous
())
:
Tensor
{});
hidden_states
=
to_torch
(
hidden_states_
);
hidden_states
=
to_torch
(
hidden_states_
);
encoder_hidden_states
=
to_torch
(
encoder_hidden_states_
);
encoder_hidden_states
=
to_torch
(
encoder_hidden_states_
);
Tensor
::
synchronizeDevice
();
Tensor
::
synchronizeDevice
();
return
{
hidden_states
,
encoder_hidden_states
};
return
{
hidden_states
,
encoder_hidden_states
};
}
}
torch
::
Tensor
forward_single_layer
(
torch
::
Tensor
forward_single_layer
(
int64_t
idx
,
int64_t
idx
,
torch
::
Tensor
hidden_states
,
torch
::
Tensor
hidden_states
,
torch
::
Tensor
temb
,
torch
::
Tensor
temb
,
torch
::
Tensor
rotary_emb_single
)
torch
::
Tensor
rotary_emb_single
)
{
{
CUDADeviceContext
ctx
(
deviceId
);
CUDADeviceContext
ctx
(
deviceId
);
spdlog
::
debug
(
"QuantizedFluxModel forward_single_layer {}"
,
idx
);
spdlog
::
debug
(
"QuantizedFluxModel forward_single_layer {}"
,
idx
);
...
@@ -146,10 +144,7 @@ public:
...
@@ -146,10 +144,7 @@ public:
rotary_emb_single
=
rotary_emb_single
.
contiguous
();
rotary_emb_single
=
rotary_emb_single
.
contiguous
();
Tensor
result
=
net
->
single_transformer_blocks
.
at
(
idx
)
->
forward
(
Tensor
result
=
net
->
single_transformer_blocks
.
at
(
idx
)
->
forward
(
from_torch
(
hidden_states
),
from_torch
(
hidden_states
),
from_torch
(
temb
),
from_torch
(
rotary_emb_single
));
from_torch
(
temb
),
from_torch
(
rotary_emb_single
)
);
hidden_states
=
to_torch
(
result
);
hidden_states
=
to_torch
(
result
);
Tensor
::
synchronizeDevice
();
Tensor
::
synchronizeDevice
();
...
@@ -159,19 +154,15 @@ public:
...
@@ -159,19 +154,15 @@ public:
// expose the norm1 forward method of the transformer blocks
// expose the norm1 forward method of the transformer blocks
// this is used by TeaCache to get the norm1 output
// this is used by TeaCache to get the norm1 output
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>
norm_one_forward
(
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>
int64_t
idx
,
norm_one_forward
(
int64_t
idx
,
torch
::
Tensor
hidden_states
,
torch
::
Tensor
temb
)
{
torch
::
Tensor
hidden_states
,
AdaLayerNormZero
::
Output
result
=
torch
::
Tensor
temb
net
->
transformer_blocks
.
at
(
idx
)
->
norm1
.
forward
(
from_torch
(
hidden_states
),
from_torch
(
temb
));
)
{
return
{
to_torch
(
result
.
x
),
AdaLayerNormZero
::
Output
result
=
net
->
transformer_blocks
.
at
(
idx
)
->
norm1
.
forward
(
from_torch
(
hidden_states
),
from_torch
(
temb
));
return
{
to_torch
(
result
.
x
),
to_torch
(
result
.
gate_msa
),
to_torch
(
result
.
gate_msa
),
to_torch
(
result
.
shift_mlp
),
to_torch
(
result
.
shift_mlp
),
to_torch
(
result
.
scale_mlp
),
to_torch
(
result
.
scale_mlp
),
to_torch
(
result
.
gate_mlp
)
to_torch
(
result
.
gate_mlp
)};
};
}
}
// must be called after loading lora
// must be called after loading lora
...
@@ -214,5 +205,4 @@ public:
...
@@ -214,5 +205,4 @@ public:
throw
std
::
invalid_argument
(
spdlog
::
fmt_lib
::
format
(
"Invalid attention implementation {}"
,
name
));
throw
std
::
invalid_argument
(
spdlog
::
fmt_lib
::
format
(
"Invalid attention implementation {}"
,
name
));
}
}
}
}
};
};
nunchaku/csrc/gemm.h
View file @
57e50f8d
...
@@ -16,7 +16,12 @@ public:
...
@@ -16,7 +16,12 @@ public:
checkCUDA
(
cudaDeviceGetLimit
(
&
val
,
cudaLimitStackSize
));
checkCUDA
(
cudaDeviceGetLimit
(
&
val
,
cudaLimitStackSize
));
spdlog
::
debug
(
"Stack={}"
,
val
);
spdlog
::
debug
(
"Stack={}"
,
val
);
net
=
std
::
make_unique
<
GEMM_W4A4
>
((
int
)
in_features
,
(
int
)
out_features
,
bias
,
use_fp4
,
bf16
?
Tensor
::
BF16
:
Tensor
::
FP16
,
Device
::
cuda
((
int
)
deviceId
));
net
=
std
::
make_unique
<
GEMM_W4A4
>
((
int
)
in_features
,
(
int
)
out_features
,
bias
,
use_fp4
,
bf16
?
Tensor
::
BF16
:
Tensor
::
FP16
,
Device
::
cuda
((
int
)
deviceId
));
}
}
torch
::
Tensor
forward
(
torch
::
Tensor
x
)
{
torch
::
Tensor
forward
(
torch
::
Tensor
x
)
{
...
@@ -95,10 +100,7 @@ public:
...
@@ -95,10 +100,7 @@ public:
x
=
x
.
contiguous
();
x
=
x
.
contiguous
();
auto
qout
=
net
->
quantize
(
auto
qout
=
net
->
quantize
(
from_torch
(
x
),
fuse_glu
);
from_torch
(
x
),
fuse_glu
);
Tensor
act
=
qout
.
act
.
copy
(
Device
::
cpu
());
Tensor
act
=
qout
.
act
.
copy
(
Device
::
cpu
());
Tensor
ascales
=
qout
.
ascales
.
copy
(
Device
::
cpu
());
Tensor
ascales
=
qout
.
ascales
.
copy
(
Device
::
cpu
());
...
@@ -109,5 +111,4 @@ public:
...
@@ -109,5 +111,4 @@ public:
spdlog
::
debug
(
"act = {}"
,
dumpTensorINT4
(
act
));
spdlog
::
debug
(
"act = {}"
,
dumpTensorINT4
(
act
));
spdlog
::
debug
(
"ascales = {}"
,
dumpTensorBF16
(
ascales
));
spdlog
::
debug
(
"ascales = {}"
,
dumpTensorBF16
(
ascales
));
}
}
};
};
nunchaku/csrc/gemm88.h
View file @
57e50f8d
...
@@ -16,7 +16,8 @@ public:
...
@@ -16,7 +16,8 @@ public:
checkCUDA
(
cudaDeviceGetLimit
(
&
val
,
cudaLimitStackSize
));
checkCUDA
(
cudaDeviceGetLimit
(
&
val
,
cudaLimitStackSize
));
spdlog
::
debug
(
"Stack={}"
,
val
);
spdlog
::
debug
(
"Stack={}"
,
val
);
net
=
std
::
make_unique
<
GEMM_W8A8
>
((
int
)
in_features
,
(
int
)
out_features
,
bias
,
bf16
?
Tensor
::
BF16
:
Tensor
::
FP16
,
Device
::
cuda
((
int
)
deviceId
));
net
=
std
::
make_unique
<
GEMM_W8A8
>
(
(
int
)
in_features
,
(
int
)
out_features
,
bias
,
bf16
?
Tensor
::
BF16
:
Tensor
::
FP16
,
Device
::
cuda
((
int
)
deviceId
));
}
}
torch
::
Tensor
forward
(
torch
::
Tensor
x
)
{
torch
::
Tensor
forward
(
torch
::
Tensor
x
)
{
...
...
nunchaku/csrc/module.h
View file @
57e50f8d
nunchaku/csrc/ops.h
View file @
57e50f8d
...
@@ -7,8 +7,7 @@
...
@@ -7,8 +7,7 @@
namespace
nunchaku
::
ops
{
namespace
nunchaku
::
ops
{
void
gemm_w4a4
(
void
gemm_w4a4
(
std
::
optional
<
torch
::
Tensor
>
act
,
// packed act [M, K / 2]
std
::
optional
<
torch
::
Tensor
>
act
,
// packed act [M, K / 2]
std
::
optional
<
torch
::
Tensor
>
wgt
,
// packed act [N, K / 2]
std
::
optional
<
torch
::
Tensor
>
wgt
,
// packed act [N, K / 2]
std
::
optional
<
torch
::
Tensor
>
out
,
// linear [M, N]
std
::
optional
<
torch
::
Tensor
>
out
,
// linear [M, N]
std
::
optional
<
torch
::
Tensor
>
qout
,
// packed act [M, N / 2]
std
::
optional
<
torch
::
Tensor
>
qout
,
// packed act [M, N / 2]
...
@@ -26,7 +25,7 @@ namespace nunchaku::ops {
...
@@ -26,7 +25,7 @@ namespace nunchaku::ops {
std
::
optional
<
torch
::
Tensor
>
bias
,
// packed ws [N]
std
::
optional
<
torch
::
Tensor
>
bias
,
// packed ws [N]
std
::
optional
<
torch
::
Tensor
>
smooth_factor
,
// packed ws [N], for quantization of the next layer
std
::
optional
<
torch
::
Tensor
>
smooth_factor
,
// packed ws [N], for quantization of the next layer
std
::
optional
<
torch
::
Tensor
>
out_vk
,
// linear [B, num_heads, head_dim + 1, head_dim]
std
::
optional
<
torch
::
Tensor
>
out_vk
,
// linear [B, num_heads, head_dim + 1, head_dim]
std
::
optional
<
torch
::
Tensor
>
out_linearattn
,
// linear [B, (M), N / 3]
std
::
optional
<
torch
::
Tensor
>
out_linearattn
,
// linear [B, (M), N / 3]
bool
act_unsigned
,
bool
act_unsigned
,
std
::
vector
<
float
>
lora_scales
,
std
::
vector
<
float
>
lora_scales
,
bool
fuse_silu
,
bool
fuse_silu
,
...
@@ -36,8 +35,7 @@ namespace nunchaku::ops {
...
@@ -36,8 +35,7 @@ namespace nunchaku::ops {
std
::
optional
<
torch
::
Tensor
>
out_q
,
// packed attention [B, H, M, D]
std
::
optional
<
torch
::
Tensor
>
out_q
,
// packed attention [B, H, M, D]
std
::
optional
<
torch
::
Tensor
>
out_k
,
// packed attention [B, H, M, D]
std
::
optional
<
torch
::
Tensor
>
out_k
,
// packed attention [B, H, M, D]
std
::
optional
<
torch
::
Tensor
>
out_v
,
// packed attention [B, H, M, D]
std
::
optional
<
torch
::
Tensor
>
out_v
,
// packed attention [B, H, M, D]
int
attn_tokens
int
attn_tokens
)
{
)
{
spdlog
::
trace
(
"running gemm_w4a4: "
);
spdlog
::
trace
(
"running gemm_w4a4: "
);
auto
getTensor
=
[](
std
::
optional
<
torch
::
Tensor
>
&
t
)
{
auto
getTensor
=
[](
std
::
optional
<
torch
::
Tensor
>
&
t
)
{
...
@@ -49,25 +47,24 @@ namespace nunchaku::ops {
...
@@ -49,25 +47,24 @@ namespace nunchaku::ops {
}
}
return
ret
;
return
ret
;
};
};
nunchaku
::
kernels
::
gemm_w4a4
(
nunchaku
::
kernels
::
gemm_w4a4
(
getTensor
(
act
),
getTensor
(
act
),
getTensor
(
wgt
),
getTensor
(
wgt
),
getTensor
(
out
),
getTensor
(
out
),
getTensor
(
qout
),
getTensor
(
qout
),
getTensor
(
ascales
),
getTensor
(
ascales
),
getTensor
(
wscales
),
getTensor
(
wscales
),
getTensor
(
oscales
),
getTensor
(
oscales
),
getTensor
(
poolout
),
getTensor
(
poolout
),
getTensor
(
lora_act_in
),
getTensor
(
lora_act_in
),
getTensor
(
lora_up
),
getTensor
(
lora_up
),
getTensor
(
lora_down
),
getTensor
(
lora_down
),
getTensor
(
lora_act_out
),
getTensor
(
lora_act_out
),
getTensor
(
norm_q
),
getTensor
(
norm_q
),
getTensor
(
norm_k
),
getTensor
(
norm_k
),
getTensor
(
rotary_emb
),
getTensor
(
rotary_emb
),
getTensor
(
bias
),
getTensor
(
bias
),
getTensor
(
smooth_factor
),
getTensor
(
smooth_factor
),
getTensor
(
out_vk
),
getTensor
(
out_vk
),
getTensor
(
out_linearattn
),
getTensor
(
out_linearattn
),
act_unsigned
,
act_unsigned
,
lora_scales
,
lora_scales
,
...
@@ -78,104 +75,64 @@ namespace nunchaku::ops {
...
@@ -78,104 +75,64 @@ namespace nunchaku::ops {
getTensor
(
out_q
),
getTensor
(
out_q
),
getTensor
(
out_k
),
getTensor
(
out_k
),
getTensor
(
out_v
),
getTensor
(
out_v
),
attn_tokens
attn_tokens
);
);
// Tensor::synchronizeDevice();
// Tensor::synchronizeDevice();
}
}
void
attention_fp16
(
void
attention_fp16
(
torch
::
Tensor
q
,
// packed [Batch, Head, TokensQ, HEAD_DIM]
torch
::
Tensor
q
,
// packed [Batch, Head, TokensQ, HEAD_DIM]
torch
::
Tensor
k
,
// packed [Batch, Head, TokensKV, HEAD_DIM]
torch
::
Tensor
k
,
// packed [Batch, Head, TokensKV, HEAD_DIM]
torch
::
Tensor
v
,
// packed [Batch, Head, TokensKV, HEAD_DIM]
torch
::
Tensor
v
,
// packed [Batch, Head, TokensKV, HEAD_DIM]
torch
::
Tensor
o
,
// linear [Batch, TokensQ, Head * HEAD_DIM]
torch
::
Tensor
o
,
// linear [Batch, TokensQ, Head * HEAD_DIM]
float
scale
float
scale
)
{
)
{
nunchaku
::
kernels
::
attention_fp16
(
from_torch
(
q
),
from_torch
(
k
),
from_torch
(
v
),
from_torch
(
o
),
scale
);
nunchaku
::
kernels
::
attention_fp16
(
}
from_torch
(
q
),
from_torch
(
k
),
from_torch
(
v
),
from_torch
(
o
),
scale
);
}
torch
::
Tensor
gemv_awq
(
torch
::
Tensor
gemv_awq
(
torch
::
Tensor
_in_feats
,
torch
::
Tensor
_in_feats
,
torch
::
Tensor
_kernel
,
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
torch
::
Tensor
_zeros
,
int64_t
m
,
int64_t
m
,
int64_t
n
,
int64_t
n
,
int64_t
k
,
int64_t
k
,
int64_t
group_size
)
int64_t
group_size
)
{
{
Tensor
result
=
::
gemv_awq
(
from_torch
(
_in_feats
.
contiguous
()),
Tensor
result
=
::
gemv_awq
(
from_torch
(
_in_feats
.
contiguous
()),
from_torch
(
_kernel
.
contiguous
()),
from_torch
(
_kernel
.
contiguous
()),
from_torch
(
_scaling_factors
.
contiguous
()),
from_torch
(
_scaling_factors
.
contiguous
()),
from_torch
(
_zeros
.
contiguous
()),
from_torch
(
_zeros
.
contiguous
()),
(
int
)
m
,
(
int
)
m
,
(
int
)
n
,
(
int
)
n
,
(
int
)
k
,
(
int
)
k
,
(
int
)
group_size
(
int
)
group_size
);
);
torch
::
Tensor
output
=
to_torch
(
result
);
torch
::
Tensor
output
=
to_torch
(
result
);
// Tensor::synchronizeDevice();
// Tensor::synchronizeDevice();
return
output
;
return
output
;
}
}
torch
::
Tensor
gemm_awq
(
torch
::
Tensor
torch
::
Tensor
_in_feats
,
gemm_awq
(
torch
::
Tensor
_in_feats
,
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
)
{
torch
::
Tensor
_kernel
,
Tensor
result
=
::
awq_gemm_forward_cuda
(
from_torch
(
_in_feats
.
contiguous
()),
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
)
{
Tensor
result
=
::
awq_gemm_forward_cuda
(
from_torch
(
_in_feats
.
contiguous
()),
from_torch
(
_kernel
.
contiguous
()),
from_torch
(
_kernel
.
contiguous
()),
from_torch
(
_scaling_factors
.
contiguous
()),
from_torch
(
_scaling_factors
.
contiguous
()),
from_torch
(
_zeros
.
contiguous
())
from_torch
(
_zeros
.
contiguous
()));
);
// TODO: allocate output in torch and use from_torch instead (to_torch needs an extra copy)
// TODO: allocate output in torch and use from_torch instead (to_torch needs an extra copy)
torch
::
Tensor
output
=
to_torch
(
result
);
torch
::
Tensor
output
=
to_torch
(
result
);
// Tensor::synchronizeDevice();
// Tensor::synchronizeDevice();
return
output
;
return
output
;
}
}
void
test_rmsnorm_rope
(
void
test_rmsnorm_rope
(
torch
::
Tensor
input
,
torch
::
Tensor
input
,
torch
::
Tensor
output
,
torch
::
Tensor
norm_q
,
torch
::
Tensor
norm_k
,
torch
::
Tensor
rotary_emb
)
{
torch
::
Tensor
output
,
torch
::
Tensor
norm_q
,
torch
::
Tensor
norm_k
,
torch
::
Tensor
rotary_emb
)
{
nunchaku
::
kernels
::
test_rmsnorm_rope
(
nunchaku
::
kernels
::
test_rmsnorm_rope
(
from_torch
(
input
),
from_torch
(
input
),
from_torch
(
output
),
from_torch
(
norm_q
),
from_torch
(
norm_k
),
from_torch
(
rotary_emb
));
from_torch
(
output
),
}
from_torch
(
norm_q
),
from_torch
(
norm_k
),
from_torch
(
rotary_emb
)
);
}
void
test_pack_qkv
(
void
test_pack_qkv
(
torch
::
Tensor
input
,
torch
::
Tensor
out_q
,
torch
::
Tensor
out_k
,
torch
::
Tensor
out_v
,
int
numTokens
)
{
torch
::
Tensor
input
,
torch
::
Tensor
out_q
,
torch
::
Tensor
out_k
,
torch
::
Tensor
out_v
,
int
numTokens
)
{
nunchaku
::
kernels
::
test_pack_qkv
(
nunchaku
::
kernels
::
test_pack_qkv
(
from_torch
(
input
),
from_torch
(
input
),
from_torch
(
out_q
),
from_torch
(
out_k
),
from_torch
(
out_v
),
numTokens
);
from_torch
(
out_q
),
}
from_torch
(
out_k
),
from_torch
(
out_v
),
numTokens
);
}
};
};
// namespace nunchaku::ops
\ No newline at end of file
nunchaku/csrc/pybind.cpp
View file @
57e50f8d
...
@@ -11,13 +11,14 @@
...
@@ -11,13 +11,14 @@
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
py
::
class_
<
QuantizedFluxModel
>
(
m
,
"QuantizedFluxModel"
)
py
::
class_
<
QuantizedFluxModel
>
(
m
,
"QuantizedFluxModel"
)
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
<>
())
.
def
(
"init"
,
&
QuantizedFluxModel
::
init
,
.
def
(
"init"
,
&
QuantizedFluxModel
::
init
,
py
::
arg
(
"use_fp4"
),
py
::
arg
(
"use_fp4"
),
py
::
arg
(
"offload"
),
py
::
arg
(
"offload"
),
py
::
arg
(
"bf16"
),
py
::
arg
(
"bf16"
),
py
::
arg
(
"deviceId"
)
py
::
arg
(
"deviceId"
)
)
)
.
def
(
"set_residual_callback"
,
.
def
(
"set_residual_callback"
,
[](
QuantizedFluxModel
&
self
,
pybind11
::
object
call_back
)
{
[](
QuantizedFluxModel
&
self
,
pybind11
::
object
call_back
)
{
if
(
call_back
.
is_none
())
{
if
(
call_back
.
is_none
())
{
self
.
set_residual_callback
(
pybind11
::
function
());
self
.
set_residual_callback
(
pybind11
::
function
());
}
else
{
}
else
{
...
@@ -25,15 +26,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -25,15 +26,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
}
}
})
})
.
def
(
"reset"
,
&
QuantizedFluxModel
::
reset
)
.
def
(
"reset"
,
&
QuantizedFluxModel
::
reset
)
.
def
(
"load"
,
&
QuantizedFluxModel
::
load
,
.
def
(
"load"
,
&
QuantizedFluxModel
::
load
,
py
::
arg
(
"path"
),
py
::
arg
(
"partial"
)
=
false
)
py
::
arg
(
"path"
),
.
def
(
"loadDict"
,
&
QuantizedFluxModel
::
loadDict
,
py
::
arg
(
"dict"
),
py
::
arg
(
"partial"
)
=
false
)
py
::
arg
(
"partial"
)
=
false
.
def
(
"forward"
,
)
&
QuantizedFluxModel
::
forward
,
.
def
(
"loadDict"
,
&
QuantizedFluxModel
::
loadDict
,
py
::
arg
(
"dict"
),
py
::
arg
(
"partial"
)
=
false
)
.
def
(
"forward"
,
&
QuantizedFluxModel
::
forward
,
py
::
arg
(
"hidden_states"
),
py
::
arg
(
"hidden_states"
),
py
::
arg
(
"encoder_hidden_states"
),
py
::
arg
(
"encoder_hidden_states"
),
py
::
arg
(
"temb"
),
py
::
arg
(
"temb"
),
...
@@ -42,9 +38,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -42,9 +38,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
arg
(
"rotary_emb_single"
),
py
::
arg
(
"rotary_emb_single"
),
py
::
arg
(
"controlnet_block_samples"
)
=
py
::
none
(),
py
::
arg
(
"controlnet_block_samples"
)
=
py
::
none
(),
py
::
arg
(
"controlnet_single_block_samples"
)
=
py
::
none
(),
py
::
arg
(
"controlnet_single_block_samples"
)
=
py
::
none
(),
py
::
arg
(
"skip_first_layer"
)
=
false
py
::
arg
(
"skip_first_layer"
)
=
false
)
)
.
def
(
"forward_layer"
,
.
def
(
"forward_layer"
,
&
QuantizedFluxModel
::
forward_layer
,
&
QuantizedFluxModel
::
forward_layer
,
py
::
arg
(
"idx"
),
py
::
arg
(
"idx"
),
py
::
arg
(
"hidden_states"
),
py
::
arg
(
"hidden_states"
),
py
::
arg
(
"encoder_hidden_states"
),
py
::
arg
(
"encoder_hidden_states"
),
...
@@ -52,8 +48,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -52,8 +48,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
arg
(
"rotary_emb_img"
),
py
::
arg
(
"rotary_emb_img"
),
py
::
arg
(
"rotary_emb_context"
),
py
::
arg
(
"rotary_emb_context"
),
py
::
arg
(
"controlnet_block_samples"
)
=
py
::
none
(),
py
::
arg
(
"controlnet_block_samples"
)
=
py
::
none
(),
py
::
arg
(
"controlnet_single_block_samples"
)
=
py
::
none
()
py
::
arg
(
"controlnet_single_block_samples"
)
=
py
::
none
())
)
.
def
(
"forward_single_layer"
,
&
QuantizedFluxModel
::
forward_single_layer
)
.
def
(
"forward_single_layer"
,
&
QuantizedFluxModel
::
forward_single_layer
)
.
def
(
"norm_one_forward"
,
&
QuantizedFluxModel
::
norm_one_forward
)
.
def
(
"norm_one_forward"
,
&
QuantizedFluxModel
::
norm_one_forward
)
.
def
(
"startDebug"
,
&
QuantizedFluxModel
::
startDebug
)
.
def
(
"startDebug"
,
&
QuantizedFluxModel
::
startDebug
)
...
@@ -61,32 +56,24 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -61,32 +56,24 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.
def
(
"getDebugResults"
,
&
QuantizedFluxModel
::
getDebugResults
)
.
def
(
"getDebugResults"
,
&
QuantizedFluxModel
::
getDebugResults
)
.
def
(
"setLoraScale"
,
&
QuantizedFluxModel
::
setLoraScale
)
.
def
(
"setLoraScale"
,
&
QuantizedFluxModel
::
setLoraScale
)
.
def
(
"setAttentionImpl"
,
&
QuantizedFluxModel
::
setAttentionImpl
)
.
def
(
"setAttentionImpl"
,
&
QuantizedFluxModel
::
setAttentionImpl
)
.
def
(
"isBF16"
,
&
QuantizedFluxModel
::
isBF16
)
.
def
(
"isBF16"
,
&
QuantizedFluxModel
::
isBF16
);
;
py
::
class_
<
QuantizedSanaModel
>
(
m
,
"QuantizedSanaModel"
)
py
::
class_
<
QuantizedSanaModel
>
(
m
,
"QuantizedSanaModel"
)
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
<>
())
.
def
(
"init"
,
&
QuantizedSanaModel
::
init
,
.
def
(
"init"
,
&
QuantizedSanaModel
::
init
,
py
::
arg
(
"config"
),
py
::
arg
(
"config"
),
py
::
arg
(
"pag_layers"
),
py
::
arg
(
"pag_layers"
),
py
::
arg
(
"use_fp4"
),
py
::
arg
(
"use_fp4"
),
py
::
arg
(
"bf16"
),
py
::
arg
(
"bf16"
),
py
::
arg
(
"deviceId"
)
py
::
arg
(
"deviceId"
))
)
.
def
(
"reset"
,
&
QuantizedSanaModel
::
reset
)
.
def
(
"reset"
,
&
QuantizedSanaModel
::
reset
)
.
def
(
"load"
,
&
QuantizedSanaModel
::
load
,
.
def
(
"load"
,
&
QuantizedSanaModel
::
load
,
py
::
arg
(
"path"
),
py
::
arg
(
"partial"
)
=
false
)
py
::
arg
(
"path"
),
.
def
(
"loadDict"
,
&
QuantizedSanaModel
::
loadDict
,
py
::
arg
(
"dict"
),
py
::
arg
(
"partial"
)
=
false
)
py
::
arg
(
"partial"
)
=
false
)
.
def
(
"loadDict"
,
&
QuantizedSanaModel
::
loadDict
,
py
::
arg
(
"dict"
),
py
::
arg
(
"partial"
)
=
false
)
.
def
(
"forward"
,
&
QuantizedSanaModel
::
forward
)
.
def
(
"forward"
,
&
QuantizedSanaModel
::
forward
)
.
def
(
"forward_layer"
,
&
QuantizedSanaModel
::
forward_layer
)
.
def
(
"forward_layer"
,
&
QuantizedSanaModel
::
forward_layer
)
.
def
(
"startDebug"
,
&
QuantizedSanaModel
::
startDebug
)
.
def
(
"startDebug"
,
&
QuantizedSanaModel
::
startDebug
)
.
def
(
"stopDebug"
,
&
QuantizedSanaModel
::
stopDebug
)
.
def
(
"stopDebug"
,
&
QuantizedSanaModel
::
stopDebug
)
.
def
(
"getDebugResults"
,
&
QuantizedSanaModel
::
getDebugResults
)
.
def
(
"getDebugResults"
,
&
QuantizedSanaModel
::
getDebugResults
);
;
py
::
class_
<
QuantizedGEMM
>
(
m
,
"QuantizedGEMM"
)
py
::
class_
<
QuantizedGEMM
>
(
m
,
"QuantizedGEMM"
)
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
<>
())
.
def
(
"init"
,
&
QuantizedGEMM
::
init
)
.
def
(
"init"
,
&
QuantizedGEMM
::
init
)
...
@@ -96,8 +83,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -96,8 +83,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.
def
(
"quantize"
,
&
QuantizedGEMM
::
quantize
)
.
def
(
"quantize"
,
&
QuantizedGEMM
::
quantize
)
.
def
(
"startDebug"
,
&
QuantizedGEMM
::
startDebug
)
.
def
(
"startDebug"
,
&
QuantizedGEMM
::
startDebug
)
.
def
(
"stopDebug"
,
&
QuantizedGEMM
::
stopDebug
)
.
def
(
"stopDebug"
,
&
QuantizedGEMM
::
stopDebug
)
.
def
(
"getDebugResults"
,
&
QuantizedGEMM
::
getDebugResults
)
.
def
(
"getDebugResults"
,
&
QuantizedGEMM
::
getDebugResults
);
;
py
::
class_
<
Tensor
>
(
m
,
"Tensor"
);
py
::
class_
<
Tensor
>
(
m
,
"Tensor"
);
py
::
class_
<
QuantizedGEMM88
>
(
m
,
"QuantizedGEMM88"
)
py
::
class_
<
QuantizedGEMM88
>
(
m
,
"QuantizedGEMM88"
)
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
<>
())
...
@@ -107,8 +93,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -107,8 +93,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.
def
(
"forward"
,
&
QuantizedGEMM88
::
forward
)
.
def
(
"forward"
,
&
QuantizedGEMM88
::
forward
)
.
def
(
"startDebug"
,
&
QuantizedGEMM88
::
startDebug
)
.
def
(
"startDebug"
,
&
QuantizedGEMM88
::
startDebug
)
.
def
(
"stopDebug"
,
&
QuantizedGEMM88
::
stopDebug
)
.
def
(
"stopDebug"
,
&
QuantizedGEMM88
::
stopDebug
)
.
def
(
"getDebugResults"
,
&
QuantizedGEMM88
::
getDebugResults
)
.
def
(
"getDebugResults"
,
&
QuantizedGEMM88
::
getDebugResults
);
;
m
.
def_submodule
(
"ops"
)
m
.
def_submodule
(
"ops"
)
.
def
(
"gemm_w4a4"
,
nunchaku
::
ops
::
gemm_w4a4
)
.
def
(
"gemm_w4a4"
,
nunchaku
::
ops
::
gemm_w4a4
)
...
@@ -117,16 +102,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -117,16 +102,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.
def
(
"gemv_awq"
,
nunchaku
::
ops
::
gemv_awq
)
.
def
(
"gemv_awq"
,
nunchaku
::
ops
::
gemv_awq
)
.
def
(
"test_rmsnorm_rope"
,
nunchaku
::
ops
::
test_rmsnorm_rope
)
.
def
(
"test_rmsnorm_rope"
,
nunchaku
::
ops
::
test_rmsnorm_rope
)
.
def
(
"test_pack_qkv"
,
nunchaku
::
ops
::
test_pack_qkv
)
.
def
(
"test_pack_qkv"
,
nunchaku
::
ops
::
test_pack_qkv
);
;
m
.
def_submodule
(
"utils"
)
m
.
def_submodule
(
"utils"
)
.
def
(
"set_log_level"
,
[](
const
std
::
string
&
level
)
{
.
def
(
"set_log_level"
,
[](
const
std
::
string
&
level
)
{
spdlog
::
set_level
(
spdlog
::
level
::
from_str
(
level
));
})
spdlog
::
set_level
(
spdlog
::
level
::
from_str
(
level
));
})
.
def
(
"set_cuda_stack_limit"
,
nunchaku
::
utils
::
set_cuda_stack_limit
)
.
def
(
"set_cuda_stack_limit"
,
nunchaku
::
utils
::
set_cuda_stack_limit
)
.
def
(
"disable_memory_auto_release"
,
nunchaku
::
utils
::
disable_memory_auto_release
)
.
def
(
"disable_memory_auto_release"
,
nunchaku
::
utils
::
disable_memory_auto_release
)
.
def
(
"trim_memory"
,
nunchaku
::
utils
::
trim_memory
)
.
def
(
"trim_memory"
,
nunchaku
::
utils
::
trim_memory
)
.
def
(
"set_faster_i2f_mode"
,
nunchaku
::
utils
::
set_faster_i2f_mode
)
.
def
(
"set_faster_i2f_mode"
,
nunchaku
::
utils
::
set_faster_i2f_mode
);
;
}
}
Prev
1
2
3
4
5
6
7
…
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment