Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
37a27712
Unverified
Commit
37a27712
authored
May 01, 2025
by
Muyang Li
Committed by
GitHub
May 01, 2025
Browse files
Merge pull request #340 from mit-han-lab/dev
feat: support PuLID, Double FBCache and TeaCache; better linter
parents
c1d6fc84
760ab022
Changes
192
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
964 additions
and
726 deletions
+964
-726
scripts/build_docker.sh
scripts/build_docker.sh
+1
-1
scripts/build_docker_torch27.sh
scripts/build_docker_torch27.sh
+1
-1
scripts/build_docker_torch28.sh
scripts/build_docker_torch28.sh
+1
-1
scripts/build_linux_wheel.sh
scripts/build_linux_wheel.sh
+1
-1
scripts/build_linux_wheel_cu128.sh
scripts/build_linux_wheel_cu128.sh
+1
-1
scripts/build_linux_wheel_torch2.7_cu128.sh
scripts/build_linux_wheel_torch2.7_cu128.sh
+1
-1
scripts/linux_cleanup.sh
scripts/linux_cleanup.sh
+1
-1
setup.py
setup.py
+1
-1
src/FluxModel.cpp
src/FluxModel.cpp
+300
-268
src/FluxModel.h
src/FluxModel.h
+54
-30
src/Linear.cpp
src/Linear.cpp
+216
-143
src/Linear.h
src/Linear.h
+20
-10
src/Module.cpp
src/Module.cpp
+2
-2
src/Module.h
src/Module.h
+28
-31
src/SanaModel.cpp
src/SanaModel.cpp
+121
-99
src/SanaModel.h
src/SanaModel.h
+29
-6
src/Serialization.cpp
src/Serialization.cpp
+37
-36
src/Serialization.h
src/Serialization.h
+7
-6
src/Tensor.h
src/Tensor.h
+138
-85
src/activation.cpp
src/activation.cpp
+4
-2
No files found.
scripts/build_docker.sh
View file @
37a27712
...
...
@@ -39,4 +39,4 @@ docker build -f docker/Dockerfile --no-cache \
-t
lmxyy/nunchaku:
${
NUNCHAKU_VERSION
}
-py
${
PYTHON_VERSION
}
-torch
${
TORCH_VERSION
}
-cuda
${
CUDA_VERSION
}
.
docker push lmxyy/nunchaku:
${
NUNCHAKU_VERSION
}
-py
${
PYTHON_VERSION
}
-torch
${
TORCH_VERSION
}
-cuda
${
CUDA_VERSION
}
docker rmi lmxyy/nunchaku:
${
NUNCHAKU_VERSION
}
-py
${
PYTHON_VERSION
}
-torch
${
TORCH_VERSION
}
-cuda
${
CUDA_VERSION
}
\ No newline at end of file
docker rmi lmxyy/nunchaku:
${
NUNCHAKU_VERSION
}
-py
${
PYTHON_VERSION
}
-torch
${
TORCH_VERSION
}
-cuda
${
CUDA_VERSION
}
scripts/build_docker_torch27.sh
View file @
37a27712
...
...
@@ -27,4 +27,4 @@ docker build -f docker/Dockerfile.torch27 --no-cache \
-t
lmxyy/nunchaku:
${
NUNCHAKU_VERSION
}
-py
${
PYTHON_VERSION
}
-torch
${
TORCH_VERSION
}
-cuda
${
CUDA_VERSION
}
.
docker push lmxyy/nunchaku:
${
NUNCHAKU_VERSION
}
-py
${
PYTHON_VERSION
}
-torch
${
TORCH_VERSION
}
-cuda
${
CUDA_VERSION
}
docker rmi lmxyy/nunchaku:
${
NUNCHAKU_VERSION
}
-py
${
PYTHON_VERSION
}
-torch
${
TORCH_VERSION
}
-cuda
${
CUDA_VERSION
}
\ No newline at end of file
docker rmi lmxyy/nunchaku:
${
NUNCHAKU_VERSION
}
-py
${
PYTHON_VERSION
}
-torch
${
TORCH_VERSION
}
-cuda
${
CUDA_VERSION
}
scripts/build_docker_torch28.sh
View file @
37a27712
...
...
@@ -39,4 +39,4 @@ docker build -f docker/Dockerfile.torch28 --no-cache \
-t
lmxyy/nunchaku:
${
NUNCHAKU_VERSION
}
-py
${
PYTHON_VERSION
}
-torch
${
TORCH_VERSION
}
-cuda
${
CUDA_VERSION
}
.
docker push lmxyy/nunchaku:
${
NUNCHAKU_VERSION
}
-py
${
PYTHON_VERSION
}
-torch
${
TORCH_VERSION
}
-cuda
${
CUDA_VERSION
}
docker rmi lmxyy/nunchaku:
${
NUNCHAKU_VERSION
}
-py
${
PYTHON_VERSION
}
-torch
${
TORCH_VERSION
}
-cuda
${
CUDA_VERSION
}
\ No newline at end of file
docker rmi lmxyy/nunchaku:
${
NUNCHAKU_VERSION
}
-py
${
PYTHON_VERSION
}
-torch
${
TORCH_VERSION
}
-cuda
${
CUDA_VERSION
}
scripts/build_linux_wheel.sh
View file @
37a27712
...
...
@@ -35,4 +35,4 @@ docker run --rm \
export NUNCHAKU_BUILD_WHEELS=1 &&
\
export MAX_JOBS=
${
MAX_JOBS
}
&&
\
${
PYTHON_ROOT_PATH
}
/bin/python -m build --wheel --no-isolation
"
\ No newline at end of file
"
scripts/build_linux_wheel_cu128.sh
View file @
37a27712
...
...
@@ -33,4 +33,4 @@ docker run --rm \
export NUNCHAKU_BUILD_WHEELS=1 &&
\
export MAX_JOBS=
${
MAX_JOBS
}
&&
\
${
PYTHON_ROOT_PATH
}
/bin/python -m build --wheel --no-isolation
"
\ No newline at end of file
"
scripts/build_linux_wheel_torch2.7_cu128.sh
View file @
37a27712
...
...
@@ -33,4 +33,4 @@ docker run --rm \
export NUNCHAKU_BUILD_WHEELS=1 &&
\
export MAX_JOBS=
${
MAX_JOBS
}
&&
\
${
PYTHON_ROOT_PATH
}
/bin/python -m build --wheel --no-isolation
"
\ No newline at end of file
"
scripts/linux_cleanup.sh
View file @
37a27712
...
...
@@ -4,4 +4,4 @@ set -ex
docker run
--rm
\
-v
"
$(
pwd
)
"
:/nunchaku
\
pytorch/manylinux-builder:cuda12.4
\
bash
-c
"cd /nunchaku && rm -rf *"
\ No newline at end of file
bash
-c
"cd /nunchaku && rm -rf *"
setup.py
View file @
37a27712
...
...
@@ -6,7 +6,7 @@ import sys
import
setuptools
import
torch
from
packaging
import
version
as
packaging_version
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDA_HOME
,
CUDAExtension
from
torch.utils.cpp_extension
import
CUDA_HOME
,
BuildExtension
,
CUDAExtension
class
CustomBuildExtension
(
BuildExtension
):
...
...
src/FluxModel.cpp
View file @
37a27712
...
...
@@ -4,20 +4,18 @@
#include "kernels/zgemm/zgemm.h"
#include "flash_api.h"
#include "activation.h"
#include <nvtx3/nvToolsExt.h>
#include <pybind11/functional.h>
#include <iostream>
using
spdlog
::
fmt_lib
::
format
;
using
namespace
nunchaku
;
Tensor
forward_mlp
(
GEMM_W4A4
&
fc1
,
GEMM_W4A4
&
fc2
,
Tensor
norm_hidden_states
)
{
Tensor
ff_output
=
fc2
.
forward_quant
(
std
::
get
<
GEMM_W4A4
::
QuantizedActivation
>
(
fc1
.
forward
(
norm_hidden_states
,
GEMM_W4A4
::
FuseOptions
::
GELU_QUANT
,
&
fc2
))
);
Tensor
ff_output
=
fc2
.
forward_quant
(
std
::
get
<
GEMM_W4A4
::
QuantizedActivation
>
(
fc1
.
forward
(
norm_hidden_states
,
GEMM_W4A4
::
FuseOptions
::
GELU_QUANT
,
&
fc2
)));
return
ff_output
;
}
...
...
@@ -26,7 +24,6 @@ Tensor forward_mlp(GEMM_W4A4 &fc1, GEMM_W4A4 &fc2, Tensor norm_hidden_states) {
// return ff_output;
// }
Tensor
forward_fc
(
GEMM_W4A4
&
fc
,
Tensor
x
)
{
return
fc
.
forward
(
x
);
// return std::get<Tensor>(fc.forward(x));
...
...
@@ -36,16 +33,9 @@ Tensor forward_fc(GEMM_W4A4 &fc, Tensor x) {
// return fc.forward(x);
// }
AdaLayerNormZeroSingle
::
AdaLayerNormZeroSingle
(
int
dim
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
linear
(
dim
,
3
*
dim
,
true
,
dtype
,
device
),
norm
(
dim
,
1e-6
,
false
,
dtype
,
device
)
{
registerChildren
(
linear
,
"linear"
)
(
norm
,
"norm"
)
;
AdaLayerNormZeroSingle
::
AdaLayerNormZeroSingle
(
int
dim
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
linear
(
dim
,
3
*
dim
,
true
,
dtype
,
device
),
norm
(
dim
,
1e-6
,
false
,
dtype
,
device
)
{
registerChildren
(
linear
,
"linear"
)(
norm
,
"norm"
);
}
AdaLayerNormZeroSingle
::
Output
AdaLayerNormZeroSingle
::
forward
(
Tensor
x
,
Tensor
emb
)
{
...
...
@@ -65,15 +55,10 @@ AdaLayerNormZeroSingle::Output AdaLayerNormZeroSingle::forward(Tensor x, Tensor
return
Output
{
norm_x
,
gate_msa
};
}
AdaLayerNormZero
::
AdaLayerNormZero
(
int
dim
,
bool
pre_only
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
pre_only
(
pre_only
),
linear
(
dim
,
pre_only
?
2
*
dim
:
6
*
dim
,
true
,
dtype
,
device
),
norm
(
dim
,
1e-6
,
false
,
dtype
,
device
)
{
registerChildren
(
linear
,
"linear"
)
(
norm
,
"norm"
)
;
AdaLayerNormZero
::
AdaLayerNormZero
(
int
dim
,
bool
pre_only
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
pre_only
(
pre_only
),
linear
(
dim
,
pre_only
?
2
*
dim
:
6
*
dim
,
true
,
dtype
,
device
),
norm
(
dim
,
1e-6
,
false
,
dtype
,
device
)
{
registerChildren
(
linear
,
"linear"
)(
norm
,
"norm"
);
}
AdaLayerNormZero
::
Output
AdaLayerNormZero
::
forward
(
Tensor
x
,
Tensor
emb
)
{
...
...
@@ -110,10 +95,8 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
}
}
Attention
::
Attention
(
int
num_heads
,
int
dim_head
,
Device
device
)
:
num_heads
(
num_heads
),
dim_head
(
dim_head
),
force_fp16
(
false
)
{
Attention
::
Attention
(
int
num_heads
,
int
dim_head
,
Device
device
)
:
num_heads
(
num_heads
),
dim_head
(
dim_head
),
force_fp16
(
false
)
{
headmask_type
=
Tensor
::
allocate
({
num_heads
},
Tensor
::
INT32
,
Device
::
cpu
());
for
(
int
i
=
0
;
i
<
num_heads
;
i
++
)
{
headmask_type
.
data_ptr
<
int32_t
>
()[
i
]
=
i
+
1
;
...
...
@@ -124,27 +107,23 @@ Attention::Attention(int num_heads, int dim_head, Device device) :
Tensor
Attention
::
forward
(
Tensor
qkv
)
{
assert
(
qkv
.
ndims
()
==
3
);
const
Device
device
=
qkv
.
device
();
const
Device
device
=
qkv
.
device
();
const
int
batch_size
=
qkv
.
shape
[
0
];
const
int
num_tokens
=
qkv
.
shape
[
1
];
assert
(
qkv
.
shape
[
2
]
==
num_heads
*
dim_head
*
3
);
Tensor
reshaped
=
qkv
.
view
({
batch_size
,
num_tokens
,
num_heads
*
3
,
dim_head
});
Tensor
q
=
reshaped
.
slice
(
2
,
0
,
num_heads
);
Tensor
k
=
reshaped
.
slice
(
2
,
num_heads
,
num_heads
*
2
);
Tensor
v
=
reshaped
.
slice
(
2
,
num_heads
*
2
,
num_heads
*
3
);
Tensor
q
=
reshaped
.
slice
(
2
,
0
,
num_heads
);
Tensor
k
=
reshaped
.
slice
(
2
,
num_heads
,
num_heads
*
2
);
Tensor
v
=
reshaped
.
slice
(
2
,
num_heads
*
2
,
num_heads
*
3
);
Tensor
raw_attn_output
=
mha_fwd
(
q
,
k
,
v
,
0.0
f
,
pow
(
q
.
shape
[
-
1
],
(
-
0.5
)),
false
,
-
1
,
-
1
,
false
).
front
();
Tensor
raw_attn_output
=
mha_fwd
(
q
,
k
,
v
,
0.0
f
,
pow
(
q
.
shape
[
-
1
],
(
-
0.5
)),
false
,
-
1
,
-
1
,
false
).
front
();
assert
(
raw_attn_output
.
shape
[
0
]
==
batch_size
);
assert
(
raw_attn_output
.
shape
[
1
]
==
num_tokens
);
assert
(
raw_attn_output
.
shape
[
2
]
==
num_heads
);
assert
(
raw_attn_output
.
shape
[
3
]
==
dim_head
);
return
raw_attn_output
.
view
({
batch_size
*
num_tokens
,
num_heads
,
dim_head
});
}
...
...
@@ -153,13 +132,13 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
assert
(
qkv
.
ndims
()
==
3
);
const
Device
device
=
qkv
.
device
();
const
Device
device
=
qkv
.
device
();
const
int
batch_size
=
qkv
.
shape
[
0
];
const
int
num_tokens
=
qkv
.
shape
[
1
];
assert
(
qkv
.
shape
[
2
]
==
num_heads
*
dim_head
*
3
);
constexpr
int
POOL_SIZE
=
128
;
const
int
pool_tokens
=
ceilDiv
(
num_tokens
,
POOL_SIZE
);
const
int
pool_tokens
=
ceilDiv
(
num_tokens
,
POOL_SIZE
);
Tensor
blockmask
;
...
...
@@ -173,11 +152,11 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
if
(
pool_qkv
.
valid
()
&&
sparsityRatio
>
0
)
{
pool_qkv
=
pool_qkv
.
view
({
batch_size
,
pool_tokens
,
3
,
num_heads
,
dim_head
});
pool_qkv
=
pool_qkv
.
transpose
(
1
,
2
).
transpose
(
2
,
3
);
// [batch_size, 3, num_heads, poolTokens, dim_head]
pool_qkv
=
pool_qkv
.
transpose
(
1
,
2
).
transpose
(
2
,
3
);
// [batch_size, 3, num_heads, poolTokens, dim_head]
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
Tensor
pool_q
=
pool_qkv
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
0
,
1
);
Tensor
pool_k
=
pool_qkv
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
1
,
2
);
Tensor
pool_s
=
pool_score
.
slice
(
0
,
i
,
i
+
1
);
Tensor
pool_q
=
pool_qkv
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
0
,
1
);
Tensor
pool_k
=
pool_qkv
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
1
,
2
);
Tensor
pool_s
=
pool_score
.
slice
(
0
,
i
,
i
+
1
);
gemm_batched_fp16
(
pool_q
,
pool_k
,
pool_s
);
}
}
...
...
@@ -197,7 +176,7 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
}
}
if
(
!
cu_seqlens_cpu
.
valid
())
{
cu_seqlens_cpu
=
Tensor
::
allocate
({
batch_size
+
1
},
Tensor
::
INT32
,
Device
::
cpu
());
cu_seqlens_cpu
=
Tensor
::
allocate
({
batch_size
+
1
},
Tensor
::
INT32
,
Device
::
cpu
());
cu_seqlens_cpu
.
data_ptr
<
int32_t
>
()[
0
]
=
0
;
for
(
int
i
=
1
;
i
<=
batch_size
;
i
++
)
{
cu_seqlens_cpu
.
data_ptr
<
int32_t
>
()[
i
]
=
cu_seqlens_cpu
.
data_ptr
<
int32_t
>
()[
i
-
1
]
+
num_tokens
;
...
...
@@ -215,25 +194,32 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
Tensor
cu_seqlens
=
cu_seqlens_cpu
.
copy
(
device
);
Tensor
reshaped
=
qkv
.
view
({
batch_size
*
num_tokens
,
num_heads
*
3
,
dim_head
});
Tensor
q
=
reshaped
.
slice
(
1
,
0
,
num_heads
);
Tensor
k
=
reshaped
.
slice
(
1
,
num_heads
,
num_heads
*
2
);
Tensor
v
=
reshaped
.
slice
(
1
,
num_heads
*
2
,
num_heads
*
3
);
Tensor
q
=
reshaped
.
slice
(
1
,
0
,
num_heads
);
Tensor
k
=
reshaped
.
slice
(
1
,
num_heads
,
num_heads
*
2
);
Tensor
v
=
reshaped
.
slice
(
1
,
num_heads
*
2
,
num_heads
*
3
);
spdlog
::
debug
(
"q,k,v={}"
,
q
.
shape
.
str
());
Tensor
raw_attn_output
=
mha_fwd_block
(
q
,
k
,
v
,
cu_seqlens
,
cu_seqlens
,
POOL_SIZE
,
POOL_SIZE
,
headmask_type
,
{},
blockmask
,
num_tokens
,
num_tokens
,
0.0
f
,
pow
(
q
.
shape
[
-
1
],
(
-
0.5
)),
false
,
false
,
false
,
-
1
,
-
1
).
front
();
Tensor
raw_attn_output
=
mha_fwd_block
(
q
,
k
,
v
,
cu_seqlens
,
cu_seqlens
,
POOL_SIZE
,
POOL_SIZE
,
headmask_type
,
{},
blockmask
,
num_tokens
,
num_tokens
,
0.0
f
,
pow
(
q
.
shape
[
-
1
],
(
-
0.5
)),
false
,
false
,
false
,
-
1
,
-
1
)
.
front
();
debug
(
"raw_attn_output"
,
raw_attn_output
);
...
...
@@ -290,30 +276,22 @@ void Attention::setForceFP16(Module *module, bool value) {
});
}
FluxSingleTransformerBlock
::
FluxSingleTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
int
mlp_ratio
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
dim_head
(
attention_head_dim
/
num_attention_heads
),
num_heads
(
num_attention_heads
),
mlp_hidden_dim
(
dim
*
mlp_ratio
),
norm
(
dim
,
dtype
,
device
),
mlp_fc1
(
dim
,
mlp_hidden_dim
,
true
,
use_fp4
,
dtype
,
device
),
mlp_fc2
(
mlp_hidden_dim
,
dim
,
true
,
use_fp4
,
dtype
,
device
),
qkv_proj
(
dim
,
dim
*
3
,
true
,
use_fp4
,
dtype
,
device
),
norm_q
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
norm_k
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
attn
(
num_attention_heads
,
attention_head_dim
/
num_attention_heads
,
device
),
out_proj
(
dim
,
dim
,
true
,
use_fp4
,
dtype
,
device
)
{
registerChildren
(
norm
,
"norm"
)
(
mlp_fc1
,
"mlp_fc1"
)
(
mlp_fc2
,
"mlp_fc2"
)
(
qkv_proj
,
"qkv_proj"
)
(
norm_q
,
"norm_q"
)
(
norm_k
,
"norm_k"
)
(
attn
,
"attn"
)
(
out_proj
,
"out_proj"
)
;
FluxSingleTransformerBlock
::
FluxSingleTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
int
mlp_ratio
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
dim_head
(
attention_head_dim
/
num_attention_heads
),
num_heads
(
num_attention_heads
),
mlp_hidden_dim
(
dim
*
mlp_ratio
),
norm
(
dim
,
dtype
,
device
),
mlp_fc1
(
dim
,
mlp_hidden_dim
,
true
,
use_fp4
,
dtype
,
device
),
mlp_fc2
(
mlp_hidden_dim
,
dim
,
true
,
use_fp4
,
dtype
,
device
),
qkv_proj
(
dim
,
dim
*
3
,
true
,
use_fp4
,
dtype
,
device
),
norm_q
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
norm_k
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
attn
(
num_attention_heads
,
attention_head_dim
/
num_attention_heads
,
device
),
out_proj
(
dim
,
dim
,
true
,
use_fp4
,
dtype
,
device
)
{
registerChildren
(
norm
,
"norm"
)(
mlp_fc1
,
"mlp_fc1"
)(
mlp_fc2
,
"mlp_fc2"
)(
qkv_proj
,
"qkv_proj"
)(
norm_q
,
"norm_q"
)(
norm_k
,
"norm_k"
)(
attn
,
"attn"
)(
out_proj
,
"out_proj"
);
}
Tensor
FluxSingleTransformerBlock
::
forward
(
Tensor
hidden_states
,
Tensor
temb
,
Tensor
rotary_emb
)
{
...
...
@@ -334,12 +312,18 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
debug
(
"rotary_emb"
,
rotary_emb
);
if
(
attnImpl
==
AttentionImpl
::
FlashAttention2
)
{
Tensor
qkv
=
Tensor
::
allocate
({
batch_size
,
num_tokens
,
dim
*
3
},
norm_hidden_states
.
scalar_type
(),
norm_hidden_states
.
device
());
Tensor
qkv
=
Tensor
::
allocate
(
{
batch_size
,
num_tokens
,
dim
*
3
},
norm_hidden_states
.
scalar_type
(),
norm_hidden_states
.
device
());
// qkv_proj.forward(norm_hidden_states, qkv, {});
// debug("qkv_raw", qkv);
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
qkv_proj
.
forward
(
norm_hidden_states
.
slice
(
0
,
i
,
i
+
1
),
qkv
.
slice
(
0
,
i
,
i
+
1
),
{},
norm_q
.
weight
,
norm_k
.
weight
,
rotary_emb
);
qkv_proj
.
forward
(
norm_hidden_states
.
slice
(
0
,
i
,
i
+
1
),
qkv
.
slice
(
0
,
i
,
i
+
1
),
{},
norm_q
.
weight
,
norm_k
.
weight
,
rotary_emb
);
}
debug
(
"qkv"
,
qkv
);
// Tensor qkv = forward_fc(qkv_proj, norm_hidden_states);
...
...
@@ -352,24 +336,33 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
const
int
num_tokens_pad
=
ceilDiv
(
num_tokens
,
256
)
*
256
;
Tensor
q
=
Tensor
::
allocate
({
batch_size
,
num_heads
,
num_tokens_pad
,
dim_head
},
Tensor
::
FP16
,
norm_hidden_states
.
device
());
Tensor
k
=
Tensor
::
allocate
({
batch_size
,
num_heads
,
num_tokens_pad
,
dim_head
},
Tensor
::
FP16
,
norm_hidden_states
.
device
());
Tensor
v
=
Tensor
::
allocate
({
batch_size
,
num_heads
,
num_tokens_pad
,
dim_head
},
Tensor
::
FP16
,
norm_hidden_states
.
device
());
Tensor
q
=
Tensor
::
allocate
(
{
batch_size
,
num_heads
,
num_tokens_pad
,
dim_head
},
Tensor
::
FP16
,
norm_hidden_states
.
device
());
Tensor
k
=
Tensor
::
allocate
(
{
batch_size
,
num_heads
,
num_tokens_pad
,
dim_head
},
Tensor
::
FP16
,
norm_hidden_states
.
device
());
Tensor
v
=
Tensor
::
allocate
(
{
batch_size
,
num_heads
,
num_tokens_pad
,
dim_head
},
Tensor
::
FP16
,
norm_hidden_states
.
device
());
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
qkv_proj
.
forward
(
norm_hidden_states
.
slice
(
0
,
i
,
i
+
1
),
{},
{},
norm_q
.
weight
,
norm_k
.
weight
,
rotary_emb
,
q
.
slice
(
0
,
i
,
i
+
1
),
k
.
slice
(
0
,
i
,
i
+
1
),
v
.
slice
(
0
,
i
,
i
+
1
),
num_tokens
);
qkv_proj
.
forward
(
norm_hidden_states
.
slice
(
0
,
i
,
i
+
1
),
{},
{},
norm_q
.
weight
,
norm_k
.
weight
,
rotary_emb
,
q
.
slice
(
0
,
i
,
i
+
1
),
k
.
slice
(
0
,
i
,
i
+
1
),
v
.
slice
(
0
,
i
,
i
+
1
),
num_tokens
);
}
debug
(
"packed_q"
,
q
);
debug
(
"packed_k"
,
k
);
debug
(
"packed_v"
,
v
);
Tensor
o
=
Tensor
::
allocate
({
batch_size
,
num_tokens_pad
,
num_heads
*
dim_head
},
norm_hidden_states
.
scalar_type
(),
norm_hidden_states
.
device
());
Tensor
o
=
Tensor
::
allocate
({
batch_size
,
num_tokens_pad
,
num_heads
*
dim_head
},
norm_hidden_states
.
scalar_type
(),
norm_hidden_states
.
device
());
kernels
::
attention_fp16
(
q
,
k
,
v
,
o
,
pow
(
dim_head
,
(
-
0.5
)));
...
...
@@ -377,16 +370,14 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
attn_output
=
o
.
slice
(
1
,
0
,
num_tokens
);
}
else
{
attn_output
=
Tensor
::
allocate
({
batch_size
,
num_tokens
,
num_heads
*
dim_head
},
o
.
scalar_type
(),
o
.
device
());
checkCUDA
(
cudaMemcpy2DAsync
(
attn_output
.
data_ptr
(),
attn_output
.
stride
(
0
)
*
attn_output
.
scalar_size
(),
o
.
data_ptr
(),
o
.
stride
(
0
)
*
o
.
scalar_size
(),
attn_output
.
stride
(
0
)
*
attn_output
.
scalar_size
(),
batch_size
,
cudaMemcpyDeviceToDevice
,
getCurrentCUDAStream
()
));
checkCUDA
(
cudaMemcpy2DAsync
(
attn_output
.
data_ptr
(),
attn_output
.
stride
(
0
)
*
attn_output
.
scalar_size
(),
o
.
data_ptr
(),
o
.
stride
(
0
)
*
o
.
scalar_size
(),
attn_output
.
stride
(
0
)
*
attn_output
.
scalar_size
(),
batch_size
,
cudaMemcpyDeviceToDevice
,
getCurrentCUDAStream
()));
}
}
else
{
assert
(
false
);
...
...
@@ -394,8 +385,6 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
debug
(
"raw_attn_output"
,
attn_output
);
attn_output
=
forward_fc
(
out_proj
,
attn_output
);
debug
(
"attn_output"
,
attn_output
);
...
...
@@ -413,54 +402,40 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
return
hidden_states
;
}
JointTransformerBlock
::
JointTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
bool
context_pre_only
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
dim_head
(
attention_head_dim
/
num_attention_heads
),
num_heads
(
num_attention_heads
),
context_pre_only
(
context_pre_only
),
norm1
(
dim
,
false
,
dtype
,
device
),
norm1_context
(
dim
,
context_pre_only
,
dtype
,
device
),
qkv_proj
(
dim
,
dim
*
3
,
true
,
use_fp4
,
dtype
,
device
),
qkv_proj_context
(
dim
,
dim
*
3
,
true
,
use_fp4
,
dtype
,
device
),
norm_q
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
norm_k
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
norm_added_q
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
norm_added_k
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
attn
(
num_attention_heads
,
attention_head_dim
/
num_attention_heads
,
device
),
out_proj
(
dim
,
dim
,
true
,
use_fp4
,
dtype
,
device
),
out_proj_context
(
dim
,
dim
,
true
,
use_fp4
,
dtype
,
device
),
norm2
(
dim
,
1e-6
,
false
,
dtype
,
device
),
norm2_context
(
dim
,
1e-6
,
false
,
dtype
,
device
),
mlp_fc1
(
dim
,
dim
*
4
,
true
,
use_fp4
,
dtype
,
device
),
mlp_fc2
(
dim
*
4
,
dim
,
true
,
use_fp4
,
dtype
,
device
),
mlp_context_fc1
(
dim
,
dim
*
4
,
true
,
use_fp4
,
dtype
,
device
),
mlp_context_fc2
(
dim
*
4
,
dim
,
true
,
use_fp4
,
dtype
,
device
)
{
registerChildren
(
norm1
,
"norm1"
)
(
norm1_context
,
"norm1_context"
)
(
qkv_proj
,
"qkv_proj"
)
(
qkv_proj_context
,
"qkv_proj_context"
)
(
norm_q
,
"norm_q"
)
(
norm_k
,
"norm_k"
)
(
norm_added_q
,
"norm_added_q"
)
(
norm_added_k
,
"norm_added_k"
)
(
attn
,
"attn"
)
(
out_proj
,
"out_proj"
)
(
out_proj_context
,
"out_proj_context"
)
(
norm2
,
"norm2"
)
(
norm2_context
,
"norm2_context"
)
(
mlp_fc1
,
"mlp_fc1"
)
(
mlp_fc2
,
"mlp_fc2"
)
(
mlp_context_fc1
,
"mlp_context_fc1"
)
(
mlp_context_fc2
,
"mlp_context_fc2"
)
;
JointTransformerBlock
::
JointTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
bool
context_pre_only
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
dim_head
(
attention_head_dim
/
num_attention_heads
),
num_heads
(
num_attention_heads
),
context_pre_only
(
context_pre_only
),
norm1
(
dim
,
false
,
dtype
,
device
),
norm1_context
(
dim
,
context_pre_only
,
dtype
,
device
),
qkv_proj
(
dim
,
dim
*
3
,
true
,
use_fp4
,
dtype
,
device
),
qkv_proj_context
(
dim
,
dim
*
3
,
true
,
use_fp4
,
dtype
,
device
),
norm_q
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
norm_k
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
norm_added_q
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
norm_added_k
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
attn
(
num_attention_heads
,
attention_head_dim
/
num_attention_heads
,
device
),
out_proj
(
dim
,
dim
,
true
,
use_fp4
,
dtype
,
device
),
out_proj_context
(
dim
,
dim
,
true
,
use_fp4
,
dtype
,
device
),
norm2
(
dim
,
1e-6
,
false
,
dtype
,
device
),
norm2_context
(
dim
,
1e-6
,
false
,
dtype
,
device
),
mlp_fc1
(
dim
,
dim
*
4
,
true
,
use_fp4
,
dtype
,
device
),
mlp_fc2
(
dim
*
4
,
dim
,
true
,
use_fp4
,
dtype
,
device
),
mlp_context_fc1
(
dim
,
dim
*
4
,
true
,
use_fp4
,
dtype
,
device
),
mlp_context_fc2
(
dim
*
4
,
dim
,
true
,
use_fp4
,
dtype
,
device
)
{
registerChildren
(
norm1
,
"norm1"
)(
norm1_context
,
"norm1_context"
)(
qkv_proj
,
"qkv_proj"
)(
qkv_proj_context
,
"qkv_proj_context"
)(
norm_q
,
"norm_q"
)(
norm_k
,
"norm_k"
)(
norm_added_q
,
"norm_added_q"
)(
norm_added_k
,
"norm_added_k"
)(
attn
,
"attn"
)(
out_proj
,
"out_proj"
)(
out_proj_context
,
"out_proj_context"
)(
norm2
,
"norm2"
)(
norm2_context
,
"norm2_context"
)(
mlp_fc1
,
"mlp_fc1"
)(
mlp_fc2
,
"mlp_fc2"
)(
mlp_context_fc1
,
"mlp_context_fc1"
)(
mlp_context_fc2
,
"mlp_context_fc2"
);
}
// hidden_states: [Batch, Width * Height, dim]
// encoder_hidden_states: [Batch, Token, dim]
std
::
tuple
<
Tensor
,
Tensor
>
JointTransformerBlock
::
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb
,
Tensor
rotary_emb_context
,
float
sparsityRatio
)
{
std
::
tuple
<
Tensor
,
Tensor
>
JointTransformerBlock
::
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb
,
Tensor
rotary_emb_context
,
float
sparsityRatio
)
{
int
batch_size
=
hidden_states
.
shape
[
0
];
assert
(
encoder_hidden_states
.
shape
[
0
]
==
batch_size
);
...
...
@@ -468,17 +443,19 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
nvtxRangePushA
(
"AdaNorm"
);
int
num_tokens_img
=
hidden_states
.
shape
[
1
];
int
num_tokens_txt
=
encoder_hidden_states
.
shape
[
1
];
assert
(
hidden_states
.
shape
[
2
]
==
dim
);
assert
(
encoder_hidden_states
.
shape
[
2
]
==
dim
);
spdlog
::
debug
(
"hidden_states={} encoder_hidden_states={} temb={}"
,
hidden_states
.
shape
.
str
(),
encoder_hidden_states
.
shape
.
str
(),
temb
.
shape
.
str
());
spdlog
::
debug
(
"hidden_states={} encoder_hidden_states={} temb={}"
,
hidden_states
.
shape
.
str
(),
encoder_hidden_states
.
shape
.
str
(),
temb
.
shape
.
str
());
spdlog
::
debug
(
"batch_size={} num_tokens_img={} num_tokens_txt={}"
,
batch_size
,
num_tokens_img
,
num_tokens_txt
);
auto
norm1_output
=
norm1
.
forward
(
hidden_states
,
temb
);
auto
norm1_output
=
norm1
.
forward
(
hidden_states
,
temb
);
auto
norm1_context_output
=
norm1_context
.
forward
(
encoder_hidden_states
,
temb
);
#if 0
...
...
@@ -511,30 +488,37 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
const
bool
blockSparse
=
sparsityRatio
>
0
;
const
int
poolTokens
=
num_tokens_img
/
POOL_SIZE
+
num_tokens_txt
/
POOL_SIZE
;
concat
=
Tensor
::
allocate
({
batch_size
,
num_tokens_img
+
num_tokens_txt
,
dim
*
3
},
norm1_output
.
x
.
scalar_type
(),
norm1_output
.
x
.
device
());
concat
=
Tensor
::
allocate
({
batch_size
,
num_tokens_img
+
num_tokens_txt
,
dim
*
3
},
norm1_output
.
x
.
scalar_type
(),
norm1_output
.
x
.
device
());
pool
=
blockSparse
?
Tensor
::
allocate
({
batch_size
,
poolTokens
,
dim
*
3
},
norm1_output
.
x
.
scalar_type
(),
norm1_output
.
x
.
device
())
:
Tensor
{};
pool
=
blockSparse
?
Tensor
::
allocate
({
batch_size
,
poolTokens
,
dim
*
3
},
norm1_output
.
x
.
scalar_type
(),
norm1_output
.
x
.
device
())
:
Tensor
{};
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
// img first
Tensor
qkv
=
concat
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
0
,
num_tokens_img
);
Tensor
qkv_context
=
concat
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
num_tokens_img
,
num_tokens_img
+
num_tokens_txt
);
Tensor
qkv_context
=
concat
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
num_tokens_img
,
num_tokens_img
+
num_tokens_txt
);
Tensor
pool_qkv
=
pool
.
valid
()
?
pool
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
0
,
num_tokens_img
/
POOL_SIZE
)
:
Tensor
{};
Tensor
pool_qkv
=
pool
.
valid
()
?
pool
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
0
,
num_tokens_img
/
POOL_SIZE
)
:
Tensor
{};
Tensor
pool_qkv_context
=
pool
.
valid
()
?
concat
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
num_tokens_img
/
POOL_SIZE
,
num_tokens_img
/
POOL_SIZE
+
num_tokens_txt
/
POOL_SIZE
)
:
Tensor
{};
?
pool
.
slice
(
0
,
i
,
i
+
1
)
.
slice
(
1
,
num_tokens_img
/
POOL_SIZE
,
num_tokens_img
/
POOL_SIZE
+
num_tokens_txt
/
POOL_SIZE
)
:
Tensor
{};
// qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv);
// debug("qkv_raw", qkv);
debug
(
"rotary_emb"
,
rotary_emb
);
qkv_proj
.
forward
(
norm1_output
.
x
.
slice
(
0
,
i
,
i
+
1
),
qkv
,
pool_qkv
,
norm_q
.
weight
,
norm_k
.
weight
,
rotary_emb
);
qkv_proj
.
forward
(
norm1_output
.
x
.
slice
(
0
,
i
,
i
+
1
),
qkv
,
pool_qkv
,
norm_q
.
weight
,
norm_k
.
weight
,
rotary_emb
);
debug
(
"qkv"
,
qkv
);
// qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context);
...
...
@@ -542,7 +526,12 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
debug
(
"rotary_emb_context"
,
rotary_emb_context
);
qkv_proj_context
.
forward
(
norm1_context_output
.
x
.
slice
(
0
,
i
,
i
+
1
),
qkv_context
,
pool_qkv_context
,
norm_added_q
.
weight
,
norm_added_k
.
weight
,
rotary_emb_context
);
qkv_proj_context
.
forward
(
norm1_context_output
.
x
.
slice
(
0
,
i
,
i
+
1
),
qkv_context
,
pool_qkv_context
,
norm_added_q
.
weight
,
norm_added_k
.
weight
,
rotary_emb_context
);
debug
(
"qkv_context"
,
qkv_context
);
}
...
...
@@ -577,28 +566,40 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
{
nvtxRangePushA
(
"qkv_proj"
);
concat_q
=
Tensor
::
allocate
({
batch_size
,
num_heads
,
num_tokens_img_pad
+
num_tokens_txt_pad
,
dim_head
},
Tensor
::
FP16
,
norm1_output
.
x
.
device
());
concat_q
=
Tensor
::
allocate
({
batch_size
,
num_heads
,
num_tokens_img_pad
+
num_tokens_txt_pad
,
dim_head
},
Tensor
::
FP16
,
norm1_output
.
x
.
device
());
concat_k
=
Tensor
::
empty_like
(
concat_q
);
concat_v
=
Tensor
::
empty_like
(
concat_q
);
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
// img first
auto
sliceImg
=
[
&
](
Tensor
x
)
{
return
x
.
slice
(
0
,
i
,
i
+
1
).
slice
(
2
,
0
,
num_tokens_img_pad
);
};
auto
sliceImg
=
[
&
](
Tensor
x
)
{
return
x
.
slice
(
0
,
i
,
i
+
1
).
slice
(
2
,
0
,
num_tokens_img_pad
);
};
auto
sliceTxt
=
[
&
](
Tensor
x
)
{
return
x
.
slice
(
0
,
i
,
i
+
1
).
slice
(
2
,
num_tokens_img_pad
,
num_tokens_img_pad
+
num_tokens_txt_pad
);
return
x
.
slice
(
0
,
i
,
i
+
1
).
slice
(
2
,
num_tokens_img_pad
,
num_tokens_img_pad
+
num_tokens_txt_pad
);
};
qkv_proj
.
forward
(
norm1_output
.
x
.
slice
(
0
,
i
,
i
+
1
),
{},
{},
norm_q
.
weight
,
norm_k
.
weight
,
rotary_emb
,
sliceImg
(
concat_q
),
sliceImg
(
concat_k
),
sliceImg
(
concat_v
),
num_tokens_img
);
qkv_proj_context
.
forward
(
norm1_context_output
.
x
.
slice
(
0
,
i
,
i
+
1
),
{},
{},
norm_added_q
.
weight
,
norm_added_k
.
weight
,
rotary_emb_context
,
sliceTxt
(
concat_q
),
sliceTxt
(
concat_k
),
sliceTxt
(
concat_v
),
num_tokens_txt
);
qkv_proj
.
forward
(
norm1_output
.
x
.
slice
(
0
,
i
,
i
+
1
),
{},
{},
norm_q
.
weight
,
norm_k
.
weight
,
rotary_emb
,
sliceImg
(
concat_q
),
sliceImg
(
concat_k
),
sliceImg
(
concat_v
),
num_tokens_img
);
qkv_proj_context
.
forward
(
norm1_context_output
.
x
.
slice
(
0
,
i
,
i
+
1
),
{},
{},
norm_added_q
.
weight
,
norm_added_k
.
weight
,
rotary_emb_context
,
sliceTxt
(
concat_q
),
sliceTxt
(
concat_k
),
sliceTxt
(
concat_v
),
num_tokens_txt
);
}
debug
(
"concat_q"
,
concat_q
);
...
...
@@ -608,7 +609,9 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
nvtxRangePop
();
}
raw_attn_output
=
Tensor
::
allocate
({
batch_size
,
num_tokens_img_pad
+
num_tokens_txt_pad
,
num_heads
*
dim_head
},
norm1_output
.
x
.
scalar_type
(),
norm1_output
.
x
.
device
());
raw_attn_output
=
Tensor
::
allocate
({
batch_size
,
num_tokens_img_pad
+
num_tokens_txt_pad
,
num_heads
*
dim_head
},
norm1_output
.
x
.
scalar_type
(),
norm1_output
.
x
.
device
());
nvtxRangePushA
(
"Attention"
);
...
...
@@ -616,7 +619,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
nvtxRangePop
();
raw_attn_output
=
raw_attn_output
.
view
({
batch_size
,
num_tokens_img_pad
+
num_tokens_txt_pad
,
num_heads
,
dim_head
});
raw_attn_output
=
raw_attn_output
.
view
({
batch_size
,
num_tokens_img_pad
+
num_tokens_txt_pad
,
num_heads
,
dim_head
});
}
else
{
assert
(
false
);
}
...
...
@@ -632,25 +636,28 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
Tensor
raw_attn_output_split
;
if
(
batch_size
==
1
)
{
raw_attn_output_split
=
raw_attn_output
.
slice
(
1
,
0
,
num_tokens_img
).
reshape
({
batch_size
,
num_tokens_img
,
num_heads
*
dim_head
});
raw_attn_output_split
=
raw_attn_output
.
slice
(
1
,
0
,
num_tokens_img
).
reshape
({
batch_size
,
num_tokens_img
,
num_heads
*
dim_head
});
}
else
{
raw_attn_output_split
=
Tensor
::
allocate
({
batch_size
,
num_tokens_img
,
num_heads
*
dim_head
},
raw_attn_output
.
scalar_type
(),
raw_attn_output
.
device
());
checkCUDA
(
cudaMemcpy2DAsync
(
raw_attn_output_split
.
data_ptr
(),
num_tokens_img
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
raw_attn_output
.
data_ptr
(),
(
num_tokens_img_pad
+
num_tokens_txt_pad
)
*
num_heads
*
dim_head
*
raw_attn_output
.
scalar_size
(),
num_tokens_img
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
batch_size
,
cudaMemcpyDeviceToDevice
,
stream
));
raw_attn_output_split
=
Tensor
::
allocate
({
batch_size
,
num_tokens_img
,
num_heads
*
dim_head
},
raw_attn_output
.
scalar_type
(),
raw_attn_output
.
device
());
checkCUDA
(
cudaMemcpy2DAsync
(
raw_attn_output_split
.
data_ptr
(),
num_tokens_img
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
raw_attn_output
.
data_ptr
(),
(
num_tokens_img_pad
+
num_tokens_txt_pad
)
*
num_heads
*
dim_head
*
raw_attn_output
.
scalar_size
(),
num_tokens_img
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
batch_size
,
cudaMemcpyDeviceToDevice
,
stream
));
}
spdlog
::
debug
(
"raw_attn_output_split={}"
,
raw_attn_output_split
.
shape
.
str
());
debug
(
"img.raw_attn_output_split"
,
raw_attn_output_split
);
Tensor
attn_output
=
forward_fc
(
out_proj
,
raw_attn_output_split
);
// std::get<Tensor>(out_proj.forward(raw_attn_output_split));
Tensor
attn_output
=
forward_fc
(
out_proj
,
raw_attn_output_split
);
// std::get<Tensor>(out_proj.forward(raw_attn_output_split));
debug
(
"img.attn_output"
,
attn_output
);
#if 1
...
...
@@ -690,7 +697,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
}
if
(
context_pre_only
)
{
return
{
hidden_states
,
encoder_hidden_states
};
return
{
hidden_states
,
encoder_hidden_states
};
}
{
...
...
@@ -700,25 +707,30 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
Tensor
raw_attn_output_split
;
if
(
batch_size
==
1
)
{
raw_attn_output_split
=
raw_attn_output
.
slice
(
1
,
num_tokens_img_pad
,
num_tokens_img_pad
+
num_tokens_txt
).
reshape
({
batch_size
,
num_tokens_txt
,
num_heads
*
dim_head
});
raw_attn_output_split
=
raw_attn_output
.
slice
(
1
,
num_tokens_img_pad
,
num_tokens_img_pad
+
num_tokens_txt
)
.
reshape
({
batch_size
,
num_tokens_txt
,
num_heads
*
dim_head
});
}
else
{
raw_attn_output_split
=
Tensor
::
allocate
({
batch_size
,
num_tokens_txt
,
num_heads
*
dim_head
},
raw_attn_output
.
scalar_type
(),
raw_attn_output
.
device
());
checkCUDA
(
cudaMemcpy2DAsync
(
raw_attn_output_split
.
data_ptr
(),
num_tokens_txt
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
raw_attn_output
.
data_ptr
<
char
>
()
+
num_tokens_img_pad
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
(
num_tokens_img_pad
+
num_tokens_txt_pad
)
*
num_heads
*
dim_head
*
raw_attn_output
.
scalar_size
(),
num_tokens_txt
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
batch_size
,
cudaMemcpyDeviceToDevice
,
stream
));
raw_attn_output_split
=
Tensor
::
allocate
({
batch_size
,
num_tokens_txt
,
num_heads
*
dim_head
},
raw_attn_output
.
scalar_type
(),
raw_attn_output
.
device
());
checkCUDA
(
cudaMemcpy2DAsync
(
raw_attn_output_split
.
data_ptr
(),
num_tokens_txt
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
raw_attn_output
.
data_ptr
<
char
>
()
+
num_tokens_img_pad
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
(
num_tokens_img_pad
+
num_tokens_txt_pad
)
*
num_heads
*
dim_head
*
raw_attn_output
.
scalar_size
(),
num_tokens_txt
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
batch_size
,
cudaMemcpyDeviceToDevice
,
stream
));
}
spdlog
::
debug
(
"raw_attn_output_split={}"
,
raw_attn_output_split
.
shape
.
str
());
debug
(
"context.raw_attn_output_split"
,
raw_attn_output_split
);
Tensor
attn_output
=
forward_fc
(
out_proj_context
,
raw_attn_output_split
);
// std::get<Tensor>(out_proj_context.forward(raw_attn_output_split));
Tensor
attn_output
=
forward_fc
(
out_proj_context
,
raw_attn_output_split
);
// std::get<Tensor>(out_proj_context.forward(raw_attn_output_split));
debug
(
"context.attn_output"
,
attn_output
);
#if 1
...
...
@@ -742,9 +754,9 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
auto
norm_hidden_states
=
encoder_hidden_states
;
#endif
// Tensor ff_output = mlp_context_fc2.forward(GELU::forward(mlp_context_fc1.forward(norm_hidden_states)));
// Tensor ff_output = mlp_context_fc2.forward_quant(quant_static_fuse_gelu(mlp_context_fc1.forward(norm_hidden_states), 1.0));
// Tensor ff_output =
// mlp_context_fc2.forward_quant(quant_static_fuse_gelu(mlp_context_fc1.forward(norm_hidden_states), 1.0));
debug
(
"context.ff_input"
,
norm_hidden_states
);
Tensor
ff_output
=
forward_mlp
(
mlp_context_fc1
,
mlp_context_fc2
,
norm_hidden_states
);
debug
(
"context.ff_output"
,
ff_output
);
...
...
@@ -761,12 +773,14 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
nvtxRangePop
();
return
{
hidden_states
,
encoder_hidden_states
};
return
{
hidden_states
,
encoder_hidden_states
};
}
FluxModel
::
FluxModel
(
bool
use_fp4
,
bool
offload
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dtype
(
dtype
),
offload
(
offload
)
{
FluxModel
::
FluxModel
(
bool
use_fp4
,
bool
offload
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dtype
(
dtype
),
offload
(
offload
)
{
for
(
int
i
=
0
;
i
<
19
;
i
++
)
{
transformer_blocks
.
push_back
(
std
::
make_unique
<
JointTransformerBlock
>
(
3072
,
24
,
3072
,
false
,
use_fp4
,
dtype
,
device
));
transformer_blocks
.
push_back
(
std
::
make_unique
<
JointTransformerBlock
>
(
3072
,
24
,
3072
,
false
,
use_fp4
,
dtype
,
device
));
registerChildren
(
*
transformer_blocks
.
back
(),
format
(
"transformer_blocks.{}"
,
i
));
if
(
offload
&&
i
>
0
)
{
// don't offload first block
transformer_blocks
.
back
()
->
setLazyLoad
(
true
);
...
...
@@ -774,7 +788,8 @@ FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Devic
}
}
for
(
int
i
=
0
;
i
<
38
;
i
++
)
{
single_transformer_blocks
.
push_back
(
std
::
make_unique
<
FluxSingleTransformerBlock
>
(
3072
,
24
,
3072
,
4
,
use_fp4
,
dtype
,
device
));
single_transformer_blocks
.
push_back
(
std
::
make_unique
<
FluxSingleTransformerBlock
>
(
3072
,
24
,
3072
,
4
,
use_fp4
,
dtype
,
device
));
registerChildren
(
*
single_transformer_blocks
.
back
(),
format
(
"single_transformer_blocks.{}"
,
i
));
if
(
offload
)
{
single_transformer_blocks
.
back
()
->
setLazyLoad
(
true
);
...
...
@@ -783,19 +798,18 @@ FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Devic
}
}
Tensor
FluxModel
::
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb_img
,
Tensor
rotary_emb_context
,
Tensor
rotary_emb_single
,
Tensor
controlnet_block_samples
,
Tensor
controlnet_single_block_samples
,
bool
skip_first_layer
)
{
const
int
batch_size
=
hidden_states
.
shape
[
0
];
Tensor
FluxModel
::
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb_img
,
Tensor
rotary_emb_context
,
Tensor
rotary_emb_single
,
Tensor
controlnet_block_samples
,
Tensor
controlnet_single_block_samples
,
bool
skip_first_layer
)
{
const
int
batch_size
=
hidden_states
.
shape
[
0
];
const
Tensor
::
ScalarType
dtype
=
hidden_states
.
dtype
();
const
Device
device
=
hidden_states
.
device
();
const
Device
device
=
hidden_states
.
device
();
const
int
txt_tokens
=
encoder_hidden_states
.
shape
[
1
];
const
int
img_tokens
=
hidden_states
.
shape
[
1
];
...
...
@@ -805,45 +819,68 @@ Tensor FluxModel::forward(
Tensor
concat
;
auto
compute
=
[
&
](
int
layer
)
{
if
(
skip_first_layer
&&
size_t
(
layer
)
==
0
)
return
;
if
(
skip_first_layer
&&
size_t
(
layer
)
==
0
)
return
;
if
(
size_t
(
layer
)
<
transformer_blocks
.
size
())
{
auto
&
block
=
transformer_blocks
.
at
(
layer
);
std
::
tie
(
hidden_states
,
encoder_hidden_states
)
=
block
->
forward
(
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_context
,
0.0
f
);
std
::
tie
(
hidden_states
,
encoder_hidden_states
)
=
block
->
forward
(
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_context
,
0.0
f
);
if
(
controlnet_block_samples
.
valid
())
{
const
int
num_controlnet_block_samples
=
controlnet_block_samples
.
shape
[
0
];
int
interval_control
=
ceilDiv
(
transformer_blocks
.
size
(),
static_cast
<
size_t
>
(
num_controlnet_block_samples
));
int
interval_control
=
ceilDiv
(
transformer_blocks
.
size
(),
static_cast
<
size_t
>
(
num_controlnet_block_samples
));
int
block_index
=
layer
/
interval_control
;
// Xlabs ControlNet
// block_index = layer % num_controlnet_block_samples;
hidden_states
=
kernels
::
add
(
hidden_states
,
controlnet_block_samples
[
block_index
]);
}
if
(
residual_callback
&&
layer
%
2
==
0
)
{
Tensor
cpu_input
=
hidden_states
.
copy
(
Device
::
cpu
());
pybind11
::
gil_scoped_acquire
gil
;
Tensor
cpu_output
=
residual_callback
(
cpu_input
);
Tensor
residual
=
cpu_output
.
copy
(
Device
::
cuda
());
hidden_states
=
kernels
::
add
(
hidden_states
,
residual
);
}
}
else
{
if
(
size_t
(
layer
)
==
transformer_blocks
.
size
())
{
// txt first, same as diffusers
concat
=
Tensor
::
allocate
({
batch_size
,
txt_tokens
+
img_tokens
,
3072
},
dtype
,
device
);
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
concat
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
0
,
txt_tokens
).
copy_
(
encoder_hidden_states
.
slice
(
0
,
i
,
i
+
1
));
concat
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
).
copy_
(
hidden_states
.
slice
(
0
,
i
,
i
+
1
));
concat
.
slice
(
0
,
i
,
i
+
1
)
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
)
.
copy_
(
hidden_states
.
slice
(
0
,
i
,
i
+
1
));
}
hidden_states
=
concat
;
hidden_states
=
concat
;
encoder_hidden_states
=
{};
}
auto
&
block
=
single_transformer_blocks
.
at
(
layer
-
transformer_blocks
.
size
());
auto
&
block
=
single_transformer_blocks
.
at
(
layer
-
transformer_blocks
.
size
());
hidden_states
=
block
->
forward
(
hidden_states
,
temb
,
rotary_emb_single
);
if
(
controlnet_single_block_samples
.
valid
())
{
const
int
num_controlnet_single_block_samples
=
controlnet_single_block_samples
.
shape
[
0
];
int
interval_control
=
ceilDiv
(
single_transformer_blocks
.
size
(),
static_cast
<
size_t
>
(
num_controlnet_single_block_samples
));
int
interval_control
=
ceilDiv
(
single_transformer_blocks
.
size
(),
static_cast
<
size_t
>
(
num_controlnet_single_block_samples
));
int
block_index
=
(
layer
-
transformer_blocks
.
size
())
/
interval_control
;
// Xlabs ControlNet
// block_index = layer % num_controlnet_single_block_samples
auto
slice
=
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
);
slice
=
kernels
::
add
(
slice
,
controlnet_single_block_samples
[
block_index
]);
slice
=
kernels
::
add
(
slice
,
controlnet_single_block_samples
[
block_index
]);
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
).
copy_
(
slice
);
}
size_t
local_layer_idx
=
layer
-
transformer_blocks
.
size
();
if
(
residual_callback
&&
local_layer_idx
%
4
==
0
)
{
Tensor
callback_input
=
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
);
Tensor
cpu_input
=
callback_input
.
copy
(
Device
::
cpu
());
pybind11
::
gil_scoped_acquire
gil
;
Tensor
cpu_output
=
residual_callback
(
cpu_input
);
Tensor
residual
=
cpu_output
.
copy
(
Device
::
cuda
());
auto
slice
=
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
);
slice
=
kernels
::
add
(
slice
,
residual
);
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
).
copy_
(
slice
);
}
}
...
...
@@ -873,31 +910,22 @@ Tensor FluxModel::forward(
return
hidden_states
;
}
std
::
tuple
<
Tensor
,
Tensor
>
FluxModel
::
forward_layer
(
size_t
layer
,
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb_img
,
Tensor
rotary_emb_context
,
Tensor
controlnet_block_samples
,
Tensor
controlnet_single_block_samples
)
{
if
(
layer
<
transformer_blocks
.
size
()){
std
::
tuple
<
Tensor
,
Tensor
>
FluxModel
::
forward_layer
(
size_t
layer
,
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb_img
,
Tensor
rotary_emb_context
,
Tensor
controlnet_block_samples
,
Tensor
controlnet_single_block_samples
)
{
if
(
layer
<
transformer_blocks
.
size
())
{
std
::
tie
(
hidden_states
,
encoder_hidden_states
)
=
transformer_blocks
.
at
(
layer
)
->
forward
(
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_context
,
0.0
f
);
}
else
{
std
::
tie
(
hidden_states
,
encoder_hidden_states
)
=
transformer_blocks
.
at
(
layer
-
transformer_blocks
.
size
())
->
forward
(
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_context
,
0.0
f
);
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_context
,
0.0
f
);
}
else
{
std
::
tie
(
hidden_states
,
encoder_hidden_states
)
=
transformer_blocks
.
at
(
layer
-
transformer_blocks
.
size
())
->
forward
(
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_context
,
0.0
f
);
}
const
int
txt_tokens
=
encoder_hidden_states
.
shape
[
1
];
...
...
@@ -907,7 +935,7 @@ std::tuple<Tensor, Tensor> FluxModel::forward_layer(
const
int
num_controlnet_block_samples
=
controlnet_block_samples
.
shape
[
0
];
int
interval_control
=
ceilDiv
(
transformer_blocks
.
size
(),
static_cast
<
size_t
>
(
num_controlnet_block_samples
));
int
block_index
=
layer
/
interval_control
;
int
block_index
=
layer
/
interval_control
;
// Xlabs ControlNet
// block_index = layer % num_controlnet_block_samples;
...
...
@@ -915,17 +943,18 @@ std::tuple<Tensor, Tensor> FluxModel::forward_layer(
}
else
if
(
layer
>=
transformer_blocks
.
size
()
&&
controlnet_single_block_samples
.
valid
())
{
const
int
num_controlnet_single_block_samples
=
controlnet_single_block_samples
.
shape
[
0
];
int
interval_control
=
ceilDiv
(
single_transformer_blocks
.
size
(),
static_cast
<
size_t
>
(
num_controlnet_single_block_samples
));
int
interval_control
=
ceilDiv
(
single_transformer_blocks
.
size
(),
static_cast
<
size_t
>
(
num_controlnet_single_block_samples
));
int
block_index
=
(
layer
-
transformer_blocks
.
size
())
/
interval_control
;
// Xlabs ControlNet
// block_index = layer % num_controlnet_single_block_samples
auto
slice
=
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
);
slice
=
kernels
::
add
(
slice
,
controlnet_single_block_samples
[
block_index
]);
slice
=
kernels
::
add
(
slice
,
controlnet_single_block_samples
[
block_index
]);
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
).
copy_
(
slice
);
}
return
{
hidden_states
,
encoder_hidden_states
};
return
{
hidden_states
,
encoder_hidden_states
};
}
void
FluxModel
::
setAttentionImpl
(
AttentionImpl
impl
)
{
...
...
@@ -936,3 +965,6 @@ void FluxModel::setAttentionImpl(AttentionImpl impl) {
block
->
attnImpl
=
impl
;
}
}
void
FluxModel
::
set_residual_callback
(
std
::
function
<
Tensor
(
const
Tensor
&
)
>
cb
)
{
residual_callback
=
std
::
move
(
cb
);
}
src/FluxModel.h
View file @
37a27712
...
...
@@ -5,6 +5,10 @@
#include "Module.h"
#include "Linear.h"
#include "layernorm.h"
#include <pybind11/functional.h>
namespace
pybind11
{
class
function
;
}
enum
class
AttentionImpl
{
FlashAttention2
=
0
,
...
...
@@ -14,7 +18,7 @@ enum class AttentionImpl {
class
AdaLayerNormZeroSingle
:
public
Module
{
public:
static
constexpr
bool
USE_4BIT
=
true
;
using
GEMM
=
std
::
conditional_t
<
USE_4BIT
,
GEMV_AWQ
,
GEMM_W8A8
>
;
using
GEMM
=
std
::
conditional_t
<
USE_4BIT
,
GEMV_AWQ
,
GEMM_W8A8
>
;
struct
Output
{
Tensor
x
;
...
...
@@ -36,7 +40,7 @@ private:
class
AdaLayerNormZero
:
public
Module
{
public:
static
constexpr
bool
USE_4BIT
=
true
;
using
GEMM
=
std
::
conditional_t
<
USE_4BIT
,
GEMV_AWQ
,
GEMM_W8A8
>
;
using
GEMM
=
std
::
conditional_t
<
USE_4BIT
,
GEMV_AWQ
,
GEMM_W8A8
>
;
struct
Output
{
Tensor
x
;
...
...
@@ -45,6 +49,7 @@ public:
Tensor
scale_mlp
;
Tensor
gate_mlp
;
};
public:
AdaLayerNormZero
(
int
dim
,
bool
pre_only
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Output
forward
(
Tensor
x
,
Tensor
emb
);
...
...
@@ -81,9 +86,15 @@ private:
class
FluxSingleTransformerBlock
:
public
Module
{
public:
static
constexpr
bool
USE_4BIT
=
true
;
using
GEMM
=
std
::
conditional_t
<
USE_4BIT
,
GEMM_W4A4
,
GEMM_W8A8
>
;
FluxSingleTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
int
mlp_ratio
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
using
GEMM
=
std
::
conditional_t
<
USE_4BIT
,
GEMM_W4A4
,
GEMM_W8A8
>
;
FluxSingleTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
int
mlp_ratio
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
temb
,
Tensor
rotary_emb
);
public:
...
...
@@ -107,21 +118,32 @@ private:
class
JointTransformerBlock
:
public
Module
{
public:
static
constexpr
bool
USE_4BIT
=
true
;
using
GEMM
=
std
::
conditional_t
<
USE_4BIT
,
GEMM_W4A4
,
GEMM_W8A8
>
;
JointTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
bool
context_pre_only
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
std
::
tuple
<
Tensor
,
Tensor
>
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb
,
Tensor
rotary_emb_context
,
float
sparsityRatio
);
using
GEMM
=
std
::
conditional_t
<
USE_4BIT
,
GEMM_W4A4
,
GEMM_W8A8
>
;
JointTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
bool
context_pre_only
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
std
::
tuple
<
Tensor
,
Tensor
>
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb
,
Tensor
rotary_emb_context
,
float
sparsityRatio
);
public:
const
int
dim
;
const
int
dim_head
;
const
int
num_heads
;
const
bool
context_pre_only
;
AdaLayerNormZero
norm1
;
AttentionImpl
attnImpl
=
AttentionImpl
::
FlashAttention2
;
private:
AdaLayerNormZero
norm1
;
AdaLayerNormZero
norm1_context
;
GEMM
qkv_proj
;
GEMM
qkv_proj_context
;
...
...
@@ -139,33 +161,35 @@ private:
class
FluxModel
:
public
Module
{
public:
FluxModel
(
bool
use_fp4
,
bool
offload
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb_img
,
Tensor
rotary_emb_context
,
Tensor
rotary_emb_single
,
Tensor
controlnet_block_samples
,
Tensor
controlnet_single_block_samples
,
bool
skip_first_layer
=
false
);
std
::
tuple
<
Tensor
,
Tensor
>
forward_layer
(
size_t
layer
,
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb_img
,
Tensor
rotary_emb_context
,
Tensor
controlnet_block_samples
,
Tensor
controlnet_single_block_samples
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb_img
,
Tensor
rotary_emb_context
,
Tensor
rotary_emb_single
,
Tensor
controlnet_block_samples
,
Tensor
controlnet_single_block_samples
,
bool
skip_first_layer
=
false
);
std
::
tuple
<
Tensor
,
Tensor
>
forward_layer
(
size_t
layer
,
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb_img
,
Tensor
rotary_emb_context
,
Tensor
controlnet_block_samples
,
Tensor
controlnet_single_block_samples
);
void
setAttentionImpl
(
AttentionImpl
impl
);
void
set_residual_callback
(
std
::
function
<
Tensor
(
const
Tensor
&
)
>
cb
);
public:
const
Tensor
::
ScalarType
dtype
;
std
::
vector
<
std
::
unique_ptr
<
JointTransformerBlock
>>
transformer_blocks
;
std
::
vector
<
std
::
unique_ptr
<
FluxSingleTransformerBlock
>>
single_transformer_blocks
;
std
::
function
<
Tensor
(
const
Tensor
&
)
>
residual_callback
;
private:
bool
offload
;
};
\ No newline at end of file
};
src/Linear.cpp
View file @
37a27712
...
...
@@ -9,16 +9,12 @@
using
namespace
nunchaku
;
GEMM_F16
::
GEMM_F16
(
int
in_features
,
int
out_features
,
bool
use_bias
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
in_features
(
in_features
),
out_features
(
out_features
)
{
GEMM_F16
::
GEMM_F16
(
int
in_features
,
int
out_features
,
bool
use_bias
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
in_features
(
in_features
),
out_features
(
out_features
)
{
this
->
weight
=
Tensor
::
allocate
({
out_features
,
in_features
},
dtype
,
device
);
this
->
bias
=
use_bias
?
Tensor
::
allocate
({
out_features
},
dtype
,
device
)
:
Tensor
{};
this
->
bias
=
use_bias
?
Tensor
::
allocate
({
out_features
},
dtype
,
device
)
:
Tensor
{};
registerParams
(
weight
,
"weight"
,
ParamFlags
::
LazyLoad
)
(
bias
,
"bias"
)
;
registerParams
(
weight
,
"weight"
,
ParamFlags
::
LazyLoad
)(
bias
,
"bias"
);
}
Tensor
GEMM_F16
::
forward
(
Tensor
x
)
{
...
...
@@ -26,26 +22,20 @@ Tensor GEMM_F16::forward(Tensor x) {
return
out
;
}
GEMV_AWQ
::
GEMV_AWQ
(
int
in_features
,
int
out_features
,
bool
use_bias
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
in_features
(
in_features
),
out_features
(
out_features
),
group_size
(
64
),
lora_rank
(
0
),
lora_scale
(
1.0
f
),
device
(
device
)
{
GEMV_AWQ
::
GEMV_AWQ
(
int
in_features
,
int
out_features
,
bool
use_bias
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
in_features
(
in_features
),
out_features
(
out_features
),
group_size
(
64
),
lora_rank
(
0
),
lora_scale
(
1.0
f
),
device
(
device
)
{
this
->
qweight
=
Tensor
::
allocate
({
out_features
/
4
,
ceilDiv
(
in_features
,
8
)
*
4
},
Tensor
::
INT32
,
device
);
this
->
wscales
=
Tensor
::
allocate
({
ceilDiv
(
in_features
,
group_size
),
out_features
},
dtype
,
device
);
this
->
wzeros
=
Tensor
::
allocate
({
ceilDiv
(
in_features
,
group_size
),
out_features
},
dtype
,
device
);
this
->
bias
=
use_bias
?
Tensor
::
allocate
({
out_features
},
dtype
,
device
)
:
Tensor
{};
this
->
bias
=
use_bias
?
Tensor
::
allocate
({
out_features
},
dtype
,
device
)
:
Tensor
{};
// !!! lora layout is different from w4a4 !!!
this
->
lora_down
=
Tensor
::
allocate
({
lora_rank
,
in_features
},
dtype
,
device
,
true
);
this
->
lora_up
=
Tensor
::
allocate
({
out_features
,
lora_rank
},
dtype
,
device
,
true
);
registerParams
(
qweight
,
"qweight"
,
ParamFlags
::
LazyLoad
)
(
wscales
,
"wscales"
)
(
wzeros
,
"wzeros"
)
(
bias
,
"bias"
)
(
lora_down
,
"lora_down"
,
ParamFlags
::
Optional
)
(
lora_up
,
"lora_up"
,
ParamFlags
::
Optional
)
;
this
->
lora_up
=
Tensor
::
allocate
({
out_features
,
lora_rank
},
dtype
,
device
,
true
);
registerParams
(
qweight
,
"qweight"
,
ParamFlags
::
LazyLoad
)(
wscales
,
"wscales"
)(
wzeros
,
"wzeros"
)(
bias
,
"bias"
)(
lora_down
,
"lora_down"
,
ParamFlags
::
Optional
)(
lora_up
,
"lora_up"
,
ParamFlags
::
Optional
);
}
void
GEMV_AWQ
::
loadParam
(
std
::
string
key
,
Tensor
&
dst
,
Tensor
src
)
{
...
...
@@ -56,7 +46,7 @@ void GEMV_AWQ::loadParam(std::string key, Tensor &dst, Tensor src) {
Module
::
loadParam
(
key
,
dst
,
src
);
if
(
key
==
"lora_down"
)
{
const
int
new_rank
=
dst
.
shape
[
0
];
this
->
lora_rank
=
new_rank
;
this
->
lora_rank
=
new_rank
;
}
}
else
{
Module
::
loadParam
(
key
,
dst
,
src
);
...
...
@@ -70,7 +60,7 @@ Tensor GEMV_AWQ::forward(Tensor x) {
debug
(
"x"
,
x
);
const
int
M
=
(
int
)
x
.
numel
()
/
x
.
shape
[
-
1
];
Tensor
out
=
gemv_awq
(
x
,
this
->
qweight
,
this
->
wscales
,
this
->
wzeros
,
M
,
out_features
,
in_features
,
group_size
);
Tensor
out
=
gemv_awq
(
x
,
this
->
qweight
,
this
->
wscales
,
this
->
wzeros
,
M
,
out_features
,
in_features
,
group_size
);
if
(
bias
.
valid
())
{
// TODO: batch
// assert(out.numel() == bias.numel());
...
...
@@ -91,19 +81,16 @@ Tensor GEMV_AWQ::forward(Tensor x) {
}
debug
(
"out"
,
out
);
return
out
;
}
#define NO_LORA_FUSION 0
GEMM_W4A4
::
GEMM_W4A4
(
int
in_features
,
int
out_features
,
bool
bias
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
in_features
(
in_features
),
out_features
(
out_features
),
in_features_pad
(
ceilDiv
(
in_features
,
128
)
*
128
),
out_features_pad
(
ceilDiv
(
out_features
,
128
)
*
128
),
use_fp4
(
use_fp4
),
lora_rank
(
0
),
dtype
(
dtype
),
device
(
device
)
{
GEMM_W4A4
::
GEMM_W4A4
(
int
in_features
,
int
out_features
,
bool
bias
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
in_features
(
in_features
),
out_features
(
out_features
),
in_features_pad
(
ceilDiv
(
in_features
,
128
)
*
128
),
out_features_pad
(
ceilDiv
(
out_features
,
128
)
*
128
),
use_fp4
(
use_fp4
),
lora_rank
(
0
),
dtype
(
dtype
),
device
(
device
)
{
this
->
qweight
=
Tensor
::
allocate
({
out_features_pad
,
in_features_pad
/
2
},
Tensor
::
INT8
,
device
,
true
);
if
(
use_fp4
)
{
this
->
wscales
=
Tensor
::
allocate
({
in_features_pad
/
16
,
out_features_pad
},
Tensor
::
FP8_E4M3
,
device
,
true
);
...
...
@@ -114,27 +101,20 @@ GEMM_W4A4::GEMM_W4A4(int in_features, int out_features, bool bias, bool use_fp4,
this
->
bias
=
bias
?
Tensor
::
allocate
({
out_features_pad
},
dtype
,
device
,
true
)
:
Tensor
{};
this
->
lora_down
=
Tensor
::
allocate
({
in_features_pad
,
lora_rank
},
dtype
,
device
,
true
);
this
->
lora_up
=
Tensor
::
allocate
({
out_features_pad
,
lora_rank
},
dtype
,
device
,
true
);
this
->
lora_up
=
Tensor
::
allocate
({
out_features_pad
,
lora_rank
},
dtype
,
device
,
true
);
// TODO: smooth factor in non-Lora fusion
this
->
smooth
=
Tensor
::
allocate
({
in_features_pad
},
dtype
,
device
,
true
);
// FIXME: reset wtscale and wcscales to default values when reloading the weights
this
->
wtscale
=
Tensor
::
allocate
({
1
},
Tensor
::
FP32
,
Device
::
cpu
(),
true
);
this
->
wtscale
=
Tensor
::
allocate
({
1
},
Tensor
::
FP32
,
Device
::
cpu
(),
true
);
*
this
->
wtscale
.
data_ptr
<
float
>
()
=
1.0
f
;
this
->
wcscales
=
Tensor
::
allocate
({
0
},
dtype
,
device
,
true
);
registerParams
(
qweight
,
"qweight"
,
ParamFlags
::
LazyLoad
)
(
wscales
,
"wscales"
)
(
this
->
bias
,
"bias"
)
(
lora_down
,
"lora_down"
,
ParamFlags
::
Optional
)
(
lora_up
,
"lora_up"
,
ParamFlags
::
Optional
)
(
smooth
,
"smooth"
)
(
wtscale
,
"wtscale"
,
ParamFlags
::
Optional
)
(
wcscales
,
"wcscales"
,
ParamFlags
::
Optional
)
;
registerParams
(
qweight
,
"qweight"
,
ParamFlags
::
LazyLoad
)(
wscales
,
"wscales"
)(
this
->
bias
,
"bias"
)(
lora_down
,
"lora_down"
,
ParamFlags
::
Optional
)(
lora_up
,
"lora_up"
,
ParamFlags
::
Optional
)(
smooth
,
"smooth"
)(
wtscale
,
"wtscale"
,
ParamFlags
::
Optional
)(
wcscales
,
"wcscales"
,
ParamFlags
::
Optional
);
#if NO_LORA_FUSION
checkCUBLAS
(
cublasCreate
(
&
handle
));
...
...
@@ -181,11 +161,21 @@ Tensor GEMM_W4A4::forward_silu(Tensor x) {
return
std
::
get
<
Tensor
>
(
this
->
forward
(
x
,
FuseOptions
::
SILU
,
nullptr
));
}
std
::
variant
<
Tensor
,
GEMM_W4A4
::
QuantizedActivation
>
GEMM_W4A4
::
forward
(
Tensor
x
,
FuseOptions
fuse
,
GEMM_W4A4
*
nextGEMM
)
{
std
::
variant
<
Tensor
,
GEMM_W4A4
::
QuantizedActivation
>
GEMM_W4A4
::
forward
(
Tensor
x
,
FuseOptions
fuse
,
GEMM_W4A4
*
nextGEMM
)
{
return
forward_quant
(
quantize
(
x
,
false
),
fuse
,
nextGEMM
);
}
void
GEMM_W4A4
::
forward
(
Tensor
x
,
Tensor
out
,
Tensor
pool
,
Tensor
norm_q
,
Tensor
norm_k
,
Tensor
rotary_emb
,
Tensor
out_q
,
Tensor
out_k
,
Tensor
out_v
,
int
numTokens
)
{
void
GEMM_W4A4
::
forward
(
Tensor
x
,
Tensor
out
,
Tensor
pool
,
Tensor
norm_q
,
Tensor
norm_k
,
Tensor
rotary_emb
,
Tensor
out_q
,
Tensor
out_k
,
Tensor
out_v
,
int
numTokens
)
{
QuantizedActivation
qact
=
quantize
(
x
,
false
);
#if !NO_LORA_FUSION
...
...
@@ -198,42 +188,87 @@ void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor
debug("gemm.nolora.out", out);
#endif
kernels
::
gemm_w4a4
(
qact
.
act
,
qweight
,
out
,
{},
qact
.
ascales
,
wscales
,
{},
pool
,
qact
.
lora_act
,
this
->
lora_up
,
{},
{},
norm_q
,
norm_k
,
rotary_emb
,
this
->
bias
,
{},
{},
{},
qact
.
is_unsigned
,
this
->
lora_scales
,
false
,
use_fp4
,
*
this
->
wtscale
.
data_ptr
<
float
>
(),
wcscales
.
numel
()
>
0
?
wcscales
:
Tensor
{},
out_q
,
out_k
,
out_v
,
numTokens
);
kernels
::
gemm_w4a4
(
qact
.
act
,
qweight
,
out
,
{},
qact
.
ascales
,
wscales
,
{},
pool
,
qact
.
lora_act
,
this
->
lora_up
,
{},
{},
norm_q
,
norm_k
,
rotary_emb
,
this
->
bias
,
{},
{},
{},
qact
.
is_unsigned
,
this
->
lora_scales
,
false
,
use_fp4
,
*
this
->
wtscale
.
data_ptr
<
float
>
(),
wcscales
.
numel
()
>
0
?
wcscales
:
Tensor
{},
out_q
,
out_k
,
out_v
,
numTokens
);
debug
(
"gemm.out"
,
out
);
#else
const
int
M
=
(
int
)
qact
.
act
.
numel
()
/
qact
.
act
.
shape
[
-
1
];
kernels
::
gemm_w4a4
(
qact
.
act
,
qweight
,
out
,
{},
qact
.
ascales
,
wscales
,
{},
pool
,
{},
{},
{},
{},
norm_q
,
norm_k
,
rotary_emb
,
this
->
bias
,
{},
qact
.
is_unsigned
,
this
->
lora_scales
);
kernels
::
gemm_w4a4
(
qact
.
act
,
qweight
,
out
,
{},
qact
.
ascales
,
wscales
,
{},
pool
,
{},
{},
{},
{},
norm_q
,
norm_k
,
rotary_emb
,
this
->
bias
,
{},
qact
.
is_unsigned
,
this
->
lora_scales
);
nvtxRangePushA
(
"LoraUp"
);
static
const
half
one
=
1.0
;
static
const
half
one
=
1.0
;
static
const
half
zero
=
0.0
;
// lora_up: [M, R] * [OC, R] => [M, OC]
// cublas view: [OC, R] * [M, R]^T
checkCUBLAS
(
cublasHgemm
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
this
->
out_features
,
M
,
this
->
lora_rank
,
&
one
,
this
->
lora_up
.
data_ptr
<
half
>
(),
this
->
lora_rank
,
qact
.
lora_act
.
data_ptr
<
half
>
(),
this
->
lora_rank
,
&
one
,
out
.
data_ptr
<
half
>
(),
this
->
out_features
));
checkCUBLAS
(
cublasHgemm
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
this
->
out_features
,
M
,
this
->
lora_rank
,
&
one
,
this
->
lora_up
.
data_ptr
<
half
>
(),
this
->
lora_rank
,
qact
.
lora_act
.
data_ptr
<
half
>
(),
this
->
lora_rank
,
&
one
,
out
.
data_ptr
<
half
>
(),
this
->
out_features
));
nvtxRangePop
();
#endif
}
std
::
variant
<
Tensor
,
GEMM_W4A4
::
QuantizedActivation
>
GEMM_W4A4
::
forward_quant
(
QuantizedActivation
qact
,
FuseOptions
fuse
,
GEMM_W4A4
*
nextGEMM
)
{
std
::
variant
<
Tensor
,
GEMM_W4A4
::
QuantizedActivation
>
GEMM_W4A4
::
forward_quant
(
QuantizedActivation
qact
,
FuseOptions
fuse
,
GEMM_W4A4
*
nextGEMM
)
{
Tensor
out
;
QuantizedActivation
qout
;
...
...
@@ -246,8 +281,8 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
// auto shape = TensorShape(qact.act.shape.dataExtent);
// shape[-1] = out_features;
auto
shape
=
TensorShape
(
qact
.
actShape
.
dataExtent
);
shape
[
-
1
]
=
out_features
;
out
=
Tensor
::
allocate
(
shape
,
dtype
,
device
);
shape
[
-
1
]
=
out_features
;
out
=
Tensor
::
allocate
(
shape
,
dtype
,
device
);
}
else
{
qout
.
act
=
Tensor
::
allocate
({
M
,
out_features_pad
/
2
},
Tensor
::
INT8
,
device
);
if
(
use_fp4
)
{
...
...
@@ -255,11 +290,11 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
}
else
{
qout
.
ascales
=
Tensor
::
allocate
({
out_features_pad
/
64
,
M
},
dtype
,
device
);
}
qout
.
lora_act
=
Tensor
::
allocate
({
M
,
lora_rank
},
Tensor
::
FP32
,
device
);
qout
.
lora_act
=
Tensor
::
allocate
({
M
,
lora_rank
},
Tensor
::
FP32
,
device
);
qout
.
is_unsigned
=
!
use_fp4
;
qout
.
actShape
=
qact
.
actShape
;
qout
.
actShape
=
qact
.
actShape
;
next_lora
=
nextGEMM
->
lora_down
;
next_lora
=
nextGEMM
->
lora_down
;
next_smooth
=
nextGEMM
->
smooth
;
}
...
...
@@ -280,11 +315,35 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
}
#endif
kernels
::
gemm_w4a4
(
qact
.
act
,
qweight
,
out
,
qout
.
act
,
qact
.
ascales
,
wscales
,
qout
.
ascales
,
{},
qact
.
lora_act
,
this
->
lora_up
,
next_lora
,
qout
.
lora_act
,
{},
{},
{},
this
->
bias
,
next_smooth
,
{},
{},
qact
.
is_unsigned
,
this
->
lora_scales
,
fuse
==
FuseOptions
::
SILU
,
use_fp4
,
*
this
->
wtscale
.
data_ptr
<
float
>
(),
wcscales
.
numel
()
>
0
?
wcscales
:
Tensor
{},
{},
{},
{},
0
);
kernels
::
gemm_w4a4
(
qact
.
act
,
qweight
,
out
,
qout
.
act
,
qact
.
ascales
,
wscales
,
qout
.
ascales
,
{},
qact
.
lora_act
,
this
->
lora_up
,
next_lora
,
qout
.
lora_act
,
{},
{},
{},
this
->
bias
,
next_smooth
,
{},
{},
qact
.
is_unsigned
,
this
->
lora_scales
,
fuse
==
FuseOptions
::
SILU
,
use_fp4
,
*
this
->
wtscale
.
data_ptr
<
float
>
(),
wcscales
.
numel
()
>
0
?
wcscales
:
Tensor
{},
{},
{},
{},
0
);
if
(
fuse
==
FuseOptions
::
EMPTY
||
fuse
==
FuseOptions
::
SILU
)
{
debug
(
"gemm.out"
,
out
);
...
...
@@ -294,36 +353,55 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
debug
(
"gemm.lora_act_out"
,
qout
.
lora_act
);
}
#else
if
(
!
out
.
valid
())
{
auto
shape
=
TensorShape
(
qact
.
act
.
shape
.
dataExtent
);
shape
[
-
1
]
=
out_features
;
out
=
Tensor
::
allocate
(
shape
,
Tensor
::
FP16
,
qweight
.
device
());
shape
[
-
1
]
=
out_features
;
out
=
Tensor
::
allocate
(
shape
,
Tensor
::
FP16
,
qweight
.
device
());
}
kernels
::
gemm_w4a4
(
qact
.
act
,
qweight
,
out
,
qout
.
act
,
qact
.
ascales
,
wscales
,
qout
.
ascales
,
{},
{},
{},
{},
{},
{},
{},
{},
this
->
bias
,
next_smooth
,
qact
.
is_unsigned
,
this
->
lora_scales
);
kernels
::
gemm_w4a4
(
qact
.
act
,
qweight
,
out
,
qout
.
act
,
qact
.
ascales
,
wscales
,
qout
.
ascales
,
{},
{},
{},
{},
{},
{},
{},
{},
this
->
bias
,
next_smooth
,
qact
.
is_unsigned
,
this
->
lora_scales
);
nvtxRangePushA
(
"LoraUp"
);
static
const
half
one
=
1.0
;
static
const
half
one
=
1.0
;
static
const
half
zero
=
0.0
;
// lora_up: [M, R] * [OC, R]^T => [M, OC]
// cublas view: [R, OC]^T * [R, M] => [OC, M]
// lora_up layout wrong?
checkCUBLAS
(
cublasHgemm
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
this
->
out_features
,
M
,
this
->
lora_rank
,
&
one
,
this
->
lora_up
.
data_ptr
<
half
>
(),
this
->
lora_rank
,
qact
.
lora_act
.
data_ptr
<
half
>
(),
this
->
lora_rank
,
&
one
,
out
.
data_ptr
<
half
>
(),
this
->
out_features
));
checkCUBLAS
(
cublasHgemm
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
this
->
out_features
,
M
,
this
->
lora_rank
,
&
one
,
this
->
lora_up
.
data_ptr
<
half
>
(),
this
->
lora_rank
,
qact
.
lora_act
.
data_ptr
<
half
>
(),
this
->
lora_rank
,
&
one
,
out
.
data_ptr
<
half
>
(),
this
->
out_features
));
nvtxRangePop
();
...
...
@@ -332,18 +410,20 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
// IC is for next lora (OC of this layer)
// lora_down: [M, IC] * [IC, R] => [M, R]
// cublas view: [R, IC] * [IC, M] => [R, M]
checkCUBLAS
(
cublasHgemm
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
this
->
lora_rank
,
M
,
this
->
out_features
,
&
one
,
next_lora
.
data_ptr
<
half
>
(),
this
->
lora_rank
,
out
.
data_ptr
<
half
>
(),
this
->
out_features
,
&
zero
,
qout
.
lora_act
.
data_ptr
<
half
>
(),
this
->
lora_rank
));
checkCUBLAS
(
cublasHgemm
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
this
->
lora_rank
,
M
,
this
->
out_features
,
&
one
,
next_lora
.
data_ptr
<
half
>
(),
this
->
lora_rank
,
out
.
data_ptr
<
half
>
(),
this
->
out_features
,
&
zero
,
qout
.
lora_act
.
data_ptr
<
half
>
(),
this
->
lora_rank
));
out
=
{};
...
...
@@ -363,7 +443,7 @@ Tensor GEMM_W4A4::forward_quant(QuantizedActivation qact) {
GEMM_W4A4
::
QuantizedActivation
GEMM_W4A4
::
quantize
(
Tensor
x
,
bool
fuse_glu
)
{
const
int
actualM
=
x
.
numel
()
/
x
.
shape
[
-
1
];
const
int
M
=
ceilDiv
(
actualM
,
256
)
*
256
;
const
int
M
=
ceilDiv
(
actualM
,
256
)
*
256
;
// auto shape = TensorShape(x.shape.dataExtent);
// shape[-1] = in_features / 2;
...
...
@@ -375,39 +455,42 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
}
else
{
qact
.
ascales
=
Tensor
::
allocate
({
in_features_pad
/
64
,
M
},
dtype
,
device
);
}
qact
.
lora_act
=
Tensor
::
allocate
({
M
,
lora_rank
},
Tensor
::
FP32
,
device
);
qact
.
lora_act
=
Tensor
::
allocate
({
M
,
lora_rank
},
Tensor
::
FP32
,
device
);
qact
.
is_unsigned
=
false
;
qact
.
actShape
=
x
.
shape
.
dataExtent
;
qact
.
actShape
=
x
.
shape
.
dataExtent
;
#if !NO_LORA_FUSION
debug
(
"quantize.x"
,
x
);
debug
(
"quantize.smooth"
,
this
->
smooth
);
kernels
::
quantize_w4a4_act_fuse_lora
(
x
,
qact
.
act
,
qact
.
ascales
,
this
->
lora_down
,
qact
.
lora_act
,
this
->
smooth
,
fuse_glu
,
use_fp4
);
kernels
::
quantize_w4a4_act_fuse_lora
(
x
,
qact
.
act
,
qact
.
ascales
,
this
->
lora_down
,
qact
.
lora_act
,
this
->
smooth
,
fuse_glu
,
use_fp4
);
debug
(
"quantize.qact"
,
qact
.
act
);
debug
(
"quantize.ascales"
,
qact
.
ascales
);
debug
(
"quantize.lora_act"
,
qact
.
lora_act
);
#else
static
const
half
one
=
1.0
;
#else
static
const
half
one
=
1.0
;
static
const
half
zero
=
0.0
;
nvtxRangePushA
(
"LoraDown"
);
// lora_down: [M, IC] * [IC, R] => [M, R]
// cublas view: [R, IC] * [IC, M]
checkCUBLAS
(
cublasHgemm
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
this
->
lora_rank
,
M
,
this
->
in_features
,
&
one
,
lora_down
.
data_ptr
<
half
>
(),
this
->
lora_rank
,
x
.
data_ptr
<
half
>
(),
this
->
in_features
,
&
zero
,
qact
.
lora_act
.
data_ptr
<
half
>
(),
this
->
lora_rank
));
checkCUBLAS
(
cublasHgemm
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
this
->
lora_rank
,
M
,
this
->
in_features
,
&
one
,
lora_down
.
data_ptr
<
half
>
(),
this
->
lora_rank
,
x
.
data_ptr
<
half
>
(),
this
->
in_features
,
&
zero
,
qact
.
lora_act
.
data_ptr
<
half
>
(),
this
->
lora_rank
));
nvtxRangePop
();
...
...
@@ -418,18 +501,13 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
return
qact
;
}
GEMM_W8A8
::
GEMM_W8A8
(
int
in_features
,
int
out_features
,
bool
bias
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
in_features
(
in_features
),
out_features
(
out_features
),
dtype
(
dtype
)
{
GEMM_W8A8
::
GEMM_W8A8
(
int
in_features
,
int
out_features
,
bool
bias
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
in_features
(
in_features
),
out_features
(
out_features
),
dtype
(
dtype
)
{
this
->
qweight
=
Tensor
::
allocate
({
out_features
,
in_features
},
Tensor
::
INT8
,
device
);
this
->
wscales
=
Tensor
::
allocate
({
out_features
},
dtype
,
device
);
this
->
bias
=
bias
?
Tensor
::
allocate
({
out_features
},
dtype
,
device
,
true
)
:
Tensor
{};
this
->
bias
=
bias
?
Tensor
::
allocate
({
out_features
},
dtype
,
device
,
true
)
:
Tensor
{};
registerParams
(
qweight
,
"qweight"
,
ParamFlags
::
LazyLoad
)
(
wscales
,
"wscales"
)
(
this
->
bias
,
"bias"
)
;
registerParams
(
qweight
,
"qweight"
,
ParamFlags
::
LazyLoad
)(
wscales
,
"wscales"
)(
this
->
bias
,
"bias"
);
}
GEMM_W8A8
::
QuantizedActivation
GEMM_W8A8
::
quantize
(
Tensor
x
,
bool
fuse_glu
)
{
...
...
@@ -438,7 +516,7 @@ GEMM_W8A8::QuantizedActivation GEMM_W8A8::quantize(Tensor x, bool fuse_glu) {
if
(
fuse_glu
)
{
qshape
[
-
1
]
/=
2
;
}
qact
.
act
=
Tensor
::
allocate
(
qshape
,
Tensor
::
INT8
,
x
.
device
());
qact
.
act
=
Tensor
::
allocate
(
qshape
,
Tensor
::
INT8
,
x
.
device
());
qact
.
ascales
=
Tensor
::
allocate
({(
int
)
x
.
numel
()
/
x
.
shape
[
-
1
]},
this
->
dtype
,
x
.
device
());
debug
(
"quantize.x"
,
x
);
...
...
@@ -453,7 +531,7 @@ GEMM_W8A8::QuantizedActivation GEMM_W8A8::quantize(Tensor x, bool fuse_glu) {
Tensor
GEMM_W8A8
::
forward_quant
(
QuantizedActivation
qact
)
{
auto
shape
=
TensorShape
(
qact
.
act
.
shape
.
dataExtent
);
shape
[
-
1
]
=
out_features
;
shape
[
-
1
]
=
out_features
;
Tensor
out
=
Tensor
::
allocate
(
shape
,
this
->
dtype
,
qact
.
act
.
device
());
kernels
::
gemm_w8a8
(
qact
.
act
,
this
->
qweight
,
out
,
qact
.
ascales
,
this
->
wscales
,
this
->
bias
);
...
...
@@ -461,18 +539,13 @@ Tensor GEMM_W8A8::forward_quant(QuantizedActivation qact) {
return
out
;
}
DWCONV
::
DWCONV
(
int
in_features
,
bool
use_bias
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
in_features
(
in_features
)
{
DWCONV
::
DWCONV
(
int
in_features
,
bool
use_bias
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
in_features
(
in_features
)
{
this
->
weight
=
Tensor
::
allocate
({
in_features
,
3
,
3
,
1
},
dtype
,
device
);
this
->
bias
=
use_bias
?
Tensor
::
allocate
({
in_features
},
dtype
,
device
)
:
Tensor
{};
this
->
bias
=
use_bias
?
Tensor
::
allocate
({
in_features
},
dtype
,
device
)
:
Tensor
{};
registerParams
(
this
->
weight
,
"weight"
)
(
this
->
bias
,
"bias"
)
;
registerParams
(
this
->
weight
,
"weight"
)(
this
->
bias
,
"bias"
);
}
Tensor
DWCONV
::
forward
(
Tensor
x
)
{
return
dwconv_f16
(
x
,
this
->
weight
,
{},
this
->
bias
);
}
\ No newline at end of file
}
src/Linear.h
View file @
37a27712
...
...
@@ -37,6 +37,7 @@ public:
float
lora_scale
;
const
Device
device
;
public:
Tensor
qweight
;
Tensor
wscales
;
...
...
@@ -69,12 +70,18 @@ public:
Tensor
forward
(
Tensor
x
);
Tensor
forward_silu
(
Tensor
x
);
std
::
variant
<
Tensor
,
QuantizedActivation
>
forward
(
Tensor
x
,
FuseOptions
fuse
,
GEMM_W4A4
*
nextGEMM
=
nullptr
);
void
forward
(
Tensor
x
,
Tensor
out
,
Tensor
pool
=
{},
Tensor
norm_q
=
{},
Tensor
norm_k
=
{},
Tensor
rotary_emb
=
{},
Tensor
out_q
=
{},
Tensor
out_k
=
{},
Tensor
out_v
=
{},
int
numTokens
=
0
);
std
::
variant
<
Tensor
,
QuantizedActivation
>
forward_quant
(
QuantizedActivation
qact
,
FuseOptions
fuse
,
GEMM_W4A4
*
nextGEMM
=
nullptr
);
void
forward
(
Tensor
x
,
Tensor
out
,
Tensor
pool
=
{},
Tensor
norm_q
=
{},
Tensor
norm_k
=
{},
Tensor
rotary_emb
=
{},
Tensor
out_q
=
{},
Tensor
out_k
=
{},
Tensor
out_v
=
{},
int
numTokens
=
0
);
std
::
variant
<
Tensor
,
QuantizedActivation
>
forward_quant
(
QuantizedActivation
qact
,
FuseOptions
fuse
,
GEMM_W4A4
*
nextGEMM
=
nullptr
);
Tensor
forward_quant
(
QuantizedActivation
qact
);
public:
...
...
@@ -86,7 +93,7 @@ public:
const
int
in_features_pad
;
const
int
out_features_pad
;
const
bool
use_fp4
;
int
lora_rank
;
std
::
vector
<
float
>
lora_scales
;
// every 16 ranks share a scale
...
...
@@ -118,13 +125,16 @@ public:
Tensor
act
;
Tensor
ascales
;
};
public:
GEMM_W8A8
(
int
in_features
,
int
out_features
,
bool
bias
,
Tensor
::
ScalarType
dtype
,
Device
device
);
public:
QuantizedActivation
quantize
(
Tensor
x
,
bool
fuse_glu
);
QuantizedActivation
quantize
(
Tensor
x
,
bool
fuse_glu
);
Tensor
forward_quant
(
QuantizedActivation
qact
);
Tensor
forward
(
Tensor
x
)
{
return
forward_quant
(
quantize
(
x
,
false
));
}
Tensor
forward
(
Tensor
x
)
{
return
forward_quant
(
quantize
(
x
,
false
));
}
public:
const
int
in_features
;
...
...
@@ -149,4 +159,4 @@ public:
public:
Tensor
weight
;
Tensor
bias
;
};
\ No newline at end of file
};
src/Module.cpp
View file @
37a27712
...
...
@@ -10,8 +10,8 @@ void Module::copyWithCast(Tensor dst, Tensor src) {
nunchaku
::
kernels
::
cast
(
src
,
dst
);
}
else
{
Tensor
tmp
;
tmp
.
buffer
=
dst
.
buffer
;
tmp
.
shape
=
dst
.
shape
;
tmp
.
buffer
=
dst
.
buffer
;
tmp
.
shape
=
dst
.
shape
;
tmp
.
scalarType
=
src
.
scalarType
;
tmp
.
copy_
(
src
);
nunchaku
::
kernels
::
cast
(
tmp
,
dst
);
...
...
src/Module.h
View file @
37a27712
...
...
@@ -7,7 +7,7 @@
class
Module
{
protected:
enum
class
ParamFlags
:
int
{
None
=
0
,
None
=
0
,
Optional
=
1
,
LazyLoad
=
2
,
};
...
...
@@ -19,7 +19,7 @@ protected:
Tensor
src
;
};
struct
Param
{
Tensor
*
tensor
=
nullptr
;
Tensor
*
tensor
=
nullptr
;
ParamFlags
flags
=
ParamFlags
::
None
;
TensorLazyLoadInfo
lazyInfo
;
...
...
@@ -50,7 +50,7 @@ public:
std
::
string
getPrefix
()
const
{
std
::
string
fullName
=
getFullName
();
std
::
string
prefix
=
fullName
.
empty
()
?
""
:
fullName
+
"."
;
std
::
string
prefix
=
fullName
.
empty
()
?
""
:
fullName
+
"."
;
return
prefix
;
}
...
...
@@ -80,7 +80,7 @@ public:
continue
;
}
// keep loading params if param is not released
}
}
this
->
loadParam
(
key
,
*
param
.
tensor
,
src
);
// tensor->copy_(src);
}
...
...
@@ -99,8 +99,8 @@ public:
}
TensorLazyLoadInfo
&
lazy
=
param
.
lazyInfo
;
Tensor
&
dst
=
*
param
.
tensor
;
Tensor
src
=
lazy
.
src
;
Tensor
&
dst
=
*
param
.
tensor
;
Tensor
src
=
lazy
.
src
;
if
(
dst
.
valid
())
{
continue
;
...
...
@@ -108,7 +108,8 @@ public:
dst
=
Tensor
::
allocate
(
lazy
.
shape
,
lazy
.
type
,
lazy
.
device
);
if
(
!
src
.
valid
()
&&
!
checkFlag
(
param
.
flags
,
ParamFlags
::
Optional
))
{
throw
std
::
runtime_error
(
spdlog
::
fmt_lib
::
format
(
"Lazy load: Tensor {} has no src"
,
m
->
getPrefix
()
+
key
));
throw
std
::
runtime_error
(
spdlog
::
fmt_lib
::
format
(
"Lazy load: Tensor {} has no src"
,
m
->
getPrefix
()
+
key
));
}
m
->
loadParam
(
key
,
dst
,
src
);
}
...
...
@@ -127,14 +128,10 @@ public:
});
}
void
setLazyLoad
(
bool
val
)
{
traverse
([
val
](
Module
*
m
)
{
m
->
enabledLazyLoad
=
val
;
});
traverse
([
val
](
Module
*
m
)
{
m
->
enabledLazyLoad
=
val
;
});
}
void
setAutoCastFP16
(
bool
val
)
{
traverse
([
val
](
Module
*
m
)
{
m
->
enabledAutoCastFP16
=
val
;
});
traverse
([
val
](
Module
*
m
)
{
m
->
enabledAutoCastFP16
=
val
;
});
}
protected:
...
...
@@ -143,7 +140,8 @@ protected:
Tensor
::
FP16
,
Tensor
::
BF16
,
};
if
(
enabledAutoCastFP16
&&
dst
.
scalar_type
()
!=
src
.
scalar_type
()
&&
whitelist
.
contains
(
dst
.
scalar_type
())
&&
whitelist
.
contains
(
src
.
scalar_type
()))
{
if
(
enabledAutoCastFP16
&&
dst
.
scalar_type
()
!=
src
.
scalar_type
()
&&
whitelist
.
contains
(
dst
.
scalar_type
())
&&
whitelist
.
contains
(
src
.
scalar_type
()))
{
copyWithCast
(
dst
,
src
);
}
else
{
dst
.
copy_
(
src
);
...
...
@@ -159,7 +157,7 @@ protected:
};
ChildrenRegisterHelper
registerChildren
(
Module
&
module
,
std
::
string
name
)
{
module
.
parent
=
this
;
module
.
name
=
name
;
module
.
name
=
name
;
children
.
push_back
(
&
module
);
return
ChildrenRegisterHelper
(
*
this
);
}
...
...
@@ -174,13 +172,13 @@ protected:
ParamsRegisterHelper
registerParams
(
Tensor
&
param
,
std
::
string
name
,
ParamFlags
flags
=
ParamFlags
::
None
)
{
if
(
param
.
valid
())
{
params
[
name
].
tensor
=
&
param
;
params
[
name
].
flags
=
flags
;
params
[
name
].
flags
=
flags
;
if
(
checkFlag
(
flags
,
ParamFlags
::
LazyLoad
)
&&
param
.
valid
())
{
TensorLazyLoadInfo
&
lazy
=
params
[
name
].
lazyInfo
;
lazy
.
shape
=
param
.
shape
;
lazy
.
type
=
param
.
dtype
();
lazy
.
device
=
param
.
device
();
lazy
.
shape
=
param
.
shape
;
lazy
.
type
=
param
.
dtype
();
lazy
.
device
=
param
.
device
();
}
}
return
ParamsRegisterHelper
(
*
this
);
...
...
@@ -204,12 +202,12 @@ private:
void
copyWithCast
(
Tensor
dst
,
Tensor
src
);
public:
Module
*
parent
=
nullptr
;
Module
*
parent
=
nullptr
;
std
::
string
name
=
""
;
std
::
vector
<
Module
*>
children
;
std
::
map
<
std
::
string
,
Param
>
params
;
bool
enabledLazyLoad
=
false
;
bool
enabledLazyLoad
=
false
;
bool
enabledAutoCastFP16
=
true
;
};
...
...
@@ -226,12 +224,11 @@ struct LayerOffloadHelper {
std
::
unique_ptr
<
CUDAEventWrapper
>
eventComputeDone
;
std
::
unique_ptr
<
CUDAEventWrapper
>
eventLoadDone
;
LayerOffloadHelper
(
bool
offload
,
int
numLayers
,
func_t
funcCompute
,
func_t
funcLoad
,
func_t
funcUnload
)
:
offload
(
offload
),
numLayers
(
numLayers
),
funcCompute
(
funcCompute
),
funcLoad
(
funcLoad
),
funcUnload
(
funcUnload
)
{
LayerOffloadHelper
(
bool
offload
,
int
numLayers
,
func_t
funcCompute
,
func_t
funcLoad
,
func_t
funcUnload
)
:
offload
(
offload
),
numLayers
(
numLayers
),
funcCompute
(
funcCompute
),
funcLoad
(
funcLoad
),
funcUnload
(
funcUnload
)
{
if
(
offload
)
{
streamCompute
=
std
::
make_unique
<
CUDAStreamWrapper
>
();
streamLoad
=
std
::
make_unique
<
CUDAStreamWrapper
>
();
streamLoad
=
std
::
make_unique
<
CUDAStreamWrapper
>
();
needWorkaround
=
checkWorkaround
();
if
(
needWorkaround
)
{
...
...
@@ -280,7 +277,7 @@ private:
}
eventComputeDone
=
std
::
move
(
nextComputeDone
);
eventLoadDone
=
std
::
move
(
nextLoadDone
);
eventLoadDone
=
std
::
move
(
nextLoadDone
);
workaroundSynchronize
();
}
...
...
@@ -304,12 +301,12 @@ private:
return
false
;
}
}
#ifdef _WIN32
#ifdef _WIN32
return
true
;
#else
#else
return
false
;
#endif
#endif
}
void
workaroundFlush
()
{
if
(
!
needWorkaround
)
{
...
...
@@ -323,4 +320,4 @@ private:
}
checkCUDA
(
cudaEventSynchronize
(
eventComputeDone
->
event
));
}
};
\ No newline at end of file
};
src/SanaModel.cpp
View file @
37a27712
...
...
@@ -10,18 +10,11 @@
using
spdlog
::
fmt_lib
::
format
;
using
namespace
nunchaku
;
SanaLinearAttention
::
SanaLinearAttention
(
int
dim
,
bool
bias
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
dim_pad
(
ceilDiv
(
dim
,
128
)
*
128
),
qkv_proj
(
dim
,
dim_pad
*
3
,
bias
,
use_fp4
,
dtype
,
device
),
out_proj
(
dim_pad
,
dim
,
bias
,
use_fp4
,
dtype
,
device
),
pag_to_v
(
std
::
nullopt
)
{
registerChildren
(
qkv_proj
,
"qkv_proj"
)
(
out_proj
,
"out_proj"
)
;
SanaLinearAttention
::
SanaLinearAttention
(
int
dim
,
bool
bias
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
dim_pad
(
ceilDiv
(
dim
,
128
)
*
128
),
qkv_proj
(
dim
,
dim_pad
*
3
,
bias
,
use_fp4
,
dtype
,
device
),
out_proj
(
dim_pad
,
dim
,
bias
,
use_fp4
,
dtype
,
device
),
pag_to_v
(
std
::
nullopt
)
{
registerChildren
(
qkv_proj
,
"qkv_proj"
)(
out_proj
,
"out_proj"
);
if
(
pag
)
{
pag_to_v
.
emplace
(
dim
,
dim_pad
,
bias
,
use_fp4
,
dtype
,
device
);
...
...
@@ -33,8 +26,8 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
constexpr
int
HEAD_DIM
=
32
;
assert
(
x
.
ndims
()
==
3
);
const
int
batch_size
=
x
.
shape
[
0
];
const
int
num_tokens
=
x
.
shape
[
1
];
const
int
batch_size
=
x
.
shape
[
0
];
const
int
num_tokens
=
x
.
shape
[
1
];
const
int
num_tokens_pad
=
ceilDiv
(
num_tokens
,
256
)
*
256
;
assert
(
x
.
shape
[
2
]
==
dim
);
...
...
@@ -54,24 +47,38 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
auto
qact
=
qkv_proj
.
quantize
(
x
,
false
);
Tensor
q
=
Tensor
::
allocate
({
batch_size
,
num_tokens_pad
,
dim_pad
},
x
.
dtype
(),
x
.
device
());
Tensor
q
=
Tensor
::
allocate
({
batch_size
,
num_tokens_pad
,
dim_pad
},
x
.
dtype
(),
x
.
device
());
Tensor
vk
=
Tensor
::
allocate
({
batch_size
,
num_heads
,
HEAD_DIM
+
1
,
HEAD_DIM
},
Tensor
::
FP32
,
x
.
device
());
kernels
::
gemm_w4a4
(
qact
.
act
,
qkv_proj
.
qweight
,
{},
{},
qact
.
ascales
,
qkv_proj
.
wscales
,
{},
{},
qact
.
lora_act
,
qkv_proj
.
lora_up
,
{},
{},
{},
{},
{},
qkv_proj
.
bias
,
{},
vk
,
q
,
qact
.
is_unsigned
,
qkv_proj
.
lora_scales
,
false
,
qkv_proj
.
use_fp4
,
*
qkv_proj
.
wtscale
.
data_ptr
<
float
>
(),
qkv_proj
.
wcscales
.
numel
()
>
0
?
qkv_proj
.
wcscales
:
Tensor
{},
{},
{},
{},
0
);
kernels
::
gemm_w4a4
(
qact
.
act
,
qkv_proj
.
qweight
,
{},
{},
qact
.
ascales
,
qkv_proj
.
wscales
,
{},
{},
qact
.
lora_act
,
qkv_proj
.
lora_up
,
{},
{},
{},
{},
{},
qkv_proj
.
bias
,
{},
vk
,
q
,
qact
.
is_unsigned
,
qkv_proj
.
lora_scales
,
false
,
qkv_proj
.
use_fp4
,
*
qkv_proj
.
wtscale
.
data_ptr
<
float
>
(),
qkv_proj
.
wcscales
.
numel
()
>
0
?
qkv_proj
.
wcscales
:
Tensor
{},
{},
{},
{},
0
);
debug
(
"vk"
,
vk
);
debug
(
"q"
,
q
);
...
...
@@ -88,7 +95,6 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
q
=
q_unpad
;
}
// kernels::gemm_w8a8_fuse_litela(qact.act, qkv.qweight, q, vk, qact.ascales, qkv.wscales);
// return out_proj.forward(q);
...
...
@@ -109,14 +115,14 @@ Tensor SanaLinearAttention::forward_pag(Tensor x, bool cfg) {
if
(
cfg
)
{
assert
(
batch_size
%
3
==
0
);
x_org
=
x
.
slice
(
0
,
0
,
batch_size
*
2
/
3
);
x_ptb
=
x
.
slice
(
0
,
batch_size
*
2
/
3
,
batch_size
);
x_org
=
x
.
slice
(
0
,
0
,
batch_size
*
2
/
3
);
x_ptb
=
x
.
slice
(
0
,
batch_size
*
2
/
3
,
batch_size
);
out_org
=
out
.
slice
(
0
,
0
,
batch_size
*
2
/
3
);
out_ptb
=
out
.
slice
(
0
,
batch_size
*
2
/
3
,
batch_size
);
}
else
{
assert
(
batch_size
%
2
==
0
);
x_org
=
x
.
slice
(
0
,
0
,
batch_size
/
2
);
x_ptb
=
x
.
slice
(
0
,
batch_size
/
2
,
batch_size
);
x_org
=
x
.
slice
(
0
,
0
,
batch_size
/
2
);
x_ptb
=
x
.
slice
(
0
,
batch_size
/
2
,
batch_size
);
out_org
=
out
.
slice
(
0
,
0
,
batch_size
/
2
);
out_ptb
=
out
.
slice
(
0
,
batch_size
/
2
,
batch_size
);
}
...
...
@@ -129,17 +135,13 @@ Tensor SanaLinearAttention::forward_pag(Tensor x, bool cfg) {
return
out
;
}
MultiHeadCrossAttention
::
MultiHeadCrossAttention
(
int
num_heads
,
int
head_dim
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
num_heads
(
num_heads
),
head_dim
(
head_dim
),
q_linear
(
num_heads
*
head_dim
,
num_heads
*
head_dim
,
true
,
use_fp4
,
dtype
,
device
),
kv_linear
(
num_heads
*
head_dim
,
num_heads
*
head_dim
*
2
,
true
,
dtype
,
device
),
out_proj
(
num_heads
*
head_dim
,
num_heads
*
head_dim
,
true
,
use_fp4
,
dtype
,
device
)
{
registerChildren
(
q_linear
,
"q_linear"
)
(
kv_linear
,
"kv_linear"
)
(
out_proj
,
"out_proj"
)
;
MultiHeadCrossAttention
::
MultiHeadCrossAttention
(
int
num_heads
,
int
head_dim
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
num_heads
(
num_heads
),
head_dim
(
head_dim
),
q_linear
(
num_heads
*
head_dim
,
num_heads
*
head_dim
,
true
,
use_fp4
,
dtype
,
device
),
kv_linear
(
num_heads
*
head_dim
,
num_heads
*
head_dim
*
2
,
true
,
dtype
,
device
),
out_proj
(
num_heads
*
head_dim
,
num_heads
*
head_dim
,
true
,
use_fp4
,
dtype
,
device
)
{
registerChildren
(
q_linear
,
"q_linear"
)(
kv_linear
,
"kv_linear"
)(
out_proj
,
"out_proj"
);
}
Tensor
MultiHeadCrossAttention
::
forward
(
Tensor
x
,
Tensor
cond
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
)
{
...
...
@@ -155,22 +157,28 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens
assert
(
cu_seqlens_img
.
shape
[
0
]
==
batch_size
+
1
);
assert
(
cu_seqlens_txt
.
shape
[
0
]
==
batch_size
+
1
);
Tensor
q
=
q_linear
.
forward
(
x
).
view
({
batch_size
*
num_tokens_img
,
num_heads
,
head_dim
});
Tensor
q
=
q_linear
.
forward
(
x
).
view
({
batch_size
*
num_tokens_img
,
num_heads
,
head_dim
});
Tensor
kv
=
kv_linear
.
forward
(
cond
).
view
({
num_tokens_txt
,
num_heads
*
2
,
head_dim
});
Tensor
k
=
kv
.
slice
(
1
,
0
,
num_heads
);
Tensor
v
=
kv
.
slice
(
1
,
num_heads
,
num_heads
*
2
);
Tensor
attn_output
=
mha_varlen_fwd
(
q
,
k
,
v
,
cu_seqlens_img
,
cu_seqlens_txt
,
num_tokens_img
,
num_tokens_txt
,
0.0
f
,
pow
(
q
.
shape
[
-
1
],
(
-
0.5
)),
false
,
false
,
-
1
,
-
1
,
false
).
front
().
view
({
batch_size
,
num_tokens_img
,
num_heads
*
head_dim
});
Tensor
attn_output
=
mha_varlen_fwd
(
q
,
k
,
v
,
cu_seqlens_img
,
cu_seqlens_txt
,
num_tokens_img
,
num_tokens_txt
,
0.0
f
,
pow
(
q
.
shape
[
-
1
],
(
-
0.5
)),
false
,
false
,
-
1
,
-
1
,
false
)
.
front
()
.
view
({
batch_size
,
num_tokens_img
,
num_heads
*
head_dim
});
// Tensor attn_output = mha_fwd(q, k, v,
// 0.0f,
...
...
@@ -181,17 +189,13 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens
return
out_proj
.
forward
(
attn_output
);
}
SanaGLUMBConv
::
SanaGLUMBConv
(
int
in_features
,
int
hidden_features
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
in_features
(
in_features
),
hidden_features
(
hidden_features
),
inverted_conv
(
in_features
,
hidden_features
*
2
,
true
,
use_fp4
,
dtype
,
device
),
depth_conv
(
hidden_features
*
2
,
true
,
dtype
,
device
),
point_conv
(
hidden_features
,
in_features
,
false
,
use_fp4
,
dtype
,
device
)
{
registerChildren
(
inverted_conv
,
"inverted_conv"
)
(
depth_conv
,
"depth_conv"
)
(
point_conv
,
"point_conv"
)
;
SanaGLUMBConv
::
SanaGLUMBConv
(
int
in_features
,
int
hidden_features
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
in_features
(
in_features
),
hidden_features
(
hidden_features
),
inverted_conv
(
in_features
,
hidden_features
*
2
,
true
,
use_fp4
,
dtype
,
device
),
depth_conv
(
hidden_features
*
2
,
true
,
dtype
,
device
),
point_conv
(
hidden_features
,
in_features
,
false
,
use_fp4
,
dtype
,
device
)
{
registerChildren
(
inverted_conv
,
"inverted_conv"
)(
depth_conv
,
"depth_conv"
)(
point_conv
,
"point_conv"
);
}
Tensor
SanaGLUMBConv
::
forward
(
Tensor
x
,
int
H
,
int
W
)
{
...
...
@@ -203,33 +207,39 @@ Tensor SanaGLUMBConv::forward(Tensor x, int H, int W) {
debug
(
"inverted_conv_output"
,
x
);
x
=
depth_conv
.
forward
(
x
);
debug
(
"depth_conv_output"
,
x
);
x
=
x
.
view
({
x
.
shape
[
0
],
H
*
W
,
x
.
shape
[
-
1
]});
x
=
x
.
view
({
x
.
shape
[
0
],
H
*
W
,
x
.
shape
[
-
1
]});
auto
qact
=
point_conv
.
quantize
(
x
,
true
);
return
point_conv
.
forward_quant
(
qact
);
}
SanaLinearTransformerBlock
::
SanaLinearTransformerBlock
(
int
hidden_size
,
int
intermediate_size
,
int
num_cross_attention_heads
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
hidden_size
(
hidden_size
),
num_cross_attention_heads
(
num_cross_attention_heads
),
attn
(
hidden_size
,
false
,
pag
,
use_fp4
,
dtype
,
device
),
cross_attn
(
num_cross_attention_heads
,
hidden_size
/
num_cross_attention_heads
,
use_fp4
,
dtype
,
device
),
ff
(
hidden_size
,
intermediate_size
,
use_fp4
,
dtype
,
device
),
norm1
(
hidden_size
,
1e-6
,
false
,
dtype
,
device
),
norm2
(
hidden_size
,
1e-6
,
false
,
dtype
,
device
)
{
SanaLinearTransformerBlock
::
SanaLinearTransformerBlock
(
int
hidden_size
,
int
intermediate_size
,
int
num_cross_attention_heads
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
hidden_size
(
hidden_size
),
num_cross_attention_heads
(
num_cross_attention_heads
),
attn
(
hidden_size
,
false
,
pag
,
use_fp4
,
dtype
,
device
),
cross_attn
(
num_cross_attention_heads
,
hidden_size
/
num_cross_attention_heads
,
use_fp4
,
dtype
,
device
),
ff
(
hidden_size
,
intermediate_size
,
use_fp4
,
dtype
,
device
),
norm1
(
hidden_size
,
1e-6
,
false
,
dtype
,
device
),
norm2
(
hidden_size
,
1e-6
,
false
,
dtype
,
device
)
{
this
->
scale_shift_table
=
Tensor
::
allocate
({
6
,
hidden_size
},
dtype
,
device
);
registerChildren
(
attn
,
"attn"
)
(
cross_attn
,
"cross_attn"
)
(
ff
,
"ff"
)
;
registerChildren
(
attn
,
"attn"
)(
cross_attn
,
"cross_attn"
)(
ff
,
"ff"
);
registerParams
(
this
->
scale_shift_table
,
"scale_shift_table"
)
;
registerParams
(
this
->
scale_shift_table
,
"scale_shift_table"
);
}
Tensor
SanaLinearTransformerBlock
::
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
)
{
Tensor
SanaLinearTransformerBlock
::
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
)
{
nvtxRangePushA
(
"SanaLinearTransformerBlock"
);
...
...
@@ -257,7 +267,7 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_
{
nvtxRangePushA
(
"LinearAttention"
);
Tensor
residual
=
hidden_states
;
Tensor
residual
=
hidden_states
;
Tensor
norm_hidden_states
=
norm1
.
forward
(
hidden_states
);
kernels
::
mul_add_batch
(
norm_hidden_states
,
scale_msa
,
true
,
1
,
shift_msa
,
true
);
debug
(
"norm_hidden_states_la"
,
norm_hidden_states
);
...
...
@@ -311,9 +321,7 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_
return
hidden_states
;
}
SanaModel
::
SanaModel
(
SanaConfig
config
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
config
(
config
)
{
SanaModel
::
SanaModel
(
SanaConfig
config
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
config
(
config
)
{
const
int
inner_dim
=
config
.
num_attention_heads
*
config
.
attention_head_dim
;
for
(
int
i
=
0
;
i
<
config
.
num_layers
;
i
++
)
{
transformer_blocks
.
push_back
(
std
::
make_unique
<
SanaLinearTransformerBlock
>
(
...
...
@@ -322,20 +330,34 @@ SanaModel::SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device)
config
.
num_cross_attention_heads
,
std
::
find
(
config
.
pag_layers
.
begin
(),
config
.
pag_layers
.
end
(),
i
)
!=
config
.
pag_layers
.
end
(),
config
.
use_fp4
,
dtype
,
device
));
dtype
,
device
));
registerChildren
(
*
transformer_blocks
.
back
(),
format
(
"transformer_blocks.{}"
,
i
));
}
}
Tensor
SanaModel
::
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
,
bool
skip_first_layer
)
{
Tensor
SanaModel
::
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
,
bool
skip_first_layer
)
{
for
(
int
i
=
(
skip_first_layer
?
1
:
0
);
i
<
config
.
num_layers
;
i
++
)
{
auto
&&
block
=
transformer_blocks
[
i
];
hidden_states
=
block
->
forward
(
hidden_states
,
encoder_hidden_states
,
timestep
,
cu_seqlens_img
,
cu_seqlens_txt
,
H
,
W
,
pag
&&
std
::
find
(
config
.
pag_layers
.
begin
(),
config
.
pag_layers
.
end
(),
i
)
!=
config
.
pag_layers
.
end
(),
cfg
);
auto
&&
block
=
transformer_blocks
[
i
];
hidden_states
=
block
->
forward
(
hidden_states
,
encoder_hidden_states
,
timestep
,
cu_seqlens_img
,
cu_seqlens_txt
,
H
,
W
,
pag
&&
std
::
find
(
config
.
pag_layers
.
begin
(),
config
.
pag_layers
.
end
(),
i
)
!=
config
.
pag_layers
.
end
(),
cfg
);
}
return
hidden_states
;
}
src/SanaModel.h
View file @
37a27712
...
...
@@ -35,7 +35,7 @@ public:
private:
GEMM_W4A4
q_linear
;
GEMM_F16
kv_linear
;
GEMM_F16
kv_linear
;
GEMM_W4A4
out_proj
;
};
...
...
@@ -57,9 +57,23 @@ private:
class
SanaLinearTransformerBlock
:
public
Module
{
public:
SanaLinearTransformerBlock
(
int
hidden_size
,
int
intermediate_size
,
int
num_cross_attention_heads
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
);
SanaLinearTransformerBlock
(
int
hidden_size
,
int
intermediate_size
,
int
num_cross_attention_heads
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
);
public:
const
int
hidden_size
;
...
...
@@ -89,11 +103,20 @@ struct SanaConfig {
class
SanaModel
:
public
Module
{
public:
SanaModel
(
SanaConfig
config
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
,
bool
skip_first_layer
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
,
bool
skip_first_layer
);
public:
const
SanaConfig
config
;
public:
std
::
vector
<
std
::
unique_ptr
<
SanaLinearTransformerBlock
>>
transformer_blocks
;
};
\ No newline at end of file
};
src/Serialization.cpp
View file @
37a27712
...
...
@@ -3,14 +3,13 @@
#include <nlohmann/json.hpp>
#include <mio/mmap.hpp>
using
json
=
nlohmann
::
json
;
using
spdlog
::
fmt_lib
::
format
;
class
SafeTensors
::
MMapImpl
{
public:
virtual
~
MMapImpl
()
{}
virtual
size_t
size
()
=
0
;
virtual
size_t
size
()
=
0
;
virtual
const
char
*
data
()
=
0
;
};
...
...
@@ -55,7 +54,7 @@ private:
std
::
unique_ptr
<
Buffer
>
buffer
;
};
#ifdef __linux__
#ifdef __linux__
#include <unistd.h>
#include <fcntl.h>
...
...
@@ -97,7 +96,7 @@ private:
void
*
ptr
;
};
#else
#else
class
SafeTensors
::
MMapImplPrivate
:
public
SafeTensors
::
MMapImpl
{
public:
...
...
@@ -117,33 +116,34 @@ public:
SafeTensors
::
SafeTensors
(
const
std
::
string
&
filename
)
{
this
->
hostRegistered
=
false
;
this
->
memoryPinned
=
false
;
this
->
memoryPinned
=
false
;
auto
methodPrivate
=
[
&
]()
{
this
->
mapped
=
std
::
make_unique
<
MMapImplPrivate
>
(
filename
);
checkCUDA
(
cudaHostRegister
(
const_cast
<
char
*>
(
this
->
mapped
->
data
()),
this
->
mapped
->
size
(),
cudaHostRegisterPortable
));
checkCUDA
(
cudaHostRegister
(
const_cast
<
char
*>
(
this
->
mapped
->
data
()),
this
->
mapped
->
size
(),
cudaHostRegisterPortable
));
this
->
hostRegistered
=
true
;
this
->
memoryPinned
=
true
;
this
->
memoryPinned
=
true
;
};
auto
methodMio
=
[
&
]()
{
this
->
mapped
=
std
::
make_unique
<
MMapImplMio
>
(
filename
);
checkCUDA
(
cudaHostRegister
(
const_cast
<
char
*>
(
this
->
mapped
->
data
()),
this
->
mapped
->
size
(),
cudaHostRegisterPortable
|
cudaHostRegisterReadOnly
));
checkCUDA
(
cudaHostRegister
(
const_cast
<
char
*>
(
this
->
mapped
->
data
()),
this
->
mapped
->
size
(),
cudaHostRegisterPortable
|
cudaHostRegisterReadOnly
));
this
->
hostRegistered
=
true
;
this
->
memoryPinned
=
true
;
this
->
memoryPinned
=
true
;
};
auto
methodRead
=
[
&
]()
{
this
->
mapped
=
std
::
make_unique
<
MMapImplRead
>
(
filename
,
true
);
this
->
mapped
=
std
::
make_unique
<
MMapImplRead
>
(
filename
,
true
);
this
->
memoryPinned
=
true
;
};
auto
methodReadNopin
=
[
&
]()
{
this
->
mapped
=
std
::
make_unique
<
MMapImplRead
>
(
filename
,
false
);
};
auto
methodReadNopin
=
[
&
]()
{
this
->
mapped
=
std
::
make_unique
<
MMapImplRead
>
(
filename
,
false
);
};
const
std
::
map
<
std
::
string
,
std
::
function
<
void
()
>>
methods
=
{
{
"PRIVATE"
,
methodPrivate
},
{
"MIO"
,
methodMio
},
{
"READ"
,
methodRead
},
{
"READNOPIN"
,
methodReadNopin
},
{
"PRIVATE"
,
methodPrivate
},
{
"MIO"
,
methodMio
},
{
"READ"
,
methodRead
},
{
"READNOPIN"
,
methodReadNopin
},
};
auto
tryMethod
=
[
&
](
std
::
string
name
)
{
...
...
@@ -168,7 +168,6 @@ SafeTensors::SafeTensors(const std::string &filename) {
#else
tryMethod
(
"MIO"
)
||
tryMethod
(
"READ"
)
||
tryMethod
(
"READNOPIN"
);
#endif
}
if
(
!
this
->
mapped
)
{
...
...
@@ -192,19 +191,20 @@ SafeTensors::~SafeTensors() {
void
SafeTensors
::
parseHeader
()
{
static
const
std
::
unordered_map
<
std
::
string
,
Tensor
::
ScalarType
>
mapDType
=
{
{
"BF16"
,
Tensor
::
BF16
},
{
"F16"
,
Tensor
::
FP16
},
{
"F32"
,
Tensor
::
FP32
},
{
"I8"
,
Tensor
::
INT8
},
{
"I32"
,
Tensor
::
INT32
},
{
"I64"
,
Tensor
::
INT64
},
{
"F8_E4M3"
,
Tensor
::
FP8_E4M3
},
{
"F8_E5M2"
,
Tensor
::
FP8_E5M2
},
{
"BF16"
,
Tensor
::
BF16
},
{
"F16"
,
Tensor
::
FP16
},
{
"F32"
,
Tensor
::
FP32
},
{
"I8"
,
Tensor
::
INT8
},
{
"I32"
,
Tensor
::
INT32
},
{
"I64"
,
Tensor
::
INT64
},
{
"F8_E4M3"
,
Tensor
::
FP8_E4M3
},
{
"F8_E5M2"
,
Tensor
::
FP8_E5M2
},
};
auto
check
=
[](
bool
cond
,
std
::
source_location
location
=
std
::
source_location
::
current
())
{
if
(
!
cond
)
{
throw
std
::
runtime_error
(
format
(
"Safetensors check failed at {}:{}"
,
location
.
file_name
(),
location
.
line
()));
throw
std
::
runtime_error
(
format
(
"Safetensors check failed at {}:{}"
,
location
.
file_name
(),
location
.
line
()));
}
};
...
...
@@ -222,8 +222,9 @@ void SafeTensors::parseHeader() {
continue
;
}
auto
dtype
=
mapDType
.
at
(
info
[
"dtype"
].
get
<
std
::
string
>
());;
auto
shape
=
info
[
"shape"
].
get
<
std
::
vector
<
int
>>
();
auto
dtype
=
mapDType
.
at
(
info
[
"dtype"
].
get
<
std
::
string
>
());
;
auto
shape
=
info
[
"shape"
].
get
<
std
::
vector
<
int
>>
();
auto
data_offsets
=
info
[
"data_offsets"
].
get
<
std
::
vector
<
uint64_t
>>
();
check
(
data_offsets
.
size
()
==
2
);
...
...
@@ -235,8 +236,8 @@ void SafeTensors::parseHeader() {
}
TensorInfo
tinfo
;
tinfo
.
type
=
dtype
;
tinfo
.
shape
=
TensorShape
(
shape
);
tinfo
.
type
=
dtype
;
tinfo
.
shape
=
TensorShape
(
shape
);
tinfo
.
length
=
data_offsets
[
1
]
-
data_offsets
[
0
];
tinfo
.
offset
=
8
+
sizeHeader
+
data_offsets
[
0
];
...
...
@@ -258,15 +259,15 @@ Tensor SafeTensors::getTensor(const std::string &key) {
std
::
shared_ptr
<
BufferMMap
>
buffer
=
info
.
buffer
.
lock
();
if
(
!
buffer
)
{
buffer
=
std
::
make_shared
<
BufferMMap
>
(
const_cast
<
char
*>
(
this
->
mapped
->
data
()
+
info
.
offset
),
info
.
length
,
shared_from_this
());
buffer
=
std
::
make_shared
<
BufferMMap
>
(
const_cast
<
char
*>
(
this
->
mapped
->
data
()
+
info
.
offset
),
info
.
length
,
shared_from_this
());
info
.
buffer
=
buffer
;
}
Tensor
result
;
result
.
shape
=
info
.
shape
;
result
.
shape
=
info
.
shape
;
result
.
scalarType
=
info
.
type
;
result
.
buffer
=
buffer
;
result
.
buffer
=
buffer
;
return
result
;
}
src/Serialization.h
View file @
37a27712
...
...
@@ -6,15 +6,15 @@
class
BufferMMap
:
public
Buffer
{
public:
BufferMMap
(
void
*
ptr
,
size_t
size
,
std
::
shared_ptr
<
void
>
parent
)
:
parent
(
parent
)
{
this
->
size
=
size
;
this
->
size
=
size
;
this
->
device
.
type
=
Device
::
CPU
;
this
->
ptr
=
ptr
;
this
->
ptr
=
ptr
;
// auto ret = cudaHostRegister(ptr, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
// if (ret == cudaSuccess) {
// this->registered = true;
// } else {
// log(std::format("cudaHostRegister failed at {:p} (size={}): {}", ptr, size,
cudaGetErrorString(cudaGetLastError())));
// this->registered = false;
// log(std::format("cudaHostRegister failed at {:p} (size={}): {}", ptr, size,
//
cudaGetErrorString(cudaGetLastError())));
this->registered = false;
// }
}
virtual
~
BufferMMap
()
{
...
...
@@ -22,6 +22,7 @@ public:
// checkCUDA(cudaHostUnregister(ptr));
// }
}
public:
std
::
shared_ptr
<
void
>
parent
;
// bool registered;
...
...
@@ -32,7 +33,7 @@ public:
SafeTensors
(
const
std
::
string
&
filename
);
~
SafeTensors
();
virtual
bool
contains
(
const
std
::
string
&
key
)
const
override
{
virtual
bool
contains
(
const
std
::
string
&
key
)
const
override
{
return
tensors
.
contains
(
key
);
}
virtual
Tensor
getTensor
(
const
std
::
string
&
key
)
override
;
...
...
@@ -57,4 +58,4 @@ private:
std
::
unique_ptr
<
MMapImpl
>
mapped
;
bool
hostRegistered
,
memoryPinned
;
};
\ No newline at end of file
};
src/Tensor.h
View file @
37a27712
...
...
@@ -3,13 +3,10 @@
#include "common.h"
struct
Device
{
enum
Type
{
INVALID_DEVICE_TYPE
=
0
,
CPU
,
CUDA
};
enum
Type
{
INVALID_DEVICE_TYPE
=
0
,
CPU
,
CUDA
};
Type
type
=
INVALID_DEVICE_TYPE
;
int
idx
=
0
;
int
idx
=
0
;
static
constexpr
Device
cpu
(
int
idx
=
0
)
{
return
Device
{
CPU
,
idx
};
...
...
@@ -23,21 +20,29 @@ struct Device {
class
Buffer
:
public
std
::
enable_shared_from_this
<
Buffer
>
{
public:
virtual
~
Buffer
()
{}
void
*
getPtr
()
{
return
ptr
;
}
void
*
getPtr
()
{
return
ptr
;
}
template
<
typename
T
>
T
*
getPtr
()
{
return
reinterpret_cast
<
T
*>
(
ptr
);
}
T
*
getPtr
()
{
return
reinterpret_cast
<
T
*>
(
ptr
);
}
size_t
getSize
()
{
return
size
;
}
Device
getDevice
()
{
return
device
;
}
size_t
getSize
()
{
return
size
;
}
Device
getDevice
()
{
return
device
;
}
virtual
bool
isAsyncBuffer
()
{
virtual
bool
isAsyncBuffer
()
{
return
false
;
}
protected:
template
<
typename
Derived
>
template
<
typename
Derived
>
std
::
shared_ptr
<
Derived
>
shared_from_base
()
{
return
std
::
static_pointer_cast
<
Derived
>
(
shared_from_this
());
}
...
...
@@ -55,9 +60,9 @@ protected:
class
BufferMalloc
:
public
Buffer
{
public:
BufferMalloc
(
size_t
size
)
{
this
->
size
=
size
;
this
->
size
=
size
;
this
->
device
.
type
=
Device
::
CPU
;
this
->
ptr
=
malloc
(
size
);
this
->
ptr
=
malloc
(
size
);
}
virtual
~
BufferMalloc
()
{
free
(
this
->
ptr
);
...
...
@@ -67,7 +72,7 @@ public:
class
BufferHost
:
public
Buffer
{
public:
BufferHost
(
size_t
size
)
{
this
->
size
=
size
;
this
->
size
=
size
;
this
->
device
.
type
=
Device
::
CPU
;
checkCUDA
(
cudaHostAlloc
(
&
this
->
ptr
,
size
,
cudaHostAllocPortable
));
}
...
...
@@ -79,7 +84,7 @@ public:
class
BufferCUDA
:
public
Buffer
{
public:
BufferCUDA
(
size_t
size
)
{
this
->
size
=
size
;
this
->
size
=
size
;
this
->
device
.
type
=
Device
::
CUDA
;
// checkCUDA(cudaGetDevice(&this->device.idx));
this
->
device
.
idx
=
CUDADeviceContext
::
getDevice
();
...
...
@@ -96,7 +101,7 @@ public:
}
checkCUDA
(
cudaFreeAsync
(
this
->
ptr
,
getCurrentCUDAStream
()));
}
virtual
bool
isAsyncBuffer
()
override
{
virtual
bool
isAsyncBuffer
()
override
{
return
true
;
}
};
...
...
@@ -104,7 +109,7 @@ public:
class
BufferCUDASync
:
public
Buffer
{
public:
BufferCUDASync
(
size_t
size
)
{
this
->
size
=
size
;
this
->
size
=
size
;
this
->
device
.
type
=
Device
::
CUDA
;
checkCUDA
(
cudaGetDevice
(
&
this
->
device
.
idx
));
checkCUDA
(
cudaMalloc
(
&
this
->
ptr
,
size
));
...
...
@@ -118,8 +123,8 @@ class BufferView : public Buffer {
public:
BufferView
(
std
::
shared_ptr
<
Buffer
>
reference
,
size_t
offset
,
size_t
size
)
:
reference
(
reference
)
{
assert
(
offset
+
size
<=
reference
->
getSize
());
this
->
ptr
=
(
void
*
)((
std
::
uint8_t
*
)
reference
->
getPtr
()
+
offset
);
this
->
size
=
size
;
this
->
ptr
=
(
void
*
)((
std
::
uint8_t
*
)
reference
->
getPtr
()
+
offset
);
this
->
size
=
size
;
this
->
device
=
reference
->
getDevice
();
}
...
...
@@ -213,23 +218,31 @@ struct TensorShape {
}
};
class
Tensor
{
public:
enum
ScalarType
{
INVALID_SCALAR_TYPE
,
INT8
,
INT16
,
INT32
,
INT64
,
FP16
,
FP32
,
BF16
,
FP8_E4M3
,
FP8_E5M2
,
INT8
,
INT16
,
INT32
,
INT64
,
FP16
,
FP32
,
BF16
,
FP8_E4M3
,
FP8_E5M2
,
};
struct
TensorOptions
{
Device
device_
;
ScalarType
dtype_
;
Device
device
()
const
{
return
device_
;
}
ScalarType
dtype
()
const
{
return
dtype_
;
}
Device
device
()
const
{
return
device_
;
}
ScalarType
dtype
()
const
{
return
dtype_
;
}
TensorOptions
device
(
Device
dev
)
const
{
TensorOptions
result
(
*
this
);
...
...
@@ -244,56 +257,95 @@ public:
};
static
const
std
::
map
<
ScalarType
,
size_t
>
scalarSize
;
public:
TensorShape
shape
;
ScalarType
scalarType
;
std
::
shared_ptr
<
Buffer
>
buffer
;
public:
bool
valid
()
const
{
return
shape
.
dataExtent
.
size
()
>
0
;
}
int
size
(
int
dim
)
const
{
return
shape
[
dim
];
}
bool
is_contiguous
()
const
{
return
shape
.
is_contiguous
();
}
std
::
vector
<
int
>
sizes
()
const
{
return
shape
.
dataExtent
;
}
bool
valid
()
const
{
return
shape
.
dataExtent
.
size
()
>
0
;
}
int
size
(
int
dim
)
const
{
return
shape
[
dim
];
}
bool
is_contiguous
()
const
{
return
shape
.
is_contiguous
();
}
std
::
vector
<
int
>
sizes
()
const
{
return
shape
.
dataExtent
;
}
bool
is_cuda
()
const
{
return
device
().
type
==
Device
::
CUDA
;
}
bool
is_cuda
()
const
{
return
device
().
type
==
Device
::
CUDA
;
}
TensorOptions
options
()
const
{
return
TensorOptions
{
device
(),
dtype
()};
}
int
get_device
()
const
{
return
device
().
idx
;
}
TensorOptions
options
()
const
{
return
TensorOptions
{
device
(),
dtype
()};
}
int
get_device
()
const
{
return
device
().
idx
;
}
template
<
typename
T
>
T
*
data_ptr
()
{
return
reinterpret_cast
<
T
*>
(
data_ptr
());
}
T
*
data_ptr
()
{
return
reinterpret_cast
<
T
*>
(
data_ptr
());
}
template
<
typename
T
>
const
T
*
data_ptr
()
const
{
return
reinterpret_cast
<
const
T
*>
(
data_ptr
());
}
const
void
*
data_ptr
()
const
{
return
buffer
->
getPtr
<
char
>
()
+
shape
.
offset
*
scalar_size
();
}
void
*
data_ptr
()
{
return
buffer
->
getPtr
<
char
>
()
+
shape
.
offset
*
scalar_size
();
}
const
T
*
data_ptr
()
const
{
return
reinterpret_cast
<
const
T
*>
(
data_ptr
());
}
Device
device
()
const
{
return
buffer
->
getDevice
();
}
const
void
*
data_ptr
()
const
{
return
buffer
->
getPtr
<
char
>
()
+
shape
.
offset
*
scalar_size
();
}
void
*
data_ptr
()
{
return
buffer
->
getPtr
<
char
>
()
+
shape
.
offset
*
scalar_size
();
}
ScalarType
scalar_type
()
const
{
return
scalarType
;
}
ScalarType
dtype
()
const
{
return
scalar_type
();
}
Device
device
()
const
{
return
buffer
->
getDevice
();
}
ScalarType
scalar_type
()
const
{
return
scalarType
;
}
ScalarType
dtype
()
const
{
return
scalar_type
();
}
size_t
stride
(
int
dim
)
const
{
return
shape
.
stride
(
dim
);
}
size_t
stride
(
int
dim
)
const
{
return
shape
.
stride
(
dim
);
}
size_t
numel
()
const
{
return
shape
.
size
();
}
size_t
ndims
()
const
{
return
shape
.
ndims
();
}
size_t
numel
()
const
{
return
shape
.
size
();
}
size_t
ndims
()
const
{
return
shape
.
ndims
();
}
size_t
dim
()
const
{
return
ndims
();
}
size_t
dim
()
const
{
return
ndims
();
}
size_t
scalar_size
()
const
{
return
scalarSize
.
at
(
scalarType
);
}
size_t
scalar_size
()
const
{
return
scalarSize
.
at
(
scalarType
);
}
Tensor
operator
[](
int
idx
)
const
{
assert
(
ndims
()
>
1
);
Tensor
result
;
result
.
shape
=
std
::
vector
<
int
>
(
this
->
shape
.
dataExtent
.
begin
()
+
1
,
this
->
shape
.
dataExtent
.
end
());
size_t
size
=
stride
(
0
)
*
scalar_size
();
result
.
buffer
=
std
::
make_shared
<
BufferView
>
(
this
->
buffer
,
idx
*
size
,
size
);
result
.
shape
=
std
::
vector
<
int
>
(
this
->
shape
.
dataExtent
.
begin
()
+
1
,
this
->
shape
.
dataExtent
.
end
());
size_t
size
=
stride
(
0
)
*
scalar_size
();
result
.
buffer
=
std
::
make_shared
<
BufferView
>
(
this
->
buffer
,
idx
*
size
,
size
);
result
.
scalarType
=
this
->
scalarType
;
return
result
;
}
template
<
typename
T
>
const
T
&
at
(
const
std
::
vector
<
int
>
&
idx
)
const
{
const
T
&
at
(
const
std
::
vector
<
int
>
&
idx
)
const
{
assert
(
ndims
()
==
idx
.
size
());
int64_t
offset
=
0
;
for
(
size_t
i
=
0
;
i
<
ndims
();
i
++
)
{
...
...
@@ -304,17 +356,17 @@ public:
}
template
<
typename
T
>
T
&
at
(
const
std
::
vector
<
int
>
&
idx
)
{
T
&
at
(
const
std
::
vector
<
int
>
&
idx
)
{
return
const_cast
<
T
&>
(
const_cast
<
const
Tensor
*>
(
this
)
->
at
<
T
>
(
idx
));
}
Tensor
slice
(
int
dim
,
int
from
,
int
to
)
const
{
assert
(
from
<=
to
);
Tensor
result
;
result
.
buffer
=
this
->
buffer
;
result
.
buffer
=
this
->
buffer
;
result
.
scalarType
=
this
->
scalarType
;
result
.
shape
=
TensorShape
(
this
->
shape
.
dataExtent
);
result
.
shape
=
TensorShape
(
this
->
shape
.
dataExtent
);
result
.
shape
[
dim
]
=
to
-
from
;
result
.
shape
.
dataStride
.
resize
(
result
.
shape
.
ndims
());
for
(
int
i
=
0
;
i
<
result
.
shape
.
ndims
();
i
++
)
{
...
...
@@ -326,7 +378,7 @@ public:
}
Tensor
transpose
(
int
dim1
,
int
dim2
)
const
{
Tensor
result
;
result
.
buffer
=
this
->
buffer
;
result
.
buffer
=
this
->
buffer
;
result
.
scalarType
=
this
->
scalarType
;
result
.
shape
=
TensorShape
(
this
->
shape
.
dataExtent
);
...
...
@@ -346,9 +398,9 @@ public:
assert
(
shape
.
size
()
==
this
->
shape
.
size
());
assert
(
this
->
is_contiguous
());
Tensor
result
;
result
.
buffer
=
this
->
buffer
;
result
.
scalarType
=
this
->
scalarType
;
result
.
shape
=
shape
;
result
.
buffer
=
this
->
buffer
;
result
.
scalarType
=
this
->
scalarType
;
result
.
shape
=
shape
;
result
.
shape
.
offset
=
this
->
shape
.
offset
;
return
result
;
}
...
...
@@ -363,7 +415,8 @@ public:
Tensor
&
zero_
()
{
assert
(
this
->
is_contiguous
());
checkCUDA
(
cudaMemsetAsync
(
data_ptr
<
char
>
()
+
shape
.
offset
*
scalar_size
(),
0
,
shape
.
size
()
*
scalar_size
(),
getCurrentCUDAStream
()));
checkCUDA
(
cudaMemsetAsync
(
data_ptr
<
char
>
()
+
shape
.
offset
*
scalar_size
(),
0
,
shape
.
size
()
*
scalar_size
(),
getCurrentCUDAStream
()));
return
*
this
;
}
Tensor
&
copy_
(
Tensor
other
)
{
...
...
@@ -380,23 +433,17 @@ public:
}
if
(
this
->
device
().
type
==
Device
::
CPU
&&
other
.
device
().
type
==
Device
::
CPU
)
{
memcpy
(
data_ptr
<
char
>
(),
other
.
data_ptr
<
char
>
(),
shape
.
size
()
*
scalar_size
()
);
memcpy
(
data_ptr
<
char
>
(),
other
.
data_ptr
<
char
>
(),
shape
.
size
()
*
scalar_size
());
return
*
this
;
}
lockBuffer
(
this
->
buffer
,
getCurrentCUDAStream
());
lockBuffer
(
other
.
buffer
,
getCurrentCUDAStream
());
checkCUDA
(
cudaMemcpyAsync
(
data_ptr
<
char
>
(),
other
.
data_ptr
<
char
>
(),
shape
.
size
()
*
scalar_size
(),
getCopyKind
(
this
->
device
(),
other
.
device
()),
getCurrentCUDAStream
()
));
checkCUDA
(
cudaMemcpyAsync
(
data_ptr
<
char
>
(),
other
.
data_ptr
<
char
>
(),
shape
.
size
()
*
scalar_size
(),
getCopyKind
(
this
->
device
(),
other
.
device
()),
getCurrentCUDAStream
()));
return
*
this
;
}
...
...
@@ -425,14 +472,15 @@ public:
assert
(
false
);
}
result
.
scalarType
=
scalarType
;
result
.
shape
=
shape
;
result
.
shape
=
shape
;
if
(
fill
)
{
if
(
device
.
type
==
Device
::
CPU
)
{
memset
(
result
.
buffer
->
getPtr
(),
0xCC
,
result
.
buffer
->
getSize
());
}
else
if
(
device
.
type
==
Device
::
CUDA
)
{
CUDADeviceContext
ctx
(
device
.
idx
);
checkCUDA
(
cudaMemsetAsync
(
result
.
buffer
->
getPtr
(),
0xCC
,
result
.
buffer
->
getSize
(),
getCurrentCUDAStream
()));
checkCUDA
(
cudaMemsetAsync
(
result
.
buffer
->
getPtr
(),
0xCC
,
result
.
buffer
->
getSize
(),
getCurrentCUDAStream
()));
}
}
...
...
@@ -450,11 +498,12 @@ public:
checkCUDA
(
cudaMemsetAsync
(
result
.
buffer
->
getPtr
(),
1
,
result
.
buffer
->
getSize
(),
getCurrentCUDAStream
()));
return
result
;
}
static
Tensor
allocate_view
(
TensorShape
shape
,
ScalarType
scalarType
,
std
::
shared_ptr
<
Buffer
>
buffer
,
size_t
offset
=
0
)
{
static
Tensor
allocate_view
(
TensorShape
shape
,
ScalarType
scalarType
,
std
::
shared_ptr
<
Buffer
>
buffer
,
size_t
offset
=
0
)
{
Tensor
result
;
result
.
buffer
=
std
::
make_shared
<
BufferView
>
(
buffer
,
offset
,
shape
.
size
()
*
scalarSize
.
at
(
scalarType
));
result
.
buffer
=
std
::
make_shared
<
BufferView
>
(
buffer
,
offset
,
shape
.
size
()
*
scalarSize
.
at
(
scalarType
));
result
.
scalarType
=
scalarType
;
result
.
shape
=
shape
;
result
.
shape
=
shape
;
return
result
;
}
...
...
@@ -468,13 +517,16 @@ public:
// lockBuffer(this->buffer, getCurrentCUDAStream());
// lockBuffer(result.buffer, getCurrentCUDAStream());
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(), cudaMemcpyDefault, getCurrentCUDAStream()));
// if (this->device().type == Device::CPU && device.type == Device::CUDA) {
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(), cudaMemcpyHostToDevice, getCurrentCUDAStream()));
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(), cudaMemcpyDefault,
// getCurrentCUDAStream())); if (this->device().type == Device::CPU && device.type == Device::CUDA) {
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(),
// cudaMemcpyHostToDevice, getCurrentCUDAStream()));
// } else if (this->device().type == Device::CUDA && device.type == Device::CPU) {
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(), cudaMemcpyDeviceToHost, getCurrentCUDAStream()));
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(),
// cudaMemcpyDeviceToHost, getCurrentCUDAStream()));
// } else {
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(), cudaMemcpyDefault, getCurrentCUDAStream()));
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(),
// cudaMemcpyDefault, getCurrentCUDAStream()));
// }
return
result
;
}
...
...
@@ -516,9 +568,10 @@ private:
// }
static
inline
std
::
map
<
cudaStream_t
,
std
::
set
<
std
::
shared_ptr
<
Buffer
>>>
lockedBuffers
;
public:
// before launching an async operation, make sure to lock the buffer in case the buffer is freed before GPU completes
// before launching an async operation, make sure to lock the buffer in case the buffer is freed before GPU
// completes
static
void
lockBuffer
(
std
::
shared_ptr
<
Buffer
>
buffer
,
cudaStream_t
stream
)
{
if
(
!
buffer
->
isAsyncBuffer
())
{
lockedBuffers
[
stream
].
insert
(
buffer
);
...
...
@@ -558,5 +611,5 @@ inline const std::map<Tensor::ScalarType, size_t> Tensor::scalarSize = {
struct
TensorsProvider
{
virtual
~
TensorsProvider
()
{}
virtual
bool
contains
(
const
std
::
string
&
key
)
const
=
0
;
virtual
Tensor
getTensor
(
const
std
::
string
&
key
)
=
0
;
};
\ No newline at end of file
virtual
Tensor
getTensor
(
const
std
::
string
&
key
)
=
0
;
};
src/activation.cpp
View file @
37a27712
...
...
@@ -22,13 +22,15 @@ Tensor GELU::forward(Tensor x) {
// return out;
// }
// Tensor SiluAndMulQuant::forward_with_act_sum(Tensor x, Tensor quantized_mlp_act_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer) {
// Tensor SiluAndMulQuant::forward_with_act_sum(Tensor x, Tensor quantized_mlp_act_buffer, Tensor
// quantized_scale_buffer, Tensor quantized_sum_buffer) {
// Tensor out = SiluAndMul::forward(x);
// invoke_quant_fuse_sum(quantized_mlp_act_buffer, out, quantized_sum_buffer, quantized_scale_buffer);
// return out;
// }
// Tensor SiluAndMulQuant::forward_wo_act_sum(Tensor x, Tensor quantized_mlp_act_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer) {
// Tensor SiluAndMulQuant::forward_wo_act_sum(Tensor x, Tensor quantized_mlp_act_buffer, Tensor quantized_scale_buffer,
// Tensor quantized_sum_buffer) {
// Tensor out = SiluAndMul::forward(x);
// invoke_quant(quantized_mlp_act_buffer, out, quantized_scale_buffer, {});
// return out;
...
...
Prev
1
2
3
4
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment