Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
37a27712
Unverified
Commit
37a27712
authored
May 01, 2025
by
Muyang Li
Committed by
GitHub
May 01, 2025
Browse files
Merge pull request #340 from mit-han-lab/dev
feat: support PuLID, Double FBCache and TeaCache; better linter
parents
c1d6fc84
760ab022
Changes
192
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
964 additions
and
726 deletions
+964
-726
scripts/build_docker.sh
scripts/build_docker.sh
+1
-1
scripts/build_docker_torch27.sh
scripts/build_docker_torch27.sh
+1
-1
scripts/build_docker_torch28.sh
scripts/build_docker_torch28.sh
+1
-1
scripts/build_linux_wheel.sh
scripts/build_linux_wheel.sh
+1
-1
scripts/build_linux_wheel_cu128.sh
scripts/build_linux_wheel_cu128.sh
+1
-1
scripts/build_linux_wheel_torch2.7_cu128.sh
scripts/build_linux_wheel_torch2.7_cu128.sh
+1
-1
scripts/linux_cleanup.sh
scripts/linux_cleanup.sh
+1
-1
setup.py
setup.py
+1
-1
src/FluxModel.cpp
src/FluxModel.cpp
+300
-268
src/FluxModel.h
src/FluxModel.h
+54
-30
src/Linear.cpp
src/Linear.cpp
+216
-143
src/Linear.h
src/Linear.h
+20
-10
src/Module.cpp
src/Module.cpp
+2
-2
src/Module.h
src/Module.h
+28
-31
src/SanaModel.cpp
src/SanaModel.cpp
+121
-99
src/SanaModel.h
src/SanaModel.h
+29
-6
src/Serialization.cpp
src/Serialization.cpp
+37
-36
src/Serialization.h
src/Serialization.h
+7
-6
src/Tensor.h
src/Tensor.h
+138
-85
src/activation.cpp
src/activation.cpp
+4
-2
No files found.
scripts/build_docker.sh
View file @
37a27712
scripts/build_docker_torch27.sh
View file @
37a27712
scripts/build_docker_torch28.sh
View file @
37a27712
scripts/build_linux_wheel.sh
View file @
37a27712
scripts/build_linux_wheel_cu128.sh
View file @
37a27712
scripts/build_linux_wheel_torch2.7_cu128.sh
View file @
37a27712
scripts/linux_cleanup.sh
View file @
37a27712
setup.py
View file @
37a27712
...
...
@@ -6,7 +6,7 @@ import sys
import
setuptools
import
torch
from
packaging
import
version
as
packaging_version
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDA_HOME
,
CUDAExtension
from
torch.utils.cpp_extension
import
CUDA_HOME
,
BuildExtension
,
CUDAExtension
class
CustomBuildExtension
(
BuildExtension
):
...
...
src/FluxModel.cpp
View file @
37a27712
...
...
@@ -4,20 +4,18 @@
#include "kernels/zgemm/zgemm.h"
#include "flash_api.h"
#include "activation.h"
#include <nvtx3/nvToolsExt.h>
#include <pybind11/functional.h>
#include <iostream>
using
spdlog
::
fmt_lib
::
format
;
using
namespace
nunchaku
;
Tensor
forward_mlp
(
GEMM_W4A4
&
fc1
,
GEMM_W4A4
&
fc2
,
Tensor
norm_hidden_states
)
{
Tensor
ff_output
=
fc2
.
forward_quant
(
std
::
get
<
GEMM_W4A4
::
QuantizedActivation
>
(
fc1
.
forward
(
norm_hidden_states
,
GEMM_W4A4
::
FuseOptions
::
GELU_QUANT
,
&
fc2
))
);
Tensor
ff_output
=
fc2
.
forward_quant
(
std
::
get
<
GEMM_W4A4
::
QuantizedActivation
>
(
fc1
.
forward
(
norm_hidden_states
,
GEMM_W4A4
::
FuseOptions
::
GELU_QUANT
,
&
fc2
)));
return
ff_output
;
}
...
...
@@ -26,7 +24,6 @@ Tensor forward_mlp(GEMM_W4A4 &fc1, GEMM_W4A4 &fc2, Tensor norm_hidden_states) {
// return ff_output;
// }
Tensor
forward_fc
(
GEMM_W4A4
&
fc
,
Tensor
x
)
{
return
fc
.
forward
(
x
);
// return std::get<Tensor>(fc.forward(x));
...
...
@@ -36,16 +33,9 @@ Tensor forward_fc(GEMM_W4A4 &fc, Tensor x) {
// return fc.forward(x);
// }
AdaLayerNormZeroSingle
::
AdaLayerNormZeroSingle
(
int
dim
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
linear
(
dim
,
3
*
dim
,
true
,
dtype
,
device
),
norm
(
dim
,
1e-6
,
false
,
dtype
,
device
)
{
registerChildren
(
linear
,
"linear"
)
(
norm
,
"norm"
)
;
AdaLayerNormZeroSingle
::
AdaLayerNormZeroSingle
(
int
dim
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
linear
(
dim
,
3
*
dim
,
true
,
dtype
,
device
),
norm
(
dim
,
1e-6
,
false
,
dtype
,
device
)
{
registerChildren
(
linear
,
"linear"
)(
norm
,
"norm"
);
}
AdaLayerNormZeroSingle
::
Output
AdaLayerNormZeroSingle
::
forward
(
Tensor
x
,
Tensor
emb
)
{
...
...
@@ -65,15 +55,10 @@ AdaLayerNormZeroSingle::Output AdaLayerNormZeroSingle::forward(Tensor x, Tensor
return
Output
{
norm_x
,
gate_msa
};
}
AdaLayerNormZero
::
AdaLayerNormZero
(
int
dim
,
bool
pre_only
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
pre_only
(
pre_only
),
linear
(
dim
,
pre_only
?
2
*
dim
:
6
*
dim
,
true
,
dtype
,
device
),
norm
(
dim
,
1e-6
,
false
,
dtype
,
device
)
{
registerChildren
(
linear
,
"linear"
)
(
norm
,
"norm"
)
;
AdaLayerNormZero
::
AdaLayerNormZero
(
int
dim
,
bool
pre_only
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
pre_only
(
pre_only
),
linear
(
dim
,
pre_only
?
2
*
dim
:
6
*
dim
,
true
,
dtype
,
device
),
norm
(
dim
,
1e-6
,
false
,
dtype
,
device
)
{
registerChildren
(
linear
,
"linear"
)(
norm
,
"norm"
);
}
AdaLayerNormZero
::
Output
AdaLayerNormZero
::
forward
(
Tensor
x
,
Tensor
emb
)
{
...
...
@@ -110,10 +95,8 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
}
}
Attention
::
Attention
(
int
num_heads
,
int
dim_head
,
Device
device
)
:
num_heads
(
num_heads
),
dim_head
(
dim_head
),
force_fp16
(
false
)
{
Attention
::
Attention
(
int
num_heads
,
int
dim_head
,
Device
device
)
:
num_heads
(
num_heads
),
dim_head
(
dim_head
),
force_fp16
(
false
)
{
headmask_type
=
Tensor
::
allocate
({
num_heads
},
Tensor
::
INT32
,
Device
::
cpu
());
for
(
int
i
=
0
;
i
<
num_heads
;
i
++
)
{
headmask_type
.
data_ptr
<
int32_t
>
()[
i
]
=
i
+
1
;
...
...
@@ -134,11 +117,7 @@ Tensor Attention::forward(Tensor qkv) {
Tensor
k
=
reshaped
.
slice
(
2
,
num_heads
,
num_heads
*
2
);
Tensor
v
=
reshaped
.
slice
(
2
,
num_heads
*
2
,
num_heads
*
3
);
Tensor
raw_attn_output
=
mha_fwd
(
q
,
k
,
v
,
0.0
f
,
pow
(
q
.
shape
[
-
1
],
(
-
0.5
)),
false
,
-
1
,
-
1
,
false
).
front
();
Tensor
raw_attn_output
=
mha_fwd
(
q
,
k
,
v
,
0.0
f
,
pow
(
q
.
shape
[
-
1
],
(
-
0.5
)),
false
,
-
1
,
-
1
,
false
).
front
();
assert
(
raw_attn_output
.
shape
[
0
]
==
batch_size
);
assert
(
raw_attn_output
.
shape
[
1
]
==
num_tokens
);
...
...
@@ -175,9 +154,9 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
pool_qkv
=
pool_qkv
.
view
({
batch_size
,
pool_tokens
,
3
,
num_heads
,
dim_head
});
pool_qkv
=
pool_qkv
.
transpose
(
1
,
2
).
transpose
(
2
,
3
);
// [batch_size, 3, num_heads, poolTokens, dim_head]
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
Tensor
pool_q
=
pool_qkv
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
0
,
1
);
Tensor
pool_k
=
pool_qkv
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
1
,
2
);
Tensor
pool_s
=
pool_score
.
slice
(
0
,
i
,
i
+
1
);
Tensor
pool_q
=
pool_qkv
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
0
,
1
);
Tensor
pool_k
=
pool_qkv
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
1
,
2
);
Tensor
pool_s
=
pool_score
.
slice
(
0
,
i
,
i
+
1
);
gemm_batched_fp16
(
pool_q
,
pool_k
,
pool_s
);
}
}
...
...
@@ -221,10 +200,13 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
spdlog
::
debug
(
"q,k,v={}"
,
q
.
shape
.
str
());
Tensor
raw_attn_output
=
mha_fwd_block
(
q
,
k
,
v
,
cu_seqlens
,
cu_seqlens
,
POOL_SIZE
,
POOL_SIZE
,
Tensor
raw_attn_output
=
mha_fwd_block
(
q
,
k
,
v
,
cu_seqlens
,
cu_seqlens
,
POOL_SIZE
,
POOL_SIZE
,
headmask_type
,
{},
blockmask
,
...
...
@@ -232,8 +214,12 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
num_tokens
,
0.0
f
,
pow
(
q
.
shape
[
-
1
],
(
-
0.5
)),
false
,
false
,
false
,
-
1
,
-
1
).
front
();
false
,
false
,
false
,
-
1
,
-
1
)
.
front
();
debug
(
"raw_attn_output"
,
raw_attn_output
);
...
...
@@ -290,30 +276,22 @@ void Attention::setForceFP16(Module *module, bool value) {
});
}
FluxSingleTransformerBlock
::
FluxSingleTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
int
mlp_ratio
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
dim_head
(
attention_head_dim
/
num_attention_heads
),
num_heads
(
num_attention_heads
),
mlp_hidden_dim
(
dim
*
mlp_ratio
),
norm
(
dim
,
dtype
,
device
),
FluxSingleTransformerBlock
::
FluxSingleTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
int
mlp_ratio
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
dim_head
(
attention_head_dim
/
num_attention_heads
),
num_heads
(
num_attention_heads
),
mlp_hidden_dim
(
dim
*
mlp_ratio
),
norm
(
dim
,
dtype
,
device
),
mlp_fc1
(
dim
,
mlp_hidden_dim
,
true
,
use_fp4
,
dtype
,
device
),
mlp_fc2
(
mlp_hidden_dim
,
dim
,
true
,
use_fp4
,
dtype
,
device
),
qkv_proj
(
dim
,
dim
*
3
,
true
,
use_fp4
,
dtype
,
device
),
norm_q
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
norm_k
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
mlp_fc2
(
mlp_hidden_dim
,
dim
,
true
,
use_fp4
,
dtype
,
device
),
qkv_proj
(
dim
,
dim
*
3
,
true
,
use_fp4
,
dtype
,
device
),
norm_q
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
norm_k
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
attn
(
num_attention_heads
,
attention_head_dim
/
num_attention_heads
,
device
),
out_proj
(
dim
,
dim
,
true
,
use_fp4
,
dtype
,
device
)
{
registerChildren
(
norm
,
"norm"
)
(
mlp_fc1
,
"mlp_fc1"
)
(
mlp_fc2
,
"mlp_fc2"
)
(
qkv_proj
,
"qkv_proj"
)
(
norm_q
,
"norm_q"
)
(
norm_k
,
"norm_k"
)
(
attn
,
"attn"
)
(
out_proj
,
"out_proj"
)
;
out_proj
(
dim
,
dim
,
true
,
use_fp4
,
dtype
,
device
)
{
registerChildren
(
norm
,
"norm"
)(
mlp_fc1
,
"mlp_fc1"
)(
mlp_fc2
,
"mlp_fc2"
)(
qkv_proj
,
"qkv_proj"
)(
norm_q
,
"norm_q"
)(
norm_k
,
"norm_k"
)(
attn
,
"attn"
)(
out_proj
,
"out_proj"
);
}
Tensor
FluxSingleTransformerBlock
::
forward
(
Tensor
hidden_states
,
Tensor
temb
,
Tensor
rotary_emb
)
{
...
...
@@ -334,12 +312,18 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
debug
(
"rotary_emb"
,
rotary_emb
);
if
(
attnImpl
==
AttentionImpl
::
FlashAttention2
)
{
Tensor
qkv
=
Tensor
::
allocate
({
batch_size
,
num_tokens
,
dim
*
3
},
norm_hidden_states
.
scalar_type
(),
norm_hidden_states
.
device
());
Tensor
qkv
=
Tensor
::
allocate
(
{
batch_size
,
num_tokens
,
dim
*
3
},
norm_hidden_states
.
scalar_type
(),
norm_hidden_states
.
device
());
// qkv_proj.forward(norm_hidden_states, qkv, {});
// debug("qkv_raw", qkv);
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
qkv_proj
.
forward
(
norm_hidden_states
.
slice
(
0
,
i
,
i
+
1
),
qkv
.
slice
(
0
,
i
,
i
+
1
),
{},
norm_q
.
weight
,
norm_k
.
weight
,
rotary_emb
);
qkv_proj
.
forward
(
norm_hidden_states
.
slice
(
0
,
i
,
i
+
1
),
qkv
.
slice
(
0
,
i
,
i
+
1
),
{},
norm_q
.
weight
,
norm_k
.
weight
,
rotary_emb
);
}
debug
(
"qkv"
,
qkv
);
// Tensor qkv = forward_fc(qkv_proj, norm_hidden_states);
...
...
@@ -352,16 +336,23 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
const
int
num_tokens_pad
=
ceilDiv
(
num_tokens
,
256
)
*
256
;
Tensor
q
=
Tensor
::
allocate
({
batch_size
,
num_heads
,
num_tokens_pad
,
dim_head
},
Tensor
::
FP16
,
norm_hidden_states
.
device
());
Tensor
k
=
Tensor
::
allocate
({
batch_size
,
num_heads
,
num_tokens_pad
,
dim_head
},
Tensor
::
FP16
,
norm_hidden_states
.
device
());
Tensor
v
=
Tensor
::
allocate
({
batch_size
,
num_heads
,
num_tokens_pad
,
dim_head
},
Tensor
::
FP16
,
norm_hidden_states
.
device
());
Tensor
q
=
Tensor
::
allocate
(
{
batch_size
,
num_heads
,
num_tokens_pad
,
dim_head
},
Tensor
::
FP16
,
norm_hidden_states
.
device
());
Tensor
k
=
Tensor
::
allocate
(
{
batch_size
,
num_heads
,
num_tokens_pad
,
dim_head
},
Tensor
::
FP16
,
norm_hidden_states
.
device
());
Tensor
v
=
Tensor
::
allocate
(
{
batch_size
,
num_heads
,
num_tokens_pad
,
dim_head
},
Tensor
::
FP16
,
norm_hidden_states
.
device
());
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
qkv_proj
.
forward
(
norm_hidden_states
.
slice
(
0
,
i
,
i
+
1
),
{},
{},
norm_q
.
weight
,
norm_k
.
weight
,
rotary_emb
,
q
.
slice
(
0
,
i
,
i
+
1
),
k
.
slice
(
0
,
i
,
i
+
1
),
v
.
slice
(
0
,
i
,
i
+
1
),
qkv_proj
.
forward
(
norm_hidden_states
.
slice
(
0
,
i
,
i
+
1
),
{},
{},
norm_q
.
weight
,
norm_k
.
weight
,
rotary_emb
,
q
.
slice
(
0
,
i
,
i
+
1
),
k
.
slice
(
0
,
i
,
i
+
1
),
v
.
slice
(
0
,
i
,
i
+
1
),
num_tokens
);
}
...
...
@@ -369,7 +360,9 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
debug
(
"packed_k"
,
k
);
debug
(
"packed_v"
,
v
);
Tensor
o
=
Tensor
::
allocate
({
batch_size
,
num_tokens_pad
,
num_heads
*
dim_head
},
norm_hidden_states
.
scalar_type
(),
norm_hidden_states
.
device
());
Tensor
o
=
Tensor
::
allocate
({
batch_size
,
num_tokens_pad
,
num_heads
*
dim_head
},
norm_hidden_states
.
scalar_type
(),
norm_hidden_states
.
device
());
kernels
::
attention_fp16
(
q
,
k
,
v
,
o
,
pow
(
dim_head
,
(
-
0.5
)));
...
...
@@ -377,16 +370,14 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
attn_output
=
o
.
slice
(
1
,
0
,
num_tokens
);
}
else
{
attn_output
=
Tensor
::
allocate
({
batch_size
,
num_tokens
,
num_heads
*
dim_head
},
o
.
scalar_type
(),
o
.
device
());
checkCUDA
(
cudaMemcpy2DAsync
(
attn_output
.
data_ptr
(),
checkCUDA
(
cudaMemcpy2DAsync
(
attn_output
.
data_ptr
(),
attn_output
.
stride
(
0
)
*
attn_output
.
scalar_size
(),
o
.
data_ptr
(),
o
.
stride
(
0
)
*
o
.
scalar_size
(),
attn_output
.
stride
(
0
)
*
attn_output
.
scalar_size
(),
batch_size
,
cudaMemcpyDeviceToDevice
,
getCurrentCUDAStream
()
));
getCurrentCUDAStream
()));
}
}
else
{
assert
(
false
);
...
...
@@ -394,8 +385,6 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
debug
(
"raw_attn_output"
,
attn_output
);
attn_output
=
forward_fc
(
out_proj
,
attn_output
);
debug
(
"attn_output"
,
attn_output
);
...
...
@@ -413,54 +402,40 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
return
hidden_states
;
}
JointTransformerBlock
::
JointTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
bool
context_pre_only
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
)
,
dim_head
(
attention_head_dim
/
num_attention_heads
)
,
num_heads
(
num_attention_heads
)
,
context_pre_only
(
context_pre_only
)
,
norm1
(
dim
,
false
,
dtype
,
device
)
,
norm1_context
(
dim
,
context_pre_only
,
dtype
,
device
)
,
qkv_proj
(
dim
,
dim
*
3
,
true
,
use_fp4
,
dtype
,
device
),
qkv_proj_context
(
dim
,
dim
*
3
,
true
,
use_fp4
,
dtype
,
device
),
norm
_q
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
norm_
k
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
norm_added_q
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
JointTransformerBlock
::
JointTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
bool
context_pre_only
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
)
,
dim
_head
(
attention_head_dim
/
num_attention_heads
),
num_heads
(
num_attention_heads
),
context_pre_only
(
context_pre_only
),
norm1
(
dim
,
false
,
dtype
,
device
),
norm
1_context
(
dim
,
context_pre_only
,
dtype
,
device
),
qkv_proj
(
dim
,
dim
*
3
,
true
,
use_fp4
,
dtype
,
device
),
qkv_proj_context
(
dim
,
dim
*
3
,
true
,
use_fp4
,
dtype
,
device
),
norm_
q
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
norm_k
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
norm_added_q
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
norm_added_k
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
attn
(
num_attention_heads
,
attention_head_dim
/
num_attention_heads
,
device
),
out_proj
(
dim
,
dim
,
true
,
use_fp4
,
dtype
,
device
),
out_proj_context
(
dim
,
dim
,
true
,
use_fp4
,
dtype
,
device
),
norm2
(
dim
,
1e-6
,
false
,
dtype
,
device
),
norm2_context
(
dim
,
1e-6
,
false
,
dtype
,
device
),
mlp_fc1
(
dim
,
dim
*
4
,
true
,
use_fp4
,
dtype
,
device
),
mlp_fc2
(
dim
*
4
,
dim
,
true
,
use_fp4
,
dtype
,
device
),
out_proj
(
dim
,
dim
,
true
,
use_fp4
,
dtype
,
device
),
out_proj_context
(
dim
,
dim
,
true
,
use_fp4
,
dtype
,
device
),
norm2
(
dim
,
1e-6
,
false
,
dtype
,
device
),
norm2_context
(
dim
,
1e-6
,
false
,
dtype
,
device
),
mlp_fc1
(
dim
,
dim
*
4
,
true
,
use_fp4
,
dtype
,
device
),
mlp_fc2
(
dim
*
4
,
dim
,
true
,
use_fp4
,
dtype
,
device
),
mlp_context_fc1
(
dim
,
dim
*
4
,
true
,
use_fp4
,
dtype
,
device
),
mlp_context_fc2
(
dim
*
4
,
dim
,
true
,
use_fp4
,
dtype
,
device
)
{
registerChildren
(
norm1
,
"norm1"
)
(
norm1_context
,
"norm1_context"
)
(
qkv_proj
,
"qkv_proj"
)
(
qkv_proj_context
,
"qkv_proj_context"
)
(
norm_q
,
"norm_q"
)
(
norm_k
,
"norm_k"
)
(
norm_added_q
,
"norm_added_q"
)
(
norm_added_k
,
"norm_added_k"
)
(
attn
,
"attn"
)
(
out_proj
,
"out_proj"
)
(
out_proj_context
,
"out_proj_context"
)
(
norm2
,
"norm2"
)
(
norm2_context
,
"norm2_context"
)
(
mlp_fc1
,
"mlp_fc1"
)
(
mlp_fc2
,
"mlp_fc2"
)
(
mlp_context_fc1
,
"mlp_context_fc1"
)
(
mlp_context_fc2
,
"mlp_context_fc2"
)
;
mlp_context_fc2
(
dim
*
4
,
dim
,
true
,
use_fp4
,
dtype
,
device
)
{
registerChildren
(
norm1
,
"norm1"
)(
norm1_context
,
"norm1_context"
)(
qkv_proj
,
"qkv_proj"
)(
qkv_proj_context
,
"qkv_proj_context"
)(
norm_q
,
"norm_q"
)(
norm_k
,
"norm_k"
)(
norm_added_q
,
"norm_added_q"
)(
norm_added_k
,
"norm_added_k"
)(
attn
,
"attn"
)(
out_proj
,
"out_proj"
)(
out_proj_context
,
"out_proj_context"
)(
norm2
,
"norm2"
)(
norm2_context
,
"norm2_context"
)(
mlp_fc1
,
"mlp_fc1"
)(
mlp_fc2
,
"mlp_fc2"
)(
mlp_context_fc1
,
"mlp_context_fc1"
)(
mlp_context_fc2
,
"mlp_context_fc2"
);
}
// hidden_states: [Batch, Width * Height, dim]
// encoder_hidden_states: [Batch, Token, dim]
std
::
tuple
<
Tensor
,
Tensor
>
JointTransformerBlock
::
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb
,
Tensor
rotary_emb_context
,
float
sparsityRatio
)
{
std
::
tuple
<
Tensor
,
Tensor
>
JointTransformerBlock
::
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb
,
Tensor
rotary_emb_context
,
float
sparsityRatio
)
{
int
batch_size
=
hidden_states
.
shape
[
0
];
assert
(
encoder_hidden_states
.
shape
[
0
]
==
batch_size
);
...
...
@@ -468,14 +443,16 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
nvtxRangePushA
(
"AdaNorm"
);
int
num_tokens_img
=
hidden_states
.
shape
[
1
];
int
num_tokens_txt
=
encoder_hidden_states
.
shape
[
1
];
assert
(
hidden_states
.
shape
[
2
]
==
dim
);
assert
(
encoder_hidden_states
.
shape
[
2
]
==
dim
);
spdlog
::
debug
(
"hidden_states={} encoder_hidden_states={} temb={}"
,
hidden_states
.
shape
.
str
(),
encoder_hidden_states
.
shape
.
str
(),
temb
.
shape
.
str
());
spdlog
::
debug
(
"hidden_states={} encoder_hidden_states={} temb={}"
,
hidden_states
.
shape
.
str
(),
encoder_hidden_states
.
shape
.
str
(),
temb
.
shape
.
str
());
spdlog
::
debug
(
"batch_size={} num_tokens_img={} num_tokens_txt={}"
,
batch_size
,
num_tokens_img
,
num_tokens_txt
);
auto
norm1_output
=
norm1
.
forward
(
hidden_states
,
temb
);
...
...
@@ -511,22 +488,28 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
const
bool
blockSparse
=
sparsityRatio
>
0
;
const
int
poolTokens
=
num_tokens_img
/
POOL_SIZE
+
num_tokens_txt
/
POOL_SIZE
;
concat
=
Tensor
::
allocate
({
batch_size
,
num_tokens_img
+
num_tokens_txt
,
dim
*
3
},
norm1_output
.
x
.
scalar_type
(),
norm1_output
.
x
.
device
());
concat
=
Tensor
::
allocate
({
batch_size
,
num_tokens_img
+
num_tokens_txt
,
dim
*
3
},
norm1_output
.
x
.
scalar_type
(),
norm1_output
.
x
.
device
());
pool
=
blockSparse
?
Tensor
::
allocate
({
batch_size
,
poolTokens
,
dim
*
3
},
norm1_output
.
x
.
scalar_type
(),
norm1_output
.
x
.
device
())
pool
=
blockSparse
?
Tensor
::
allocate
({
batch_size
,
poolTokens
,
dim
*
3
},
norm1_output
.
x
.
scalar_type
(),
norm1_output
.
x
.
device
())
:
Tensor
{};
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
// img first
Tensor
qkv
=
concat
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
0
,
num_tokens_img
);
Tensor
qkv_context
=
concat
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
num_tokens_img
,
num_tokens_img
+
num_tokens_txt
);
Tensor
qkv_context
=
concat
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
num_tokens_img
,
num_tokens_img
+
num_tokens_txt
);
Tensor
pool_qkv
=
pool
.
valid
()
?
pool
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
0
,
num_tokens_img
/
POOL_SIZE
)
:
Tensor
{};
Tensor
pool_qkv
=
pool
.
valid
()
?
pool
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
0
,
num_tokens_img
/
POOL_SIZE
)
:
Tensor
{};
Tensor
pool_qkv_context
=
pool
.
valid
()
?
concat
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
num_tokens_img
/
POOL_SIZE
,
num_tokens_img
/
POOL_SIZE
+
num_tokens_txt
/
POOL_SIZE
)
?
pool
.
slice
(
0
,
i
,
i
+
1
)
.
slice
(
1
,
num_tokens_img
/
POOL_SIZE
,
num_tokens_img
/
POOL_SIZE
+
num_tokens_txt
/
POOL_SIZE
)
:
Tensor
{};
// qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv);
...
...
@@ -534,7 +517,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
debug
(
"rotary_emb"
,
rotary_emb
);
qkv_proj
.
forward
(
norm1_output
.
x
.
slice
(
0
,
i
,
i
+
1
),
qkv
,
pool_qkv
,
norm_q
.
weight
,
norm_k
.
weight
,
rotary_emb
);
qkv_proj
.
forward
(
norm1_output
.
x
.
slice
(
0
,
i
,
i
+
1
),
qkv
,
pool_qkv
,
norm_q
.
weight
,
norm_k
.
weight
,
rotary_emb
);
debug
(
"qkv"
,
qkv
);
// qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context);
...
...
@@ -542,7 +526,12 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
debug
(
"rotary_emb_context"
,
rotary_emb_context
);
qkv_proj_context
.
forward
(
norm1_context_output
.
x
.
slice
(
0
,
i
,
i
+
1
),
qkv_context
,
pool_qkv_context
,
norm_added_q
.
weight
,
norm_added_k
.
weight
,
rotary_emb_context
);
qkv_proj_context
.
forward
(
norm1_context_output
.
x
.
slice
(
0
,
i
,
i
+
1
),
qkv_context
,
pool_qkv_context
,
norm_added_q
.
weight
,
norm_added_k
.
weight
,
rotary_emb_context
);
debug
(
"qkv_context"
,
qkv_context
);
}
...
...
@@ -577,28 +566,40 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
{
nvtxRangePushA
(
"qkv_proj"
);
concat_q
=
Tensor
::
allocate
({
batch_size
,
num_heads
,
num_tokens_img_pad
+
num_tokens_txt_pad
,
dim_head
},
Tensor
::
FP16
,
norm1_output
.
x
.
device
());
concat_q
=
Tensor
::
allocate
({
batch_size
,
num_heads
,
num_tokens_img_pad
+
num_tokens_txt_pad
,
dim_head
},
Tensor
::
FP16
,
norm1_output
.
x
.
device
());
concat_k
=
Tensor
::
empty_like
(
concat_q
);
concat_v
=
Tensor
::
empty_like
(
concat_q
);
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
// img first
auto
sliceImg
=
[
&
](
Tensor
x
)
{
return
x
.
slice
(
0
,
i
,
i
+
1
).
slice
(
2
,
0
,
num_tokens_img_pad
);
};
auto
sliceImg
=
[
&
](
Tensor
x
)
{
return
x
.
slice
(
0
,
i
,
i
+
1
).
slice
(
2
,
0
,
num_tokens_img_pad
);
};
auto
sliceTxt
=
[
&
](
Tensor
x
)
{
return
x
.
slice
(
0
,
i
,
i
+
1
).
slice
(
2
,
num_tokens_img_pad
,
num_tokens_img_pad
+
num_tokens_txt_pad
);
return
x
.
slice
(
0
,
i
,
i
+
1
).
slice
(
2
,
num_tokens_img_pad
,
num_tokens_img_pad
+
num_tokens_txt_pad
);
};
qkv_proj
.
forward
(
norm1_output
.
x
.
slice
(
0
,
i
,
i
+
1
),
{},
{},
norm_q
.
weight
,
norm_k
.
weight
,
rotary_emb
,
sliceImg
(
concat_q
),
sliceImg
(
concat_k
),
sliceImg
(
concat_v
),
num_tokens_img
);
qkv_proj_context
.
forward
(
norm1_context_output
.
x
.
slice
(
0
,
i
,
i
+
1
),
{},
{},
norm_added_q
.
weight
,
norm_added_k
.
weight
,
rotary_emb_context
,
sliceTxt
(
concat_q
),
sliceTxt
(
concat_k
),
sliceTxt
(
concat_v
),
num_tokens_txt
);
qkv_proj
.
forward
(
norm1_output
.
x
.
slice
(
0
,
i
,
i
+
1
),
{},
{},
norm_q
.
weight
,
norm_k
.
weight
,
rotary_emb
,
sliceImg
(
concat_q
),
sliceImg
(
concat_k
),
sliceImg
(
concat_v
),
num_tokens_img
);
qkv_proj_context
.
forward
(
norm1_context_output
.
x
.
slice
(
0
,
i
,
i
+
1
),
{},
{},
norm_added_q
.
weight
,
norm_added_k
.
weight
,
rotary_emb_context
,
sliceTxt
(
concat_q
),
sliceTxt
(
concat_k
),
sliceTxt
(
concat_v
),
num_tokens_txt
);
}
debug
(
"concat_q"
,
concat_q
);
...
...
@@ -608,7 +609,9 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
nvtxRangePop
();
}
raw_attn_output
=
Tensor
::
allocate
({
batch_size
,
num_tokens_img_pad
+
num_tokens_txt_pad
,
num_heads
*
dim_head
},
norm1_output
.
x
.
scalar_type
(),
norm1_output
.
x
.
device
());
raw_attn_output
=
Tensor
::
allocate
({
batch_size
,
num_tokens_img_pad
+
num_tokens_txt_pad
,
num_heads
*
dim_head
},
norm1_output
.
x
.
scalar_type
(),
norm1_output
.
x
.
device
());
nvtxRangePushA
(
"Attention"
);
...
...
@@ -616,7 +619,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
nvtxRangePop
();
raw_attn_output
=
raw_attn_output
.
view
({
batch_size
,
num_tokens_img_pad
+
num_tokens_txt_pad
,
num_heads
,
dim_head
});
raw_attn_output
=
raw_attn_output
.
view
({
batch_size
,
num_tokens_img_pad
+
num_tokens_txt_pad
,
num_heads
,
dim_head
});
}
else
{
assert
(
false
);
}
...
...
@@ -632,25 +636,28 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
Tensor
raw_attn_output_split
;
if
(
batch_size
==
1
)
{
raw_attn_output_split
=
raw_attn_output
.
slice
(
1
,
0
,
num_tokens_img
).
reshape
({
batch_size
,
num_tokens_img
,
num_heads
*
dim_head
});
raw_attn_output_split
=
raw_attn_output
.
slice
(
1
,
0
,
num_tokens_img
).
reshape
({
batch_size
,
num_tokens_img
,
num_heads
*
dim_head
});
}
else
{
raw_attn_output_split
=
Tensor
::
allocate
({
batch_size
,
num_tokens_img
,
num_heads
*
dim_head
},
raw_attn_output
.
scalar_type
(),
raw_attn_output
.
device
());
checkCUDA
(
cudaMemcpy2DAsync
(
raw_attn_output_split
.
data_ptr
(),
raw_attn_output_split
=
Tensor
::
allocate
({
batch_size
,
num_tokens_img
,
num_heads
*
dim_head
},
raw_attn_output
.
scalar_type
(),
raw_attn_output
.
device
());
checkCUDA
(
cudaMemcpy2DAsync
(
raw_attn_output_split
.
data_ptr
(),
num_tokens_img
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
raw_attn_output
.
data_ptr
(),
(
num_tokens_img_pad
+
num_tokens_txt_pad
)
*
num_heads
*
dim_head
*
raw_attn_output
.
scalar_size
(),
(
num_tokens_img_pad
+
num_tokens_txt_pad
)
*
num_heads
*
dim_head
*
raw_attn_output
.
scalar_size
(),
num_tokens_img
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
batch_size
,
cudaMemcpyDeviceToDevice
,
stream
));
}
spdlog
::
debug
(
"raw_attn_output_split={}"
,
raw_attn_output_split
.
shape
.
str
());
debug
(
"img.raw_attn_output_split"
,
raw_attn_output_split
);
Tensor
attn_output
=
forward_fc
(
out_proj
,
raw_attn_output_split
);
// std::get<Tensor>(out_proj.forward(raw_attn_output_split));
Tensor
attn_output
=
forward_fc
(
out_proj
,
raw_attn_output_split
);
// std::get<Tensor>(out_proj.forward(raw_attn_output_split));
debug
(
"img.attn_output"
,
attn_output
);
#if 1
...
...
@@ -690,7 +697,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
}
if
(
context_pre_only
)
{
return
{
hidden_states
,
encoder_hidden_states
};
return
{
hidden_states
,
encoder_hidden_states
};
}
{
...
...
@@ -700,25 +707,30 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
Tensor
raw_attn_output_split
;
if
(
batch_size
==
1
)
{
raw_attn_output_split
=
raw_attn_output
.
slice
(
1
,
num_tokens_img_pad
,
num_tokens_img_pad
+
num_tokens_txt
).
reshape
({
batch_size
,
num_tokens_txt
,
num_heads
*
dim_head
});
raw_attn_output_split
=
raw_attn_output
.
slice
(
1
,
num_tokens_img_pad
,
num_tokens_img_pad
+
num_tokens_txt
)
.
reshape
({
batch_size
,
num_tokens_txt
,
num_heads
*
dim_head
});
}
else
{
raw_attn_output_split
=
Tensor
::
allocate
({
batch_size
,
num_tokens_txt
,
num_heads
*
dim_head
},
raw_attn_output
.
scalar_type
(),
raw_attn_output
.
device
());
checkCUDA
(
cudaMemcpy2DAsync
(
raw_attn_output_split
.
data_ptr
(),
raw_attn_output_split
=
Tensor
::
allocate
({
batch_size
,
num_tokens_txt
,
num_heads
*
dim_head
},
raw_attn_output
.
scalar_type
(),
raw_attn_output
.
device
());
checkCUDA
(
cudaMemcpy2DAsync
(
raw_attn_output_split
.
data_ptr
(),
num_tokens_txt
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
raw_attn_output
.
data_ptr
<
char
>
()
+
num_tokens_img_pad
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
(
num_tokens_img_pad
+
num_tokens_txt_pad
)
*
num_heads
*
dim_head
*
raw_attn_output
.
scalar_size
(),
raw_attn_output
.
data_ptr
<
char
>
()
+
num_tokens_img_pad
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
(
num_tokens_img_pad
+
num_tokens_txt_pad
)
*
num_heads
*
dim_head
*
raw_attn_output
.
scalar_size
(),
num_tokens_txt
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
batch_size
,
cudaMemcpyDeviceToDevice
,
stream
));
}
spdlog
::
debug
(
"raw_attn_output_split={}"
,
raw_attn_output_split
.
shape
.
str
());
debug
(
"context.raw_attn_output_split"
,
raw_attn_output_split
);
Tensor
attn_output
=
forward_fc
(
out_proj_context
,
raw_attn_output_split
);
// std::get<Tensor>(out_proj_context.forward(raw_attn_output_split));
Tensor
attn_output
=
forward_fc
(
out_proj_context
,
raw_attn_output_split
);
// std::get<Tensor>(out_proj_context.forward(raw_attn_output_split));
debug
(
"context.attn_output"
,
attn_output
);
#if 1
...
...
@@ -742,9 +754,9 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
auto
norm_hidden_states
=
encoder_hidden_states
;
#endif
// Tensor ff_output = mlp_context_fc2.forward(GELU::forward(mlp_context_fc1.forward(norm_hidden_states)));
// Tensor ff_output = mlp_context_fc2.forward_quant(quant_static_fuse_gelu(mlp_context_fc1.forward(norm_hidden_states), 1.0));
// Tensor ff_output =
// mlp_context_fc2.forward_quant(quant_static_fuse_gelu(mlp_context_fc1.forward(norm_hidden_states), 1.0));
debug
(
"context.ff_input"
,
norm_hidden_states
);
Tensor
ff_output
=
forward_mlp
(
mlp_context_fc1
,
mlp_context_fc2
,
norm_hidden_states
);
debug
(
"context.ff_output"
,
ff_output
);
...
...
@@ -761,12 +773,14 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
nvtxRangePop
();
return
{
hidden_states
,
encoder_hidden_states
};
return
{
hidden_states
,
encoder_hidden_states
};
}
FluxModel
::
FluxModel
(
bool
use_fp4
,
bool
offload
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dtype
(
dtype
),
offload
(
offload
)
{
FluxModel
::
FluxModel
(
bool
use_fp4
,
bool
offload
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dtype
(
dtype
),
offload
(
offload
)
{
for
(
int
i
=
0
;
i
<
19
;
i
++
)
{
transformer_blocks
.
push_back
(
std
::
make_unique
<
JointTransformerBlock
>
(
3072
,
24
,
3072
,
false
,
use_fp4
,
dtype
,
device
));
transformer_blocks
.
push_back
(
std
::
make_unique
<
JointTransformerBlock
>
(
3072
,
24
,
3072
,
false
,
use_fp4
,
dtype
,
device
));
registerChildren
(
*
transformer_blocks
.
back
(),
format
(
"transformer_blocks.{}"
,
i
));
if
(
offload
&&
i
>
0
)
{
// don't offload first block
transformer_blocks
.
back
()
->
setLazyLoad
(
true
);
...
...
@@ -774,7 +788,8 @@ FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Devic
}
}
for
(
int
i
=
0
;
i
<
38
;
i
++
)
{
single_transformer_blocks
.
push_back
(
std
::
make_unique
<
FluxSingleTransformerBlock
>
(
3072
,
24
,
3072
,
4
,
use_fp4
,
dtype
,
device
));
single_transformer_blocks
.
push_back
(
std
::
make_unique
<
FluxSingleTransformerBlock
>
(
3072
,
24
,
3072
,
4
,
use_fp4
,
dtype
,
device
));
registerChildren
(
*
single_transformer_blocks
.
back
(),
format
(
"single_transformer_blocks.{}"
,
i
));
if
(
offload
)
{
single_transformer_blocks
.
back
()
->
setLazyLoad
(
true
);
...
...
@@ -783,8 +798,7 @@ FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Devic
}
}
Tensor
FluxModel
::
forward
(
Tensor
hidden_states
,
Tensor
FluxModel
::
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb_img
,
...
...
@@ -805,31 +819,42 @@ Tensor FluxModel::forward(
Tensor
concat
;
auto
compute
=
[
&
](
int
layer
)
{
if
(
skip_first_layer
&&
size_t
(
layer
)
==
0
)
return
;
if
(
skip_first_layer
&&
size_t
(
layer
)
==
0
)
return
;
if
(
size_t
(
layer
)
<
transformer_blocks
.
size
())
{
auto
&
block
=
transformer_blocks
.
at
(
layer
);
std
::
tie
(
hidden_states
,
encoder_hidden_states
)
=
block
->
forward
(
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_context
,
0.0
f
);
std
::
tie
(
hidden_states
,
encoder_hidden_states
)
=
block
->
forward
(
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_context
,
0.0
f
);
if
(
controlnet_block_samples
.
valid
())
{
const
int
num_controlnet_block_samples
=
controlnet_block_samples
.
shape
[
0
];
int
interval_control
=
ceilDiv
(
transformer_blocks
.
size
(),
static_cast
<
size_t
>
(
num_controlnet_block_samples
));
int
interval_control
=
ceilDiv
(
transformer_blocks
.
size
(),
static_cast
<
size_t
>
(
num_controlnet_block_samples
));
int
block_index
=
layer
/
interval_control
;
// Xlabs ControlNet
// block_index = layer % num_controlnet_block_samples;
hidden_states
=
kernels
::
add
(
hidden_states
,
controlnet_block_samples
[
block_index
]);
}
if
(
residual_callback
&&
layer
%
2
==
0
)
{
Tensor
cpu_input
=
hidden_states
.
copy
(
Device
::
cpu
());
pybind11
::
gil_scoped_acquire
gil
;
Tensor
cpu_output
=
residual_callback
(
cpu_input
);
Tensor
residual
=
cpu_output
.
copy
(
Device
::
cuda
());
hidden_states
=
kernels
::
add
(
hidden_states
,
residual
);
}
}
else
{
if
(
size_t
(
layer
)
==
transformer_blocks
.
size
())
{
// txt first, same as diffusers
concat
=
Tensor
::
allocate
({
batch_size
,
txt_tokens
+
img_tokens
,
3072
},
dtype
,
device
);
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
concat
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
0
,
txt_tokens
).
copy_
(
encoder_hidden_states
.
slice
(
0
,
i
,
i
+
1
));
concat
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
).
copy_
(
hidden_states
.
slice
(
0
,
i
,
i
+
1
));
concat
.
slice
(
0
,
i
,
i
+
1
)
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
)
.
copy_
(
hidden_states
.
slice
(
0
,
i
,
i
+
1
));
}
hidden_states
=
concat
;
encoder_hidden_states
=
{};
}
auto
&
block
=
single_transformer_blocks
.
at
(
layer
-
transformer_blocks
.
size
());
...
...
@@ -837,7 +862,8 @@ Tensor FluxModel::forward(
if
(
controlnet_single_block_samples
.
valid
())
{
const
int
num_controlnet_single_block_samples
=
controlnet_single_block_samples
.
shape
[
0
];
int
interval_control
=
ceilDiv
(
single_transformer_blocks
.
size
(),
static_cast
<
size_t
>
(
num_controlnet_single_block_samples
));
int
interval_control
=
ceilDiv
(
single_transformer_blocks
.
size
(),
static_cast
<
size_t
>
(
num_controlnet_single_block_samples
));
int
block_index
=
(
layer
-
transformer_blocks
.
size
())
/
interval_control
;
// Xlabs ControlNet
// block_index = layer % num_controlnet_single_block_samples
...
...
@@ -846,6 +872,17 @@ Tensor FluxModel::forward(
slice
=
kernels
::
add
(
slice
,
controlnet_single_block_samples
[
block_index
]);
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
).
copy_
(
slice
);
}
size_t
local_layer_idx
=
layer
-
transformer_blocks
.
size
();
if
(
residual_callback
&&
local_layer_idx
%
4
==
0
)
{
Tensor
callback_input
=
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
);
Tensor
cpu_input
=
callback_input
.
copy
(
Device
::
cpu
());
pybind11
::
gil_scoped_acquire
gil
;
Tensor
cpu_output
=
residual_callback
(
cpu_input
);
Tensor
residual
=
cpu_output
.
copy
(
Device
::
cuda
());
auto
slice
=
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
);
slice
=
kernels
::
add
(
slice
,
residual
);
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
).
copy_
(
slice
);
}
}
};
auto
load
=
[
&
](
int
layer
)
{
...
...
@@ -873,8 +910,7 @@ Tensor FluxModel::forward(
return
hidden_states
;
}
std
::
tuple
<
Tensor
,
Tensor
>
FluxModel
::
forward_layer
(
size_t
layer
,
std
::
tuple
<
Tensor
,
Tensor
>
FluxModel
::
forward_layer
(
size_t
layer
,
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
...
...
@@ -883,21 +919,13 @@ std::tuple<Tensor, Tensor> FluxModel::forward_layer(
Tensor
controlnet_block_samples
,
Tensor
controlnet_single_block_samples
)
{
if
(
layer
<
transformer_blocks
.
size
()){
if
(
layer
<
transformer_blocks
.
size
())
{
std
::
tie
(
hidden_states
,
encoder_hidden_states
)
=
transformer_blocks
.
at
(
layer
)
->
forward
(
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_context
,
0.0
f
);
}
else
{
std
::
tie
(
hidden_states
,
encoder_hidden_states
)
=
transformer_blocks
.
at
(
layer
-
transformer_blocks
.
size
())
->
forward
(
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_context
,
0.0
f
);
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_context
,
0.0
f
);
}
else
{
std
::
tie
(
hidden_states
,
encoder_hidden_states
)
=
transformer_blocks
.
at
(
layer
-
transformer_blocks
.
size
())
->
forward
(
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_context
,
0.0
f
);
}
const
int
txt_tokens
=
encoder_hidden_states
.
shape
[
1
];
...
...
@@ -915,7 +943,8 @@ std::tuple<Tensor, Tensor> FluxModel::forward_layer(
}
else
if
(
layer
>=
transformer_blocks
.
size
()
&&
controlnet_single_block_samples
.
valid
())
{
const
int
num_controlnet_single_block_samples
=
controlnet_single_block_samples
.
shape
[
0
];
int
interval_control
=
ceilDiv
(
single_transformer_blocks
.
size
(),
static_cast
<
size_t
>
(
num_controlnet_single_block_samples
));
int
interval_control
=
ceilDiv
(
single_transformer_blocks
.
size
(),
static_cast
<
size_t
>
(
num_controlnet_single_block_samples
));
int
block_index
=
(
layer
-
transformer_blocks
.
size
())
/
interval_control
;
// Xlabs ControlNet
// block_index = layer % num_controlnet_single_block_samples
...
...
@@ -925,7 +954,7 @@ std::tuple<Tensor, Tensor> FluxModel::forward_layer(
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
).
copy_
(
slice
);
}
return
{
hidden_states
,
encoder_hidden_states
};
return
{
hidden_states
,
encoder_hidden_states
};
}
void
FluxModel
::
setAttentionImpl
(
AttentionImpl
impl
)
{
...
...
@@ -936,3 +965,6 @@ void FluxModel::setAttentionImpl(AttentionImpl impl) {
block
->
attnImpl
=
impl
;
}
}
void
FluxModel
::
set_residual_callback
(
std
::
function
<
Tensor
(
const
Tensor
&
)
>
cb
)
{
residual_callback
=
std
::
move
(
cb
);
}
src/FluxModel.h
View file @
37a27712
...
...
@@ -5,6 +5,10 @@
#include "Module.h"
#include "Linear.h"
#include "layernorm.h"
#include <pybind11/functional.h>
namespace
pybind11
{
class
function
;
}
enum
class
AttentionImpl
{
FlashAttention2
=
0
,
...
...
@@ -45,6 +49,7 @@ public:
Tensor
scale_mlp
;
Tensor
gate_mlp
;
};
public:
AdaLayerNormZero
(
int
dim
,
bool
pre_only
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Output
forward
(
Tensor
x
,
Tensor
emb
);
...
...
@@ -83,7 +88,13 @@ public:
static
constexpr
bool
USE_4BIT
=
true
;
using
GEMM
=
std
::
conditional_t
<
USE_4BIT
,
GEMM_W4A4
,
GEMM_W8A8
>
;
FluxSingleTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
int
mlp_ratio
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
FluxSingleTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
int
mlp_ratio
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
temb
,
Tensor
rotary_emb
);
public:
...
...
@@ -109,19 +120,30 @@ public:
static
constexpr
bool
USE_4BIT
=
true
;
using
GEMM
=
std
::
conditional_t
<
USE_4BIT
,
GEMM_W4A4
,
GEMM_W8A8
>
;
JointTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
bool
context_pre_only
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
std
::
tuple
<
Tensor
,
Tensor
>
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb
,
Tensor
rotary_emb_context
,
float
sparsityRatio
);
JointTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
bool
context_pre_only
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
std
::
tuple
<
Tensor
,
Tensor
>
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb
,
Tensor
rotary_emb_context
,
float
sparsityRatio
);
public:
const
int
dim
;
const
int
dim_head
;
const
int
num_heads
;
const
bool
context_pre_only
;
AdaLayerNormZero
norm1
;
AttentionImpl
attnImpl
=
AttentionImpl
::
FlashAttention2
;
private:
AdaLayerNormZero
norm1
;
AdaLayerNormZero
norm1_context
;
GEMM
qkv_proj
;
GEMM
qkv_proj_context
;
...
...
@@ -139,8 +161,7 @@ private:
class
FluxModel
:
public
Module
{
public:
FluxModel
(
bool
use_fp4
,
bool
offload
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb_img
,
...
...
@@ -149,8 +170,7 @@ public:
Tensor
controlnet_block_samples
,
Tensor
controlnet_single_block_samples
,
bool
skip_first_layer
=
false
);
std
::
tuple
<
Tensor
,
Tensor
>
forward_layer
(
size_t
layer
,
std
::
tuple
<
Tensor
,
Tensor
>
forward_layer
(
size_t
layer
,
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
...
...
@@ -160,12 +180,16 @@ public:
Tensor
controlnet_single_block_samples
);
void
setAttentionImpl
(
AttentionImpl
impl
);
void
set_residual_callback
(
std
::
function
<
Tensor
(
const
Tensor
&
)
>
cb
);
public:
const
Tensor
::
ScalarType
dtype
;
std
::
vector
<
std
::
unique_ptr
<
JointTransformerBlock
>>
transformer_blocks
;
std
::
vector
<
std
::
unique_ptr
<
FluxSingleTransformerBlock
>>
single_transformer_blocks
;
std
::
function
<
Tensor
(
const
Tensor
&
)
>
residual_callback
;
private:
bool
offload
;
};
src/Linear.cpp
View file @
37a27712
...
...
@@ -9,16 +9,12 @@
using
namespace
nunchaku
;
GEMM_F16
::
GEMM_F16
(
int
in_features
,
int
out_features
,
bool
use_bias
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
in_features
(
in_features
),
out_features
(
out_features
)
{
GEMM_F16
::
GEMM_F16
(
int
in_features
,
int
out_features
,
bool
use_bias
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
in_features
(
in_features
),
out_features
(
out_features
)
{
this
->
weight
=
Tensor
::
allocate
({
out_features
,
in_features
},
dtype
,
device
);
this
->
bias
=
use_bias
?
Tensor
::
allocate
({
out_features
},
dtype
,
device
)
:
Tensor
{};
registerParams
(
weight
,
"weight"
,
ParamFlags
::
LazyLoad
)
(
bias
,
"bias"
)
;
registerParams
(
weight
,
"weight"
,
ParamFlags
::
LazyLoad
)(
bias
,
"bias"
);
}
Tensor
GEMM_F16
::
forward
(
Tensor
x
)
{
...
...
@@ -26,9 +22,9 @@ Tensor GEMM_F16::forward(Tensor x) {
return
out
;
}
GEMV_AWQ
::
GEMV_AWQ
(
int
in_features
,
int
out_features
,
bool
use_bias
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
in_features
(
in_features
),
out_features
(
out_features
),
group_size
(
64
),
lora_rank
(
0
),
lora_scale
(
1.0
f
),
device
(
device
)
{
GEMV_AWQ
::
GEMV_AWQ
(
int
in_features
,
int
out_features
,
bool
use_bias
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
in_features
(
in_features
),
out_features
(
out_features
),
group_size
(
64
),
lora_rank
(
0
),
lora_scale
(
1.0
f
),
device
(
device
)
{
this
->
qweight
=
Tensor
::
allocate
({
out_features
/
4
,
ceilDiv
(
in_features
,
8
)
*
4
},
Tensor
::
INT32
,
device
);
this
->
wscales
=
Tensor
::
allocate
({
ceilDiv
(
in_features
,
group_size
),
out_features
},
dtype
,
device
);
this
->
wzeros
=
Tensor
::
allocate
({
ceilDiv
(
in_features
,
group_size
),
out_features
},
dtype
,
device
);
...
...
@@ -38,14 +34,8 @@ GEMV_AWQ::GEMV_AWQ(int in_features, int out_features, bool use_bias, Tensor::Sca
this
->
lora_down
=
Tensor
::
allocate
({
lora_rank
,
in_features
},
dtype
,
device
,
true
);
this
->
lora_up
=
Tensor
::
allocate
({
out_features
,
lora_rank
},
dtype
,
device
,
true
);
registerParams
(
qweight
,
"qweight"
,
ParamFlags
::
LazyLoad
)
(
wscales
,
"wscales"
)
(
wzeros
,
"wzeros"
)
(
bias
,
"bias"
)
(
lora_down
,
"lora_down"
,
ParamFlags
::
Optional
)
(
lora_up
,
"lora_up"
,
ParamFlags
::
Optional
)
;
registerParams
(
qweight
,
"qweight"
,
ParamFlags
::
LazyLoad
)(
wscales
,
"wscales"
)(
wzeros
,
"wzeros"
)(
bias
,
"bias"
)(
lora_down
,
"lora_down"
,
ParamFlags
::
Optional
)(
lora_up
,
"lora_up"
,
ParamFlags
::
Optional
);
}
void
GEMV_AWQ
::
loadParam
(
std
::
string
key
,
Tensor
&
dst
,
Tensor
src
)
{
...
...
@@ -95,15 +85,12 @@ Tensor GEMV_AWQ::forward(Tensor x) {
return
out
;
}
#define NO_LORA_FUSION 0
GEMM_W4A4
::
GEMM_W4A4
(
int
in_features
,
int
out_features
,
bool
bias
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
in_features
(
in_features
),
out_features
(
out_features
),
in_features_pad
(
ceilDiv
(
in_features
,
128
)
*
128
),
out_features_pad
(
ceilDiv
(
out_features
,
128
)
*
128
),
use_fp4
(
use_fp4
),
lora_rank
(
0
),
dtype
(
dtype
),
device
(
device
)
{
GEMM_W4A4
::
GEMM_W4A4
(
int
in_features
,
int
out_features
,
bool
bias
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
in_features
(
in_features
),
out_features
(
out_features
),
in_features_pad
(
ceilDiv
(
in_features
,
128
)
*
128
),
out_features_pad
(
ceilDiv
(
out_features
,
128
)
*
128
),
use_fp4
(
use_fp4
),
lora_rank
(
0
),
dtype
(
dtype
),
device
(
device
)
{
this
->
qweight
=
Tensor
::
allocate
({
out_features_pad
,
in_features_pad
/
2
},
Tensor
::
INT8
,
device
,
true
);
if
(
use_fp4
)
{
this
->
wscales
=
Tensor
::
allocate
({
in_features_pad
/
16
,
out_features_pad
},
Tensor
::
FP8_E4M3
,
device
,
true
);
...
...
@@ -125,16 +112,9 @@ GEMM_W4A4::GEMM_W4A4(int in_features, int out_features, bool bias, bool use_fp4,
this
->
wcscales
=
Tensor
::
allocate
({
0
},
dtype
,
device
,
true
);
registerParams
(
qweight
,
"qweight"
,
ParamFlags
::
LazyLoad
)
(
wscales
,
"wscales"
)
(
this
->
bias
,
"bias"
)
(
lora_down
,
"lora_down"
,
ParamFlags
::
Optional
)
(
lora_up
,
"lora_up"
,
ParamFlags
::
Optional
)
(
smooth
,
"smooth"
)
(
wtscale
,
"wtscale"
,
ParamFlags
::
Optional
)
(
wcscales
,
"wcscales"
,
ParamFlags
::
Optional
)
;
registerParams
(
qweight
,
"qweight"
,
ParamFlags
::
LazyLoad
)(
wscales
,
"wscales"
)(
this
->
bias
,
"bias"
)(
lora_down
,
"lora_down"
,
ParamFlags
::
Optional
)(
lora_up
,
"lora_up"
,
ParamFlags
::
Optional
)(
smooth
,
"smooth"
)(
wtscale
,
"wtscale"
,
ParamFlags
::
Optional
)(
wcscales
,
"wcscales"
,
ParamFlags
::
Optional
);
#if NO_LORA_FUSION
checkCUBLAS
(
cublasCreate
(
&
handle
));
...
...
@@ -181,11 +161,21 @@ Tensor GEMM_W4A4::forward_silu(Tensor x) {
return
std
::
get
<
Tensor
>
(
this
->
forward
(
x
,
FuseOptions
::
SILU
,
nullptr
));
}
std
::
variant
<
Tensor
,
GEMM_W4A4
::
QuantizedActivation
>
GEMM_W4A4
::
forward
(
Tensor
x
,
FuseOptions
fuse
,
GEMM_W4A4
*
nextGEMM
)
{
std
::
variant
<
Tensor
,
GEMM_W4A4
::
QuantizedActivation
>
GEMM_W4A4
::
forward
(
Tensor
x
,
FuseOptions
fuse
,
GEMM_W4A4
*
nextGEMM
)
{
return
forward_quant
(
quantize
(
x
,
false
),
fuse
,
nextGEMM
);
}
void
GEMM_W4A4
::
forward
(
Tensor
x
,
Tensor
out
,
Tensor
pool
,
Tensor
norm_q
,
Tensor
norm_k
,
Tensor
rotary_emb
,
Tensor
out_q
,
Tensor
out_k
,
Tensor
out_v
,
int
numTokens
)
{
void
GEMM_W4A4
::
forward
(
Tensor
x
,
Tensor
out
,
Tensor
pool
,
Tensor
norm_q
,
Tensor
norm_k
,
Tensor
rotary_emb
,
Tensor
out_q
,
Tensor
out_k
,
Tensor
out_v
,
int
numTokens
)
{
QuantizedActivation
qact
=
quantize
(
x
,
false
);
#if !NO_LORA_FUSION
...
...
@@ -198,17 +188,59 @@ void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor
debug("gemm.nolora.out", out);
#endif
kernels
::
gemm_w4a4
(
qact
.
act
,
qweight
,
out
,
{},
qact
.
ascales
,
wscales
,
{},
pool
,
qact
.
lora_act
,
this
->
lora_up
,
{},
{},
norm_q
,
norm_k
,
rotary_emb
,
this
->
bias
,
{},
{},
{},
qact
.
is_unsigned
,
this
->
lora_scales
,
false
,
use_fp4
,
*
this
->
wtscale
.
data_ptr
<
float
>
(),
wcscales
.
numel
()
>
0
?
wcscales
:
Tensor
{},
out_q
,
out_k
,
out_v
,
numTokens
);
kernels
::
gemm_w4a4
(
qact
.
act
,
qweight
,
out
,
{},
qact
.
ascales
,
wscales
,
{},
pool
,
qact
.
lora_act
,
this
->
lora_up
,
{},
{},
norm_q
,
norm_k
,
rotary_emb
,
this
->
bias
,
{},
{},
{},
qact
.
is_unsigned
,
this
->
lora_scales
,
false
,
use_fp4
,
*
this
->
wtscale
.
data_ptr
<
float
>
(),
wcscales
.
numel
()
>
0
?
wcscales
:
Tensor
{},
out_q
,
out_k
,
out_v
,
numTokens
);
debug
(
"gemm.out"
,
out
);
#else
const
int
M
=
(
int
)
qact
.
act
.
numel
()
/
qact
.
act
.
shape
[
-
1
];
kernels
::
gemm_w4a4
(
qact
.
act
,
qweight
,
out
,
{},
qact
.
ascales
,
wscales
,
{},
pool
,
{},
{},
{},
{},
norm_q
,
norm_k
,
rotary_emb
,
this
->
bias
,
{},
qact
.
is_unsigned
,
this
->
lora_scales
);
kernels
::
gemm_w4a4
(
qact
.
act
,
qweight
,
out
,
{},
qact
.
ascales
,
wscales
,
{},
pool
,
{},
{},
{},
{},
norm_q
,
norm_k
,
rotary_emb
,
this
->
bias
,
{},
qact
.
is_unsigned
,
this
->
lora_scales
);
nvtxRangePushA
(
"LoraUp"
);
...
...
@@ -216,10 +248,12 @@ void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor
static
const
half
zero
=
0.0
;
// lora_up: [M, R] * [OC, R] => [M, OC]
// cublas view: [OC, R] * [M, R]^T
checkCUBLAS
(
cublasHgemm
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
this
->
out_features
,
M
,
this
->
lora_rank
,
checkCUBLAS
(
cublasHgemm
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
this
->
out_features
,
M
,
this
->
lora_rank
,
&
one
,
this
->
lora_up
.
data_ptr
<
half
>
(),
this
->
lora_rank
,
...
...
@@ -233,7 +267,8 @@ void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor
#endif
}
std
::
variant
<
Tensor
,
GEMM_W4A4
::
QuantizedActivation
>
GEMM_W4A4
::
forward_quant
(
QuantizedActivation
qact
,
FuseOptions
fuse
,
GEMM_W4A4
*
nextGEMM
)
{
std
::
variant
<
Tensor
,
GEMM_W4A4
::
QuantizedActivation
>
GEMM_W4A4
::
forward_quant
(
QuantizedActivation
qact
,
FuseOptions
fuse
,
GEMM_W4A4
*
nextGEMM
)
{
Tensor
out
;
QuantizedActivation
qout
;
...
...
@@ -280,11 +315,35 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
}
#endif
kernels
::
gemm_w4a4
(
qact
.
act
,
qweight
,
out
,
qout
.
act
,
qact
.
ascales
,
wscales
,
qout
.
ascales
,
{},
qact
.
lora_act
,
this
->
lora_up
,
next_lora
,
qout
.
lora_act
,
{},
{},
{},
this
->
bias
,
next_smooth
,
{},
{},
qact
.
is_unsigned
,
this
->
lora_scales
,
fuse
==
FuseOptions
::
SILU
,
use_fp4
,
*
this
->
wtscale
.
data_ptr
<
float
>
(),
wcscales
.
numel
()
>
0
?
wcscales
:
Tensor
{},
{},
{},
{},
0
);
kernels
::
gemm_w4a4
(
qact
.
act
,
qweight
,
out
,
qout
.
act
,
qact
.
ascales
,
wscales
,
qout
.
ascales
,
{},
qact
.
lora_act
,
this
->
lora_up
,
next_lora
,
qout
.
lora_act
,
{},
{},
{},
this
->
bias
,
next_smooth
,
{},
{},
qact
.
is_unsigned
,
this
->
lora_scales
,
fuse
==
FuseOptions
::
SILU
,
use_fp4
,
*
this
->
wtscale
.
data_ptr
<
float
>
(),
wcscales
.
numel
()
>
0
?
wcscales
:
Tensor
{},
{},
{},
{},
0
);
if
(
fuse
==
FuseOptions
::
EMPTY
||
fuse
==
FuseOptions
::
SILU
)
{
debug
(
"gemm.out"
,
out
);
...
...
@@ -294,7 +353,6 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
debug
(
"gemm.lora_act_out"
,
qout
.
lora_act
);
}
#else
if
(
!
out
.
valid
())
{
auto
shape
=
TensorShape
(
qact
.
act
.
shape
.
dataExtent
);
...
...
@@ -302,7 +360,25 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
out
=
Tensor
::
allocate
(
shape
,
Tensor
::
FP16
,
qweight
.
device
());
}
kernels
::
gemm_w4a4
(
qact
.
act
,
qweight
,
out
,
qout
.
act
,
qact
.
ascales
,
wscales
,
qout
.
ascales
,
{},
{},
{},
{},
{},
{},
{},
{},
this
->
bias
,
next_smooth
,
qact
.
is_unsigned
,
this
->
lora_scales
);
kernels
::
gemm_w4a4
(
qact
.
act
,
qweight
,
out
,
qout
.
act
,
qact
.
ascales
,
wscales
,
qout
.
ascales
,
{},
{},
{},
{},
{},
{},
{},
{},
this
->
bias
,
next_smooth
,
qact
.
is_unsigned
,
this
->
lora_scales
);
nvtxRangePushA
(
"LoraUp"
);
...
...
@@ -312,10 +388,12 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
// lora_up: [M, R] * [OC, R]^T => [M, OC]
// cublas view: [R, OC]^T * [R, M] => [OC, M]
// lora_up layout wrong?
checkCUBLAS
(
cublasHgemm
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
this
->
out_features
,
M
,
this
->
lora_rank
,
checkCUBLAS
(
cublasHgemm
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
this
->
out_features
,
M
,
this
->
lora_rank
,
&
one
,
this
->
lora_up
.
data_ptr
<
half
>
(),
this
->
lora_rank
,
...
...
@@ -332,10 +410,12 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
// IC is for next lora (OC of this layer)
// lora_down: [M, IC] * [IC, R] => [M, R]
// cublas view: [R, IC] * [IC, M] => [R, M]
checkCUBLAS
(
cublasHgemm
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
this
->
lora_rank
,
M
,
this
->
out_features
,
checkCUBLAS
(
cublasHgemm
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
this
->
lora_rank
,
M
,
this
->
out_features
,
&
one
,
next_lora
.
data_ptr
<
half
>
(),
this
->
lora_rank
,
...
...
@@ -383,7 +463,8 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
debug
(
"quantize.x"
,
x
);
debug
(
"quantize.smooth"
,
this
->
smooth
);
kernels
::
quantize_w4a4_act_fuse_lora
(
x
,
qact
.
act
,
qact
.
ascales
,
this
->
lora_down
,
qact
.
lora_act
,
this
->
smooth
,
fuse_glu
,
use_fp4
);
kernels
::
quantize_w4a4_act_fuse_lora
(
x
,
qact
.
act
,
qact
.
ascales
,
this
->
lora_down
,
qact
.
lora_act
,
this
->
smooth
,
fuse_glu
,
use_fp4
);
debug
(
"quantize.qact"
,
qact
.
act
);
debug
(
"quantize.ascales"
,
qact
.
ascales
);
...
...
@@ -396,10 +477,12 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
// lora_down: [M, IC] * [IC, R] => [M, R]
// cublas view: [R, IC] * [IC, M]
checkCUBLAS
(
cublasHgemm
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
this
->
lora_rank
,
M
,
this
->
in_features
,
checkCUBLAS
(
cublasHgemm
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
this
->
lora_rank
,
M
,
this
->
in_features
,
&
one
,
lora_down
.
data_ptr
<
half
>
(),
this
->
lora_rank
,
...
...
@@ -418,18 +501,13 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
return
qact
;
}
GEMM_W8A8
::
GEMM_W8A8
(
int
in_features
,
int
out_features
,
bool
bias
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
in_features
(
in_features
),
out_features
(
out_features
),
dtype
(
dtype
)
{
GEMM_W8A8
::
GEMM_W8A8
(
int
in_features
,
int
out_features
,
bool
bias
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
in_features
(
in_features
),
out_features
(
out_features
),
dtype
(
dtype
)
{
this
->
qweight
=
Tensor
::
allocate
({
out_features
,
in_features
},
Tensor
::
INT8
,
device
);
this
->
wscales
=
Tensor
::
allocate
({
out_features
},
dtype
,
device
);
this
->
bias
=
bias
?
Tensor
::
allocate
({
out_features
},
dtype
,
device
,
true
)
:
Tensor
{};
registerParams
(
qweight
,
"qweight"
,
ParamFlags
::
LazyLoad
)
(
wscales
,
"wscales"
)
(
this
->
bias
,
"bias"
)
;
registerParams
(
qweight
,
"qweight"
,
ParamFlags
::
LazyLoad
)(
wscales
,
"wscales"
)(
this
->
bias
,
"bias"
);
}
GEMM_W8A8
::
QuantizedActivation
GEMM_W8A8
::
quantize
(
Tensor
x
,
bool
fuse_glu
)
{
...
...
@@ -461,16 +539,11 @@ Tensor GEMM_W8A8::forward_quant(QuantizedActivation qact) {
return
out
;
}
DWCONV
::
DWCONV
(
int
in_features
,
bool
use_bias
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
in_features
(
in_features
)
{
DWCONV
::
DWCONV
(
int
in_features
,
bool
use_bias
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
in_features
(
in_features
)
{
this
->
weight
=
Tensor
::
allocate
({
in_features
,
3
,
3
,
1
},
dtype
,
device
);
this
->
bias
=
use_bias
?
Tensor
::
allocate
({
in_features
},
dtype
,
device
)
:
Tensor
{};
registerParams
(
this
->
weight
,
"weight"
)
(
this
->
bias
,
"bias"
)
;
registerParams
(
this
->
weight
,
"weight"
)(
this
->
bias
,
"bias"
);
}
Tensor
DWCONV
::
forward
(
Tensor
x
)
{
...
...
src/Linear.h
View file @
37a27712
...
...
@@ -37,6 +37,7 @@ public:
float
lora_scale
;
const
Device
device
;
public:
Tensor
qweight
;
Tensor
wscales
;
...
...
@@ -69,12 +70,18 @@ public:
Tensor
forward
(
Tensor
x
);
Tensor
forward_silu
(
Tensor
x
);
std
::
variant
<
Tensor
,
QuantizedActivation
>
forward
(
Tensor
x
,
FuseOptions
fuse
,
GEMM_W4A4
*
nextGEMM
=
nullptr
);
void
forward
(
Tensor
x
,
Tensor
out
,
Tensor
pool
=
{},
Tensor
norm_q
=
{},
Tensor
norm_k
=
{},
Tensor
rotary_emb
=
{},
Tensor
out_q
=
{},
Tensor
out_k
=
{},
Tensor
out_v
=
{},
int
numTokens
=
0
);
std
::
variant
<
Tensor
,
QuantizedActivation
>
forward_quant
(
QuantizedActivation
qact
,
FuseOptions
fuse
,
GEMM_W4A4
*
nextGEMM
=
nullptr
);
void
forward
(
Tensor
x
,
Tensor
out
,
Tensor
pool
=
{},
Tensor
norm_q
=
{},
Tensor
norm_k
=
{},
Tensor
rotary_emb
=
{},
Tensor
out_q
=
{},
Tensor
out_k
=
{},
Tensor
out_v
=
{},
int
numTokens
=
0
);
std
::
variant
<
Tensor
,
QuantizedActivation
>
forward_quant
(
QuantizedActivation
qact
,
FuseOptions
fuse
,
GEMM_W4A4
*
nextGEMM
=
nullptr
);
Tensor
forward_quant
(
QuantizedActivation
qact
);
public:
...
...
@@ -118,13 +125,16 @@ public:
Tensor
act
;
Tensor
ascales
;
};
public:
GEMM_W8A8
(
int
in_features
,
int
out_features
,
bool
bias
,
Tensor
::
ScalarType
dtype
,
Device
device
);
public:
QuantizedActivation
quantize
(
Tensor
x
,
bool
fuse_glu
);
Tensor
forward_quant
(
QuantizedActivation
qact
);
Tensor
forward
(
Tensor
x
)
{
return
forward_quant
(
quantize
(
x
,
false
));
}
Tensor
forward
(
Tensor
x
)
{
return
forward_quant
(
quantize
(
x
,
false
));
}
public:
const
int
in_features
;
...
...
src/Module.cpp
View file @
37a27712
src/Module.h
View file @
37a27712
...
...
@@ -108,7 +108,8 @@ public:
dst
=
Tensor
::
allocate
(
lazy
.
shape
,
lazy
.
type
,
lazy
.
device
);
if
(
!
src
.
valid
()
&&
!
checkFlag
(
param
.
flags
,
ParamFlags
::
Optional
))
{
throw
std
::
runtime_error
(
spdlog
::
fmt_lib
::
format
(
"Lazy load: Tensor {} has no src"
,
m
->
getPrefix
()
+
key
));
throw
std
::
runtime_error
(
spdlog
::
fmt_lib
::
format
(
"Lazy load: Tensor {} has no src"
,
m
->
getPrefix
()
+
key
));
}
m
->
loadParam
(
key
,
dst
,
src
);
}
...
...
@@ -127,14 +128,10 @@ public:
});
}
void
setLazyLoad
(
bool
val
)
{
traverse
([
val
](
Module
*
m
)
{
m
->
enabledLazyLoad
=
val
;
});
traverse
([
val
](
Module
*
m
)
{
m
->
enabledLazyLoad
=
val
;
});
}
void
setAutoCastFP16
(
bool
val
)
{
traverse
([
val
](
Module
*
m
)
{
m
->
enabledAutoCastFP16
=
val
;
});
traverse
([
val
](
Module
*
m
)
{
m
->
enabledAutoCastFP16
=
val
;
});
}
protected:
...
...
@@ -143,7 +140,8 @@ protected:
Tensor
::
FP16
,
Tensor
::
BF16
,
};
if
(
enabledAutoCastFP16
&&
dst
.
scalar_type
()
!=
src
.
scalar_type
()
&&
whitelist
.
contains
(
dst
.
scalar_type
())
&&
whitelist
.
contains
(
src
.
scalar_type
()))
{
if
(
enabledAutoCastFP16
&&
dst
.
scalar_type
()
!=
src
.
scalar_type
()
&&
whitelist
.
contains
(
dst
.
scalar_type
())
&&
whitelist
.
contains
(
src
.
scalar_type
()))
{
copyWithCast
(
dst
,
src
);
}
else
{
dst
.
copy_
(
src
);
...
...
@@ -227,8 +225,7 @@ struct LayerOffloadHelper {
std
::
unique_ptr
<
CUDAEventWrapper
>
eventLoadDone
;
LayerOffloadHelper
(
bool
offload
,
int
numLayers
,
func_t
funcCompute
,
func_t
funcLoad
,
func_t
funcUnload
)
:
offload
(
offload
),
numLayers
(
numLayers
),
funcCompute
(
funcCompute
),
funcLoad
(
funcLoad
),
funcUnload
(
funcUnload
)
{
:
offload
(
offload
),
numLayers
(
numLayers
),
funcCompute
(
funcCompute
),
funcLoad
(
funcLoad
),
funcUnload
(
funcUnload
)
{
if
(
offload
)
{
streamCompute
=
std
::
make_unique
<
CUDAStreamWrapper
>
();
streamLoad
=
std
::
make_unique
<
CUDAStreamWrapper
>
();
...
...
@@ -305,11 +302,11 @@ private:
}
}
#ifdef _WIN32
#ifdef _WIN32
return
true
;
#else
#else
return
false
;
#endif
#endif
}
void
workaroundFlush
()
{
if
(
!
needWorkaround
)
{
...
...
src/SanaModel.cpp
View file @
37a27712
...
...
@@ -10,18 +10,11 @@
using
spdlog
::
fmt_lib
::
format
;
using
namespace
nunchaku
;
SanaLinearAttention
::
SanaLinearAttention
(
int
dim
,
bool
bias
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
dim_pad
(
ceilDiv
(
dim
,
128
)
*
128
),
qkv_proj
(
dim
,
dim_pad
*
3
,
bias
,
use_fp4
,
dtype
,
device
),
out_proj
(
dim_pad
,
dim
,
bias
,
use_fp4
,
dtype
,
device
),
pag_to_v
(
std
::
nullopt
)
{
registerChildren
(
qkv_proj
,
"qkv_proj"
)
(
out_proj
,
"out_proj"
)
;
SanaLinearAttention
::
SanaLinearAttention
(
int
dim
,
bool
bias
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
dim_pad
(
ceilDiv
(
dim
,
128
)
*
128
),
qkv_proj
(
dim
,
dim_pad
*
3
,
bias
,
use_fp4
,
dtype
,
device
),
out_proj
(
dim_pad
,
dim
,
bias
,
use_fp4
,
dtype
,
device
),
pag_to_v
(
std
::
nullopt
)
{
registerChildren
(
qkv_proj
,
"qkv_proj"
)(
out_proj
,
"out_proj"
);
if
(
pag
)
{
pag_to_v
.
emplace
(
dim
,
dim_pad
,
bias
,
use_fp4
,
dtype
,
device
);
...
...
@@ -57,21 +50,35 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
Tensor
q
=
Tensor
::
allocate
({
batch_size
,
num_tokens_pad
,
dim_pad
},
x
.
dtype
(),
x
.
device
());
Tensor
vk
=
Tensor
::
allocate
({
batch_size
,
num_heads
,
HEAD_DIM
+
1
,
HEAD_DIM
},
Tensor
::
FP32
,
x
.
device
());
kernels
::
gemm_w4a4
(
qact
.
act
,
kernels
::
gemm_w4a4
(
qact
.
act
,
qkv_proj
.
qweight
,
{},
{},
qact
.
ascales
,
qkv_proj
.
wscales
,
{},
{},
qact
.
lora_act
,
qkv_proj
.
lora_up
,
{},
{},
{},
{},
{},
qkv_proj
.
bias
,
{},
vk
,
q
,
qact
.
is_unsigned
,
qkv_proj
.
lora_scales
,
false
,
{},
{},
qact
.
lora_act
,
qkv_proj
.
lora_up
,
{},
{},
{},
{},
{},
qkv_proj
.
bias
,
{},
vk
,
q
,
qact
.
is_unsigned
,
qkv_proj
.
lora_scales
,
false
,
qkv_proj
.
use_fp4
,
*
qkv_proj
.
wtscale
.
data_ptr
<
float
>
(),
qkv_proj
.
wcscales
.
numel
()
>
0
?
qkv_proj
.
wcscales
:
Tensor
{},
{},
{},
{},
0
);
{},
{},
{},
0
);
debug
(
"vk"
,
vk
);
debug
(
"q"
,
q
);
...
...
@@ -88,7 +95,6 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
q
=
q_unpad
;
}
// kernels::gemm_w8a8_fuse_litela(qact.act, qkv.qweight, q, vk, qact.ascales, qkv.wscales);
// return out_proj.forward(q);
...
...
@@ -129,17 +135,13 @@ Tensor SanaLinearAttention::forward_pag(Tensor x, bool cfg) {
return
out
;
}
MultiHeadCrossAttention
::
MultiHeadCrossAttention
(
int
num_heads
,
int
head_dim
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
num_heads
(
num_heads
),
head_dim
(
head_dim
),
MultiHeadCrossAttention
::
MultiHeadCrossAttention
(
int
num_heads
,
int
head_dim
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
num_heads
(
num_heads
),
head_dim
(
head_dim
),
q_linear
(
num_heads
*
head_dim
,
num_heads
*
head_dim
,
true
,
use_fp4
,
dtype
,
device
),
kv_linear
(
num_heads
*
head_dim
,
num_heads
*
head_dim
*
2
,
true
,
dtype
,
device
),
out_proj
(
num_heads
*
head_dim
,
num_heads
*
head_dim
,
true
,
use_fp4
,
dtype
,
device
)
{
registerChildren
(
q_linear
,
"q_linear"
)
(
kv_linear
,
"kv_linear"
)
(
out_proj
,
"out_proj"
)
;
out_proj
(
num_heads
*
head_dim
,
num_heads
*
head_dim
,
true
,
use_fp4
,
dtype
,
device
)
{
registerChildren
(
q_linear
,
"q_linear"
)(
kv_linear
,
"kv_linear"
)(
out_proj
,
"out_proj"
);
}
Tensor
MultiHeadCrossAttention
::
forward
(
Tensor
x
,
Tensor
cond
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
)
{
...
...
@@ -161,16 +163,22 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens
Tensor
k
=
kv
.
slice
(
1
,
0
,
num_heads
);
Tensor
v
=
kv
.
slice
(
1
,
num_heads
,
num_heads
*
2
);
Tensor
attn_output
=
mha_varlen_fwd
(
q
,
k
,
v
,
cu_seqlens_img
,
cu_seqlens_txt
,
num_tokens_img
,
num_tokens_txt
,
Tensor
attn_output
=
mha_varlen_fwd
(
q
,
k
,
v
,
cu_seqlens_img
,
cu_seqlens_txt
,
num_tokens_img
,
num_tokens_txt
,
0.0
f
,
pow
(
q
.
shape
[
-
1
],
(
-
0.5
)),
false
,
false
,
-
1
,
-
1
,
false
).
front
().
view
({
batch_size
,
num_tokens_img
,
num_heads
*
head_dim
});
false
,
false
,
-
1
,
-
1
,
false
)
.
front
()
.
view
({
batch_size
,
num_tokens_img
,
num_heads
*
head_dim
});
// Tensor attn_output = mha_fwd(q, k, v,
// 0.0f,
...
...
@@ -181,17 +189,13 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens
return
out_proj
.
forward
(
attn_output
);
}
SanaGLUMBConv
::
SanaGLUMBConv
(
int
in_features
,
int
hidden_features
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
in_features
(
in_features
),
hidden_features
(
hidden_features
),
SanaGLUMBConv
::
SanaGLUMBConv
(
int
in_features
,
int
hidden_features
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
in_features
(
in_features
),
hidden_features
(
hidden_features
),
inverted_conv
(
in_features
,
hidden_features
*
2
,
true
,
use_fp4
,
dtype
,
device
),
depth_conv
(
hidden_features
*
2
,
true
,
dtype
,
device
),
point_conv
(
hidden_features
,
in_features
,
false
,
use_fp4
,
dtype
,
device
)
{
registerChildren
(
inverted_conv
,
"inverted_conv"
)
(
depth_conv
,
"depth_conv"
)
(
point_conv
,
"point_conv"
)
;
point_conv
(
hidden_features
,
in_features
,
false
,
use_fp4
,
dtype
,
device
)
{
registerChildren
(
inverted_conv
,
"inverted_conv"
)(
depth_conv
,
"depth_conv"
)(
point_conv
,
"point_conv"
);
}
Tensor
SanaGLUMBConv
::
forward
(
Tensor
x
,
int
H
,
int
W
)
{
...
...
@@ -208,28 +212,34 @@ Tensor SanaGLUMBConv::forward(Tensor x, int H, int W) {
return
point_conv
.
forward_quant
(
qact
);
}
SanaLinearTransformerBlock
::
SanaLinearTransformerBlock
(
int
hidden_size
,
int
intermediate_size
,
int
num_cross_attention_heads
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
hidden_size
(
hidden_size
),
num_cross_attention_heads
(
num_cross_attention_heads
),
SanaLinearTransformerBlock
::
SanaLinearTransformerBlock
(
int
hidden_size
,
int
intermediate_size
,
int
num_cross_attention_heads
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
hidden_size
(
hidden_size
),
num_cross_attention_heads
(
num_cross_attention_heads
),
attn
(
hidden_size
,
false
,
pag
,
use_fp4
,
dtype
,
device
),
cross_attn
(
num_cross_attention_heads
,
hidden_size
/
num_cross_attention_heads
,
use_fp4
,
dtype
,
device
),
ff
(
hidden_size
,
intermediate_size
,
use_fp4
,
dtype
,
device
),
norm1
(
hidden_size
,
1e-6
,
false
,
dtype
,
device
),
norm2
(
hidden_size
,
1e-6
,
false
,
dtype
,
device
)
{
ff
(
hidden_size
,
intermediate_size
,
use_fp4
,
dtype
,
device
),
norm1
(
hidden_size
,
1e-6
,
false
,
dtype
,
device
),
norm2
(
hidden_size
,
1e-6
,
false
,
dtype
,
device
)
{
this
->
scale_shift_table
=
Tensor
::
allocate
({
6
,
hidden_size
},
dtype
,
device
);
registerChildren
(
attn
,
"attn"
)
(
cross_attn
,
"cross_attn"
)
(
ff
,
"ff"
)
;
registerChildren
(
attn
,
"attn"
)(
cross_attn
,
"cross_attn"
)(
ff
,
"ff"
);
registerParams
(
this
->
scale_shift_table
,
"scale_shift_table"
)
;
registerParams
(
this
->
scale_shift_table
,
"scale_shift_table"
);
}
Tensor
SanaLinearTransformerBlock
::
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
)
{
Tensor
SanaLinearTransformerBlock
::
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
)
{
nvtxRangePushA
(
"SanaLinearTransformerBlock"
);
...
...
@@ -311,9 +321,7 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_
return
hidden_states
;
}
SanaModel
::
SanaModel
(
SanaConfig
config
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
config
(
config
)
{
SanaModel
::
SanaModel
(
SanaConfig
config
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
config
(
config
)
{
const
int
inner_dim
=
config
.
num_attention_heads
*
config
.
attention_head_dim
;
for
(
int
i
=
0
;
i
<
config
.
num_layers
;
i
++
)
{
transformer_blocks
.
push_back
(
std
::
make_unique
<
SanaLinearTransformerBlock
>
(
...
...
@@ -322,20 +330,34 @@ SanaModel::SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device)
config
.
num_cross_attention_heads
,
std
::
find
(
config
.
pag_layers
.
begin
(),
config
.
pag_layers
.
end
(),
i
)
!=
config
.
pag_layers
.
end
(),
config
.
use_fp4
,
dtype
,
device
));
dtype
,
device
));
registerChildren
(
*
transformer_blocks
.
back
(),
format
(
"transformer_blocks.{}"
,
i
));
}
}
Tensor
SanaModel
::
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
,
bool
skip_first_layer
)
{
Tensor
SanaModel
::
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
,
bool
skip_first_layer
)
{
for
(
int
i
=
(
skip_first_layer
?
1
:
0
);
i
<
config
.
num_layers
;
i
++
)
{
auto
&&
block
=
transformer_blocks
[
i
];
hidden_states
=
block
->
forward
(
hidden_states
,
encoder_hidden_states
,
timestep
,
cu_seqlens_img
,
cu_seqlens_txt
,
H
,
W
,
pag
&&
std
::
find
(
config
.
pag_layers
.
begin
(),
config
.
pag_layers
.
end
(),
i
)
!=
config
.
pag_layers
.
end
(),
cfg
);
hidden_states
=
block
->
forward
(
hidden_states
,
encoder_hidden_states
,
timestep
,
cu_seqlens_img
,
cu_seqlens_txt
,
H
,
W
,
pag
&&
std
::
find
(
config
.
pag_layers
.
begin
(),
config
.
pag_layers
.
end
(),
i
)
!=
config
.
pag_layers
.
end
(),
cfg
);
}
return
hidden_states
;
}
src/SanaModel.h
View file @
37a27712
...
...
@@ -57,9 +57,23 @@ private:
class
SanaLinearTransformerBlock
:
public
Module
{
public:
SanaLinearTransformerBlock
(
int
hidden_size
,
int
intermediate_size
,
int
num_cross_attention_heads
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
);
SanaLinearTransformerBlock
(
int
hidden_size
,
int
intermediate_size
,
int
num_cross_attention_heads
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
);
public:
const
int
hidden_size
;
...
...
@@ -89,7 +103,16 @@ struct SanaConfig {
class
SanaModel
:
public
Module
{
public:
SanaModel
(
SanaConfig
config
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
,
bool
skip_first_layer
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
,
bool
skip_first_layer
);
public:
const
SanaConfig
config
;
...
...
src/Serialization.cpp
View file @
37a27712
...
...
@@ -3,7 +3,6 @@
#include <nlohmann/json.hpp>
#include <mio/mmap.hpp>
using
json
=
nlohmann
::
json
;
using
spdlog
::
fmt_lib
::
format
;
...
...
@@ -121,13 +120,16 @@ SafeTensors::SafeTensors(const std::string &filename) {
auto
methodPrivate
=
[
&
]()
{
this
->
mapped
=
std
::
make_unique
<
MMapImplPrivate
>
(
filename
);
checkCUDA
(
cudaHostRegister
(
const_cast
<
char
*>
(
this
->
mapped
->
data
()),
this
->
mapped
->
size
(),
cudaHostRegisterPortable
));
checkCUDA
(
cudaHostRegister
(
const_cast
<
char
*>
(
this
->
mapped
->
data
()),
this
->
mapped
->
size
(),
cudaHostRegisterPortable
));
this
->
hostRegistered
=
true
;
this
->
memoryPinned
=
true
;
};
auto
methodMio
=
[
&
]()
{
this
->
mapped
=
std
::
make_unique
<
MMapImplMio
>
(
filename
);
checkCUDA
(
cudaHostRegister
(
const_cast
<
char
*>
(
this
->
mapped
->
data
()),
this
->
mapped
->
size
(),
cudaHostRegisterPortable
|
cudaHostRegisterReadOnly
));
checkCUDA
(
cudaHostRegister
(
const_cast
<
char
*>
(
this
->
mapped
->
data
()),
this
->
mapped
->
size
(),
cudaHostRegisterPortable
|
cudaHostRegisterReadOnly
));
this
->
hostRegistered
=
true
;
this
->
memoryPinned
=
true
;
};
...
...
@@ -135,15 +137,13 @@ SafeTensors::SafeTensors(const std::string &filename) {
this
->
mapped
=
std
::
make_unique
<
MMapImplRead
>
(
filename
,
true
);
this
->
memoryPinned
=
true
;
};
auto
methodReadNopin
=
[
&
]()
{
this
->
mapped
=
std
::
make_unique
<
MMapImplRead
>
(
filename
,
false
);
};
auto
methodReadNopin
=
[
&
]()
{
this
->
mapped
=
std
::
make_unique
<
MMapImplRead
>
(
filename
,
false
);
};
const
std
::
map
<
std
::
string
,
std
::
function
<
void
()
>>
methods
=
{
{
"PRIVATE"
,
methodPrivate
},
{
"MIO"
,
methodMio
},
{
"READ"
,
methodRead
},
{
"READNOPIN"
,
methodReadNopin
},
{
"PRIVATE"
,
methodPrivate
},
{
"MIO"
,
methodMio
},
{
"READ"
,
methodRead
},
{
"READNOPIN"
,
methodReadNopin
},
};
auto
tryMethod
=
[
&
](
std
::
string
name
)
{
...
...
@@ -168,7 +168,6 @@ SafeTensors::SafeTensors(const std::string &filename) {
#else
tryMethod
(
"MIO"
)
||
tryMethod
(
"READ"
)
||
tryMethod
(
"READNOPIN"
);
#endif
}
if
(
!
this
->
mapped
)
{
...
...
@@ -192,19 +191,20 @@ SafeTensors::~SafeTensors() {
void
SafeTensors
::
parseHeader
()
{
static
const
std
::
unordered_map
<
std
::
string
,
Tensor
::
ScalarType
>
mapDType
=
{
{
"BF16"
,
Tensor
::
BF16
},
{
"F16"
,
Tensor
::
FP16
},
{
"F32"
,
Tensor
::
FP32
},
{
"I8"
,
Tensor
::
INT8
},
{
"I32"
,
Tensor
::
INT32
},
{
"I64"
,
Tensor
::
INT64
},
{
"F8_E4M3"
,
Tensor
::
FP8_E4M3
},
{
"F8_E5M2"
,
Tensor
::
FP8_E5M2
},
{
"BF16"
,
Tensor
::
BF16
},
{
"F16"
,
Tensor
::
FP16
},
{
"F32"
,
Tensor
::
FP32
},
{
"I8"
,
Tensor
::
INT8
},
{
"I32"
,
Tensor
::
INT32
},
{
"I64"
,
Tensor
::
INT64
},
{
"F8_E4M3"
,
Tensor
::
FP8_E4M3
},
{
"F8_E5M2"
,
Tensor
::
FP8_E5M2
},
};
auto
check
=
[](
bool
cond
,
std
::
source_location
location
=
std
::
source_location
::
current
())
{
if
(
!
cond
)
{
throw
std
::
runtime_error
(
format
(
"Safetensors check failed at {}:{}"
,
location
.
file_name
(),
location
.
line
()));
throw
std
::
runtime_error
(
format
(
"Safetensors check failed at {}:{}"
,
location
.
file_name
(),
location
.
line
()));
}
};
...
...
@@ -222,7 +222,8 @@ void SafeTensors::parseHeader() {
continue
;
}
auto
dtype
=
mapDType
.
at
(
info
[
"dtype"
].
get
<
std
::
string
>
());;
auto
dtype
=
mapDType
.
at
(
info
[
"dtype"
].
get
<
std
::
string
>
());
;
auto
shape
=
info
[
"shape"
].
get
<
std
::
vector
<
int
>>
();
auto
data_offsets
=
info
[
"data_offsets"
].
get
<
std
::
vector
<
uint64_t
>>
();
...
...
@@ -258,7 +259,8 @@ Tensor SafeTensors::getTensor(const std::string &key) {
std
::
shared_ptr
<
BufferMMap
>
buffer
=
info
.
buffer
.
lock
();
if
(
!
buffer
)
{
buffer
=
std
::
make_shared
<
BufferMMap
>
(
const_cast
<
char
*>
(
this
->
mapped
->
data
()
+
info
.
offset
),
info
.
length
,
shared_from_this
());
buffer
=
std
::
make_shared
<
BufferMMap
>
(
const_cast
<
char
*>
(
this
->
mapped
->
data
()
+
info
.
offset
),
info
.
length
,
shared_from_this
());
info
.
buffer
=
buffer
;
}
...
...
@@ -269,4 +271,3 @@ Tensor SafeTensors::getTensor(const std::string &key) {
return
result
;
}
src/Serialization.h
View file @
37a27712
...
...
@@ -13,8 +13,8 @@ public:
// if (ret == cudaSuccess) {
// this->registered = true;
// } else {
// log(std::format("cudaHostRegister failed at {:p} (size={}): {}", ptr, size,
cudaGetErrorString(cudaGetLastError())));
// this->registered = false;
// log(std::format("cudaHostRegister failed at {:p} (size={}): {}", ptr, size,
//
cudaGetErrorString(cudaGetLastError())));
this->registered = false;
// }
}
virtual
~
BufferMMap
()
{
...
...
@@ -22,6 +22,7 @@ public:
// checkCUDA(cudaHostUnregister(ptr));
// }
}
public:
std
::
shared_ptr
<
void
>
parent
;
// bool registered;
...
...
src/Tensor.h
View file @
37a27712
...
...
@@ -3,10 +3,7 @@
#include "common.h"
struct
Device
{
enum
Type
{
INVALID_DEVICE_TYPE
=
0
,
CPU
,
CUDA
};
enum
Type
{
INVALID_DEVICE_TYPE
=
0
,
CPU
,
CUDA
};
Type
type
=
INVALID_DEVICE_TYPE
;
int
idx
=
0
;
...
...
@@ -24,20 +21,28 @@ class Buffer : public std::enable_shared_from_this<Buffer> {
public:
virtual
~
Buffer
()
{}
void
*
getPtr
()
{
return
ptr
;
}
void
*
getPtr
()
{
return
ptr
;
}
template
<
typename
T
>
T
*
getPtr
()
{
return
reinterpret_cast
<
T
*>
(
ptr
);
}
T
*
getPtr
()
{
return
reinterpret_cast
<
T
*>
(
ptr
);
}
size_t
getSize
()
{
return
size
;
}
Device
getDevice
()
{
return
device
;
}
size_t
getSize
()
{
return
size
;
}
Device
getDevice
()
{
return
device
;
}
virtual
bool
isAsyncBuffer
()
{
return
false
;
}
protected:
template
<
typename
Derived
>
template
<
typename
Derived
>
std
::
shared_ptr
<
Derived
>
shared_from_base
()
{
return
std
::
static_pointer_cast
<
Derived
>
(
shared_from_this
());
}
...
...
@@ -213,23 +218,31 @@ struct TensorShape {
}
};
class
Tensor
{
public:
enum
ScalarType
{
INVALID_SCALAR_TYPE
,
INT8
,
INT16
,
INT32
,
INT64
,
FP16
,
FP32
,
BF16
,
FP8_E4M3
,
FP8_E5M2
,
INT8
,
INT16
,
INT32
,
INT64
,
FP16
,
FP32
,
BF16
,
FP8_E4M3
,
FP8_E5M2
,
};
struct
TensorOptions
{
Device
device_
;
ScalarType
dtype_
;
Device
device
()
const
{
return
device_
;
}
ScalarType
dtype
()
const
{
return
dtype_
;
}
Device
device
()
const
{
return
device_
;
}
ScalarType
dtype
()
const
{
return
dtype_
;
}
TensorOptions
device
(
Device
dev
)
const
{
TensorOptions
result
(
*
this
);
...
...
@@ -244,43 +257,82 @@ public:
};
static
const
std
::
map
<
ScalarType
,
size_t
>
scalarSize
;
public:
TensorShape
shape
;
ScalarType
scalarType
;
std
::
shared_ptr
<
Buffer
>
buffer
;
public:
bool
valid
()
const
{
return
shape
.
dataExtent
.
size
()
>
0
;
}
int
size
(
int
dim
)
const
{
return
shape
[
dim
];
}
bool
is_contiguous
()
const
{
return
shape
.
is_contiguous
();
}
std
::
vector
<
int
>
sizes
()
const
{
return
shape
.
dataExtent
;
}
bool
valid
()
const
{
return
shape
.
dataExtent
.
size
()
>
0
;
}
int
size
(
int
dim
)
const
{
return
shape
[
dim
];
}
bool
is_contiguous
()
const
{
return
shape
.
is_contiguous
();
}
std
::
vector
<
int
>
sizes
()
const
{
return
shape
.
dataExtent
;
}
bool
is_cuda
()
const
{
return
device
().
type
==
Device
::
CUDA
;
}
bool
is_cuda
()
const
{
return
device
().
type
==
Device
::
CUDA
;
}
TensorOptions
options
()
const
{
return
TensorOptions
{
device
(),
dtype
()};
}
int
get_device
()
const
{
return
device
().
idx
;
}
TensorOptions
options
()
const
{
return
TensorOptions
{
device
(),
dtype
()};
}
int
get_device
()
const
{
return
device
().
idx
;
}
template
<
typename
T
>
T
*
data_ptr
()
{
return
reinterpret_cast
<
T
*>
(
data_ptr
());
}
T
*
data_ptr
()
{
return
reinterpret_cast
<
T
*>
(
data_ptr
());
}
template
<
typename
T
>
const
T
*
data_ptr
()
const
{
return
reinterpret_cast
<
const
T
*>
(
data_ptr
());
}
const
T
*
data_ptr
()
const
{
return
reinterpret_cast
<
const
T
*>
(
data_ptr
());
}
const
void
*
data_ptr
()
const
{
return
buffer
->
getPtr
<
char
>
()
+
shape
.
offset
*
scalar_size
();
}
void
*
data_ptr
()
{
return
buffer
->
getPtr
<
char
>
()
+
shape
.
offset
*
scalar_size
();
}
const
void
*
data_ptr
()
const
{
return
buffer
->
getPtr
<
char
>
()
+
shape
.
offset
*
scalar_size
();
}
void
*
data_ptr
()
{
return
buffer
->
getPtr
<
char
>
()
+
shape
.
offset
*
scalar_size
();
}
Device
device
()
const
{
return
buffer
->
getDevice
();
}
Device
device
()
const
{
return
buffer
->
getDevice
();
}
ScalarType
scalar_type
()
const
{
return
scalarType
;
}
ScalarType
dtype
()
const
{
return
scalar_type
();
}
ScalarType
scalar_type
()
const
{
return
scalarType
;
}
ScalarType
dtype
()
const
{
return
scalar_type
();
}
size_t
stride
(
int
dim
)
const
{
return
shape
.
stride
(
dim
);
}
size_t
stride
(
int
dim
)
const
{
return
shape
.
stride
(
dim
);
}
size_t
numel
()
const
{
return
shape
.
size
();
}
size_t
ndims
()
const
{
return
shape
.
ndims
();
}
size_t
numel
()
const
{
return
shape
.
size
();
}
size_t
ndims
()
const
{
return
shape
.
ndims
();
}
size_t
dim
()
const
{
return
ndims
();
}
size_t
dim
()
const
{
return
ndims
();
}
size_t
scalar_size
()
const
{
return
scalarSize
.
at
(
scalarType
);
}
size_t
scalar_size
()
const
{
return
scalarSize
.
at
(
scalarType
);
}
Tensor
operator
[](
int
idx
)
const
{
assert
(
ndims
()
>
1
);
...
...
@@ -293,7 +345,7 @@ public:
}
template
<
typename
T
>
const
T
&
at
(
const
std
::
vector
<
int
>
&
idx
)
const
{
const
T
&
at
(
const
std
::
vector
<
int
>
&
idx
)
const
{
assert
(
ndims
()
==
idx
.
size
());
int64_t
offset
=
0
;
for
(
size_t
i
=
0
;
i
<
ndims
();
i
++
)
{
...
...
@@ -304,7 +356,7 @@ public:
}
template
<
typename
T
>
T
&
at
(
const
std
::
vector
<
int
>
&
idx
)
{
T
&
at
(
const
std
::
vector
<
int
>
&
idx
)
{
return
const_cast
<
T
&>
(
const_cast
<
const
Tensor
*>
(
this
)
->
at
<
T
>
(
idx
));
}
...
...
@@ -363,7 +415,8 @@ public:
Tensor
&
zero_
()
{
assert
(
this
->
is_contiguous
());
checkCUDA
(
cudaMemsetAsync
(
data_ptr
<
char
>
()
+
shape
.
offset
*
scalar_size
(),
0
,
shape
.
size
()
*
scalar_size
(),
getCurrentCUDAStream
()));
checkCUDA
(
cudaMemsetAsync
(
data_ptr
<
char
>
()
+
shape
.
offset
*
scalar_size
(),
0
,
shape
.
size
()
*
scalar_size
(),
getCurrentCUDAStream
()));
return
*
this
;
}
Tensor
&
copy_
(
Tensor
other
)
{
...
...
@@ -380,23 +433,17 @@ public:
}
if
(
this
->
device
().
type
==
Device
::
CPU
&&
other
.
device
().
type
==
Device
::
CPU
)
{
memcpy
(
data_ptr
<
char
>
(),
other
.
data_ptr
<
char
>
(),
shape
.
size
()
*
scalar_size
()
);
memcpy
(
data_ptr
<
char
>
(),
other
.
data_ptr
<
char
>
(),
shape
.
size
()
*
scalar_size
());
return
*
this
;
}
lockBuffer
(
this
->
buffer
,
getCurrentCUDAStream
());
lockBuffer
(
other
.
buffer
,
getCurrentCUDAStream
());
checkCUDA
(
cudaMemcpyAsync
(
data_ptr
<
char
>
(),
checkCUDA
(
cudaMemcpyAsync
(
data_ptr
<
char
>
(),
other
.
data_ptr
<
char
>
(),
shape
.
size
()
*
scalar_size
(),
getCopyKind
(
this
->
device
(),
other
.
device
()),
getCurrentCUDAStream
()
));
getCurrentCUDAStream
()));
return
*
this
;
}
...
...
@@ -432,7 +479,8 @@ public:
memset
(
result
.
buffer
->
getPtr
(),
0xCC
,
result
.
buffer
->
getSize
());
}
else
if
(
device
.
type
==
Device
::
CUDA
)
{
CUDADeviceContext
ctx
(
device
.
idx
);
checkCUDA
(
cudaMemsetAsync
(
result
.
buffer
->
getPtr
(),
0xCC
,
result
.
buffer
->
getSize
(),
getCurrentCUDAStream
()));
checkCUDA
(
cudaMemsetAsync
(
result
.
buffer
->
getPtr
(),
0xCC
,
result
.
buffer
->
getSize
(),
getCurrentCUDAStream
()));
}
}
...
...
@@ -450,7 +498,8 @@ public:
checkCUDA
(
cudaMemsetAsync
(
result
.
buffer
->
getPtr
(),
1
,
result
.
buffer
->
getSize
(),
getCurrentCUDAStream
()));
return
result
;
}
static
Tensor
allocate_view
(
TensorShape
shape
,
ScalarType
scalarType
,
std
::
shared_ptr
<
Buffer
>
buffer
,
size_t
offset
=
0
)
{
static
Tensor
allocate_view
(
TensorShape
shape
,
ScalarType
scalarType
,
std
::
shared_ptr
<
Buffer
>
buffer
,
size_t
offset
=
0
)
{
Tensor
result
;
result
.
buffer
=
std
::
make_shared
<
BufferView
>
(
buffer
,
offset
,
shape
.
size
()
*
scalarSize
.
at
(
scalarType
));
result
.
scalarType
=
scalarType
;
...
...
@@ -468,13 +517,16 @@ public:
// lockBuffer(this->buffer, getCurrentCUDAStream());
// lockBuffer(result.buffer, getCurrentCUDAStream());
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(), cudaMemcpyDefault, getCurrentCUDAStream()));
// if (this->device().type == Device::CPU && device.type == Device::CUDA) {
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(), cudaMemcpyHostToDevice, getCurrentCUDAStream()));
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(), cudaMemcpyDefault,
// getCurrentCUDAStream())); if (this->device().type == Device::CPU && device.type == Device::CUDA) {
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(),
// cudaMemcpyHostToDevice, getCurrentCUDAStream()));
// } else if (this->device().type == Device::CUDA && device.type == Device::CPU) {
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(), cudaMemcpyDeviceToHost, getCurrentCUDAStream()));
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(),
// cudaMemcpyDeviceToHost, getCurrentCUDAStream()));
// } else {
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(), cudaMemcpyDefault, getCurrentCUDAStream()));
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(),
// cudaMemcpyDefault, getCurrentCUDAStream()));
// }
return
result
;
}
...
...
@@ -518,7 +570,8 @@ private:
static
inline
std
::
map
<
cudaStream_t
,
std
::
set
<
std
::
shared_ptr
<
Buffer
>>>
lockedBuffers
;
public:
// before launching an async operation, make sure to lock the buffer in case the buffer is freed before GPU completes
// before launching an async operation, make sure to lock the buffer in case the buffer is freed before GPU
// completes
static
void
lockBuffer
(
std
::
shared_ptr
<
Buffer
>
buffer
,
cudaStream_t
stream
)
{
if
(
!
buffer
->
isAsyncBuffer
())
{
lockedBuffers
[
stream
].
insert
(
buffer
);
...
...
src/activation.cpp
View file @
37a27712
...
...
@@ -22,13 +22,15 @@ Tensor GELU::forward(Tensor x) {
// return out;
// }
// Tensor SiluAndMulQuant::forward_with_act_sum(Tensor x, Tensor quantized_mlp_act_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer) {
// Tensor SiluAndMulQuant::forward_with_act_sum(Tensor x, Tensor quantized_mlp_act_buffer, Tensor
// quantized_scale_buffer, Tensor quantized_sum_buffer) {
// Tensor out = SiluAndMul::forward(x);
// invoke_quant_fuse_sum(quantized_mlp_act_buffer, out, quantized_sum_buffer, quantized_scale_buffer);
// return out;
// }
// Tensor SiluAndMulQuant::forward_wo_act_sum(Tensor x, Tensor quantized_mlp_act_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer) {
// Tensor SiluAndMulQuant::forward_wo_act_sum(Tensor x, Tensor quantized_mlp_act_buffer, Tensor quantized_scale_buffer,
// Tensor quantized_sum_buffer) {
// Tensor out = SiluAndMul::forward(x);
// invoke_quant(quantized_mlp_act_buffer, out, quantized_scale_buffer, {});
// return out;
...
...
Prev
1
2
3
4
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment