Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
57e50f8d
Unverified
Commit
57e50f8d
authored
May 01, 2025
by
Muyang Li
Committed by
GitHub
May 01, 2025
Browse files
style: upgrade the linter (#339)
* style: reformated codes * style: reformated codes
parent
b737368d
Changes
173
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
786 additions
and
630 deletions
+786
-630
nunchaku/pipeline/pipeline_flux_pulid.py
nunchaku/pipeline/pipeline_flux_pulid.py
+2
-4
pyproject.toml
pyproject.toml
+20
-15
requirements.txt
requirements.txt
+1
-1
scripts/build_all_linux_wheels.sh
scripts/build_all_linux_wheels.sh
+1
-1
scripts/build_docker.sh
scripts/build_docker.sh
+1
-1
scripts/build_docker_torch27.sh
scripts/build_docker_torch27.sh
+1
-1
scripts/build_docker_torch28.sh
scripts/build_docker_torch28.sh
+1
-1
scripts/build_linux_wheel.sh
scripts/build_linux_wheel.sh
+1
-1
scripts/build_linux_wheel_cu128.sh
scripts/build_linux_wheel_cu128.sh
+1
-1
scripts/build_linux_wheel_torch2.7_cu128.sh
scripts/build_linux_wheel_torch2.7_cu128.sh
+1
-1
scripts/linux_cleanup.sh
scripts/linux_cleanup.sh
+1
-1
setup.py
setup.py
+1
-1
src/FluxModel.cpp
src/FluxModel.cpp
+288
-278
src/FluxModel.h
src/FluxModel.h
+50
-32
src/Linear.cpp
src/Linear.cpp
+216
-143
src/Linear.h
src/Linear.h
+20
-10
src/Module.cpp
src/Module.cpp
+2
-2
src/Module.h
src/Module.h
+28
-31
src/SanaModel.cpp
src/SanaModel.cpp
+121
-99
src/SanaModel.h
src/SanaModel.h
+29
-6
No files found.
nunchaku/pipeline/pipeline_flux_pulid.py
View file @
57e50f8d
...
@@ -8,11 +8,9 @@ import numpy as np
...
@@ -8,11 +8,9 @@ import numpy as np
import
torch
import
torch
from
diffusers
import
FluxPipeline
from
diffusers
import
FluxPipeline
from
diffusers.image_processor
import
PipelineImageInput
from
diffusers.image_processor
import
PipelineImageInput
from
diffusers.pipelines.flux.pipeline_flux
import
calculate_shift
,
EXAMPLE_DOC_STRING
,
retrieve_timesteps
from
diffusers.pipelines.flux.pipeline_flux
import
EXAMPLE_DOC_STRING
,
calculate_shift
,
retrieve_timesteps
from
diffusers.pipelines.flux.pipeline_output
import
FluxPipelineOutput
from
diffusers.pipelines.flux.pipeline_output
import
FluxPipelineOutput
from
diffusers.utils
import
(
from
diffusers.utils
import
replace_example_docstring
replace_example_docstring
,
)
from
facexlib.parsing
import
init_parsing_model
from
facexlib.parsing
import
init_parsing_model
from
facexlib.utils.face_restoration_helper
import
FaceRestoreHelper
from
facexlib.utils.face_restoration_helper
import
FaceRestoreHelper
from
huggingface_hub
import
hf_hub_download
,
snapshot_download
from
huggingface_hub
import
hf_hub_download
,
snapshot_download
...
...
pyproject.toml
View file @
57e50f8d
[build-system]
[tool.isort]
requires
=
[
profile
=
"black"
"setuptools"
,
known_first_party
=
["nunchaku"]
"torch>=2.5"
,
line_length
=
120
"wheel"
,
"ninja"
,
]
build-backend
=
"setuptools.build_meta"
[tool.setuptools.packages.find]
[tool.black]
include
=
["nunchaku"]
line-length
=
120
target-version
=
['py311']
[tool.ruff]
[tool.ruff]
line-length
=
140
line-length
=
120
[tool.ruff.lint]
select
=
[
"E"
,
"W"
,
"F"
]
ignore
=
["F401"]
[project]
[project]
dynamic
=
["version"]
dynamic
=
["version"]
...
@@ -29,3 +22,15 @@ dependencies = [
...
@@ -29,3 +22,15 @@ dependencies = [
"huggingface_hub"
,
"huggingface_hub"
,
]
]
requires-python
=
">=3.10"
requires-python
=
">=3.10"
[build-system]
requires
=
[
"setuptools"
,
"torch>=2.5"
,
"wheel"
,
"ninja"
,
]
build-backend
=
"setuptools.build_meta"
[tool.setuptools.packages.find]
include
=
["nunchaku"]
requirements.txt
View file @
57e50f8d
scripts/build_all_linux_wheels.sh
View file @
57e50f8d
scripts/build_docker.sh
View file @
57e50f8d
scripts/build_docker_torch27.sh
View file @
57e50f8d
scripts/build_docker_torch28.sh
View file @
57e50f8d
scripts/build_linux_wheel.sh
View file @
57e50f8d
scripts/build_linux_wheel_cu128.sh
View file @
57e50f8d
scripts/build_linux_wheel_torch2.7_cu128.sh
View file @
57e50f8d
scripts/linux_cleanup.sh
View file @
57e50f8d
setup.py
View file @
57e50f8d
...
@@ -6,7 +6,7 @@ import sys
...
@@ -6,7 +6,7 @@ import sys
import
setuptools
import
setuptools
import
torch
import
torch
from
packaging
import
version
as
packaging_version
from
packaging
import
version
as
packaging_version
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDA_HOME
,
CUDAExtension
from
torch.utils.cpp_extension
import
CUDA_HOME
,
BuildExtension
,
CUDAExtension
class
CustomBuildExtension
(
BuildExtension
):
class
CustomBuildExtension
(
BuildExtension
):
...
...
src/FluxModel.cpp
View file @
57e50f8d
This diff is collapsed.
Click to expand it.
src/FluxModel.h
View file @
57e50f8d
...
@@ -7,7 +7,7 @@
...
@@ -7,7 +7,7 @@
#include "layernorm.h"
#include "layernorm.h"
#include <pybind11/functional.h>
#include <pybind11/functional.h>
namespace
pybind11
{
namespace
pybind11
{
class
function
;
class
function
;
}
}
enum
class
AttentionImpl
{
enum
class
AttentionImpl
{
...
@@ -49,6 +49,7 @@ public:
...
@@ -49,6 +49,7 @@ public:
Tensor
scale_mlp
;
Tensor
scale_mlp
;
Tensor
gate_mlp
;
Tensor
gate_mlp
;
};
};
public:
public:
AdaLayerNormZero
(
int
dim
,
bool
pre_only
,
Tensor
::
ScalarType
dtype
,
Device
device
);
AdaLayerNormZero
(
int
dim
,
bool
pre_only
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Output
forward
(
Tensor
x
,
Tensor
emb
);
Output
forward
(
Tensor
x
,
Tensor
emb
);
...
@@ -87,7 +88,13 @@ public:
...
@@ -87,7 +88,13 @@ public:
static
constexpr
bool
USE_4BIT
=
true
;
static
constexpr
bool
USE_4BIT
=
true
;
using
GEMM
=
std
::
conditional_t
<
USE_4BIT
,
GEMM_W4A4
,
GEMM_W8A8
>
;
using
GEMM
=
std
::
conditional_t
<
USE_4BIT
,
GEMM_W4A4
,
GEMM_W8A8
>
;
FluxSingleTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
int
mlp_ratio
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
FluxSingleTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
int
mlp_ratio
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
temb
,
Tensor
rotary_emb
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
temb
,
Tensor
rotary_emb
);
public:
public:
...
@@ -113,8 +120,19 @@ public:
...
@@ -113,8 +120,19 @@ public:
static
constexpr
bool
USE_4BIT
=
true
;
static
constexpr
bool
USE_4BIT
=
true
;
using
GEMM
=
std
::
conditional_t
<
USE_4BIT
,
GEMM_W4A4
,
GEMM_W8A8
>
;
using
GEMM
=
std
::
conditional_t
<
USE_4BIT
,
GEMM_W4A4
,
GEMM_W8A8
>
;
JointTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
bool
context_pre_only
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
JointTransformerBlock
(
int
dim
,
std
::
tuple
<
Tensor
,
Tensor
>
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb
,
Tensor
rotary_emb_context
,
float
sparsityRatio
);
int
num_attention_heads
,
int
attention_head_dim
,
bool
context_pre_only
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
std
::
tuple
<
Tensor
,
Tensor
>
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb
,
Tensor
rotary_emb_context
,
float
sparsityRatio
);
public:
public:
const
int
dim
;
const
int
dim
;
...
@@ -143,8 +161,7 @@ private:
...
@@ -143,8 +161,7 @@ private:
class
FluxModel
:
public
Module
{
class
FluxModel
:
public
Module
{
public:
public:
FluxModel
(
bool
use_fp4
,
bool
offload
,
Tensor
::
ScalarType
dtype
,
Device
device
);
FluxModel
(
bool
use_fp4
,
bool
offload
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
forward
(
Tensor
hidden_states
,
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
temb
,
Tensor
rotary_emb_img
,
Tensor
rotary_emb_img
,
...
@@ -153,8 +170,7 @@ public:
...
@@ -153,8 +170,7 @@ public:
Tensor
controlnet_block_samples
,
Tensor
controlnet_block_samples
,
Tensor
controlnet_single_block_samples
,
Tensor
controlnet_single_block_samples
,
bool
skip_first_layer
=
false
);
bool
skip_first_layer
=
false
);
std
::
tuple
<
Tensor
,
Tensor
>
forward_layer
(
std
::
tuple
<
Tensor
,
Tensor
>
forward_layer
(
size_t
layer
,
size_t
layer
,
Tensor
hidden_states
,
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
temb
,
...
@@ -164,14 +180,16 @@ public:
...
@@ -164,14 +180,16 @@ public:
Tensor
controlnet_single_block_samples
);
Tensor
controlnet_single_block_samples
);
void
setAttentionImpl
(
AttentionImpl
impl
);
void
setAttentionImpl
(
AttentionImpl
impl
);
void
set_residual_callback
(
std
::
function
<
Tensor
(
const
Tensor
&
)
>
cb
);
void
set_residual_callback
(
std
::
function
<
Tensor
(
const
Tensor
&
)
>
cb
);
public:
public:
const
Tensor
::
ScalarType
dtype
;
const
Tensor
::
ScalarType
dtype
;
std
::
vector
<
std
::
unique_ptr
<
JointTransformerBlock
>>
transformer_blocks
;
std
::
vector
<
std
::
unique_ptr
<
JointTransformerBlock
>>
transformer_blocks
;
std
::
vector
<
std
::
unique_ptr
<
FluxSingleTransformerBlock
>>
single_transformer_blocks
;
std
::
vector
<
std
::
unique_ptr
<
FluxSingleTransformerBlock
>>
single_transformer_blocks
;
std
::
function
<
Tensor
(
const
Tensor
&
)
>
residual_callback
;
std
::
function
<
Tensor
(
const
Tensor
&
)
>
residual_callback
;
private:
private:
bool
offload
;
bool
offload
;
};
};
src/Linear.cpp
View file @
57e50f8d
...
@@ -9,16 +9,12 @@
...
@@ -9,16 +9,12 @@
using
namespace
nunchaku
;
using
namespace
nunchaku
;
GEMM_F16
::
GEMM_F16
(
int
in_features
,
int
out_features
,
bool
use_bias
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
GEMM_F16
::
GEMM_F16
(
int
in_features
,
int
out_features
,
bool
use_bias
,
Tensor
::
ScalarType
dtype
,
Device
device
)
in_features
(
in_features
),
out_features
(
out_features
)
:
in_features
(
in_features
),
out_features
(
out_features
)
{
{
this
->
weight
=
Tensor
::
allocate
({
out_features
,
in_features
},
dtype
,
device
);
this
->
weight
=
Tensor
::
allocate
({
out_features
,
in_features
},
dtype
,
device
);
this
->
bias
=
use_bias
?
Tensor
::
allocate
({
out_features
},
dtype
,
device
)
:
Tensor
{};
this
->
bias
=
use_bias
?
Tensor
::
allocate
({
out_features
},
dtype
,
device
)
:
Tensor
{};
registerParams
registerParams
(
weight
,
"weight"
,
ParamFlags
::
LazyLoad
)(
bias
,
"bias"
);
(
weight
,
"weight"
,
ParamFlags
::
LazyLoad
)
(
bias
,
"bias"
)
;
}
}
Tensor
GEMM_F16
::
forward
(
Tensor
x
)
{
Tensor
GEMM_F16
::
forward
(
Tensor
x
)
{
...
@@ -26,9 +22,9 @@ Tensor GEMM_F16::forward(Tensor x) {
...
@@ -26,9 +22,9 @@ Tensor GEMM_F16::forward(Tensor x) {
return
out
;
return
out
;
}
}
GEMV_AWQ
::
GEMV_AWQ
(
int
in_features
,
int
out_features
,
bool
use_bias
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
GEMV_AWQ
::
GEMV_AWQ
(
int
in_features
,
int
out_features
,
bool
use_bias
,
Tensor
::
ScalarType
dtype
,
Device
device
)
in_features
(
in_features
),
out_features
(
out_features
),
group_size
(
64
),
lora_rank
(
0
),
lora_scale
(
1.0
f
),
device
(
device
)
:
in_features
(
in_features
),
out_features
(
out_features
),
group_size
(
64
),
lora_rank
(
0
),
lora_scale
(
1.0
f
),
{
device
(
device
)
{
this
->
qweight
=
Tensor
::
allocate
({
out_features
/
4
,
ceilDiv
(
in_features
,
8
)
*
4
},
Tensor
::
INT32
,
device
);
this
->
qweight
=
Tensor
::
allocate
({
out_features
/
4
,
ceilDiv
(
in_features
,
8
)
*
4
},
Tensor
::
INT32
,
device
);
this
->
wscales
=
Tensor
::
allocate
({
ceilDiv
(
in_features
,
group_size
),
out_features
},
dtype
,
device
);
this
->
wscales
=
Tensor
::
allocate
({
ceilDiv
(
in_features
,
group_size
),
out_features
},
dtype
,
device
);
this
->
wzeros
=
Tensor
::
allocate
({
ceilDiv
(
in_features
,
group_size
),
out_features
},
dtype
,
device
);
this
->
wzeros
=
Tensor
::
allocate
({
ceilDiv
(
in_features
,
group_size
),
out_features
},
dtype
,
device
);
...
@@ -38,14 +34,8 @@ GEMV_AWQ::GEMV_AWQ(int in_features, int out_features, bool use_bias, Tensor::Sca
...
@@ -38,14 +34,8 @@ GEMV_AWQ::GEMV_AWQ(int in_features, int out_features, bool use_bias, Tensor::Sca
this
->
lora_down
=
Tensor
::
allocate
({
lora_rank
,
in_features
},
dtype
,
device
,
true
);
this
->
lora_down
=
Tensor
::
allocate
({
lora_rank
,
in_features
},
dtype
,
device
,
true
);
this
->
lora_up
=
Tensor
::
allocate
({
out_features
,
lora_rank
},
dtype
,
device
,
true
);
this
->
lora_up
=
Tensor
::
allocate
({
out_features
,
lora_rank
},
dtype
,
device
,
true
);
registerParams
registerParams
(
qweight
,
"qweight"
,
ParamFlags
::
LazyLoad
)(
wscales
,
"wscales"
)(
wzeros
,
"wzeros"
)(
bias
,
"bias"
)(
(
qweight
,
"qweight"
,
ParamFlags
::
LazyLoad
)
lora_down
,
"lora_down"
,
ParamFlags
::
Optional
)(
lora_up
,
"lora_up"
,
ParamFlags
::
Optional
);
(
wscales
,
"wscales"
)
(
wzeros
,
"wzeros"
)
(
bias
,
"bias"
)
(
lora_down
,
"lora_down"
,
ParamFlags
::
Optional
)
(
lora_up
,
"lora_up"
,
ParamFlags
::
Optional
)
;
}
}
void
GEMV_AWQ
::
loadParam
(
std
::
string
key
,
Tensor
&
dst
,
Tensor
src
)
{
void
GEMV_AWQ
::
loadParam
(
std
::
string
key
,
Tensor
&
dst
,
Tensor
src
)
{
...
@@ -95,15 +85,12 @@ Tensor GEMV_AWQ::forward(Tensor x) {
...
@@ -95,15 +85,12 @@ Tensor GEMV_AWQ::forward(Tensor x) {
return
out
;
return
out
;
}
}
#define NO_LORA_FUSION 0
#define NO_LORA_FUSION 0
GEMM_W4A4
::
GEMM_W4A4
(
int
in_features
,
int
out_features
,
bool
bias
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
GEMM_W4A4
::
GEMM_W4A4
(
in_features
(
in_features
),
out_features
(
out_features
),
int
in_features
,
int
out_features
,
bool
bias
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
in_features_pad
(
ceilDiv
(
in_features
,
128
)
*
128
),
out_features_pad
(
ceilDiv
(
out_features
,
128
)
*
128
),
:
in_features
(
in_features
),
out_features
(
out_features
),
in_features_pad
(
ceilDiv
(
in_features
,
128
)
*
128
),
use_fp4
(
use_fp4
),
out_features_pad
(
ceilDiv
(
out_features
,
128
)
*
128
),
use_fp4
(
use_fp4
),
lora_rank
(
0
),
dtype
(
dtype
),
device
(
device
)
{
lora_rank
(
0
),
dtype
(
dtype
),
device
(
device
)
{
this
->
qweight
=
Tensor
::
allocate
({
out_features_pad
,
in_features_pad
/
2
},
Tensor
::
INT8
,
device
,
true
);
this
->
qweight
=
Tensor
::
allocate
({
out_features_pad
,
in_features_pad
/
2
},
Tensor
::
INT8
,
device
,
true
);
if
(
use_fp4
)
{
if
(
use_fp4
)
{
this
->
wscales
=
Tensor
::
allocate
({
in_features_pad
/
16
,
out_features_pad
},
Tensor
::
FP8_E4M3
,
device
,
true
);
this
->
wscales
=
Tensor
::
allocate
({
in_features_pad
/
16
,
out_features_pad
},
Tensor
::
FP8_E4M3
,
device
,
true
);
...
@@ -125,16 +112,9 @@ GEMM_W4A4::GEMM_W4A4(int in_features, int out_features, bool bias, bool use_fp4,
...
@@ -125,16 +112,9 @@ GEMM_W4A4::GEMM_W4A4(int in_features, int out_features, bool bias, bool use_fp4,
this
->
wcscales
=
Tensor
::
allocate
({
0
},
dtype
,
device
,
true
);
this
->
wcscales
=
Tensor
::
allocate
({
0
},
dtype
,
device
,
true
);
registerParams
registerParams
(
qweight
,
"qweight"
,
ParamFlags
::
LazyLoad
)(
wscales
,
"wscales"
)(
this
->
bias
,
"bias"
)(
(
qweight
,
"qweight"
,
ParamFlags
::
LazyLoad
)
lora_down
,
"lora_down"
,
ParamFlags
::
Optional
)(
lora_up
,
"lora_up"
,
ParamFlags
::
Optional
)(
smooth
,
"smooth"
)(
(
wscales
,
"wscales"
)
wtscale
,
"wtscale"
,
ParamFlags
::
Optional
)(
wcscales
,
"wcscales"
,
ParamFlags
::
Optional
);
(
this
->
bias
,
"bias"
)
(
lora_down
,
"lora_down"
,
ParamFlags
::
Optional
)
(
lora_up
,
"lora_up"
,
ParamFlags
::
Optional
)
(
smooth
,
"smooth"
)
(
wtscale
,
"wtscale"
,
ParamFlags
::
Optional
)
(
wcscales
,
"wcscales"
,
ParamFlags
::
Optional
)
;
#if NO_LORA_FUSION
#if NO_LORA_FUSION
checkCUBLAS
(
cublasCreate
(
&
handle
));
checkCUBLAS
(
cublasCreate
(
&
handle
));
...
@@ -181,11 +161,21 @@ Tensor GEMM_W4A4::forward_silu(Tensor x) {
...
@@ -181,11 +161,21 @@ Tensor GEMM_W4A4::forward_silu(Tensor x) {
return
std
::
get
<
Tensor
>
(
this
->
forward
(
x
,
FuseOptions
::
SILU
,
nullptr
));
return
std
::
get
<
Tensor
>
(
this
->
forward
(
x
,
FuseOptions
::
SILU
,
nullptr
));
}
}
std
::
variant
<
Tensor
,
GEMM_W4A4
::
QuantizedActivation
>
GEMM_W4A4
::
forward
(
Tensor
x
,
FuseOptions
fuse
,
GEMM_W4A4
*
nextGEMM
)
{
std
::
variant
<
Tensor
,
GEMM_W4A4
::
QuantizedActivation
>
GEMM_W4A4
::
forward
(
Tensor
x
,
FuseOptions
fuse
,
GEMM_W4A4
*
nextGEMM
)
{
return
forward_quant
(
quantize
(
x
,
false
),
fuse
,
nextGEMM
);
return
forward_quant
(
quantize
(
x
,
false
),
fuse
,
nextGEMM
);
}
}
void
GEMM_W4A4
::
forward
(
Tensor
x
,
Tensor
out
,
Tensor
pool
,
Tensor
norm_q
,
Tensor
norm_k
,
Tensor
rotary_emb
,
Tensor
out_q
,
Tensor
out_k
,
Tensor
out_v
,
int
numTokens
)
{
void
GEMM_W4A4
::
forward
(
Tensor
x
,
Tensor
out
,
Tensor
pool
,
Tensor
norm_q
,
Tensor
norm_k
,
Tensor
rotary_emb
,
Tensor
out_q
,
Tensor
out_k
,
Tensor
out_v
,
int
numTokens
)
{
QuantizedActivation
qact
=
quantize
(
x
,
false
);
QuantizedActivation
qact
=
quantize
(
x
,
false
);
#if !NO_LORA_FUSION
#if !NO_LORA_FUSION
...
@@ -198,17 +188,59 @@ void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor
...
@@ -198,17 +188,59 @@ void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor
debug("gemm.nolora.out", out);
debug("gemm.nolora.out", out);
#endif
#endif
kernels
::
gemm_w4a4
(
kernels
::
gemm_w4a4
(
qact
.
act
,
qact
.
act
,
qweight
,
out
,
{},
qact
.
ascales
,
wscales
,
{},
pool
,
qact
.
lora_act
,
this
->
lora_up
,
{},
{},
norm_q
,
norm_k
,
rotary_emb
,
this
->
bias
,
{},
{},
{},
qact
.
is_unsigned
,
this
->
lora_scales
,
false
,
qweight
,
use_fp4
,
*
this
->
wtscale
.
data_ptr
<
float
>
(),
wcscales
.
numel
()
>
0
?
wcscales
:
Tensor
{},
out
,
out_q
,
out_k
,
out_v
,
numTokens
{},
);
qact
.
ascales
,
wscales
,
{},
pool
,
qact
.
lora_act
,
this
->
lora_up
,
{},
{},
norm_q
,
norm_k
,
rotary_emb
,
this
->
bias
,
{},
{},
{},
qact
.
is_unsigned
,
this
->
lora_scales
,
false
,
use_fp4
,
*
this
->
wtscale
.
data_ptr
<
float
>
(),
wcscales
.
numel
()
>
0
?
wcscales
:
Tensor
{},
out_q
,
out_k
,
out_v
,
numTokens
);
debug
(
"gemm.out"
,
out
);
debug
(
"gemm.out"
,
out
);
#else
#else
const
int
M
=
(
int
)
qact
.
act
.
numel
()
/
qact
.
act
.
shape
[
-
1
];
const
int
M
=
(
int
)
qact
.
act
.
numel
()
/
qact
.
act
.
shape
[
-
1
];
kernels
::
gemm_w4a4
(
qact
.
act
,
qweight
,
out
,
{},
qact
.
ascales
,
wscales
,
{},
pool
,
{},
{},
{},
{},
norm_q
,
norm_k
,
rotary_emb
,
this
->
bias
,
{},
qact
.
is_unsigned
,
this
->
lora_scales
);
kernels
::
gemm_w4a4
(
qact
.
act
,
qweight
,
out
,
{},
qact
.
ascales
,
wscales
,
{},
pool
,
{},
{},
{},
{},
norm_q
,
norm_k
,
rotary_emb
,
this
->
bias
,
{},
qact
.
is_unsigned
,
this
->
lora_scales
);
nvtxRangePushA
(
"LoraUp"
);
nvtxRangePushA
(
"LoraUp"
);
...
@@ -216,10 +248,12 @@ void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor
...
@@ -216,10 +248,12 @@ void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor
static
const
half
zero
=
0.0
;
static
const
half
zero
=
0.0
;
// lora_up: [M, R] * [OC, R] => [M, OC]
// lora_up: [M, R] * [OC, R] => [M, OC]
// cublas view: [OC, R] * [M, R]^T
// cublas view: [OC, R] * [M, R]^T
checkCUBLAS
(
cublasHgemm
(
checkCUBLAS
(
cublasHgemm
(
handle
,
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
this
->
out_features
,
M
,
this
->
lora_rank
,
this
->
out_features
,
M
,
this
->
lora_rank
,
&
one
,
&
one
,
this
->
lora_up
.
data_ptr
<
half
>
(),
this
->
lora_up
.
data_ptr
<
half
>
(),
this
->
lora_rank
,
this
->
lora_rank
,
...
@@ -233,7 +267,8 @@ void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor
...
@@ -233,7 +267,8 @@ void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor
#endif
#endif
}
}
std
::
variant
<
Tensor
,
GEMM_W4A4
::
QuantizedActivation
>
GEMM_W4A4
::
forward_quant
(
QuantizedActivation
qact
,
FuseOptions
fuse
,
GEMM_W4A4
*
nextGEMM
)
{
std
::
variant
<
Tensor
,
GEMM_W4A4
::
QuantizedActivation
>
GEMM_W4A4
::
forward_quant
(
QuantizedActivation
qact
,
FuseOptions
fuse
,
GEMM_W4A4
*
nextGEMM
)
{
Tensor
out
;
Tensor
out
;
QuantizedActivation
qout
;
QuantizedActivation
qout
;
...
@@ -280,11 +315,35 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
...
@@ -280,11 +315,35 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
}
}
#endif
#endif
kernels
::
gemm_w4a4
(
kernels
::
gemm_w4a4
(
qact
.
act
,
qact
.
act
,
qweight
,
out
,
qout
.
act
,
qact
.
ascales
,
wscales
,
qout
.
ascales
,
{},
qact
.
lora_act
,
this
->
lora_up
,
next_lora
,
qout
.
lora_act
,
{},
{},
{},
this
->
bias
,
next_smooth
,
{},
{},
qact
.
is_unsigned
,
this
->
lora_scales
,
fuse
==
FuseOptions
::
SILU
,
qweight
,
use_fp4
,
*
this
->
wtscale
.
data_ptr
<
float
>
(),
wcscales
.
numel
()
>
0
?
wcscales
:
Tensor
{},
out
,
{},
{},
{},
0
qout
.
act
,
);
qact
.
ascales
,
wscales
,
qout
.
ascales
,
{},
qact
.
lora_act
,
this
->
lora_up
,
next_lora
,
qout
.
lora_act
,
{},
{},
{},
this
->
bias
,
next_smooth
,
{},
{},
qact
.
is_unsigned
,
this
->
lora_scales
,
fuse
==
FuseOptions
::
SILU
,
use_fp4
,
*
this
->
wtscale
.
data_ptr
<
float
>
(),
wcscales
.
numel
()
>
0
?
wcscales
:
Tensor
{},
{},
{},
{},
0
);
if
(
fuse
==
FuseOptions
::
EMPTY
||
fuse
==
FuseOptions
::
SILU
)
{
if
(
fuse
==
FuseOptions
::
EMPTY
||
fuse
==
FuseOptions
::
SILU
)
{
debug
(
"gemm.out"
,
out
);
debug
(
"gemm.out"
,
out
);
...
@@ -294,7 +353,6 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
...
@@ -294,7 +353,6 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
debug
(
"gemm.lora_act_out"
,
qout
.
lora_act
);
debug
(
"gemm.lora_act_out"
,
qout
.
lora_act
);
}
}
#else
#else
if
(
!
out
.
valid
())
{
if
(
!
out
.
valid
())
{
auto
shape
=
TensorShape
(
qact
.
act
.
shape
.
dataExtent
);
auto
shape
=
TensorShape
(
qact
.
act
.
shape
.
dataExtent
);
...
@@ -302,7 +360,25 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
...
@@ -302,7 +360,25 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
out
=
Tensor
::
allocate
(
shape
,
Tensor
::
FP16
,
qweight
.
device
());
out
=
Tensor
::
allocate
(
shape
,
Tensor
::
FP16
,
qweight
.
device
());
}
}
kernels
::
gemm_w4a4
(
qact
.
act
,
qweight
,
out
,
qout
.
act
,
qact
.
ascales
,
wscales
,
qout
.
ascales
,
{},
{},
{},
{},
{},
{},
{},
{},
this
->
bias
,
next_smooth
,
qact
.
is_unsigned
,
this
->
lora_scales
);
kernels
::
gemm_w4a4
(
qact
.
act
,
qweight
,
out
,
qout
.
act
,
qact
.
ascales
,
wscales
,
qout
.
ascales
,
{},
{},
{},
{},
{},
{},
{},
{},
this
->
bias
,
next_smooth
,
qact
.
is_unsigned
,
this
->
lora_scales
);
nvtxRangePushA
(
"LoraUp"
);
nvtxRangePushA
(
"LoraUp"
);
...
@@ -312,10 +388,12 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
...
@@ -312,10 +388,12 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
// lora_up: [M, R] * [OC, R]^T => [M, OC]
// lora_up: [M, R] * [OC, R]^T => [M, OC]
// cublas view: [R, OC]^T * [R, M] => [OC, M]
// cublas view: [R, OC]^T * [R, M] => [OC, M]
// lora_up layout wrong?
// lora_up layout wrong?
checkCUBLAS
(
cublasHgemm
(
checkCUBLAS
(
cublasHgemm
(
handle
,
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
this
->
out_features
,
M
,
this
->
lora_rank
,
this
->
out_features
,
M
,
this
->
lora_rank
,
&
one
,
&
one
,
this
->
lora_up
.
data_ptr
<
half
>
(),
this
->
lora_up
.
data_ptr
<
half
>
(),
this
->
lora_rank
,
this
->
lora_rank
,
...
@@ -332,10 +410,12 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
...
@@ -332,10 +410,12 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
// IC is for next lora (OC of this layer)
// IC is for next lora (OC of this layer)
// lora_down: [M, IC] * [IC, R] => [M, R]
// lora_down: [M, IC] * [IC, R] => [M, R]
// cublas view: [R, IC] * [IC, M] => [R, M]
// cublas view: [R, IC] * [IC, M] => [R, M]
checkCUBLAS
(
cublasHgemm
(
checkCUBLAS
(
cublasHgemm
(
handle
,
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
this
->
lora_rank
,
M
,
this
->
out_features
,
this
->
lora_rank
,
M
,
this
->
out_features
,
&
one
,
&
one
,
next_lora
.
data_ptr
<
half
>
(),
next_lora
.
data_ptr
<
half
>
(),
this
->
lora_rank
,
this
->
lora_rank
,
...
@@ -383,7 +463,8 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
...
@@ -383,7 +463,8 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
debug
(
"quantize.x"
,
x
);
debug
(
"quantize.x"
,
x
);
debug
(
"quantize.smooth"
,
this
->
smooth
);
debug
(
"quantize.smooth"
,
this
->
smooth
);
kernels
::
quantize_w4a4_act_fuse_lora
(
x
,
qact
.
act
,
qact
.
ascales
,
this
->
lora_down
,
qact
.
lora_act
,
this
->
smooth
,
fuse_glu
,
use_fp4
);
kernels
::
quantize_w4a4_act_fuse_lora
(
x
,
qact
.
act
,
qact
.
ascales
,
this
->
lora_down
,
qact
.
lora_act
,
this
->
smooth
,
fuse_glu
,
use_fp4
);
debug
(
"quantize.qact"
,
qact
.
act
);
debug
(
"quantize.qact"
,
qact
.
act
);
debug
(
"quantize.ascales"
,
qact
.
ascales
);
debug
(
"quantize.ascales"
,
qact
.
ascales
);
...
@@ -396,10 +477,12 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
...
@@ -396,10 +477,12 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
// lora_down: [M, IC] * [IC, R] => [M, R]
// lora_down: [M, IC] * [IC, R] => [M, R]
// cublas view: [R, IC] * [IC, M]
// cublas view: [R, IC] * [IC, M]
checkCUBLAS
(
cublasHgemm
(
checkCUBLAS
(
cublasHgemm
(
handle
,
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
this
->
lora_rank
,
M
,
this
->
in_features
,
this
->
lora_rank
,
M
,
this
->
in_features
,
&
one
,
&
one
,
lora_down
.
data_ptr
<
half
>
(),
lora_down
.
data_ptr
<
half
>
(),
this
->
lora_rank
,
this
->
lora_rank
,
...
@@ -418,18 +501,13 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
...
@@ -418,18 +501,13 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
return
qact
;
return
qact
;
}
}
GEMM_W8A8
::
GEMM_W8A8
(
int
in_features
,
int
out_features
,
bool
bias
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
GEMM_W8A8
::
GEMM_W8A8
(
int
in_features
,
int
out_features
,
bool
bias
,
Tensor
::
ScalarType
dtype
,
Device
device
)
in_features
(
in_features
),
out_features
(
out_features
),
dtype
(
dtype
)
:
in_features
(
in_features
),
out_features
(
out_features
),
dtype
(
dtype
)
{
{
this
->
qweight
=
Tensor
::
allocate
({
out_features
,
in_features
},
Tensor
::
INT8
,
device
);
this
->
qweight
=
Tensor
::
allocate
({
out_features
,
in_features
},
Tensor
::
INT8
,
device
);
this
->
wscales
=
Tensor
::
allocate
({
out_features
},
dtype
,
device
);
this
->
wscales
=
Tensor
::
allocate
({
out_features
},
dtype
,
device
);
this
->
bias
=
bias
?
Tensor
::
allocate
({
out_features
},
dtype
,
device
,
true
)
:
Tensor
{};
this
->
bias
=
bias
?
Tensor
::
allocate
({
out_features
},
dtype
,
device
,
true
)
:
Tensor
{};
registerParams
registerParams
(
qweight
,
"qweight"
,
ParamFlags
::
LazyLoad
)(
wscales
,
"wscales"
)(
this
->
bias
,
"bias"
);
(
qweight
,
"qweight"
,
ParamFlags
::
LazyLoad
)
(
wscales
,
"wscales"
)
(
this
->
bias
,
"bias"
)
;
}
}
GEMM_W8A8
::
QuantizedActivation
GEMM_W8A8
::
quantize
(
Tensor
x
,
bool
fuse_glu
)
{
GEMM_W8A8
::
QuantizedActivation
GEMM_W8A8
::
quantize
(
Tensor
x
,
bool
fuse_glu
)
{
...
@@ -461,16 +539,11 @@ Tensor GEMM_W8A8::forward_quant(QuantizedActivation qact) {
...
@@ -461,16 +539,11 @@ Tensor GEMM_W8A8::forward_quant(QuantizedActivation qact) {
return
out
;
return
out
;
}
}
DWCONV
::
DWCONV
(
int
in_features
,
bool
use_bias
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
DWCONV
::
DWCONV
(
int
in_features
,
bool
use_bias
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
in_features
(
in_features
)
{
in_features
(
in_features
)
{
this
->
weight
=
Tensor
::
allocate
({
in_features
,
3
,
3
,
1
},
dtype
,
device
);
this
->
weight
=
Tensor
::
allocate
({
in_features
,
3
,
3
,
1
},
dtype
,
device
);
this
->
bias
=
use_bias
?
Tensor
::
allocate
({
in_features
},
dtype
,
device
)
:
Tensor
{};
this
->
bias
=
use_bias
?
Tensor
::
allocate
({
in_features
},
dtype
,
device
)
:
Tensor
{};
registerParams
registerParams
(
this
->
weight
,
"weight"
)(
this
->
bias
,
"bias"
);
(
this
->
weight
,
"weight"
)
(
this
->
bias
,
"bias"
)
;
}
}
Tensor
DWCONV
::
forward
(
Tensor
x
)
{
Tensor
DWCONV
::
forward
(
Tensor
x
)
{
...
...
src/Linear.h
View file @
57e50f8d
...
@@ -37,6 +37,7 @@ public:
...
@@ -37,6 +37,7 @@ public:
float
lora_scale
;
float
lora_scale
;
const
Device
device
;
const
Device
device
;
public:
public:
Tensor
qweight
;
Tensor
qweight
;
Tensor
wscales
;
Tensor
wscales
;
...
@@ -69,12 +70,18 @@ public:
...
@@ -69,12 +70,18 @@ public:
Tensor
forward
(
Tensor
x
);
Tensor
forward
(
Tensor
x
);
Tensor
forward_silu
(
Tensor
x
);
Tensor
forward_silu
(
Tensor
x
);
std
::
variant
<
Tensor
,
QuantizedActivation
>
forward
(
Tensor
x
,
FuseOptions
fuse
,
GEMM_W4A4
*
nextGEMM
=
nullptr
);
std
::
variant
<
Tensor
,
QuantizedActivation
>
forward
(
Tensor
x
,
FuseOptions
fuse
,
GEMM_W4A4
*
nextGEMM
=
nullptr
);
void
forward
(
void
forward
(
Tensor
x
,
Tensor
x
,
Tensor
out
,
Tensor
out
,
Tensor
pool
=
{},
Tensor
norm_q
=
{},
Tensor
norm_k
=
{},
Tensor
rotary_emb
=
{},
Tensor
pool
=
{},
Tensor
out_q
=
{},
Tensor
out_k
=
{},
Tensor
out_v
=
{},
int
numTokens
=
0
Tensor
norm_q
=
{},
);
Tensor
norm_k
=
{},
std
::
variant
<
Tensor
,
QuantizedActivation
>
forward_quant
(
QuantizedActivation
qact
,
FuseOptions
fuse
,
GEMM_W4A4
*
nextGEMM
=
nullptr
);
Tensor
rotary_emb
=
{},
Tensor
out_q
=
{},
Tensor
out_k
=
{},
Tensor
out_v
=
{},
int
numTokens
=
0
);
std
::
variant
<
Tensor
,
QuantizedActivation
>
forward_quant
(
QuantizedActivation
qact
,
FuseOptions
fuse
,
GEMM_W4A4
*
nextGEMM
=
nullptr
);
Tensor
forward_quant
(
QuantizedActivation
qact
);
Tensor
forward_quant
(
QuantizedActivation
qact
);
public:
public:
...
@@ -118,13 +125,16 @@ public:
...
@@ -118,13 +125,16 @@ public:
Tensor
act
;
Tensor
act
;
Tensor
ascales
;
Tensor
ascales
;
};
};
public:
public:
GEMM_W8A8
(
int
in_features
,
int
out_features
,
bool
bias
,
Tensor
::
ScalarType
dtype
,
Device
device
);
GEMM_W8A8
(
int
in_features
,
int
out_features
,
bool
bias
,
Tensor
::
ScalarType
dtype
,
Device
device
);
public:
public:
QuantizedActivation
quantize
(
Tensor
x
,
bool
fuse_glu
);
QuantizedActivation
quantize
(
Tensor
x
,
bool
fuse_glu
);
Tensor
forward_quant
(
QuantizedActivation
qact
);
Tensor
forward_quant
(
QuantizedActivation
qact
);
Tensor
forward
(
Tensor
x
)
{
return
forward_quant
(
quantize
(
x
,
false
));
}
Tensor
forward
(
Tensor
x
)
{
return
forward_quant
(
quantize
(
x
,
false
));
}
public:
public:
const
int
in_features
;
const
int
in_features
;
...
...
src/Module.cpp
View file @
57e50f8d
src/Module.h
View file @
57e50f8d
...
@@ -108,7 +108,8 @@ public:
...
@@ -108,7 +108,8 @@ public:
dst
=
Tensor
::
allocate
(
lazy
.
shape
,
lazy
.
type
,
lazy
.
device
);
dst
=
Tensor
::
allocate
(
lazy
.
shape
,
lazy
.
type
,
lazy
.
device
);
if
(
!
src
.
valid
()
&&
!
checkFlag
(
param
.
flags
,
ParamFlags
::
Optional
))
{
if
(
!
src
.
valid
()
&&
!
checkFlag
(
param
.
flags
,
ParamFlags
::
Optional
))
{
throw
std
::
runtime_error
(
spdlog
::
fmt_lib
::
format
(
"Lazy load: Tensor {} has no src"
,
m
->
getPrefix
()
+
key
));
throw
std
::
runtime_error
(
spdlog
::
fmt_lib
::
format
(
"Lazy load: Tensor {} has no src"
,
m
->
getPrefix
()
+
key
));
}
}
m
->
loadParam
(
key
,
dst
,
src
);
m
->
loadParam
(
key
,
dst
,
src
);
}
}
...
@@ -127,14 +128,10 @@ public:
...
@@ -127,14 +128,10 @@ public:
});
});
}
}
void
setLazyLoad
(
bool
val
)
{
void
setLazyLoad
(
bool
val
)
{
traverse
([
val
](
Module
*
m
)
{
traverse
([
val
](
Module
*
m
)
{
m
->
enabledLazyLoad
=
val
;
});
m
->
enabledLazyLoad
=
val
;
});
}
}
void
setAutoCastFP16
(
bool
val
)
{
void
setAutoCastFP16
(
bool
val
)
{
traverse
([
val
](
Module
*
m
)
{
traverse
([
val
](
Module
*
m
)
{
m
->
enabledAutoCastFP16
=
val
;
});
m
->
enabledAutoCastFP16
=
val
;
});
}
}
protected:
protected:
...
@@ -143,7 +140,8 @@ protected:
...
@@ -143,7 +140,8 @@ protected:
Tensor
::
FP16
,
Tensor
::
FP16
,
Tensor
::
BF16
,
Tensor
::
BF16
,
};
};
if
(
enabledAutoCastFP16
&&
dst
.
scalar_type
()
!=
src
.
scalar_type
()
&&
whitelist
.
contains
(
dst
.
scalar_type
())
&&
whitelist
.
contains
(
src
.
scalar_type
()))
{
if
(
enabledAutoCastFP16
&&
dst
.
scalar_type
()
!=
src
.
scalar_type
()
&&
whitelist
.
contains
(
dst
.
scalar_type
())
&&
whitelist
.
contains
(
src
.
scalar_type
()))
{
copyWithCast
(
dst
,
src
);
copyWithCast
(
dst
,
src
);
}
else
{
}
else
{
dst
.
copy_
(
src
);
dst
.
copy_
(
src
);
...
@@ -227,8 +225,7 @@ struct LayerOffloadHelper {
...
@@ -227,8 +225,7 @@ struct LayerOffloadHelper {
std
::
unique_ptr
<
CUDAEventWrapper
>
eventLoadDone
;
std
::
unique_ptr
<
CUDAEventWrapper
>
eventLoadDone
;
LayerOffloadHelper
(
bool
offload
,
int
numLayers
,
func_t
funcCompute
,
func_t
funcLoad
,
func_t
funcUnload
)
LayerOffloadHelper
(
bool
offload
,
int
numLayers
,
func_t
funcCompute
,
func_t
funcLoad
,
func_t
funcUnload
)
:
offload
(
offload
),
numLayers
(
numLayers
),
funcCompute
(
funcCompute
),
funcLoad
(
funcLoad
),
funcUnload
(
funcUnload
)
:
offload
(
offload
),
numLayers
(
numLayers
),
funcCompute
(
funcCompute
),
funcLoad
(
funcLoad
),
funcUnload
(
funcUnload
)
{
{
if
(
offload
)
{
if
(
offload
)
{
streamCompute
=
std
::
make_unique
<
CUDAStreamWrapper
>
();
streamCompute
=
std
::
make_unique
<
CUDAStreamWrapper
>
();
streamLoad
=
std
::
make_unique
<
CUDAStreamWrapper
>
();
streamLoad
=
std
::
make_unique
<
CUDAStreamWrapper
>
();
...
@@ -305,11 +302,11 @@ private:
...
@@ -305,11 +302,11 @@ private:
}
}
}
}
#ifdef _WIN32
#ifdef _WIN32
return
true
;
return
true
;
#else
#else
return
false
;
return
false
;
#endif
#endif
}
}
void
workaroundFlush
()
{
void
workaroundFlush
()
{
if
(
!
needWorkaround
)
{
if
(
!
needWorkaround
)
{
...
...
src/SanaModel.cpp
View file @
57e50f8d
...
@@ -10,18 +10,11 @@
...
@@ -10,18 +10,11 @@
using
spdlog
::
fmt_lib
::
format
;
using
spdlog
::
fmt_lib
::
format
;
using
namespace
nunchaku
;
using
namespace
nunchaku
;
SanaLinearAttention
::
SanaLinearAttention
(
SanaLinearAttention
::
SanaLinearAttention
(
int
dim
,
bool
bias
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
int
dim
,
bool
bias
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
dim
(
dim
),
:
dim
(
dim
),
dim_pad
(
ceilDiv
(
dim
,
128
)
*
128
),
qkv_proj
(
dim
,
dim_pad
*
3
,
bias
,
use_fp4
,
dtype
,
device
),
dim_pad
(
ceilDiv
(
dim
,
128
)
*
128
),
out_proj
(
dim_pad
,
dim
,
bias
,
use_fp4
,
dtype
,
device
),
pag_to_v
(
std
::
nullopt
)
{
qkv_proj
(
dim
,
dim_pad
*
3
,
bias
,
use_fp4
,
dtype
,
device
),
registerChildren
(
qkv_proj
,
"qkv_proj"
)(
out_proj
,
"out_proj"
);
out_proj
(
dim_pad
,
dim
,
bias
,
use_fp4
,
dtype
,
device
),
pag_to_v
(
std
::
nullopt
)
{
registerChildren
(
qkv_proj
,
"qkv_proj"
)
(
out_proj
,
"out_proj"
)
;
if
(
pag
)
{
if
(
pag
)
{
pag_to_v
.
emplace
(
dim
,
dim_pad
,
bias
,
use_fp4
,
dtype
,
device
);
pag_to_v
.
emplace
(
dim
,
dim_pad
,
bias
,
use_fp4
,
dtype
,
device
);
...
@@ -57,21 +50,35 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
...
@@ -57,21 +50,35 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
Tensor
q
=
Tensor
::
allocate
({
batch_size
,
num_tokens_pad
,
dim_pad
},
x
.
dtype
(),
x
.
device
());
Tensor
q
=
Tensor
::
allocate
({
batch_size
,
num_tokens_pad
,
dim_pad
},
x
.
dtype
(),
x
.
device
());
Tensor
vk
=
Tensor
::
allocate
({
batch_size
,
num_heads
,
HEAD_DIM
+
1
,
HEAD_DIM
},
Tensor
::
FP32
,
x
.
device
());
Tensor
vk
=
Tensor
::
allocate
({
batch_size
,
num_heads
,
HEAD_DIM
+
1
,
HEAD_DIM
},
Tensor
::
FP32
,
x
.
device
());
kernels
::
gemm_w4a4
(
kernels
::
gemm_w4a4
(
qact
.
act
,
qact
.
act
,
qkv_proj
.
qweight
,
qkv_proj
.
qweight
,
{},
{},
{},
{},
qact
.
ascales
,
qact
.
ascales
,
qkv_proj
.
wscales
,
qkv_proj
.
wscales
,
{},
{},
qact
.
lora_act
,
qkv_proj
.
lora_up
,
{},
{},
{},
{},
{},
qkv_proj
.
bias
,
{},
{},
vk
,
q
,
{},
qact
.
is_unsigned
,
qkv_proj
.
lora_scales
,
false
,
qact
.
lora_act
,
qkv_proj
.
lora_up
,
{},
{},
{},
{},
{},
qkv_proj
.
bias
,
{},
vk
,
q
,
qact
.
is_unsigned
,
qkv_proj
.
lora_scales
,
false
,
qkv_proj
.
use_fp4
,
qkv_proj
.
use_fp4
,
*
qkv_proj
.
wtscale
.
data_ptr
<
float
>
(),
*
qkv_proj
.
wtscale
.
data_ptr
<
float
>
(),
qkv_proj
.
wcscales
.
numel
()
>
0
?
qkv_proj
.
wcscales
:
Tensor
{},
qkv_proj
.
wcscales
.
numel
()
>
0
?
qkv_proj
.
wcscales
:
Tensor
{},
{},
{},
{},
0
{},
);
{},
{},
0
);
debug
(
"vk"
,
vk
);
debug
(
"vk"
,
vk
);
debug
(
"q"
,
q
);
debug
(
"q"
,
q
);
...
@@ -88,7 +95,6 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
...
@@ -88,7 +95,6 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
q
=
q_unpad
;
q
=
q_unpad
;
}
}
// kernels::gemm_w8a8_fuse_litela(qact.act, qkv.qweight, q, vk, qact.ascales, qkv.wscales);
// kernels::gemm_w8a8_fuse_litela(qact.act, qkv.qweight, q, vk, qact.ascales, qkv.wscales);
// return out_proj.forward(q);
// return out_proj.forward(q);
...
@@ -129,17 +135,13 @@ Tensor SanaLinearAttention::forward_pag(Tensor x, bool cfg) {
...
@@ -129,17 +135,13 @@ Tensor SanaLinearAttention::forward_pag(Tensor x, bool cfg) {
return
out
;
return
out
;
}
}
MultiHeadCrossAttention
::
MultiHeadCrossAttention
(
int
num_heads
,
int
head_dim
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
MultiHeadCrossAttention
::
MultiHeadCrossAttention
(
num_heads
(
num_heads
),
head_dim
(
head_dim
),
int
num_heads
,
int
head_dim
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
num_heads
(
num_heads
),
head_dim
(
head_dim
),
q_linear
(
num_heads
*
head_dim
,
num_heads
*
head_dim
,
true
,
use_fp4
,
dtype
,
device
),
q_linear
(
num_heads
*
head_dim
,
num_heads
*
head_dim
,
true
,
use_fp4
,
dtype
,
device
),
kv_linear
(
num_heads
*
head_dim
,
num_heads
*
head_dim
*
2
,
true
,
dtype
,
device
),
kv_linear
(
num_heads
*
head_dim
,
num_heads
*
head_dim
*
2
,
true
,
dtype
,
device
),
out_proj
(
num_heads
*
head_dim
,
num_heads
*
head_dim
,
true
,
use_fp4
,
dtype
,
device
)
out_proj
(
num_heads
*
head_dim
,
num_heads
*
head_dim
,
true
,
use_fp4
,
dtype
,
device
)
{
{
registerChildren
(
q_linear
,
"q_linear"
)(
kv_linear
,
"kv_linear"
)(
out_proj
,
"out_proj"
);
registerChildren
(
q_linear
,
"q_linear"
)
(
kv_linear
,
"kv_linear"
)
(
out_proj
,
"out_proj"
)
;
}
}
Tensor
MultiHeadCrossAttention
::
forward
(
Tensor
x
,
Tensor
cond
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
)
{
Tensor
MultiHeadCrossAttention
::
forward
(
Tensor
x
,
Tensor
cond
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
)
{
...
@@ -161,16 +163,22 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens
...
@@ -161,16 +163,22 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens
Tensor
k
=
kv
.
slice
(
1
,
0
,
num_heads
);
Tensor
k
=
kv
.
slice
(
1
,
0
,
num_heads
);
Tensor
v
=
kv
.
slice
(
1
,
num_heads
,
num_heads
*
2
);
Tensor
v
=
kv
.
slice
(
1
,
num_heads
,
num_heads
*
2
);
Tensor
attn_output
=
mha_varlen_fwd
(
Tensor
attn_output
=
mha_varlen_fwd
(
q
,
q
,
k
,
v
,
k
,
cu_seqlens_img
,
cu_seqlens_txt
,
v
,
num_tokens_img
,
num_tokens_txt
,
cu_seqlens_img
,
cu_seqlens_txt
,
num_tokens_img
,
num_tokens_txt
,
0.0
f
,
0.0
f
,
pow
(
q
.
shape
[
-
1
],
(
-
0.5
)),
pow
(
q
.
shape
[
-
1
],
(
-
0.5
)),
false
,
false
,
false
,
-
1
,
-
1
,
false
,
false
-
1
,
).
front
().
view
({
batch_size
,
num_tokens_img
,
num_heads
*
head_dim
});
-
1
,
false
)
.
front
()
.
view
({
batch_size
,
num_tokens_img
,
num_heads
*
head_dim
});
// Tensor attn_output = mha_fwd(q, k, v,
// Tensor attn_output = mha_fwd(q, k, v,
// 0.0f,
// 0.0f,
...
@@ -181,17 +189,13 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens
...
@@ -181,17 +189,13 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens
return
out_proj
.
forward
(
attn_output
);
return
out_proj
.
forward
(
attn_output
);
}
}
SanaGLUMBConv
::
SanaGLUMBConv
(
int
in_features
,
int
hidden_features
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
SanaGLUMBConv
::
SanaGLUMBConv
(
in_features
(
in_features
),
hidden_features
(
hidden_features
),
int
in_features
,
int
hidden_features
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
in_features
(
in_features
),
hidden_features
(
hidden_features
),
inverted_conv
(
in_features
,
hidden_features
*
2
,
true
,
use_fp4
,
dtype
,
device
),
inverted_conv
(
in_features
,
hidden_features
*
2
,
true
,
use_fp4
,
dtype
,
device
),
depth_conv
(
hidden_features
*
2
,
true
,
dtype
,
device
),
depth_conv
(
hidden_features
*
2
,
true
,
dtype
,
device
),
point_conv
(
hidden_features
,
in_features
,
false
,
use_fp4
,
dtype
,
device
)
point_conv
(
hidden_features
,
in_features
,
false
,
use_fp4
,
dtype
,
device
)
{
{
registerChildren
(
inverted_conv
,
"inverted_conv"
)(
depth_conv
,
"depth_conv"
)(
point_conv
,
"point_conv"
);
registerChildren
(
inverted_conv
,
"inverted_conv"
)
(
depth_conv
,
"depth_conv"
)
(
point_conv
,
"point_conv"
)
;
}
}
Tensor
SanaGLUMBConv
::
forward
(
Tensor
x
,
int
H
,
int
W
)
{
Tensor
SanaGLUMBConv
::
forward
(
Tensor
x
,
int
H
,
int
W
)
{
...
@@ -208,28 +212,34 @@ Tensor SanaGLUMBConv::forward(Tensor x, int H, int W) {
...
@@ -208,28 +212,34 @@ Tensor SanaGLUMBConv::forward(Tensor x, int H, int W) {
return
point_conv
.
forward_quant
(
qact
);
return
point_conv
.
forward_quant
(
qact
);
}
}
SanaLinearTransformerBlock
::
SanaLinearTransformerBlock
(
int
hidden_size
,
int
intermediate_size
,
int
num_cross_attention_heads
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
SanaLinearTransformerBlock
::
SanaLinearTransformerBlock
(
int
hidden_size
,
hidden_size
(
hidden_size
),
num_cross_attention_heads
(
num_cross_attention_heads
),
int
intermediate_size
,
int
num_cross_attention_heads
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
hidden_size
(
hidden_size
),
num_cross_attention_heads
(
num_cross_attention_heads
),
attn
(
hidden_size
,
false
,
pag
,
use_fp4
,
dtype
,
device
),
attn
(
hidden_size
,
false
,
pag
,
use_fp4
,
dtype
,
device
),
cross_attn
(
num_cross_attention_heads
,
hidden_size
/
num_cross_attention_heads
,
use_fp4
,
dtype
,
device
),
cross_attn
(
num_cross_attention_heads
,
hidden_size
/
num_cross_attention_heads
,
use_fp4
,
dtype
,
device
),
ff
(
hidden_size
,
intermediate_size
,
use_fp4
,
dtype
,
device
),
ff
(
hidden_size
,
intermediate_size
,
use_fp4
,
dtype
,
device
),
norm1
(
hidden_size
,
1e-6
,
false
,
dtype
,
device
),
norm1
(
hidden_size
,
1e-6
,
false
,
dtype
,
device
),
norm2
(
hidden_size
,
1e-6
,
false
,
dtype
,
device
)
{
norm2
(
hidden_size
,
1e-6
,
false
,
dtype
,
device
)
{
this
->
scale_shift_table
=
Tensor
::
allocate
({
6
,
hidden_size
},
dtype
,
device
);
this
->
scale_shift_table
=
Tensor
::
allocate
({
6
,
hidden_size
},
dtype
,
device
);
registerChildren
registerChildren
(
attn
,
"attn"
)(
cross_attn
,
"cross_attn"
)(
ff
,
"ff"
);
(
attn
,
"attn"
)
(
cross_attn
,
"cross_attn"
)
(
ff
,
"ff"
)
;
registerParams
registerParams
(
this
->
scale_shift_table
,
"scale_shift_table"
);
(
this
->
scale_shift_table
,
"scale_shift_table"
)
;
}
}
Tensor
SanaLinearTransformerBlock
::
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
)
{
Tensor
SanaLinearTransformerBlock
::
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
)
{
nvtxRangePushA
(
"SanaLinearTransformerBlock"
);
nvtxRangePushA
(
"SanaLinearTransformerBlock"
);
...
@@ -311,9 +321,7 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_
...
@@ -311,9 +321,7 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_
return
hidden_states
;
return
hidden_states
;
}
}
SanaModel
::
SanaModel
(
SanaConfig
config
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
SanaModel
::
SanaModel
(
SanaConfig
config
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
config
(
config
)
{
config
(
config
)
{
const
int
inner_dim
=
config
.
num_attention_heads
*
config
.
attention_head_dim
;
const
int
inner_dim
=
config
.
num_attention_heads
*
config
.
attention_head_dim
;
for
(
int
i
=
0
;
i
<
config
.
num_layers
;
i
++
)
{
for
(
int
i
=
0
;
i
<
config
.
num_layers
;
i
++
)
{
transformer_blocks
.
push_back
(
std
::
make_unique
<
SanaLinearTransformerBlock
>
(
transformer_blocks
.
push_back
(
std
::
make_unique
<
SanaLinearTransformerBlock
>
(
...
@@ -322,20 +330,34 @@ SanaModel::SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device)
...
@@ -322,20 +330,34 @@ SanaModel::SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device)
config
.
num_cross_attention_heads
,
config
.
num_cross_attention_heads
,
std
::
find
(
config
.
pag_layers
.
begin
(),
config
.
pag_layers
.
end
(),
i
)
!=
config
.
pag_layers
.
end
(),
std
::
find
(
config
.
pag_layers
.
begin
(),
config
.
pag_layers
.
end
(),
i
)
!=
config
.
pag_layers
.
end
(),
config
.
use_fp4
,
config
.
use_fp4
,
dtype
,
device
dtype
,
));
device
));
registerChildren
(
*
transformer_blocks
.
back
(),
format
(
"transformer_blocks.{}"
,
i
));
registerChildren
(
*
transformer_blocks
.
back
(),
format
(
"transformer_blocks.{}"
,
i
));
}
}
}
}
Tensor
SanaModel
::
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
,
bool
skip_first_layer
)
{
Tensor
SanaModel
::
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
,
bool
skip_first_layer
)
{
for
(
int
i
=
(
skip_first_layer
?
1
:
0
);
i
<
config
.
num_layers
;
i
++
)
{
for
(
int
i
=
(
skip_first_layer
?
1
:
0
);
i
<
config
.
num_layers
;
i
++
)
{
auto
&&
block
=
transformer_blocks
[
i
];
auto
&&
block
=
transformer_blocks
[
i
];
hidden_states
=
block
->
forward
(
hidden_states
=
block
->
forward
(
hidden_states
,
hidden_states
,
encoder_hidden_states
,
timestep
,
cu_seqlens_img
,
cu_seqlens_txt
,
H
,
W
,
encoder_hidden_states
,
pag
&&
std
::
find
(
config
.
pag_layers
.
begin
(),
config
.
pag_layers
.
end
(),
i
)
!=
config
.
pag_layers
.
end
(),
timestep
,
cfg
cu_seqlens_img
,
);
cu_seqlens_txt
,
H
,
W
,
pag
&&
std
::
find
(
config
.
pag_layers
.
begin
(),
config
.
pag_layers
.
end
(),
i
)
!=
config
.
pag_layers
.
end
(),
cfg
);
}
}
return
hidden_states
;
return
hidden_states
;
}
}
src/SanaModel.h
View file @
57e50f8d
...
@@ -57,9 +57,23 @@ private:
...
@@ -57,9 +57,23 @@ private:
class
SanaLinearTransformerBlock
:
public
Module
{
class
SanaLinearTransformerBlock
:
public
Module
{
public:
public:
SanaLinearTransformerBlock
(
int
hidden_size
,
int
intermediate_size
,
int
num_cross_attention_heads
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
SanaLinearTransformerBlock
(
int
hidden_size
,
int
intermediate_size
,
Tensor
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
);
int
num_cross_attention_heads
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
);
public:
public:
const
int
hidden_size
;
const
int
hidden_size
;
...
@@ -89,7 +103,16 @@ struct SanaConfig {
...
@@ -89,7 +103,16 @@ struct SanaConfig {
class
SanaModel
:
public
Module
{
class
SanaModel
:
public
Module
{
public:
public:
SanaModel
(
SanaConfig
config
,
Tensor
::
ScalarType
dtype
,
Device
device
);
SanaModel
(
SanaConfig
config
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
,
bool
skip_first_layer
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
,
bool
skip_first_layer
);
public:
public:
const
SanaConfig
config
;
const
SanaConfig
config
;
...
...
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment