Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
ad8097b9
Unverified
Commit
ad8097b9
authored
Apr 04, 2025
by
Muyang Li
Committed by
GitHub
Apr 04, 2025
Browse files
Release v0.2.0
Ready to release v0.2.0
parents
804a6d30
998192ca
Changes
142
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
776 additions
and
152 deletions
+776
-152
scripts/build_docker.sh
scripts/build_docker.sh
+40
-0
scripts/build_linux_wheel.sh
scripts/build_linux_wheel.sh
+14
-1
scripts/build_linux_wheel_cu128.sh
scripts/build_linux_wheel_cu128.sh
+36
-0
scripts/build_linux_wheel_torch2.7_cu128.sh
scripts/build_linux_wheel_torch2.7_cu128.sh
+36
-0
scripts/build_windows_wheel.cmd
scripts/build_windows_wheel.cmd
+55
-0
scripts/build_windows_wheel_cu128.cmd
scripts/build_windows_wheel_cu128.cmd
+48
-0
scripts/linux_cleanup.sh
scripts/linux_cleanup.sh
+2
-6
setup.py
setup.py
+10
-4
src/FluxModel.cpp
src/FluxModel.cpp
+295
-84
src/FluxModel.h
src/FluxModel.h
+33
-2
src/Linear.cpp
src/Linear.cpp
+17
-12
src/Linear.h
src/Linear.h
+5
-1
src/Module.cpp
src/Module.cpp
+19
-0
src/Module.h
src/Module.h
+19
-2
src/SanaModel.cpp
src/SanaModel.cpp
+29
-25
src/SanaModel.h
src/SanaModel.h
+1
-1
src/Tensor.h
src/Tensor.h
+4
-1
src/common.h
src/common.h
+90
-9
src/interop/torch.cpp
src/interop/torch.cpp
+2
-1
src/interop/torch.h
src/interop/torch.h
+21
-3
No files found.
scripts/build_docker.sh
0 → 100644
View file @
ad8097b9
#!/bin/bash
PYTHON_VERSION
=
$1
TORCH_VERSION
=
$2
CUDA_VERSION
=
$3
NUNCHAKU_VERSION
=
$4
# Check if TORCH_VERSION is 2.5 or 2.6 and set the corresponding versions for TORCHVISION and TORCHAUDIO
if
[
"
$TORCH_VERSION
"
==
"2.5"
]
;
then
TORCHVISION_VERSION
=
"0.20"
TORCHAUDIO_VERSION
=
"2.5"
echo
"TORCH_VERSION is 2.5, setting TORCHVISION_VERSION to
$TORCHVISION_VERSION
and TORCHAUDIO_VERSION to
$TORCHAUDIO_VERSION
"
elif
[
"
$TORCH_VERSION
"
==
"2.6"
]
;
then
TORCHVISION_VERSION
=
"0.21"
TORCHAUDIO_VERSION
=
"2.6"
echo
"TORCH_VERSION is 2.6, setting TORCHVISION_VERSION to
$TORCHVISION_VERSION
and TORCHAUDIO_VERSION to
$TORCHAUDIO_VERSION
"
else
echo
"TORCH_VERSION is not 2.5 or 2.6. Exit."
exit
2
fi
if
[
"
$CUDA_VERSION
"
==
"12.8"
]
;
then
CUDA_IMAGE
=
"12.8.1-devel-ubuntu24.04"
echo
"CUDA_VERSION is 12.8, setting CUDA_IMAGE to
$CUDA_IMAGE
"
elif
[
"
$CUDA_VERSION
"
==
"12.4"
]
;
then
CUDA_IMAGE
=
"12.4.1-devel-ubuntu22.04"
echo
"CUDA_VERSION is 12.4, setting CUDA_IMAGE to
$CUDA_IMAGE
"
else
echo
"CUDA_VERSION is not 12.8 or 12.4. Exit."
exit
2
fi
docker build
--no-cache
\
--build-arg
PYTHON_VERSION
=
${
PYTHON_VERSION
}
\
--build-arg
CUDA_SHORT_VERSION
=
${
CUDA_VERSION
//.
}
\
--build-arg
CUDA_IMAGE
=
${
CUDA_IMAGE
}
\
--build-arg
TORCH_VERSION
=
${
TORCH_VERSION
}
\
--build-arg
TORCHVISION_VERSION
=
${
TORCHVISION_VERSION
}
\
--build-arg
TORCHAUDIO_VERSION
=
${
TORCHAUDIO_VERSION
}
\
-t
nunchaku:
${
NUNCHAKU_VERSION
}
-py
${
PYTHON_VERSION
}
-torch
${
TORCH_VERSION
}
-cuda
${
CUDA_VERSION
}
.
scripts/build_linux_wheel
s
.sh
→
scripts/build_linux_wheel.sh
View file @
ad8097b9
...
@@ -7,6 +7,19 @@ CUDA_VERSION=$3
...
@@ -7,6 +7,19 @@ CUDA_VERSION=$3
MAX_JOBS
=
${
4
:-}
# optional
MAX_JOBS
=
${
4
:-}
# optional
PYTHON_ROOT_PATH
=
/opt/python/cp
${
PYTHON_VERSION
//.
}
-cp
${
PYTHON_VERSION
//.
}
PYTHON_ROOT_PATH
=
/opt/python/cp
${
PYTHON_VERSION
//.
}
-cp
${
PYTHON_VERSION
//.
}
# Check if TORCH_VERSION is 2.5 or 2.6 and set the corresponding versions for TORCHVISION and TORCHAUDIO
if
[
"
$TORCH_VERSION
"
==
"2.5"
]
;
then
TORCHVISION_VERSION
=
"0.20"
TORCHAUDIO_VERSION
=
"2.5"
echo
"TORCH_VERSION is 2.5, setting TORCHVISION_VERSION to
$TORCHVISION_VERSION
and TORCHAUDIO_VERSION to
$TORCHAUDIO_VERSION
"
elif
[
"
$TORCH_VERSION
"
==
"2.6"
]
;
then
TORCHVISION_VERSION
=
"0.21"
TORCHAUDIO_VERSION
=
"2.6"
echo
"TORCH_VERSION is 2.6, setting TORCHVISION_VERSION to
$TORCHVISION_VERSION
and TORCHAUDIO_VERSION to
$TORCHAUDIO_VERSION
"
else
echo
"TORCH_VERSION is not 2.5 or 2.6, no changes to versions."
fi
docker run
--rm
\
docker run
--rm
\
-v
"
$(
pwd
)
"
:/nunchaku
\
-v
"
$(
pwd
)
"
:/nunchaku
\
pytorch/manylinux-builder:cuda
${
CUDA_VERSION
}
\
pytorch/manylinux-builder:cuda
${
CUDA_VERSION
}
\
...
@@ -16,7 +29,7 @@ docker run --rm \
...
@@ -16,7 +29,7 @@ docker run --rm \
yum install -y devtoolset-11 &&
\
yum install -y devtoolset-11 &&
\
source scl_source enable devtoolset-11 &&
\
source scl_source enable devtoolset-11 &&
\
gcc --version && g++ --version &&
\
gcc --version && g++ --version &&
\
${
PYTHON_ROOT_PATH
}
/bin/pip install --no-cache-dir torch==
${
TORCH_VERSION
}
numpy
--index-url https://download.pytorch.org/whl/cu
${
CUDA_VERSION
//.
}
&&
\
${
PYTHON_ROOT_PATH
}
/bin/pip install --no-cache-dir torch==
${
TORCH_VERSION
}
torchvision==
${
TORCHVISION_VERSION
}
torchaudio==
${
TORCHAUDIO_VERSION
}
--index-url https://download.pytorch.org/whl/cu
${
CUDA_VERSION
//.
}
&&
\
${
PYTHON_ROOT_PATH
}
/bin/pip install build ninja wheel setuptools &&
\
${
PYTHON_ROOT_PATH
}
/bin/pip install build ninja wheel setuptools &&
\
export NUNCHAKU_INSTALL_MODE=ALL &&
\
export NUNCHAKU_INSTALL_MODE=ALL &&
\
export NUNCHAKU_BUILD_WHEELS=1 &&
\
export NUNCHAKU_BUILD_WHEELS=1 &&
\
...
...
scripts/build_linux_wheel_cu128.sh
0 → 100644
View file @
ad8097b9
#!/bin/bash
# Modified from https://github.com/sgl-project/sglang/blob/main/sgl-kernel/build.sh
set
-ex
PYTHON_VERSION
=
$1
TORCH_VERSION
=
$2
# has no use for now
CUDA_VERSION
=
$3
MAX_JOBS
=
${
4
:-}
# optional
PYTHON_ROOT_PATH
=
/opt/python/cp
${
PYTHON_VERSION
//.
}
-cp
${
PYTHON_VERSION
//.
}
# Check if TORCH_VERSION is 2.5 or 2.6 and set the corresponding versions for TORCHVISION and TORCHAUDIO
#if [ "$TORCH_VERSION" == "2.5" ]; then
# TORCHVISION_VERSION="0.20"
# TORCHAUDIO_VERSION="2.5"
# echo "TORCH_VERSION is 2.5, setting TORCHVISION_VERSION to $TORCHVISION_VERSION and TORCHAUDIO_VERSION to $TORCHAUDIO_VERSION"
#elif [ "$TORCH_VERSION" == "2.6" ]; then
# TORCHVISION_VERSION="0.21"
# TORCHAUDIO_VERSION="2.6"
# echo "TORCH_VERSION is 2.6, setting TORCHVISION_VERSION to $TORCHVISION_VERSION and TORCHAUDIO_VERSION to $TORCHAUDIO_VERSION"
#else
# echo "TORCH_VERSION is not 2.5 or 2.6, no changes to versions."
#fi
docker run
--rm
\
-v
"
$(
pwd
)
"
:/nunchaku
\
pytorch/manylinux2_28-builder:cuda
${
CUDA_VERSION
}
\
bash
-c
"
cd /nunchaku &&
\
rm -rf build &&
\
gcc --version && g++ --version &&
\
${
PYTHON_ROOT_PATH
}
/bin/pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 &&
\
${
PYTHON_ROOT_PATH
}
/bin/pip install build ninja wheel setuptools &&
\
export NUNCHAKU_INSTALL_MODE=ALL &&
\
export NUNCHAKU_BUILD_WHEELS=1 &&
\
export MAX_JOBS=
${
MAX_JOBS
}
&&
\
${
PYTHON_ROOT_PATH
}
/bin/python -m build --wheel --no-isolation
"
\ No newline at end of file
scripts/build_linux_wheel_torch2.7_cu128.sh
0 → 100644
View file @
ad8097b9
#!/bin/bash
# Modified from https://github.com/sgl-project/sglang/blob/main/sgl-kernel/build.sh
set
-ex
PYTHON_VERSION
=
$1
TORCH_VERSION
=
$2
# has no use for now
CUDA_VERSION
=
$3
MAX_JOBS
=
${
4
:-}
# optional
PYTHON_ROOT_PATH
=
/opt/python/cp
${
PYTHON_VERSION
//.
}
-cp
${
PYTHON_VERSION
//.
}
# Check if TORCH_VERSION is 2.5 or 2.6 and set the corresponding versions for TORCHVISION and TORCHAUDIO
#if [ "$TORCH_VERSION" == "2.5" ]; then
# TORCHVISION_VERSION="0.20"
# TORCHAUDIO_VERSION="2.5"
# echo "TORCH_VERSION is 2.5, setting TORCHVISION_VERSION to $TORCHVISION_VERSION and TORCHAUDIO_VERSION to $TORCHAUDIO_VERSION"
#elif [ "$TORCH_VERSION" == "2.6" ]; then
# TORCHVISION_VERSION="0.21"
# TORCHAUDIO_VERSION="2.6"
# echo "TORCH_VERSION is 2.6, setting TORCHVISION_VERSION to $TORCHVISION_VERSION and TORCHAUDIO_VERSION to $TORCHAUDIO_VERSION"
#else
# echo "TORCH_VERSION is not 2.5 or 2.6, no changes to versions."
#fi
docker run
--rm
\
-v
"
$(
pwd
)
"
:/nunchaku
\
pytorch/manylinux2_28-builder:cuda
${
CUDA_VERSION
}
\
bash
-c
"
cd /nunchaku &&
\
rm -rf build &&
\
gcc --version && g++ --version &&
\
${
PYTHON_ROOT_PATH
}
/bin/pip install --pre torch==2.7.0.dev20250307+cu128 torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 &&
\
${
PYTHON_ROOT_PATH
}
/bin/pip install build ninja wheel setuptools &&
\
export NUNCHAKU_INSTALL_MODE=ALL &&
\
export NUNCHAKU_BUILD_WHEELS=1 &&
\
export MAX_JOBS=
${
MAX_JOBS
}
&&
\
${
PYTHON_ROOT_PATH
}
/bin/python -m build --wheel --no-isolation
"
\ No newline at end of file
scripts/build_windows_wheel.cmd
0 → 100644
View file @
ad8097b9
@echo
off
setlocal
enabledelayedexpansion
:: get arguments
set
PYTHON_VERSION
=
%
1
set
TORCH_VERSION
=
%
2
set
CUDA_VERSION
=
%
3
set
CUDA_SHORT_VERSION
=
%CUDA
_VERSION:.
=
%
echo
%CUDA_SHORT_VERSION%
:: setup some variables
if
"
%TORCH_VERSION%
"
==
"2.5"
(
set
TORCHVISION_VERSION
=
0
.20
set
TORCHAUDIO_VERSION
=
2
.5
)
else
if
"
%TORCH_VERSION%
"
==
"2.6"
(
set
TORCHVISION_VERSION
=
0
.21
set
TORCHAUDIO_VERSION
=
2
.6
)
else
(
echo
TORCH_VERSION
is
not
2
.5
or
2
.6
,
no
changes
to
versions
.
)
echo
setting
TORCHVISION_VERSION
to
%TORCHVISION_VERSION%
and
TORCHAUDIO_VERSION
to
%TORCHAUDIO_VERSION%
:: conda environment name
set
ENV_NAME
=
build_env_
%PYTHON_VERSION%
_
%TORCH_VERSION%
echo
Using
conda
environment
:
%ENV_NAME%
:: create conda environment
call
conda
create
-y -n
%ENV_NAME%
python
=
%PYTHON_VERSION%
call
conda
activate
%ENV_NAME%
:: install dependencies
call
pip
install
ninja
setuptools
wheel
build
call
pip
install
--no-cache-dir
torch
==
%TORCH_VERSION%
torchvision
==
%TORCHVISION_VERSION%
torchaudio
==
%TORCHAUDIO_VERSION%
--index-url
"https://download.pytorch.org/whl/cu
%CUDA_SHORT_VERSION%
/"
:: set environment variables
set
NUNCHAKU_INSTALL_MODE
=
ALL
set
NUNCHAKU_BUILD_WHEELS
=
1
:: cd to the parent directory
cd
/d
"
%~dp0
.."
if
exist
build
rd
/s /q
build
:: set up Visual Studio compilation environment
call
"C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\Common7\Tools\VsDevCmd.bat"
-startdir
=
none
-arch
=
x64
-host
_arch
=
x64
set
DISTUTILS_USE_SDK
=
1
:: build wheels
python
-m
build
--wheel --no-isolation
:: exit conda
call
conda
deactivate
call
conda
remove
-y -n
%ENV_NAME%
--all
echo
Build
complete
!
scripts/build_windows_wheel_cu128.cmd
0 → 100644
View file @
ad8097b9
@echo
off
setlocal
enabledelayedexpansion
:: get arguments
set
PYTHON_VERSION
=
%
1
set
TORCH_VERSION
=
%
2
set
CUDA_VERSION
=
%
3
set
CUDA_SHORT_VERSION
=
%CUDA
_VERSION:.
=
%
echo
%CUDA_SHORT_VERSION%
:: conda environment name
set
ENV_NAME
=
build_env_
%PYTHON_VERSION%
_
%TORCH_VERSION%
echo
Using
conda
environment
:
%ENV_NAME%
:: create conda environment
call
conda
create
-y -n
%ENV_NAME%
python
=
%PYTHON_VERSION%
call
conda
activate
%ENV_NAME%
:: install dependencies
call
pip
install
ninja
setuptools
wheel
build
if
"
%TORCH_VERSION%
"
==
"2.7"
(
call
pip
install
--pre
torch
==
2
.7.0.dev20250307
+cu
128
torchvision
torchaudio
--index-url
https
://download.pytorch.org/whl/nightly/cu128
)
else
(
call
pip
install
--pre
torch
torchvision
torchaudio
--index-url
https
://download.pytorch.org/whl/nightly/cu128
)
:: set environment variables
set
NUNCHAKU_INSTALL_MODE
=
ALL
set
NUNCHAKU_BUILD_WHEELS
=
1
:: cd to the parent directory
cd
/d
"
%~dp0
.."
if
exist
build
rd
/s /q
build
:: set up Visual Studio compilation environment
call
"C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\Common7\Tools\VsDevCmd.bat"
-startdir
=
none
-arch
=
x64
-host
_arch
=
x64
set
DISTUTILS_USE_SDK
=
1
:: build wheels
python
-m
build
--wheel --no-isolation
:: exit conda
call
conda
deactivate
call
conda
remove
-y -n
%ENV_NAME%
--all
echo
Build
complete
!
scripts/linux_cleanup.sh
View file @
ad8097b9
#!/bin/bash
#!/bin/bash
set
-ex
set
-ex
#docker run --rm \
docker run
--rm
\
# -v "$(pwd)":/nunchaku \
# pytorch/manylinux-builder:cuda12.4 \
# bash -c "cd /nunchaku && rm -r *"
docker run
--rm
-it
\
-v
"
$(
pwd
)
"
:/nunchaku
\
-v
"
$(
pwd
)
"
:/nunchaku
\
pytorch/manylinux-builder:cuda12.4
\
pytorch/manylinux-builder:cuda12.4
\
bash
bash
-c
"cd /nunchaku && rm -rf *"
\ No newline at end of file
\ No newline at end of file
setup.py
View file @
ad8097b9
...
@@ -47,12 +47,12 @@ def get_sm_targets() -> list[str]:
...
@@ -47,12 +47,12 @@ def get_sm_targets() -> list[str]:
sm
=
f
"
{
capability
[
0
]
}{
capability
[
1
]
}
"
sm
=
f
"
{
capability
[
0
]
}{
capability
[
1
]
}
"
if
sm
==
"120"
and
support_sm120
:
if
sm
==
"120"
and
support_sm120
:
sm
=
"120a"
sm
=
"120a"
assert
sm
in
[
"80"
,
"86"
,
"89"
,
"120a"
],
f
"Unsupported SM
{
sm
}
"
assert
sm
in
[
"75"
,
"80"
,
"86"
,
"89"
,
"120a"
],
f
"Unsupported SM
{
sm
}
"
if
sm
not
in
ret
:
if
sm
not
in
ret
:
ret
.
append
(
sm
)
ret
.
append
(
sm
)
else
:
else
:
assert
install_mode
==
"ALL"
assert
install_mode
==
"ALL"
ret
=
[
"80"
,
"86"
,
"89"
]
ret
=
[
"75"
,
"80"
,
"86"
,
"89"
]
if
support_sm120
:
if
support_sm120
:
ret
.
append
(
"120a"
)
ret
.
append
(
"120a"
)
return
ret
return
ret
...
@@ -142,6 +142,7 @@ if __name__ == "__main__":
...
@@ -142,6 +142,7 @@ if __name__ == "__main__":
*
ncond
(
"src/FluxModel.cpp"
),
*
ncond
(
"src/FluxModel.cpp"
),
*
ncond
(
"src/SanaModel.cpp"
),
*
ncond
(
"src/SanaModel.cpp"
),
"src/Serialization.cpp"
,
"src/Serialization.cpp"
,
"src/Module.cpp"
,
*
ncond
(
"third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim64_fp16_sm80.cu"
),
*
ncond
(
"third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim64_fp16_sm80.cu"
),
*
ncond
(
"third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim64_bf16_sm80.cu"
),
*
ncond
(
"third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim64_bf16_sm80.cu"
),
*
ncond
(
"third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim128_fp16_sm80.cu"
),
*
ncond
(
"third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim128_fp16_sm80.cu"
),
...
@@ -158,9 +159,14 @@ if __name__ == "__main__":
...
@@ -158,9 +159,14 @@ if __name__ == "__main__":
"src/kernels/layernorm_kernels.cu"
,
"src/kernels/layernorm_kernels.cu"
,
"src/kernels/misc_kernels.cu"
,
"src/kernels/misc_kernels.cu"
,
"src/kernels/zgemm/gemm_w4a4.cu"
,
"src/kernels/zgemm/gemm_w4a4.cu"
,
"src/kernels/zgemm/gemm_w4a4_launch_fp16.cu"
,
"src/kernels/zgemm/gemm_w4a4_test.cu"
,
"src/kernels/zgemm/gemm_w4a4_launch_bf16.cu"
,
"src/kernels/zgemm/gemm_w4a4_launch_fp16_int4.cu"
,
"src/kernels/zgemm/gemm_w4a4_launch_fp16_int4_fasteri2f.cu"
,
"src/kernels/zgemm/gemm_w4a4_launch_fp16_fp4.cu"
,
"src/kernels/zgemm/gemm_w4a4_launch_bf16_int4.cu"
,
"src/kernels/zgemm/gemm_w4a4_launch_bf16_fp4.cu"
,
"src/kernels/zgemm/gemm_w8a8.cu"
,
"src/kernels/zgemm/gemm_w8a8.cu"
,
"src/kernels/zgemm/attention.cu"
,
"src/kernels/dwconv.cu"
,
"src/kernels/dwconv.cu"
,
"src/kernels/gemm_batched.cu"
,
"src/kernels/gemm_batched.cu"
,
"src/kernels/gemm_f16.cu"
,
"src/kernels/gemm_f16.cu"
,
...
...
src/FluxModel.cpp
View file @
ad8097b9
#include "FluxModel.h"
#include "FluxModel.h"
#include "kernels/misc_kernels.h"
#include "kernels/misc_kernels.h"
#include "kernels/gemm_batched.h"
#include "kernels/gemm_batched.h"
#include "kernels/zgemm/zgemm.h"
#include "flash_api.h"
#include "flash_api.h"
#include "activation.h"
#include "activation.h"
...
@@ -39,7 +40,7 @@ Tensor forward_fc(GEMM_W4A4 &fc, Tensor x) {
...
@@ -39,7 +40,7 @@ Tensor forward_fc(GEMM_W4A4 &fc, Tensor x) {
AdaLayerNormZeroSingle
::
AdaLayerNormZeroSingle
(
int
dim
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
AdaLayerNormZeroSingle
::
AdaLayerNormZeroSingle
(
int
dim
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
dim
(
dim
),
linear
(
dim
,
3
*
dim
,
true
,
dtype
,
device
),
linear
(
dim
,
3
*
dim
,
true
,
dtype
,
device
),
norm
(
dim
,
1e-6
,
false
,
dtype
,
device
)
norm
(
dim
,
1e-6
,
false
,
dtype
,
device
)
{
{
registerChildren
registerChildren
(
linear
,
"linear"
)
(
linear
,
"linear"
)
...
@@ -58,12 +59,12 @@ AdaLayerNormZeroSingle::Output AdaLayerNormZeroSingle::forward(Tensor x, Tensor
...
@@ -58,12 +59,12 @@ AdaLayerNormZeroSingle::Output AdaLayerNormZeroSingle::forward(Tensor x, Tensor
debug
(
"x"
,
x
);
debug
(
"x"
,
x
);
Tensor
norm_x
=
norm
.
forward
(
x
);
Tensor
norm_x
=
norm
.
forward
(
x
);
debug
(
"norm_x"
,
norm_x
);
debug
(
"norm_x"
,
norm_x
);
kernels
::
mul_add
(
norm_x
,
scale_msa
,
shift_msa
);
kernels
::
mul_add
(
norm_x
,
scale_msa
,
shift_msa
);
return
Output
{
norm_x
,
gate_msa
};
return
Output
{
norm_x
,
gate_msa
};
}
}
AdaLayerNormZero
::
AdaLayerNormZero
(
int
dim
,
bool
pre_only
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
AdaLayerNormZero
::
AdaLayerNormZero
(
int
dim
,
bool
pre_only
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
pre_only
(
pre_only
),
dim
(
dim
),
pre_only
(
pre_only
),
linear
(
dim
,
pre_only
?
2
*
dim
:
6
*
dim
,
true
,
dtype
,
device
),
linear
(
dim
,
pre_only
?
2
*
dim
:
6
*
dim
,
true
,
dtype
,
device
),
norm
(
dim
,
1e-6
,
false
,
dtype
,
device
)
norm
(
dim
,
1e-6
,
false
,
dtype
,
device
)
...
@@ -90,7 +91,7 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
...
@@ -90,7 +91,7 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
kernels
::
mul_add
(
norm_x
,
scale_msa
,
shift_msa
);
kernels
::
mul_add
(
norm_x
,
scale_msa
,
shift_msa
);
debug
(
"norm_x_scaled"
,
norm_x
);
debug
(
"norm_x_scaled"
,
norm_x
);
return
Output
{
norm_x
};
return
Output
{
norm_x
};
}
else
{
}
else
{
auto
&&
[
shift_msa
,
scale_msa
,
gate_msa
,
shift_mlp
,
scale_mlp
,
gate_mlp
]
=
kernels
::
split_mod
<
6
>
(
emb
);
auto
&&
[
shift_msa
,
scale_msa
,
gate_msa
,
shift_mlp
,
scale_mlp
,
gate_mlp
]
=
kernels
::
split_mod
<
6
>
(
emb
);
...
@@ -107,7 +108,7 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
...
@@ -107,7 +108,7 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
}
}
Attention
::
Attention
(
int
num_heads
,
int
dim_head
,
Device
device
)
:
Attention
::
Attention
(
int
num_heads
,
int
dim_head
,
Device
device
)
:
num_heads
(
num_heads
),
dim_head
(
dim_head
),
force_fp16
(
false
)
num_heads
(
num_heads
),
dim_head
(
dim_head
),
force_fp16
(
false
)
{
{
headmask_type
=
Tensor
::
allocate
({
num_heads
},
Tensor
::
INT32
,
Device
::
cpu
());
headmask_type
=
Tensor
::
allocate
({
num_heads
},
Tensor
::
INT32
,
Device
::
cpu
());
...
@@ -117,6 +118,33 @@ Attention::Attention(int num_heads, int dim_head, Device device) :
...
@@ -117,6 +118,33 @@ Attention::Attention(int num_heads, int dim_head, Device device) :
headmask_type
=
headmask_type
.
copy
(
device
);
headmask_type
=
headmask_type
.
copy
(
device
);
}
}
Tensor
Attention
::
forward
(
Tensor
qkv
)
{
assert
(
qkv
.
ndims
()
==
3
);
const
Device
device
=
qkv
.
device
();
const
int
batch_size
=
qkv
.
shape
[
0
];
const
int
num_tokens
=
qkv
.
shape
[
1
];
assert
(
qkv
.
shape
[
2
]
==
num_heads
*
dim_head
*
3
);
Tensor
reshaped
=
qkv
.
view
({
batch_size
,
num_tokens
,
num_heads
*
3
,
dim_head
});
Tensor
q
=
reshaped
.
slice
(
2
,
0
,
num_heads
);
Tensor
k
=
reshaped
.
slice
(
2
,
num_heads
,
num_heads
*
2
);
Tensor
v
=
reshaped
.
slice
(
2
,
num_heads
*
2
,
num_heads
*
3
);
Tensor
raw_attn_output
=
mha_fwd
(
q
,
k
,
v
,
0.0
f
,
pow
(
q
.
shape
[
-
1
],
(
-
0.5
)),
false
,
-
1
,
-
1
,
false
).
front
();
assert
(
raw_attn_output
.
shape
[
0
]
==
batch_size
);
assert
(
raw_attn_output
.
shape
[
1
]
==
num_tokens
);
assert
(
raw_attn_output
.
shape
[
2
]
==
num_heads
);
assert
(
raw_attn_output
.
shape
[
3
]
==
dim_head
);
return
raw_attn_output
.
view
({
batch_size
*
num_tokens
,
num_heads
,
dim_head
});
}
Tensor
Attention
::
forward
(
Tensor
qkv
,
Tensor
pool_qkv
,
float
sparsityRatio
)
{
Tensor
Attention
::
forward
(
Tensor
qkv
,
Tensor
pool_qkv
,
float
sparsityRatio
)
{
const
bool
cast_fp16
=
this
->
force_fp16
&&
qkv
.
scalar_type
()
!=
Tensor
::
FP16
;
const
bool
cast_fp16
=
this
->
force_fp16
&&
qkv
.
scalar_type
()
!=
Tensor
::
FP16
;
...
@@ -150,7 +178,7 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
...
@@ -150,7 +178,7 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
gemm_batched_fp16
(
pool_q
,
pool_k
,
pool_s
);
gemm_batched_fp16
(
pool_q
,
pool_k
,
pool_s
);
}
}
}
}
blockmask
=
kernels
::
topk
(
pool_score
,
pool_tokens
*
(
1
-
sparsityRatio
));
blockmask
=
kernels
::
topk
(
pool_score
,
pool_tokens
*
(
1
-
sparsityRatio
));
if
(
cu_seqlens_cpu
.
valid
())
{
if
(
cu_seqlens_cpu
.
valid
())
{
...
@@ -226,16 +254,16 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
...
@@ -226,16 +254,16 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
false
false
).front();
).front();
Tensor raw_attn_output = mha_fwd(q, k, v,
Tensor raw_attn_output = mha_fwd(q, k, v,
0.0f,
0.0f,
pow(q.shape[-1], (-0.5)),
pow(q.shape[-1], (-0.5)),
false, -1, -1, false
false, -1, -1, false
).front();
).front();
Tensor raw_attn_output = mha_varlen_fwd(
Tensor raw_attn_output = mha_varlen_fwd(
q, k, v,
q, k, v,
cu_seqlens, cu_seqlens,
cu_seqlens, cu_seqlens,
num_tokens_img + num_tokens_
conte
xt, num_tokens_img + num_tokens_
conte
xt,
num_tokens_img + num_tokens_
t
xt, num_tokens_img + num_tokens_
t
xt,
0.0f,
0.0f,
pow(q.shape[-1], (-0.5)),
pow(q.shape[-1], (-0.5)),
false, false, -1, -1, false
false, false, -1, -1, false
...
@@ -260,7 +288,7 @@ void Attention::setForceFP16(Module *module, bool value) {
...
@@ -260,7 +288,7 @@ void Attention::setForceFP16(Module *module, bool value) {
}
}
FluxSingleTransformerBlock
::
FluxSingleTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
int
mlp_ratio
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
FluxSingleTransformerBlock
::
FluxSingleTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
int
mlp_ratio
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
dim
(
dim
),
dim_head
(
attention_head_dim
/
num_attention_heads
),
dim_head
(
attention_head_dim
/
num_attention_heads
),
num_heads
(
num_attention_heads
),
num_heads
(
num_attention_heads
),
mlp_hidden_dim
(
dim
*
mlp_ratio
),
mlp_hidden_dim
(
dim
*
mlp_ratio
),
...
@@ -298,19 +326,50 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
...
@@ -298,19 +326,50 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
Tensor
residual
=
hidden_states
;
Tensor
residual
=
hidden_states
;
Tensor
qkv
=
Tensor
::
allocate
({
batch_size
,
num_tokens
,
dim
*
3
},
norm_hidden_states
.
scalar_type
(),
norm_hidden_states
.
device
());
Tensor
attn_output
;
// qkv_proj.forward(norm_hidden_states, qkv, {});
// debug("qkv_raw", qkv);
debug
(
"rotary_emb"
,
rotary_emb
);
debug
(
"rotary_emb"
,
rotary_emb
);
qkv_proj
.
forward
(
norm_hidden_states
,
qkv
,
{},
norm_q
.
weight
,
norm_k
.
weight
,
rotary_emb
);
debug
(
"qkv"
,
qkv
);
if
(
attnImpl
==
AttentionImpl
::
FlashAttention2
)
{
// Tensor qkv = forward_fc(qkv_proj, norm_hidden_states);
Tensor
qkv
=
Tensor
::
allocate
({
batch_size
,
num_tokens
,
dim
*
3
},
norm_hidden_states
.
scalar_type
(),
norm_hidden_states
.
device
());
// qkv_proj.forward(norm_hidden_states, qkv, {});
Tensor
attn_output
=
attn
.
forward
(
qkv
,
{},
0
);
// debug("qkv_raw", qkv);
attn_output
=
attn_output
.
reshape
({
batch_size
,
num_tokens
,
num_heads
*
dim_head
});
qkv_proj
.
forward
(
norm_hidden_states
,
qkv
,
{},
norm_q
.
weight
,
norm_k
.
weight
,
rotary_emb
);
debug
(
"qkv"
,
qkv
);
// Tensor qkv = forward_fc(qkv_proj, norm_hidden_states);
// attn_output = attn.forward(qkv, {}, 0);
attn_output
=
attn
.
forward
(
qkv
);
attn_output
=
attn_output
.
reshape
({
batch_size
,
num_tokens
,
num_heads
*
dim_head
});
}
else
if
(
attnImpl
==
AttentionImpl
::
NunchakuFP16
)
{
assert
(
batch_size
==
1
);
const
int
num_tokens_pad
=
ceilDiv
(
num_tokens
,
256
)
*
256
;
Tensor
q
=
Tensor
::
allocate
({
batch_size
,
num_heads
,
num_tokens_pad
,
dim_head
},
Tensor
::
FP16
,
norm_hidden_states
.
device
());
Tensor
k
=
Tensor
::
allocate
({
batch_size
,
num_heads
,
num_tokens_pad
,
dim_head
},
Tensor
::
FP16
,
norm_hidden_states
.
device
());
Tensor
v
=
Tensor
::
allocate
({
batch_size
,
num_heads
,
num_tokens_pad
,
dim_head
},
Tensor
::
FP16
,
norm_hidden_states
.
device
());
qkv_proj
.
forward
(
norm_hidden_states
,
{},
{},
norm_q
.
weight
,
norm_k
.
weight
,
rotary_emb
,
q
,
k
,
v
,
num_tokens
);
debug
(
"packed_q"
,
q
);
debug
(
"packed_k"
,
k
);
debug
(
"packed_v"
,
v
);
Tensor
o
=
Tensor
::
allocate
({
batch_size
,
num_tokens_pad
,
num_heads
*
dim_head
},
norm_hidden_states
.
scalar_type
(),
norm_hidden_states
.
device
());
kernels
::
attention_fp16
(
q
,
k
,
v
,
o
,
pow
(
dim_head
,
(
-
0.5
)));
attn_output
=
o
.
slice
(
1
,
0
,
num_tokens
);
}
else
{
assert
(
false
);
}
debug
(
"raw_attn_output"
,
attn_output
);
debug
(
"raw_attn_output"
,
attn_output
);
attn_output
=
forward_fc
(
out_proj
,
attn_output
);
attn_output
=
forward_fc
(
out_proj
,
attn_output
);
debug
(
"attn_output"
,
attn_output
);
debug
(
"attn_output"
,
attn_output
);
...
@@ -319,7 +378,7 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
...
@@ -319,7 +378,7 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
hidden_states
=
kernels
::
add
(
attn_output
,
ff_output
);
hidden_states
=
kernels
::
add
(
attn_output
,
ff_output
);
debug
(
"attn_ff_output"
,
hidden_states
);
debug
(
"attn_ff_output"
,
hidden_states
);
kernels
::
mul_add
(
hidden_states
,
gate
,
residual
);
kernels
::
mul_add
(
hidden_states
,
gate
,
residual
);
nvtxRangePop
();
nvtxRangePop
();
...
@@ -327,7 +386,7 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
...
@@ -327,7 +386,7 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
return
hidden_states
;
return
hidden_states
;
}
}
JointTransformerBlock
::
JointTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
bool
context_pre_only
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
JointTransformerBlock
::
JointTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
bool
context_pre_only
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
dim
(
dim
),
dim_head
(
attention_head_dim
/
num_attention_heads
),
dim_head
(
attention_head_dim
/
num_attention_heads
),
num_heads
(
num_attention_heads
),
num_heads
(
num_attention_heads
),
...
@@ -384,13 +443,13 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
...
@@ -384,13 +443,13 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
int
num_tokens_img
=
hidden_states
.
shape
[
1
];
int
num_tokens_img
=
hidden_states
.
shape
[
1
];
int
num_tokens_
conte
xt
=
encoder_hidden_states
.
shape
[
1
];
int
num_tokens_
t
xt
=
encoder_hidden_states
.
shape
[
1
];
assert
(
hidden_states
.
shape
[
2
]
==
dim
);
assert
(
hidden_states
.
shape
[
2
]
==
dim
);
assert
(
encoder_hidden_states
.
shape
[
2
]
==
dim
);
assert
(
encoder_hidden_states
.
shape
[
2
]
==
dim
);
spdlog
::
debug
(
"hidden_states={} encoder_hidden_states={} temb={}"
,
hidden_states
.
shape
.
str
(),
encoder_hidden_states
.
shape
.
str
(),
temb
.
shape
.
str
());
spdlog
::
debug
(
"hidden_states={} encoder_hidden_states={} temb={}"
,
hidden_states
.
shape
.
str
(),
encoder_hidden_states
.
shape
.
str
(),
temb
.
shape
.
str
());
spdlog
::
debug
(
"batch_size={} num_tokens_img={} num_tokens_
conte
xt={}"
,
batch_size
,
num_tokens_img
,
num_tokens_
conte
xt
);
spdlog
::
debug
(
"batch_size={} num_tokens_img={} num_tokens_
t
xt={}"
,
batch_size
,
num_tokens_img
,
num_tokens_
t
xt
);
auto
norm1_output
=
norm1
.
forward
(
hidden_states
,
temb
);
auto
norm1_output
=
norm1
.
forward
(
hidden_states
,
temb
);
auto
norm1_context_output
=
norm1_context
.
forward
(
encoder_hidden_states
,
temb
);
auto
norm1_context_output
=
norm1_context
.
forward
(
encoder_hidden_states
,
temb
);
...
@@ -408,76 +467,141 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
...
@@ -408,76 +467,141 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
nvtxRangePop
();
nvtxRangePop
();
auto
stream
=
getCurrentCUDAStream
();
auto
stream
=
getCurrentCUDAStream
();
Tensor
concat
;
Tensor
pool
;
{
nvtxRangePushA
(
"qkv_proj"
);
const
bool
blockSparse
=
sparsityRatio
>
0
;
int
num_tokens_img_pad
=
0
,
num_tokens_txt_pad
=
0
;
Tensor
raw_attn_output
;
const
int
poolTokens
=
num_tokens_img
/
POOL_SIZE
+
num_tokens_context
/
POOL_SIZE
;
if
(
attnImpl
==
AttentionImpl
::
FlashAttention2
)
{
concat
=
Tensor
::
allocate
({
batch_size
,
num_tokens_img
+
num_tokens_context
,
dim
*
3
},
norm1_output
.
x
.
scalar_type
(),
norm1_output
.
x
.
device
());
num_tokens_img_pad
=
num_tokens_img
;
num_tokens_txt_pad
=
num_tokens_txt
;
pool
=
blockSparse
Tensor
concat
;
?
Tensor
::
allocate
({
batch_size
,
poolTokens
,
dim
*
3
},
norm1_output
.
x
.
scalar_type
(),
norm1_output
.
x
.
device
())
Tensor
pool
;
:
Tensor
{};
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
// img first
Tensor
qkv
=
concat
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
0
,
num_tokens_img
);
Tensor
qkv_context
=
concat
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
num_tokens_img
,
num_tokens_img
+
num_tokens_context
);
Tensor
pool_qkv
=
pool
.
valid
()
{
?
pool
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
0
,
num_tokens_img
/
POOL_SIZE
)
nvtxRangePushA
(
"qkv_proj"
);
:
Tensor
{};
Tensor
pool_qkv_context
=
pool
.
valid
()
const
bool
blockSparse
=
sparsityRatio
>
0
;
?
concat
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
num_tokens_img
/
POOL_SIZE
,
num_tokens_img
/
POOL_SIZE
+
num_tokens_context
/
POOL_SIZE
)
const
int
poolTokens
=
num_tokens_img
/
POOL_SIZE
+
num_tokens_txt
/
POOL_SIZE
;
concat
=
Tensor
::
allocate
({
batch_size
,
num_tokens_img
+
num_tokens_txt
,
dim
*
3
},
norm1_output
.
x
.
scalar_type
(),
norm1_output
.
x
.
device
());
pool
=
blockSparse
?
Tensor
::
allocate
({
batch_size
,
poolTokens
,
dim
*
3
},
norm1_output
.
x
.
scalar_type
(),
norm1_output
.
x
.
device
())
:
Tensor
{};
:
Tensor
{};
// qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv);
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
// debug("qkv_raw", qkv);
// img first
Tensor
qkv
=
concat
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
0
,
num_tokens_img
);
Tensor
qkv_context
=
concat
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
num_tokens_img
,
num_tokens_img
+
num_tokens_txt
);
Tensor
pool_qkv
=
pool
.
valid
()
?
pool
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
0
,
num_tokens_img
/
POOL_SIZE
)
:
Tensor
{};
Tensor
pool_qkv_context
=
pool
.
valid
()
?
concat
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
num_tokens_img
/
POOL_SIZE
,
num_tokens_img
/
POOL_SIZE
+
num_tokens_txt
/
POOL_SIZE
)
:
Tensor
{};
debug
(
"rotary_emb"
,
rotary_emb
);
// qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv);
// debug("qkv_raw", qkv);
qkv_proj
.
forward
(
norm1_output
.
x
.
slice
(
0
,
i
,
i
+
1
),
qkv
,
pool_qkv
,
norm_q
.
weight
,
norm_k
.
weight
,
rotary_emb
);
debug
(
"rotary_emb"
,
rotary_emb
);
debug
(
"qkv"
,
qkv
);
//
qkv_proj
_context
.forward(norm1_
context_
output.x.slice(0, i, i + 1), qkv
_context
);
qkv_proj
.
forward
(
norm1_output
.
x
.
slice
(
0
,
i
,
i
+
1
),
qkv
,
pool_qkv
,
norm_q
.
weight
,
norm_k
.
weight
,
rotary_emb
);
//
debug("qkv
_context_raw", qkv_context
);
debug
(
"qkv
"
,
qkv
);
debug
(
"rotary_emb_context"
,
rotary_emb_context
);
// qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context);
// debug("qkv_context_raw", qkv_context);
qkv_proj_context
.
forward
(
norm1_context_output
.
x
.
slice
(
0
,
i
,
i
+
1
),
qkv_context
,
pool_qkv_context
,
norm_added_q
.
weight
,
norm_added_k
.
weight
,
rotary_emb_context
);
debug
(
"rotary_emb_context"
,
rotary_emb_context
);
debug
(
"qkv_context"
,
qkv_context
);
qkv_proj_context
.
forward
(
norm1_context_output
.
x
.
slice
(
0
,
i
,
i
+
1
),
qkv_context
,
pool_qkv_context
,
norm_added_q
.
weight
,
norm_added_k
.
weight
,
rotary_emb_context
);
debug
(
"qkv_context"
,
qkv_context
);
}
nvtxRangePop
();
}
spdlog
::
debug
(
"concat={}"
,
concat
.
shape
.
str
());
debug
(
"concat"
,
concat
);
assert
(
concat
.
shape
[
2
]
==
num_heads
*
dim_head
*
3
);
nvtxRangePushA
(
"Attention"
);
if
(
pool
.
valid
())
{
raw_attn_output
=
attn
.
forward
(
concat
,
pool
,
sparsityRatio
);
}
else
{
raw_attn_output
=
attn
.
forward
(
concat
);
}
}
nvtxRangePop
();
nvtxRangePop
();
}
spdlog
::
debug
(
"concat={}"
,
concat
.
shape
.
str
());
spdlog
::
debug
(
"raw_attn_output={}"
,
raw_attn_output
.
shape
.
str
());
debug
(
"concat"
,
concat
);
assert
(
concat
.
shape
[
2
]
==
num_heads
*
dim_head
*
3
);
raw_attn_output
=
raw_attn_output
.
view
({
batch_size
,
num_tokens_img
+
num_tokens_txt
,
num_heads
,
dim_head
}
);
nvtxRangePushA
(
"Attention"
);
}
else
if
(
attnImpl
==
AttentionImpl
::
NunchakuFP16
)
{
num_tokens_img_pad
=
ceilDiv
(
num_tokens_img
,
256
)
*
256
;
num_tokens_txt_pad
=
ceilDiv
(
num_tokens_txt
,
256
)
*
256
;
Tensor
raw_attn_output
=
attn
.
forward
(
concat
,
pool
,
sparsityRatio
)
;
Tensor
concat_q
,
concat_k
,
concat_v
;
nvtxRangePop
();
{
nvtxRangePushA
(
"qkv_proj"
);
spdlog
::
debug
(
"raw_attn_output={}"
,
raw_attn_output
.
shape
.
str
());
concat_q
=
Tensor
::
allocate
({
batch_size
,
num_heads
,
num_tokens_img_pad
+
num_tokens_txt_pad
,
dim_head
},
Tensor
::
FP16
,
norm1_output
.
x
.
device
());
concat_k
=
Tensor
::
empty_like
(
concat_q
);
concat_v
=
Tensor
::
empty_like
(
concat_q
);
raw_attn_output
=
raw_attn_output
.
view
({
batch_size
,
num_tokens_img
+
num_tokens_context
,
num_heads
,
dim_head
});
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
debug
(
"raw_attn_output"
,
raw_attn_output
);
// img first
auto
sliceImg
=
[
&
](
Tensor
x
)
{
return
x
.
slice
(
0
,
i
,
i
+
1
).
slice
(
2
,
0
,
num_tokens_img_pad
);
};
auto
sliceTxt
=
[
&
](
Tensor
x
)
{
return
x
.
slice
(
0
,
i
,
i
+
1
).
slice
(
2
,
num_tokens_img_pad
,
num_tokens_img_pad
+
num_tokens_txt_pad
);
};
qkv_proj
.
forward
(
norm1_output
.
x
.
slice
(
0
,
i
,
i
+
1
),
{},
{},
norm_q
.
weight
,
norm_k
.
weight
,
rotary_emb
,
sliceImg
(
concat_q
),
sliceImg
(
concat_k
),
sliceImg
(
concat_v
),
num_tokens_img
);
qkv_proj_context
.
forward
(
norm1_context_output
.
x
.
slice
(
0
,
i
,
i
+
1
),
{},
{},
norm_added_q
.
weight
,
norm_added_k
.
weight
,
rotary_emb_context
,
sliceTxt
(
concat_q
),
sliceTxt
(
concat_k
),
sliceTxt
(
concat_v
),
num_tokens_txt
);
}
debug
(
"concat_q"
,
concat_q
);
debug
(
"concat_k"
,
concat_k
);
debug
(
"concat_v"
,
concat_v
);
nvtxRangePop
();
}
raw_attn_output
=
Tensor
::
allocate
({
batch_size
,
num_tokens_img_pad
+
num_tokens_txt_pad
,
num_heads
*
dim_head
},
norm1_output
.
x
.
scalar_type
(),
norm1_output
.
x
.
device
());
nvtxRangePushA
(
"Attention"
);
kernels
::
attention_fp16
(
concat_q
,
concat_k
,
concat_v
,
raw_attn_output
,
pow
(
dim_head
,
(
-
0.5
)));
nvtxRangePop
();
raw_attn_output
=
raw_attn_output
.
view
({
batch_size
,
num_tokens_img_pad
+
num_tokens_txt_pad
,
num_heads
,
dim_head
});
}
else
{
assert
(
false
);
}
debug
(
"raw_attn_output"
,
raw_attn_output
);
{
{
nvtxRangePushA
(
"o_proj"
);
nvtxRangePushA
(
"o_proj"
);
auto
&&
[
_
,
gate_msa
,
shift_mlp
,
scale_mlp
,
gate_mlp
]
=
norm1_output
;
auto
&&
[
_
,
gate_msa
,
shift_mlp
,
scale_mlp
,
gate_mlp
]
=
norm1_output
;
// raw_attn_output: [batch_size, num_tokens_img + num_tokens_
conte
xt, num_heads * dim_head]
// raw_attn_output: [batch_size, num_tokens_img + num_tokens_
t
xt, num_heads * dim_head]
Tensor
raw_attn_output_split
;
Tensor
raw_attn_output_split
;
if
(
batch_size
==
1
)
{
if
(
batch_size
==
1
)
{
...
@@ -485,16 +609,16 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
...
@@ -485,16 +609,16 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
}
else
{
}
else
{
raw_attn_output_split
=
Tensor
::
allocate
({
batch_size
,
num_tokens_img
,
num_heads
*
dim_head
},
raw_attn_output
.
scalar_type
(),
raw_attn_output
.
device
());
raw_attn_output_split
=
Tensor
::
allocate
({
batch_size
,
num_tokens_img
,
num_heads
*
dim_head
},
raw_attn_output
.
scalar_type
(),
raw_attn_output
.
device
());
checkCUDA
(
cudaMemcpy2DAsync
(
checkCUDA
(
cudaMemcpy2DAsync
(
raw_attn_output_split
.
data_ptr
(),
raw_attn_output_split
.
data_ptr
(),
num_tokens_img
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
num_tokens_img
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
raw_attn_output
.
data_ptr
(),
raw_attn_output
.
data_ptr
(),
(
num_tokens_img
+
num_tokens_
context
)
*
num_heads
*
dim_head
*
raw_attn_output
.
scalar_size
(),
(
num_tokens_img
_pad
+
num_tokens_
txt_pad
)
*
num_heads
*
dim_head
*
raw_attn_output
.
scalar_size
(),
num_tokens_img
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
num_tokens_img
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
batch_size
,
batch_size
,
cudaMemcpyDeviceToDevice
,
cudaMemcpyDeviceToDevice
,
stream
));
stream
));
}
}
spdlog
::
debug
(
"raw_attn_output_split={}"
,
raw_attn_output_split
.
shape
.
str
());
spdlog
::
debug
(
"raw_attn_output_split={}"
,
raw_attn_output_split
.
shape
.
str
());
debug
(
"img.raw_attn_output_split"
,
raw_attn_output_split
);
debug
(
"img.raw_attn_output_split"
,
raw_attn_output_split
);
...
@@ -546,20 +670,20 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
...
@@ -546,20 +670,20 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
Tensor
raw_attn_output_split
;
Tensor
raw_attn_output_split
;
if
(
batch_size
==
1
)
{
if
(
batch_size
==
1
)
{
raw_attn_output_split
=
raw_attn_output
.
slice
(
1
,
num_tokens_img
,
num_tokens_img
+
num_tokens_
conte
xt
).
reshape
({
batch_size
,
num_tokens_
conte
xt
,
num_heads
*
dim_head
});
raw_attn_output_split
=
raw_attn_output
.
slice
(
1
,
num_tokens_img
_pad
,
num_tokens_img
_pad
+
num_tokens_
t
xt
).
reshape
({
batch_size
,
num_tokens_
t
xt
,
num_heads
*
dim_head
});
}
else
{
}
else
{
raw_attn_output_split
=
Tensor
::
allocate
({
batch_size
,
num_tokens_
conte
xt
,
num_heads
*
dim_head
},
raw_attn_output
.
scalar_type
(),
raw_attn_output
.
device
());
raw_attn_output_split
=
Tensor
::
allocate
({
batch_size
,
num_tokens_
t
xt
,
num_heads
*
dim_head
},
raw_attn_output
.
scalar_type
(),
raw_attn_output
.
device
());
checkCUDA
(
cudaMemcpy2DAsync
(
checkCUDA
(
cudaMemcpy2DAsync
(
raw_attn_output_split
.
data_ptr
(),
raw_attn_output_split
.
data_ptr
(),
num_tokens_
conte
xt
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
num_tokens_
t
xt
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
raw_attn_output
.
data_ptr
<
char
>
()
+
num_tokens_img
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
raw_attn_output
.
data_ptr
<
char
>
()
+
num_tokens_img
_pad
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
(
num_tokens_img
+
num_tokens_
context
)
*
num_heads
*
dim_head
*
raw_attn_output
.
scalar_size
(),
(
num_tokens_img
_pad
+
num_tokens_
txt_pad
)
*
num_heads
*
dim_head
*
raw_attn_output
.
scalar_size
(),
num_tokens_
conte
xt
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
num_tokens_
t
xt
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
batch_size
,
batch_size
,
cudaMemcpyDeviceToDevice
,
cudaMemcpyDeviceToDevice
,
stream
));
stream
));
}
}
spdlog
::
debug
(
"raw_attn_output_split={}"
,
raw_attn_output_split
.
shape
.
str
());
spdlog
::
debug
(
"raw_attn_output_split={}"
,
raw_attn_output_split
.
shape
.
str
());
debug
(
"context.raw_attn_output_split"
,
raw_attn_output_split
);
debug
(
"context.raw_attn_output_split"
,
raw_attn_output_split
);
...
@@ -585,7 +709,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
...
@@ -585,7 +709,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
#else
#else
auto
norm_hidden_states
=
encoder_hidden_states
;
auto
norm_hidden_states
=
encoder_hidden_states
;
#endif
#endif
// Tensor ff_output = mlp_context_fc2.forward(GELU::forward(mlp_context_fc1.forward(norm_hidden_states)));
// Tensor ff_output = mlp_context_fc2.forward(GELU::forward(mlp_context_fc1.forward(norm_hidden_states)));
// Tensor ff_output = mlp_context_fc2.forward_quant(quant_static_fuse_gelu(mlp_context_fc1.forward(norm_hidden_states), 1.0));
// Tensor ff_output = mlp_context_fc2.forward_quant(quant_static_fuse_gelu(mlp_context_fc1.forward(norm_hidden_states), 1.0));
...
@@ -607,7 +731,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
...
@@ -607,7 +731,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
return
{
hidden_states
,
encoder_hidden_states
};
return
{
hidden_states
,
encoder_hidden_states
};
}
}
FluxModel
::
FluxModel
(
bool
use_fp4
,
bool
offload
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
offload
(
offload
)
{
FluxModel
::
FluxModel
(
bool
use_fp4
,
bool
offload
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dtype
(
dtype
),
offload
(
offload
)
{
for
(
int
i
=
0
;
i
<
19
;
i
++
)
{
for
(
int
i
=
0
;
i
<
19
;
i
++
)
{
transformer_blocks
.
push_back
(
std
::
make_unique
<
JointTransformerBlock
>
(
3072
,
24
,
3072
,
false
,
use_fp4
,
dtype
,
device
));
transformer_blocks
.
push_back
(
std
::
make_unique
<
JointTransformerBlock
>
(
3072
,
24
,
3072
,
false
,
use_fp4
,
dtype
,
device
));
registerChildren
(
*
transformer_blocks
.
back
(),
format
(
"transformer_blocks.{}"
,
i
));
registerChildren
(
*
transformer_blocks
.
back
(),
format
(
"transformer_blocks.{}"
,
i
));
...
@@ -626,7 +750,16 @@ FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Devic
...
@@ -626,7 +750,16 @@ FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Devic
}
}
}
}
Tensor
FluxModel
::
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb_img
,
Tensor
rotary_emb_context
,
Tensor
rotary_emb_single
)
{
Tensor
FluxModel
::
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb_img
,
Tensor
rotary_emb_context
,
Tensor
rotary_emb_single
,
Tensor
controlnet_block_samples
,
Tensor
controlnet_single_block_samples
,
bool
skip_first_layer
)
{
const
int
batch_size
=
hidden_states
.
shape
[
0
];
const
int
batch_size
=
hidden_states
.
shape
[
0
];
const
Tensor
::
ScalarType
dtype
=
hidden_states
.
dtype
();
const
Tensor
::
ScalarType
dtype
=
hidden_states
.
dtype
();
const
Device
device
=
hidden_states
.
device
();
const
Device
device
=
hidden_states
.
device
();
...
@@ -639,9 +772,20 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
...
@@ -639,9 +772,20 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
Tensor
concat
;
Tensor
concat
;
auto
compute
=
[
&
](
int
layer
)
{
auto
compute
=
[
&
](
int
layer
)
{
if
(
skip_first_layer
&&
size_t
(
layer
)
==
0
)
return
;
if
(
size_t
(
layer
)
<
transformer_blocks
.
size
())
{
if
(
size_t
(
layer
)
<
transformer_blocks
.
size
())
{
auto
&
block
=
transformer_blocks
.
at
(
layer
);
auto
&
block
=
transformer_blocks
.
at
(
layer
);
std
::
tie
(
hidden_states
,
encoder_hidden_states
)
=
block
->
forward
(
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_context
,
0.0
f
);
std
::
tie
(
hidden_states
,
encoder_hidden_states
)
=
block
->
forward
(
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_context
,
0.0
f
);
if
(
controlnet_block_samples
.
valid
())
{
const
int
num_controlnet_block_samples
=
controlnet_block_samples
.
shape
[
0
];
int
interval_control
=
ceilDiv
(
transformer_blocks
.
size
(),
static_cast
<
size_t
>
(
num_controlnet_block_samples
));
int
block_index
=
layer
/
interval_control
;
// Xlabs ControlNet
// block_index = layer % num_controlnet_block_samples;
hidden_states
=
kernels
::
add
(
hidden_states
,
controlnet_block_samples
[
block_index
]);
}
}
else
{
}
else
{
if
(
size_t
(
layer
)
==
transformer_blocks
.
size
())
{
if
(
size_t
(
layer
)
==
transformer_blocks
.
size
())
{
// txt first, same as diffusers
// txt first, same as diffusers
...
@@ -652,10 +796,23 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
...
@@ -652,10 +796,23 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
}
}
hidden_states
=
concat
;
hidden_states
=
concat
;
encoder_hidden_states
=
{};
encoder_hidden_states
=
{};
}
}
auto
&
block
=
single_transformer_blocks
.
at
(
layer
-
transformer_blocks
.
size
());
auto
&
block
=
single_transformer_blocks
.
at
(
layer
-
transformer_blocks
.
size
());
hidden_states
=
block
->
forward
(
hidden_states
,
temb
,
rotary_emb_single
);
hidden_states
=
block
->
forward
(
hidden_states
,
temb
,
rotary_emb_single
);
if
(
controlnet_single_block_samples
.
valid
())
{
const
int
num_controlnet_single_block_samples
=
controlnet_single_block_samples
.
shape
[
0
];
int
interval_control
=
ceilDiv
(
single_transformer_blocks
.
size
(),
static_cast
<
size_t
>
(
num_controlnet_single_block_samples
));
int
block_index
=
(
layer
-
transformer_blocks
.
size
())
/
interval_control
;
// Xlabs ControlNet
// block_index = layer % num_controlnet_single_block_samples
auto
slice
=
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
);
slice
=
kernels
::
add
(
slice
,
controlnet_single_block_samples
[
block_index
]);
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
).
copy_
(
slice
);
}
}
}
};
};
auto
load
=
[
&
](
int
layer
)
{
auto
load
=
[
&
](
int
layer
)
{
...
@@ -681,4 +838,58 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
...
@@ -681,4 +838,58 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
helper
.
run
();
helper
.
run
();
return
hidden_states
;
return
hidden_states
;
}
}
\ No newline at end of file
std
::
tuple
<
Tensor
,
Tensor
>
FluxModel
::
forward_layer
(
size_t
layer
,
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb_img
,
Tensor
rotary_emb_context
,
Tensor
controlnet_block_samples
,
Tensor
controlnet_single_block_samples
)
{
std
::
tie
(
hidden_states
,
encoder_hidden_states
)
=
transformer_blocks
.
at
(
layer
)
->
forward
(
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_context
,
0.0
f
);
const
int
txt_tokens
=
encoder_hidden_states
.
shape
[
1
];
const
int
img_tokens
=
hidden_states
.
shape
[
1
];
if
(
layer
<
transformer_blocks
.
size
()
&&
controlnet_block_samples
.
valid
())
{
const
int
num_controlnet_block_samples
=
controlnet_block_samples
.
shape
[
0
];
int
interval_control
=
ceilDiv
(
transformer_blocks
.
size
(),
static_cast
<
size_t
>
(
num_controlnet_block_samples
));
int
block_index
=
layer
/
interval_control
;
// Xlabs ControlNet
// block_index = layer % num_controlnet_block_samples;
hidden_states
=
kernels
::
add
(
hidden_states
,
controlnet_block_samples
[
block_index
]);
}
else
if
(
layer
>=
transformer_blocks
.
size
()
&&
controlnet_single_block_samples
.
valid
())
{
const
int
num_controlnet_single_block_samples
=
controlnet_single_block_samples
.
shape
[
0
];
int
interval_control
=
ceilDiv
(
single_transformer_blocks
.
size
(),
static_cast
<
size_t
>
(
num_controlnet_single_block_samples
));
int
block_index
=
(
layer
-
transformer_blocks
.
size
())
/
interval_control
;
// Xlabs ControlNet
// block_index = layer % num_controlnet_single_block_samples
auto
slice
=
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
);
slice
=
kernels
::
add
(
slice
,
controlnet_single_block_samples
[
block_index
]);
hidden_states
.
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
).
copy_
(
slice
);
}
return
{
hidden_states
,
encoder_hidden_states
};
}
void
FluxModel
::
setAttentionImpl
(
AttentionImpl
impl
)
{
for
(
auto
&&
block
:
this
->
transformer_blocks
)
{
block
->
attnImpl
=
impl
;
}
for
(
auto
&&
block
:
this
->
single_transformer_blocks
)
{
block
->
attnImpl
=
impl
;
}
}
src/FluxModel.h
View file @
ad8097b9
...
@@ -6,6 +6,11 @@
...
@@ -6,6 +6,11 @@
#include "Linear.h"
#include "Linear.h"
#include "layernorm.h"
#include "layernorm.h"
enum
class
AttentionImpl
{
FlashAttention2
=
0
,
NunchakuFP16
,
};
class
AdaLayerNormZeroSingle
:
public
Module
{
class
AdaLayerNormZeroSingle
:
public
Module
{
public:
public:
static
constexpr
bool
USE_4BIT
=
true
;
static
constexpr
bool
USE_4BIT
=
true
;
...
@@ -56,8 +61,9 @@ private:
...
@@ -56,8 +61,9 @@ private:
class
Attention
:
public
Module
{
class
Attention
:
public
Module
{
public:
public:
static
constexpr
int
POOL_SIZE
=
128
;
static
constexpr
int
POOL_SIZE
=
128
;
Attention
(
int
num_heads
,
int
dim_head
,
Device
device
);
Attention
(
int
num_heads
,
int
dim_head
,
Device
device
);
Tensor
forward
(
Tensor
qkv
);
Tensor
forward
(
Tensor
qkv
,
Tensor
pool_qkv
,
float
sparsityRatio
);
Tensor
forward
(
Tensor
qkv
,
Tensor
pool_qkv
,
float
sparsityRatio
);
static
void
setForceFP16
(
Module
*
module
,
bool
value
);
static
void
setForceFP16
(
Module
*
module
,
bool
value
);
...
@@ -86,6 +92,8 @@ public:
...
@@ -86,6 +92,8 @@ public:
const
int
num_heads
;
const
int
num_heads
;
const
int
mlp_hidden_dim
;
const
int
mlp_hidden_dim
;
AttentionImpl
attnImpl
=
AttentionImpl
::
FlashAttention2
;
private:
private:
AdaLayerNormZeroSingle
norm
;
AdaLayerNormZeroSingle
norm
;
GEMM
mlp_fc1
;
GEMM
mlp_fc1
;
...
@@ -110,6 +118,8 @@ public:
...
@@ -110,6 +118,8 @@ public:
const
int
num_heads
;
const
int
num_heads
;
const
bool
context_pre_only
;
const
bool
context_pre_only
;
AttentionImpl
attnImpl
=
AttentionImpl
::
FlashAttention2
;
private:
private:
AdaLayerNormZero
norm1
;
AdaLayerNormZero
norm1
;
AdaLayerNormZero
norm1_context
;
AdaLayerNormZero
norm1_context
;
...
@@ -129,9 +139,30 @@ private:
...
@@ -129,9 +139,30 @@ private:
class
FluxModel
:
public
Module
{
class
FluxModel
:
public
Module
{
public:
public:
FluxModel
(
bool
use_fp4
,
bool
offload
,
Tensor
::
ScalarType
dtype
,
Device
device
);
FluxModel
(
bool
use_fp4
,
bool
offload
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb_img
,
Tensor
rotary_emb_context
,
Tensor
rotary_emb_single
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb_img
,
Tensor
rotary_emb_context
,
Tensor
rotary_emb_single
,
Tensor
controlnet_block_samples
,
Tensor
controlnet_single_block_samples
,
bool
skip_first_layer
=
false
);
std
::
tuple
<
Tensor
,
Tensor
>
forward_layer
(
size_t
layer
,
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb_img
,
Tensor
rotary_emb_context
,
Tensor
controlnet_block_samples
,
Tensor
controlnet_single_block_samples
);
void
setAttentionImpl
(
AttentionImpl
impl
);
public:
public:
const
Tensor
::
ScalarType
dtype
;
std
::
vector
<
std
::
unique_ptr
<
JointTransformerBlock
>>
transformer_blocks
;
std
::
vector
<
std
::
unique_ptr
<
JointTransformerBlock
>>
transformer_blocks
;
std
::
vector
<
std
::
unique_ptr
<
FluxSingleTransformerBlock
>>
single_transformer_blocks
;
std
::
vector
<
std
::
unique_ptr
<
FluxSingleTransformerBlock
>>
single_transformer_blocks
;
...
...
src/Linear.cpp
View file @
ad8097b9
...
@@ -52,13 +52,14 @@ void GEMV_AWQ::loadParam(std::string key, Tensor &dst, Tensor src) {
...
@@ -52,13 +52,14 @@ void GEMV_AWQ::loadParam(std::string key, Tensor &dst, Tensor src) {
if
(
key
==
"lora_down"
||
key
==
"lora_up"
)
{
if
(
key
==
"lora_down"
||
key
==
"lora_up"
)
{
assert
(
src
.
ndims
()
==
2
);
assert
(
src
.
ndims
()
==
2
);
if
(
dst
.
shape
.
dataExtent
!=
src
.
shape
.
dataExtent
)
{
if
(
dst
.
shape
.
dataExtent
!=
src
.
shape
.
dataExtent
)
{
dst
=
src
.
copy
(
this
->
device
);
dst
=
Tensor
::
allocate
(
src
.
shape
.
dataExtent
,
dst
.
scalar_type
(),
this
->
device
);
Module
::
loadParam
(
key
,
dst
,
src
);
if
(
key
==
"lora_down"
)
{
if
(
key
==
"lora_down"
)
{
const
int
new_rank
=
dst
.
shape
[
0
];
const
int
new_rank
=
dst
.
shape
[
0
];
this
->
lora_rank
=
new_rank
;
this
->
lora_rank
=
new_rank
;
}
}
}
else
{
}
else
{
dst
.
copy_
(
src
);
Module
::
loadParam
(
key
,
dst
,
src
);
}
}
}
else
{
}
else
{
Module
::
loadParam
(
key
,
dst
,
src
);
Module
::
loadParam
(
key
,
dst
,
src
);
...
@@ -143,16 +144,18 @@ void GEMM_W4A4::loadParam(std::string key, Tensor &dst, Tensor src) {
...
@@ -143,16 +144,18 @@ void GEMM_W4A4::loadParam(std::string key, Tensor &dst, Tensor src) {
if
(
key
==
"lora_down"
||
key
==
"lora_up"
)
{
if
(
key
==
"lora_down"
||
key
==
"lora_up"
)
{
assert
(
src
.
ndims
()
==
2
);
assert
(
src
.
ndims
()
==
2
);
if
(
dst
.
shape
.
dataExtent
!=
src
.
shape
.
dataExtent
)
{
if
(
dst
.
shape
.
dataExtent
!=
src
.
shape
.
dataExtent
)
{
dst
=
src
.
copy
(
this
->
device
);
dst
=
Tensor
::
allocate
(
src
.
shape
.
dataExtent
,
dst
.
scalar_type
(),
this
->
device
);
Module
::
loadParam
(
key
,
dst
,
src
);
this
->
lora_rank
=
dst
.
shape
[
1
];
this
->
lora_rank
=
dst
.
shape
[
1
];
this
->
lora_scales
.
resize
(
ceilDiv
(
this
->
lora_rank
,
16
),
1.0
f
);
this
->
lora_scales
.
resize
(
ceilDiv
(
this
->
lora_rank
,
16
),
1.0
f
);
}
else
{
}
else
{
dst
.
copy_
(
src
);
Module
::
loadParam
(
key
,
dst
,
src
);
}
}
}
else
if
(
key
==
"wcscales"
)
{
}
else
if
(
key
==
"wcscales"
)
{
assert
(
src
.
ndims
()
==
1
);
assert
(
src
.
ndims
()
==
1
);
assert
(
src
.
shape
[
0
]
==
out_features_pad
);
assert
(
src
.
shape
[
0
]
==
out_features_pad
);
dst
=
src
.
copy
(
this
->
device
);
dst
=
Tensor
::
allocate
(
src
.
shape
.
dataExtent
,
dst
.
scalar_type
(),
this
->
device
);
Module
::
loadParam
(
key
,
dst
,
src
);
}
else
if
(
key
==
"wtscale"
)
{
}
else
if
(
key
==
"wtscale"
)
{
assert
(
src
.
numel
()
==
1
);
assert
(
src
.
numel
()
==
1
);
if
(
src
.
dtype
()
==
Tensor
::
BF16
)
{
if
(
src
.
dtype
()
==
Tensor
::
BF16
)
{
...
@@ -160,7 +163,7 @@ void GEMM_W4A4::loadParam(std::string key, Tensor &dst, Tensor src) {
...
@@ -160,7 +163,7 @@ void GEMM_W4A4::loadParam(std::string key, Tensor &dst, Tensor src) {
}
else
if
(
src
.
dtype
()
==
Tensor
::
FP16
)
{
}
else
if
(
src
.
dtype
()
==
Tensor
::
FP16
)
{
*
dst
.
data_ptr
<
float
>
()
=
float
(
*
src
.
data_ptr
<
half
>
());
*
dst
.
data_ptr
<
float
>
()
=
float
(
*
src
.
data_ptr
<
half
>
());
}
else
if
(
src
.
dtype
()
==
Tensor
::
FP32
)
{
}
else
if
(
src
.
dtype
()
==
Tensor
::
FP32
)
{
dst
.
copy_
(
src
);
Module
::
loadParam
(
key
,
dst
,
src
);
}
else
{
}
else
{
assert
(
false
);
assert
(
false
);
}
}
...
@@ -181,7 +184,7 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward(Tensor x
...
@@ -181,7 +184,7 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward(Tensor x
return
forward_quant
(
quantize
(
x
,
false
),
fuse
,
nextGEMM
);
return
forward_quant
(
quantize
(
x
,
false
),
fuse
,
nextGEMM
);
}
}
void
GEMM_W4A4
::
forward
(
Tensor
x
,
Tensor
out
,
Tensor
pool
,
Tensor
norm_q
,
Tensor
norm_k
,
Tensor
rotary_emb
)
{
void
GEMM_W4A4
::
forward
(
Tensor
x
,
Tensor
out
,
Tensor
pool
,
Tensor
norm_q
,
Tensor
norm_k
,
Tensor
rotary_emb
,
Tensor
out_q
,
Tensor
out_k
,
Tensor
out_v
,
int
numTokens
)
{
QuantizedActivation
qact
=
quantize
(
x
,
false
);
QuantizedActivation
qact
=
quantize
(
x
,
false
);
#if !NO_LORA_FUSION
#if !NO_LORA_FUSION
...
@@ -196,7 +199,8 @@ void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor
...
@@ -196,7 +199,8 @@ void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor
kernels
::
gemm_w4a4
(
kernels
::
gemm_w4a4
(
qact
.
act
,
qweight
,
out
,
{},
qact
.
ascales
,
wscales
,
{},
pool
,
qact
.
lora_act
,
this
->
lora_up
,
{},
{},
norm_q
,
norm_k
,
rotary_emb
,
this
->
bias
,
{},
{},
{},
qact
.
is_unsigned
,
this
->
lora_scales
,
false
,
qact
.
act
,
qweight
,
out
,
{},
qact
.
ascales
,
wscales
,
{},
pool
,
qact
.
lora_act
,
this
->
lora_up
,
{},
{},
norm_q
,
norm_k
,
rotary_emb
,
this
->
bias
,
{},
{},
{},
qact
.
is_unsigned
,
this
->
lora_scales
,
false
,
use_fp4
,
*
this
->
wtscale
.
data_ptr
<
float
>
(),
wcscales
.
numel
()
>
0
?
wcscales
:
Tensor
{}
use_fp4
,
*
this
->
wtscale
.
data_ptr
<
float
>
(),
wcscales
.
numel
()
>
0
?
wcscales
:
Tensor
{},
out_q
,
out_k
,
out_v
,
numTokens
);
);
debug
(
"gemm.out"
,
out
);
debug
(
"gemm.out"
,
out
);
...
@@ -277,7 +281,8 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
...
@@ -277,7 +281,8 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
kernels
::
gemm_w4a4
(
kernels
::
gemm_w4a4
(
qact
.
act
,
qweight
,
out
,
qout
.
act
,
qact
.
ascales
,
wscales
,
qout
.
ascales
,
{},
qact
.
lora_act
,
this
->
lora_up
,
next_lora
,
qout
.
lora_act
,
{},
{},
{},
this
->
bias
,
next_smooth
,
{},
{},
qact
.
is_unsigned
,
this
->
lora_scales
,
fuse
==
FuseOptions
::
SILU
,
qact
.
act
,
qweight
,
out
,
qout
.
act
,
qact
.
ascales
,
wscales
,
qout
.
ascales
,
{},
qact
.
lora_act
,
this
->
lora_up
,
next_lora
,
qout
.
lora_act
,
{},
{},
{},
this
->
bias
,
next_smooth
,
{},
{},
qact
.
is_unsigned
,
this
->
lora_scales
,
fuse
==
FuseOptions
::
SILU
,
use_fp4
,
*
this
->
wtscale
.
data_ptr
<
float
>
(),
wcscales
.
numel
()
>
0
?
wcscales
:
Tensor
{}
use_fp4
,
*
this
->
wtscale
.
data_ptr
<
float
>
(),
wcscales
.
numel
()
>
0
?
wcscales
:
Tensor
{},
{},
{},
{},
0
);
);
if
(
fuse
==
FuseOptions
::
EMPTY
||
fuse
==
FuseOptions
::
SILU
)
{
if
(
fuse
==
FuseOptions
::
EMPTY
||
fuse
==
FuseOptions
::
SILU
)
{
...
@@ -446,9 +451,9 @@ GEMM_W8A8::QuantizedActivation GEMM_W8A8::quantize(Tensor x, bool fuse_glu) {
...
@@ -446,9 +451,9 @@ GEMM_W8A8::QuantizedActivation GEMM_W8A8::quantize(Tensor x, bool fuse_glu) {
}
}
Tensor
GEMM_W8A8
::
forward_quant
(
QuantizedActivation
qact
)
{
Tensor
GEMM_W8A8
::
forward_quant
(
QuantizedActivation
qact
)
{
auto
o
shape
=
qact
.
act
.
shape
;
auto
shape
=
TensorShape
(
qact
.
act
.
shape
.
dataExtent
)
;
o
shape
[
-
1
]
=
out_features
;
shape
[
-
1
]
=
out_features
;
Tensor
out
=
Tensor
::
allocate
(
o
shape
,
this
->
dtype
,
qact
.
act
.
device
());
Tensor
out
=
Tensor
::
allocate
(
shape
,
this
->
dtype
,
qact
.
act
.
device
());
kernels
::
gemm_w8a8
(
qact
.
act
,
this
->
qweight
,
out
,
qact
.
ascales
,
this
->
wscales
,
this
->
bias
);
kernels
::
gemm_w8a8
(
qact
.
act
,
this
->
qweight
,
out
,
qact
.
ascales
,
this
->
wscales
,
this
->
bias
);
debug
(
"gemm.out"
,
out
);
debug
(
"gemm.out"
,
out
);
...
...
src/Linear.h
View file @
ad8097b9
...
@@ -69,7 +69,11 @@ public:
...
@@ -69,7 +69,11 @@ public:
Tensor
forward
(
Tensor
x
);
Tensor
forward
(
Tensor
x
);
Tensor
forward_silu
(
Tensor
x
);
Tensor
forward_silu
(
Tensor
x
);
std
::
variant
<
Tensor
,
QuantizedActivation
>
forward
(
Tensor
x
,
FuseOptions
fuse
,
GEMM_W4A4
*
nextGEMM
=
nullptr
);
std
::
variant
<
Tensor
,
QuantizedActivation
>
forward
(
Tensor
x
,
FuseOptions
fuse
,
GEMM_W4A4
*
nextGEMM
=
nullptr
);
void
forward
(
Tensor
x
,
Tensor
out
,
Tensor
pool
=
{},
Tensor
norm_q
=
{},
Tensor
norm_k
=
{},
Tensor
rotary_emb
=
{});
void
forward
(
Tensor
x
,
Tensor
out
,
Tensor
pool
=
{},
Tensor
norm_q
=
{},
Tensor
norm_k
=
{},
Tensor
rotary_emb
=
{},
Tensor
out_q
=
{},
Tensor
out_k
=
{},
Tensor
out_v
=
{},
int
numTokens
=
0
);
std
::
variant
<
Tensor
,
QuantizedActivation
>
forward_quant
(
QuantizedActivation
qact
,
FuseOptions
fuse
,
GEMM_W4A4
*
nextGEMM
=
nullptr
);
std
::
variant
<
Tensor
,
QuantizedActivation
>
forward_quant
(
QuantizedActivation
qact
,
FuseOptions
fuse
,
GEMM_W4A4
*
nextGEMM
=
nullptr
);
Tensor
forward_quant
(
QuantizedActivation
qact
);
Tensor
forward_quant
(
QuantizedActivation
qact
);
...
...
src/Module.cpp
0 → 100644
View file @
ad8097b9
#include "common.h"
#include "Module.h"
#include "kernels/misc_kernels.h"
void
Module
::
copyWithCast
(
Tensor
dst
,
Tensor
src
)
{
assert
(
dst
.
is_contiguous
());
assert
(
dst
.
device
().
type
==
Device
::
CUDA
);
if
(
src
.
device
().
type
==
Device
::
CUDA
&&
src
.
device
().
idx
==
dst
.
device
().
idx
)
{
nunchaku
::
kernels
::
cast
(
src
,
dst
);
}
else
{
Tensor
tmp
;
tmp
.
buffer
=
dst
.
buffer
;
tmp
.
shape
=
dst
.
shape
;
tmp
.
scalarType
=
src
.
scalarType
;
tmp
.
copy_
(
src
);
nunchaku
::
kernels
::
cast
(
tmp
,
dst
);
}
}
src/Module.h
View file @
ad8097b9
...
@@ -131,10 +131,23 @@ public:
...
@@ -131,10 +131,23 @@ public:
m
->
enabledLazyLoad
=
val
;
m
->
enabledLazyLoad
=
val
;
});
});
}
}
void
setAutoCastFP16
(
bool
val
)
{
traverse
([
val
](
Module
*
m
)
{
m
->
enabledAutoCastFP16
=
val
;
});
}
protected:
protected:
virtual
void
loadParam
(
std
::
string
key
,
Tensor
&
dst
,
Tensor
src
)
{
virtual
void
loadParam
(
std
::
string
key
,
Tensor
&
dst
,
Tensor
src
)
{
dst
.
copy_
(
src
);
static
const
std
::
set
<
Tensor
::
ScalarType
>
whitelist
=
{
Tensor
::
FP16
,
Tensor
::
BF16
,
};
if
(
enabledAutoCastFP16
&&
dst
.
scalar_type
()
!=
src
.
scalar_type
()
&&
whitelist
.
contains
(
dst
.
scalar_type
())
&&
whitelist
.
contains
(
src
.
scalar_type
()))
{
copyWithCast
(
dst
,
src
);
}
else
{
dst
.
copy_
(
src
);
}
}
}
struct
ChildrenRegisterHelper
{
struct
ChildrenRegisterHelper
{
...
@@ -174,7 +187,7 @@ protected:
...
@@ -174,7 +187,7 @@ protected:
}
}
void
debug
(
std
::
string
name
,
Tensor
tensor
)
{
void
debug
(
std
::
string
name
,
Tensor
tensor
)
{
if
(
DebugContext
::
ctxs
.
empty
())
{
if
(
DebugContext
::
ctxs
.
empty
()
||
!
tensor
.
valid
()
)
{
return
;
return
;
}
}
std
::
string
prefix
=
getFullName
();
std
::
string
prefix
=
getFullName
();
...
@@ -187,6 +200,9 @@ protected:
...
@@ -187,6 +200,9 @@ protected:
}
}
}
}
private:
void
copyWithCast
(
Tensor
dst
,
Tensor
src
);
public:
public:
Module
*
parent
=
nullptr
;
Module
*
parent
=
nullptr
;
std
::
string
name
=
""
;
std
::
string
name
=
""
;
...
@@ -194,6 +210,7 @@ public:
...
@@ -194,6 +210,7 @@ public:
std
::
map
<
std
::
string
,
Param
>
params
;
std
::
map
<
std
::
string
,
Param
>
params
;
bool
enabledLazyLoad
=
false
;
bool
enabledLazyLoad
=
false
;
bool
enabledAutoCastFP16
=
true
;
};
};
struct
LayerOffloadHelper
{
struct
LayerOffloadHelper
{
...
...
src/SanaModel.cpp
View file @
ad8097b9
#include <iostream>
#include "SanaModel.h"
#include "SanaModel.h"
#include "kernels/zgemm/zgemm.h"
#include "kernels/zgemm/zgemm.h"
#include "flash_api.h"
#include "flash_api.h"
...
@@ -8,6 +10,7 @@
...
@@ -8,6 +10,7 @@
using
spdlog
::
fmt_lib
::
format
;
using
spdlog
::
fmt_lib
::
format
;
using
namespace
nunchaku
;
using
namespace
nunchaku
;
SanaLinearAttention
::
SanaLinearAttention
(
int
dim
,
bool
bias
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
SanaLinearAttention
::
SanaLinearAttention
(
int
dim
,
bool
bias
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
dim
(
dim
),
dim_pad
(
ceilDiv
(
dim
,
128
)
*
128
),
dim_pad
(
ceilDiv
(
dim
,
128
)
*
128
),
...
@@ -28,7 +31,7 @@ SanaLinearAttention::SanaLinearAttention(int dim, bool bias, bool pag, bool use_
...
@@ -28,7 +31,7 @@ SanaLinearAttention::SanaLinearAttention(int dim, bool bias, bool pag, bool use_
Tensor
SanaLinearAttention
::
forward
(
Tensor
x
,
Tensor
out
)
{
Tensor
SanaLinearAttention
::
forward
(
Tensor
x
,
Tensor
out
)
{
constexpr
int
HEAD_DIM
=
32
;
constexpr
int
HEAD_DIM
=
32
;
assert
(
x
.
ndims
()
==
3
);
assert
(
x
.
ndims
()
==
3
);
const
int
batch_size
=
x
.
shape
[
0
];
const
int
batch_size
=
x
.
shape
[
0
];
const
int
num_tokens
=
x
.
shape
[
1
];
const
int
num_tokens
=
x
.
shape
[
1
];
...
@@ -45,7 +48,7 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
...
@@ -45,7 +48,7 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
x_pad
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
0
,
num_tokens
).
copy_
(
x
.
slice
(
0
,
i
,
i
+
1
));
x_pad
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
0
,
num_tokens
).
copy_
(
x
.
slice
(
0
,
i
,
i
+
1
));
}
}
x
=
x_pad
;
x
=
x_pad
;
}
}
...
@@ -55,18 +58,19 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
...
@@ -55,18 +58,19 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
Tensor
vk
=
Tensor
::
allocate
({
batch_size
,
num_heads
,
HEAD_DIM
+
1
,
HEAD_DIM
},
Tensor
::
FP32
,
x
.
device
());
Tensor
vk
=
Tensor
::
allocate
({
batch_size
,
num_heads
,
HEAD_DIM
+
1
,
HEAD_DIM
},
Tensor
::
FP32
,
x
.
device
());
kernels
::
gemm_w4a4
(
kernels
::
gemm_w4a4
(
qact
.
act
,
qact
.
act
,
qkv_proj
.
qweight
,
qkv_proj
.
qweight
,
{},
{},
{},
{},
qact
.
ascales
,
qact
.
ascales
,
qkv_proj
.
wscales
,
qkv_proj
.
wscales
,
{},
{},
qact
.
lora_act
,
qkv_proj
.
lora_up
,
{},
{},
{},
{},
{},
qkv_proj
.
bias
,
{},
{},
{},
qact
.
lora_act
,
qkv_proj
.
lora_up
,
{},
{},
{},
{},
{},
qkv_proj
.
bias
,
{},
vk
,
q
,
vk
,
q
,
qact
.
is_unsigned
,
qkv_proj
.
lora_scales
,
false
,
qact
.
is_unsigned
,
qkv_proj
.
lora_scales
,
false
,
qkv_proj
.
use_fp4
,
qkv_proj
.
use_fp4
,
*
qkv_proj
.
wtscale
.
data_ptr
<
float
>
(),
*
qkv_proj
.
wtscale
.
data_ptr
<
float
>
(),
qkv_proj
.
wcscales
.
numel
()
>
0
?
qkv_proj
.
wcscales
:
Tensor
{}
qkv_proj
.
wcscales
.
numel
()
>
0
?
qkv_proj
.
wcscales
:
Tensor
{},
{},
{},
{},
0
);
);
debug
(
"vk"
,
vk
);
debug
(
"vk"
,
vk
);
...
@@ -118,12 +122,12 @@ Tensor SanaLinearAttention::forward_pag(Tensor x, bool cfg) {
...
@@ -118,12 +122,12 @@ Tensor SanaLinearAttention::forward_pag(Tensor x, bool cfg) {
}
}
this
->
forward
(
x_org
,
out_org
);
this
->
forward
(
x_org
,
out_org
);
Tensor
v_ptb
=
this
->
pag_to_v
.
value
().
forward
(
x_ptb
);
Tensor
v_ptb
=
this
->
pag_to_v
.
value
().
forward
(
x_ptb
);
this
->
out_proj
.
forward
(
v_ptb
,
out_ptb
);
this
->
out_proj
.
forward
(
v_ptb
,
out_ptb
);
return
out
;
return
out
;
}
}
MultiHeadCrossAttention
::
MultiHeadCrossAttention
(
int
num_heads
,
int
head_dim
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
MultiHeadCrossAttention
::
MultiHeadCrossAttention
(
int
num_heads
,
int
head_dim
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
num_heads
(
num_heads
),
head_dim
(
head_dim
),
num_heads
(
num_heads
),
head_dim
(
head_dim
),
...
@@ -143,7 +147,7 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens
...
@@ -143,7 +147,7 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens
assert
(
cond
.
ndims
()
==
2
);
assert
(
cond
.
ndims
()
==
2
);
assert
(
cu_seqlens_img
.
ndims
()
==
1
);
assert
(
cu_seqlens_img
.
ndims
()
==
1
);
assert
(
cu_seqlens_txt
.
ndims
()
==
1
);
assert
(
cu_seqlens_txt
.
ndims
()
==
1
);
const
int
batch_size
=
x
.
shape
[
0
];
const
int
batch_size
=
x
.
shape
[
0
];
const
int
num_tokens_img
=
x
.
shape
[
1
];
const
int
num_tokens_img
=
x
.
shape
[
1
];
const
int
num_tokens_txt
=
cond
.
shape
[
0
];
const
int
num_tokens_txt
=
cond
.
shape
[
0
];
...
@@ -163,21 +167,21 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens
...
@@ -163,21 +167,21 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens
num_tokens_img
,
num_tokens_txt
,
num_tokens_img
,
num_tokens_txt
,
0.0
f
,
0.0
f
,
pow
(
q
.
shape
[
-
1
],
(
-
0.5
)),
pow
(
q
.
shape
[
-
1
],
(
-
0.5
)),
false
,
false
,
false
,
false
,
-
1
,
-
1
,
-
1
,
-
1
,
false
false
).
front
().
view
({
batch_size
,
num_tokens_img
,
num_heads
*
head_dim
});
).
front
().
view
({
batch_size
,
num_tokens_img
,
num_heads
*
head_dim
});
// Tensor attn_output = mha_fwd(q, k, v,
// Tensor attn_output = mha_fwd(q, k, v,
// 0.0f,
// 0.0f,
// pow(q.shape[-1], (-0.5)),
// pow(q.shape[-1], (-0.5)),
// false, -1, -1, false
// false, -1, -1, false
// ).front().view({B, N, num_heads * head_dim});
// ).front().view({B, N, num_heads * head_dim});
return
out_proj
.
forward
(
attn_output
);
return
out_proj
.
forward
(
attn_output
);
}
}
SanaGLUMBConv
::
SanaGLUMBConv
(
int
in_features
,
int
hidden_features
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
SanaGLUMBConv
::
SanaGLUMBConv
(
int
in_features
,
int
hidden_features
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
in_features
(
in_features
),
hidden_features
(
hidden_features
),
in_features
(
in_features
),
hidden_features
(
hidden_features
),
inverted_conv
(
in_features
,
hidden_features
*
2
,
true
,
use_fp4
,
dtype
,
device
),
inverted_conv
(
in_features
,
hidden_features
*
2
,
true
,
use_fp4
,
dtype
,
device
),
depth_conv
(
hidden_features
*
2
,
true
,
dtype
,
device
),
depth_conv
(
hidden_features
*
2
,
true
,
dtype
,
device
),
...
@@ -204,7 +208,7 @@ Tensor SanaGLUMBConv::forward(Tensor x, int H, int W) {
...
@@ -204,7 +208,7 @@ Tensor SanaGLUMBConv::forward(Tensor x, int H, int W) {
return
point_conv
.
forward_quant
(
qact
);
return
point_conv
.
forward_quant
(
qact
);
}
}
SanaLinearTransformerBlock
::
SanaLinearTransformerBlock
(
int
hidden_size
,
int
intermediate_size
,
int
num_cross_attention_heads
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
SanaLinearTransformerBlock
::
SanaLinearTransformerBlock
(
int
hidden_size
,
int
intermediate_size
,
int
num_cross_attention_heads
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
hidden_size
(
hidden_size
),
num_cross_attention_heads
(
num_cross_attention_heads
),
hidden_size
(
hidden_size
),
num_cross_attention_heads
(
num_cross_attention_heads
),
attn
(
hidden_size
,
false
,
pag
,
use_fp4
,
dtype
,
device
),
attn
(
hidden_size
,
false
,
pag
,
use_fp4
,
dtype
,
device
),
cross_attn
(
num_cross_attention_heads
,
hidden_size
/
num_cross_attention_heads
,
use_fp4
,
dtype
,
device
),
cross_attn
(
num_cross_attention_heads
,
hidden_size
/
num_cross_attention_heads
,
use_fp4
,
dtype
,
device
),
...
@@ -240,7 +244,7 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_
...
@@ -240,7 +244,7 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_
kernels
::
mul_add_batch
(
timestep
,
{},
false
,
0
,
this
->
scale_shift_table
,
false
);
kernels
::
mul_add_batch
(
timestep
,
{},
false
,
0
,
this
->
scale_shift_table
,
false
);
debug
(
"shifted_timestep"
,
timestep
);
debug
(
"shifted_timestep"
,
timestep
);
std
::
array
<
Tensor
,
6
>
chunked
;
std
::
array
<
Tensor
,
6
>
chunked
;
for
(
int
i
=
0
;
i
<
6
;
i
++
)
{
for
(
int
i
=
0
;
i
<
6
;
i
++
)
{
chunked
[
i
]
=
timestep
.
slice
(
1
,
i
,
i
+
1
);
chunked
[
i
]
=
timestep
.
slice
(
1
,
i
,
i
+
1
);
...
@@ -299,7 +303,7 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_
...
@@ -299,7 +303,7 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_
nvtxRangePop
();
nvtxRangePop
();
}
}
nvtxRangePop
();
nvtxRangePop
();
debug
(
"hidden_states_out"
,
hidden_states
);
debug
(
"hidden_states_out"
,
hidden_states
);
...
@@ -307,7 +311,7 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_
...
@@ -307,7 +311,7 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_
return
hidden_states
;
return
hidden_states
;
}
}
SanaModel
::
SanaModel
(
SanaConfig
config
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
SanaModel
::
SanaModel
(
SanaConfig
config
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
config
(
config
)
config
(
config
)
{
{
const
int
inner_dim
=
config
.
num_attention_heads
*
config
.
attention_head_dim
;
const
int
inner_dim
=
config
.
num_attention_heads
*
config
.
attention_head_dim
;
...
@@ -324,8 +328,8 @@ SanaModel::SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device)
...
@@ -324,8 +328,8 @@ SanaModel::SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device)
}
}
}
}
Tensor
SanaModel
::
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
)
{
Tensor
SanaModel
::
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
,
bool
skip_first_layer
)
{
for
(
int
i
=
0
;
i
<
config
.
num_layers
;
i
++
)
{
for
(
int
i
=
(
skip_first_layer
?
1
:
0
)
;
i
<
config
.
num_layers
;
i
++
)
{
auto
&&
block
=
transformer_blocks
[
i
];
auto
&&
block
=
transformer_blocks
[
i
];
hidden_states
=
block
->
forward
(
hidden_states
=
block
->
forward
(
hidden_states
,
encoder_hidden_states
,
timestep
,
cu_seqlens_img
,
cu_seqlens_txt
,
H
,
W
,
hidden_states
,
encoder_hidden_states
,
timestep
,
cu_seqlens_img
,
cu_seqlens_txt
,
H
,
W
,
...
...
src/SanaModel.h
View file @
ad8097b9
...
@@ -89,7 +89,7 @@ struct SanaConfig {
...
@@ -89,7 +89,7 @@ struct SanaConfig {
class
SanaModel
:
public
Module
{
class
SanaModel
:
public
Module
{
public:
public:
SanaModel
(
SanaConfig
config
,
Tensor
::
ScalarType
dtype
,
Device
device
);
SanaModel
(
SanaConfig
config
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
,
bool
skip_first_layer
);
public:
public:
const
SanaConfig
config
;
const
SanaConfig
config
;
...
...
src/Tensor.h
View file @
ad8097b9
...
@@ -81,7 +81,8 @@ public:
...
@@ -81,7 +81,8 @@ public:
BufferCUDA
(
size_t
size
)
{
BufferCUDA
(
size_t
size
)
{
this
->
size
=
size
;
this
->
size
=
size
;
this
->
device
.
type
=
Device
::
CUDA
;
this
->
device
.
type
=
Device
::
CUDA
;
checkCUDA
(
cudaGetDevice
(
&
this
->
device
.
idx
));
// checkCUDA(cudaGetDevice(&this->device.idx));
this
->
device
.
idx
=
CUDADeviceContext
::
getDevice
();
if
(
size
==
0
)
{
if
(
size
==
0
)
{
this
->
ptr
=
nullptr
;
this
->
ptr
=
nullptr
;
}
}
...
@@ -418,6 +419,7 @@ public:
...
@@ -418,6 +419,7 @@ public:
result
.
buffer
=
std
::
make_shared
<
BufferMalloc
>
(
shape
.
size
()
*
scalarSize
.
at
(
scalarType
));
result
.
buffer
=
std
::
make_shared
<
BufferMalloc
>
(
shape
.
size
()
*
scalarSize
.
at
(
scalarType
));
}
else
if
(
device
.
type
==
Device
::
CUDA
)
{
}
else
if
(
device
.
type
==
Device
::
CUDA
)
{
// TODO: cross device allocate
// TODO: cross device allocate
CUDADeviceContext
ctx
(
device
.
idx
);
result
.
buffer
=
std
::
make_shared
<
BufferCUDA
>
(
shape
.
size
()
*
scalarSize
.
at
(
scalarType
));
result
.
buffer
=
std
::
make_shared
<
BufferCUDA
>
(
shape
.
size
()
*
scalarSize
.
at
(
scalarType
));
}
else
{
}
else
{
assert
(
false
);
assert
(
false
);
...
@@ -429,6 +431,7 @@ public:
...
@@ -429,6 +431,7 @@ public:
if
(
device
.
type
==
Device
::
CPU
)
{
if
(
device
.
type
==
Device
::
CPU
)
{
memset
(
result
.
buffer
->
getPtr
(),
0xCC
,
result
.
buffer
->
getSize
());
memset
(
result
.
buffer
->
getPtr
(),
0xCC
,
result
.
buffer
->
getSize
());
}
else
if
(
device
.
type
==
Device
::
CUDA
)
{
}
else
if
(
device
.
type
==
Device
::
CUDA
)
{
CUDADeviceContext
ctx
(
device
.
idx
);
checkCUDA
(
cudaMemsetAsync
(
result
.
buffer
->
getPtr
(),
0xCC
,
result
.
buffer
->
getSize
(),
getCurrentCUDAStream
()));
checkCUDA
(
cudaMemsetAsync
(
result
.
buffer
->
getPtr
(),
0xCC
,
result
.
buffer
->
getSize
(),
getCurrentCUDAStream
()));
}
}
}
}
...
...
src/common.h
View file @
ad8097b9
...
@@ -107,16 +107,97 @@ struct CUDAEventWrapper {
...
@@ -107,16 +107,97 @@ struct CUDAEventWrapper {
}
}
};
};
/**
* 1. hold one when entered from external code (set `device` to -1 to avoid device change)
* 2. hold one when switching device
* 3. hold one with `disableCache` when calling external code that may change the device
*/
class
CUDADeviceContext
{
public:
CUDADeviceContext
(
int
device
=
-
1
,
bool
disableCache
=
false
)
:
disableCache
(
disableCache
)
{
if
(
cacheDisabled
())
{
// no previous context => we might entered from external code, reset cache
// previous context is reset on => external code may be executed, reset
currentDeviceCache
=
-
1
;
}
ctxs
.
push
(
this
);
lastDevice
=
getDevice
();
if
(
device
>=
0
)
{
setDevice
(
device
);
}
if
(
disableCache
)
{
// we are about to call external code, reset cache
currentDeviceCache
=
-
1
;
}
}
CUDADeviceContext
(
const
CUDADeviceContext
&
)
=
delete
;
CUDADeviceContext
(
CUDADeviceContext
&&
)
=
delete
;
~
CUDADeviceContext
()
{
if
(
disableCache
)
{
// retured from external code, cache is not reliable, reset
currentDeviceCache
=
-
1
;
}
setDevice
(
lastDevice
);
assert
(
ctxs
.
top
()
==
this
);
ctxs
.
pop
();
if
(
cacheDisabled
())
{
// ctxs.empty() => we are about to return to external code, reset cache
// otherwise => we are a nested context in a previous context with reset on, we might continue to execute external code, reset
currentDeviceCache
=
-
1
;
}
}
const
bool
disableCache
;
int
lastDevice
;
public:
static
int
getDevice
()
{
int
idx
=
-
1
;
if
(
cacheDisabled
()
||
currentDeviceCache
<
0
)
{
checkCUDA
(
cudaGetDevice
(
&
idx
));
}
else
{
idx
=
currentDeviceCache
;
}
currentDeviceCache
=
cacheDisabled
()
?
-
1
:
idx
;
return
idx
;
}
private:
static
void
setDevice
(
int
idx
)
{
// TODO: deal with stream when switching device
assert
(
idx
>=
0
);
if
(
!
cacheDisabled
()
&&
currentDeviceCache
==
idx
)
{
return
;
}
checkCUDA
(
cudaSetDevice
(
idx
));
currentDeviceCache
=
cacheDisabled
()
?
-
1
:
idx
;
}
private:
static
inline
thread_local
std
::
stack
<
CUDADeviceContext
*>
ctxs
;
static
inline
thread_local
int
currentDeviceCache
=
-
1
;
static
bool
cacheDisabled
()
{
return
ctxs
.
empty
()
||
ctxs
.
top
()
->
disableCache
;
}
};
inline
cudaDeviceProp
*
getCurrentDeviceProperties
()
{
inline
cudaDeviceProp
*
getCurrentDeviceProperties
()
{
static
thread_local
cudaDeviceProp
prop
;
static
thread_local
std
::
map
<
int
,
cudaDeviceProp
>
prop
s
;
static
thread_local
bool
propAvailable
=
false
;
i
f
(
!
propAvailable
)
{
i
nt
deviceId
=
CUDADeviceContext
::
getDevice
();
int
device
;
if
(
!
props
.
contains
(
deviceId
))
{
c
heckCUDA
(
cudaGetDevice
(
&
device
))
;
c
udaDeviceProp
prop
;
checkCUDA
(
cudaGetDeviceProperties
(
&
prop
,
device
));
checkCUDA
(
cudaGetDeviceProperties
(
&
prop
,
device
Id
));
prop
Available
=
true
;
prop
s
[
deviceId
]
=
prop
;
}
}
return
&
prop
;
return
&
prop
s
.
at
(
deviceId
)
;
}
}
template
<
typename
T
>
template
<
typename
T
>
...
...
src/interop/torch.cpp
View file @
ad8097b9
...
@@ -22,6 +22,7 @@ Tensor from_torch(at::Tensor input) {
...
@@ -22,6 +22,7 @@ Tensor from_torch(at::Tensor input) {
}
}
static
const
std
::
map
<
at
::
ScalarType
,
Tensor
::
ScalarType
>
mapType
=
{
static
const
std
::
map
<
at
::
ScalarType
,
Tensor
::
ScalarType
>
mapType
=
{
{
at
::
ScalarType
::
Char
,
Tensor
::
INT8
},
{
at
::
ScalarType
::
Byte
,
Tensor
::
INT8
},
{
at
::
ScalarType
::
Byte
,
Tensor
::
INT8
},
{
at
::
ScalarType
::
Int
,
Tensor
::
INT32
},
{
at
::
ScalarType
::
Int
,
Tensor
::
INT32
},
{
at
::
ScalarType
::
Long
,
Tensor
::
INT64
},
{
at
::
ScalarType
::
Long
,
Tensor
::
INT64
},
...
@@ -36,7 +37,7 @@ Tensor from_torch(at::Tensor input) {
...
@@ -36,7 +37,7 @@ Tensor from_torch(at::Tensor input) {
result
.
scalarType
=
mapType
.
at
(
input
.
scalar_type
());
result
.
scalarType
=
mapType
.
at
(
input
.
scalar_type
());
result
.
buffer
=
std
::
make_shared
<
BufferTorchTensor
>
(
std
::
move
(
input
));
result
.
buffer
=
std
::
make_shared
<
BufferTorchTensor
>
(
std
::
move
(
input
));
//
Tensor::lockBuffer(result.buffer, getCurrentCUDAStream());
Tensor
::
lockBuffer
(
result
.
buffer
,
getCurrentCUDAStream
());
return
result
;
return
result
;
}
}
...
...
src/interop/torch.h
View file @
ad8097b9
...
@@ -13,9 +13,9 @@ public:
...
@@ -13,9 +13,9 @@ public:
this
->
device
.
type
=
this
->
tensor
.
is_cuda
()
?
Device
::
CUDA
:
Device
::
CPU
;
this
->
device
.
type
=
this
->
tensor
.
is_cuda
()
?
Device
::
CUDA
:
Device
::
CPU
;
this
->
device
.
idx
=
this
->
tensor
.
get_device
();
this
->
device
.
idx
=
this
->
tensor
.
get_device
();
}
}
virtual
bool
isAsyncBuffer
()
override
{
virtual
bool
isAsyncBuffer
()
override
{
// TODO: figure out how torch manages memory
// TODO: figure out how torch manages memory
return
t
rue
;
return
t
his
->
device
.
type
==
Device
::
CUDA
;
}
}
private:
private:
at
::
Tensor
tensor
;
at
::
Tensor
tensor
;
...
@@ -30,4 +30,22 @@ public:
...
@@ -30,4 +30,22 @@ public:
};
};
Tensor
from_torch
(
at
::
Tensor
input
);
Tensor
from_torch
(
at
::
Tensor
input
);
at
::
Tensor
to_torch
(
Tensor
input
);
at
::
Tensor
to_torch
(
Tensor
input
);
\ No newline at end of file
class
TensorsProviderTorch
:
public
TensorsProvider
{
public:
TensorsProviderTorch
(
std
::
map
<
std
::
string
,
at
::
Tensor
>
dict
)
:
storage
(
std
::
move
(
dict
))
{}
virtual
bool
contains
(
const
std
::
string
&
key
)
const
override
{
return
storage
.
contains
(
key
);
}
virtual
Tensor
getTensor
(
const
std
::
string
&
key
)
override
{
if
(
!
storage
.
contains
(
key
))
{
return
Tensor
{};
}
return
from_torch
(
storage
.
at
(
key
));
}
private:
std
::
map
<
std
::
string
,
at
::
Tensor
>
storage
;
};
\ No newline at end of file
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment