Unverified Commit ad8097b9 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

Release v0.2.0

Ready to release v0.2.0
parents 804a6d30 998192ca
#!/bin/bash
PYTHON_VERSION=$1
TORCH_VERSION=$2
CUDA_VERSION=$3
NUNCHAKU_VERSION=$4
# Check if TORCH_VERSION is 2.5 or 2.6 and set the corresponding versions for TORCHVISION and TORCHAUDIO
if [ "$TORCH_VERSION" == "2.5" ]; then
TORCHVISION_VERSION="0.20"
TORCHAUDIO_VERSION="2.5"
echo "TORCH_VERSION is 2.5, setting TORCHVISION_VERSION to $TORCHVISION_VERSION and TORCHAUDIO_VERSION to $TORCHAUDIO_VERSION"
elif [ "$TORCH_VERSION" == "2.6" ]; then
TORCHVISION_VERSION="0.21"
TORCHAUDIO_VERSION="2.6"
echo "TORCH_VERSION is 2.6, setting TORCHVISION_VERSION to $TORCHVISION_VERSION and TORCHAUDIO_VERSION to $TORCHAUDIO_VERSION"
else
echo "TORCH_VERSION is not 2.5 or 2.6. Exit."
exit 2
fi
if [ "$CUDA_VERSION" == "12.8" ]; then
CUDA_IMAGE="12.8.1-devel-ubuntu24.04"
echo "CUDA_VERSION is 12.8, setting CUDA_IMAGE to $CUDA_IMAGE"
elif [ "$CUDA_VERSION" == "12.4" ]; then
CUDA_IMAGE="12.4.1-devel-ubuntu22.04"
echo "CUDA_VERSION is 12.4, setting CUDA_IMAGE to $CUDA_IMAGE"
else
echo "CUDA_VERSION is not 12.8 or 12.4. Exit."
exit 2
fi
docker build --no-cache \
--build-arg PYTHON_VERSION=${PYTHON_VERSION} \
--build-arg CUDA_SHORT_VERSION=${CUDA_VERSION//.} \
--build-arg CUDA_IMAGE=${CUDA_IMAGE} \
--build-arg TORCH_VERSION=${TORCH_VERSION} \
--build-arg TORCHVISION_VERSION=${TORCHVISION_VERSION} \
--build-arg TORCHAUDIO_VERSION=${TORCHAUDIO_VERSION} \
-t nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION} .
...@@ -7,6 +7,19 @@ CUDA_VERSION=$3 ...@@ -7,6 +7,19 @@ CUDA_VERSION=$3
MAX_JOBS=${4:-} # optional MAX_JOBS=${4:-} # optional
PYTHON_ROOT_PATH=/opt/python/cp${PYTHON_VERSION//.}-cp${PYTHON_VERSION//.} PYTHON_ROOT_PATH=/opt/python/cp${PYTHON_VERSION//.}-cp${PYTHON_VERSION//.}
# Check if TORCH_VERSION is 2.5 or 2.6 and set the corresponding versions for TORCHVISION and TORCHAUDIO
if [ "$TORCH_VERSION" == "2.5" ]; then
TORCHVISION_VERSION="0.20"
TORCHAUDIO_VERSION="2.5"
echo "TORCH_VERSION is 2.5, setting TORCHVISION_VERSION to $TORCHVISION_VERSION and TORCHAUDIO_VERSION to $TORCHAUDIO_VERSION"
elif [ "$TORCH_VERSION" == "2.6" ]; then
TORCHVISION_VERSION="0.21"
TORCHAUDIO_VERSION="2.6"
echo "TORCH_VERSION is 2.6, setting TORCHVISION_VERSION to $TORCHVISION_VERSION and TORCHAUDIO_VERSION to $TORCHAUDIO_VERSION"
else
echo "TORCH_VERSION is not 2.5 or 2.6, no changes to versions."
fi
docker run --rm \ docker run --rm \
-v "$(pwd)":/nunchaku \ -v "$(pwd)":/nunchaku \
pytorch/manylinux-builder:cuda${CUDA_VERSION} \ pytorch/manylinux-builder:cuda${CUDA_VERSION} \
...@@ -16,7 +29,7 @@ docker run --rm \ ...@@ -16,7 +29,7 @@ docker run --rm \
yum install -y devtoolset-11 && \ yum install -y devtoolset-11 && \
source scl_source enable devtoolset-11 && \ source scl_source enable devtoolset-11 && \
gcc --version && g++ --version && \ gcc --version && g++ --version && \
${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir torch==${TORCH_VERSION} numpy --index-url https://download.pytorch.org/whl/cu${CUDA_VERSION//.} && \ ${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir torch==${TORCH_VERSION} torchvision==${TORCHVISION_VERSION} torchaudio==${TORCHAUDIO_VERSION} --index-url https://download.pytorch.org/whl/cu${CUDA_VERSION//.} && \
${PYTHON_ROOT_PATH}/bin/pip install build ninja wheel setuptools && \ ${PYTHON_ROOT_PATH}/bin/pip install build ninja wheel setuptools && \
export NUNCHAKU_INSTALL_MODE=ALL && \ export NUNCHAKU_INSTALL_MODE=ALL && \
export NUNCHAKU_BUILD_WHEELS=1 && \ export NUNCHAKU_BUILD_WHEELS=1 && \
......
#!/bin/bash
# Modified from https://github.com/sgl-project/sglang/blob/main/sgl-kernel/build.sh
set -ex
PYTHON_VERSION=$1
TORCH_VERSION=$2 # has no use for now
CUDA_VERSION=$3
MAX_JOBS=${4:-} # optional
PYTHON_ROOT_PATH=/opt/python/cp${PYTHON_VERSION//.}-cp${PYTHON_VERSION//.}
# Check if TORCH_VERSION is 2.5 or 2.6 and set the corresponding versions for TORCHVISION and TORCHAUDIO
#if [ "$TORCH_VERSION" == "2.5" ]; then
# TORCHVISION_VERSION="0.20"
# TORCHAUDIO_VERSION="2.5"
# echo "TORCH_VERSION is 2.5, setting TORCHVISION_VERSION to $TORCHVISION_VERSION and TORCHAUDIO_VERSION to $TORCHAUDIO_VERSION"
#elif [ "$TORCH_VERSION" == "2.6" ]; then
# TORCHVISION_VERSION="0.21"
# TORCHAUDIO_VERSION="2.6"
# echo "TORCH_VERSION is 2.6, setting TORCHVISION_VERSION to $TORCHVISION_VERSION and TORCHAUDIO_VERSION to $TORCHAUDIO_VERSION"
#else
# echo "TORCH_VERSION is not 2.5 or 2.6, no changes to versions."
#fi
docker run --rm \
-v "$(pwd)":/nunchaku \
pytorch/manylinux2_28-builder:cuda${CUDA_VERSION} \
bash -c "
cd /nunchaku && \
rm -rf build && \
gcc --version && g++ --version && \
${PYTHON_ROOT_PATH}/bin/pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 && \
${PYTHON_ROOT_PATH}/bin/pip install build ninja wheel setuptools && \
export NUNCHAKU_INSTALL_MODE=ALL && \
export NUNCHAKU_BUILD_WHEELS=1 && \
export MAX_JOBS=${MAX_JOBS} && \
${PYTHON_ROOT_PATH}/bin/python -m build --wheel --no-isolation
"
\ No newline at end of file
#!/bin/bash
# Modified from https://github.com/sgl-project/sglang/blob/main/sgl-kernel/build.sh
set -ex
PYTHON_VERSION=$1
TORCH_VERSION=$2 # has no use for now
CUDA_VERSION=$3
MAX_JOBS=${4:-} # optional
PYTHON_ROOT_PATH=/opt/python/cp${PYTHON_VERSION//.}-cp${PYTHON_VERSION//.}
# Check if TORCH_VERSION is 2.5 or 2.6 and set the corresponding versions for TORCHVISION and TORCHAUDIO
#if [ "$TORCH_VERSION" == "2.5" ]; then
# TORCHVISION_VERSION="0.20"
# TORCHAUDIO_VERSION="2.5"
# echo "TORCH_VERSION is 2.5, setting TORCHVISION_VERSION to $TORCHVISION_VERSION and TORCHAUDIO_VERSION to $TORCHAUDIO_VERSION"
#elif [ "$TORCH_VERSION" == "2.6" ]; then
# TORCHVISION_VERSION="0.21"
# TORCHAUDIO_VERSION="2.6"
# echo "TORCH_VERSION is 2.6, setting TORCHVISION_VERSION to $TORCHVISION_VERSION and TORCHAUDIO_VERSION to $TORCHAUDIO_VERSION"
#else
# echo "TORCH_VERSION is not 2.5 or 2.6, no changes to versions."
#fi
docker run --rm \
-v "$(pwd)":/nunchaku \
pytorch/manylinux2_28-builder:cuda${CUDA_VERSION} \
bash -c "
cd /nunchaku && \
rm -rf build && \
gcc --version && g++ --version && \
${PYTHON_ROOT_PATH}/bin/pip install --pre torch==2.7.0.dev20250307+cu128 torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 && \
${PYTHON_ROOT_PATH}/bin/pip install build ninja wheel setuptools && \
export NUNCHAKU_INSTALL_MODE=ALL && \
export NUNCHAKU_BUILD_WHEELS=1 && \
export MAX_JOBS=${MAX_JOBS} && \
${PYTHON_ROOT_PATH}/bin/python -m build --wheel --no-isolation
"
\ No newline at end of file
@echo off
setlocal enabledelayedexpansion
:: get arguments
set PYTHON_VERSION=%1
set TORCH_VERSION=%2
set CUDA_VERSION=%3
set CUDA_SHORT_VERSION=%CUDA_VERSION:.=%
echo %CUDA_SHORT_VERSION%
:: setup some variables
if "%TORCH_VERSION%"=="2.5" (
set TORCHVISION_VERSION=0.20
set TORCHAUDIO_VERSION=2.5
) else if "%TORCH_VERSION%"=="2.6" (
set TORCHVISION_VERSION=0.21
set TORCHAUDIO_VERSION=2.6
) else (
echo TORCH_VERSION is not 2.5 or 2.6, no changes to versions.
)
echo setting TORCHVISION_VERSION to %TORCHVISION_VERSION% and TORCHAUDIO_VERSION to %TORCHAUDIO_VERSION%
:: conda environment name
set ENV_NAME=build_env_%PYTHON_VERSION%_%TORCH_VERSION%
echo Using conda environment: %ENV_NAME%
:: create conda environment
call conda create -y -n %ENV_NAME% python=%PYTHON_VERSION%
call conda activate %ENV_NAME%
:: install dependencies
call pip install ninja setuptools wheel build
call pip install --no-cache-dir torch==%TORCH_VERSION% torchvision==%TORCHVISION_VERSION% torchaudio==%TORCHAUDIO_VERSION% --index-url "https://download.pytorch.org/whl/cu%CUDA_SHORT_VERSION%/"
:: set environment variables
set NUNCHAKU_INSTALL_MODE=ALL
set NUNCHAKU_BUILD_WHEELS=1
:: cd to the parent directory
cd /d "%~dp0.."
if exist build rd /s /q build
:: set up Visual Studio compilation environment
call "C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\Common7\Tools\VsDevCmd.bat" -startdir=none -arch=x64 -host_arch=x64
set DISTUTILS_USE_SDK=1
:: build wheels
python -m build --wheel --no-isolation
:: exit conda
call conda deactivate
call conda remove -y -n %ENV_NAME% --all
echo Build complete!
@echo off
setlocal enabledelayedexpansion
:: get arguments
set PYTHON_VERSION=%1
set TORCH_VERSION=%2
set CUDA_VERSION=%3
set CUDA_SHORT_VERSION=%CUDA_VERSION:.=%
echo %CUDA_SHORT_VERSION%
:: conda environment name
set ENV_NAME=build_env_%PYTHON_VERSION%_%TORCH_VERSION%
echo Using conda environment: %ENV_NAME%
:: create conda environment
call conda create -y -n %ENV_NAME% python=%PYTHON_VERSION%
call conda activate %ENV_NAME%
:: install dependencies
call pip install ninja setuptools wheel build
if "%TORCH_VERSION%"=="2.7" (
call pip install --pre torch==2.7.0.dev20250307+cu128 torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
) else (
call pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
)
:: set environment variables
set NUNCHAKU_INSTALL_MODE=ALL
set NUNCHAKU_BUILD_WHEELS=1
:: cd to the parent directory
cd /d "%~dp0.."
if exist build rd /s /q build
:: set up Visual Studio compilation environment
call "C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\Common7\Tools\VsDevCmd.bat" -startdir=none -arch=x64 -host_arch=x64
set DISTUTILS_USE_SDK=1
:: build wheels
python -m build --wheel --no-isolation
:: exit conda
call conda deactivate
call conda remove -y -n %ENV_NAME% --all
echo Build complete!
#!/bin/bash #!/bin/bash
set -ex set -ex
#docker run --rm \ docker run --rm \
# -v "$(pwd)":/nunchaku \
# pytorch/manylinux-builder:cuda12.4 \
# bash -c "cd /nunchaku && rm -r *"
docker run --rm -it \
-v "$(pwd)":/nunchaku \ -v "$(pwd)":/nunchaku \
pytorch/manylinux-builder:cuda12.4 \ pytorch/manylinux-builder:cuda12.4 \
bash bash -c "cd /nunchaku && rm -rf *"
\ No newline at end of file \ No newline at end of file
...@@ -47,12 +47,12 @@ def get_sm_targets() -> list[str]: ...@@ -47,12 +47,12 @@ def get_sm_targets() -> list[str]:
sm = f"{capability[0]}{capability[1]}" sm = f"{capability[0]}{capability[1]}"
if sm == "120" and support_sm120: if sm == "120" and support_sm120:
sm = "120a" sm = "120a"
assert sm in ["80", "86", "89", "120a"], f"Unsupported SM {sm}" assert sm in ["75", "80", "86", "89", "120a"], f"Unsupported SM {sm}"
if sm not in ret: if sm not in ret:
ret.append(sm) ret.append(sm)
else: else:
assert install_mode == "ALL" assert install_mode == "ALL"
ret = ["80", "86", "89"] ret = ["75", "80", "86", "89"]
if support_sm120: if support_sm120:
ret.append("120a") ret.append("120a")
return ret return ret
...@@ -142,6 +142,7 @@ if __name__ == "__main__": ...@@ -142,6 +142,7 @@ if __name__ == "__main__":
*ncond("src/FluxModel.cpp"), *ncond("src/FluxModel.cpp"),
*ncond("src/SanaModel.cpp"), *ncond("src/SanaModel.cpp"),
"src/Serialization.cpp", "src/Serialization.cpp",
"src/Module.cpp",
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim64_fp16_sm80.cu"), *ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim64_fp16_sm80.cu"),
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim64_bf16_sm80.cu"), *ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim64_bf16_sm80.cu"),
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim128_fp16_sm80.cu"), *ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim128_fp16_sm80.cu"),
...@@ -158,9 +159,14 @@ if __name__ == "__main__": ...@@ -158,9 +159,14 @@ if __name__ == "__main__":
"src/kernels/layernorm_kernels.cu", "src/kernels/layernorm_kernels.cu",
"src/kernels/misc_kernels.cu", "src/kernels/misc_kernels.cu",
"src/kernels/zgemm/gemm_w4a4.cu", "src/kernels/zgemm/gemm_w4a4.cu",
"src/kernels/zgemm/gemm_w4a4_launch_fp16.cu", "src/kernels/zgemm/gemm_w4a4_test.cu",
"src/kernels/zgemm/gemm_w4a4_launch_bf16.cu", "src/kernels/zgemm/gemm_w4a4_launch_fp16_int4.cu",
"src/kernels/zgemm/gemm_w4a4_launch_fp16_int4_fasteri2f.cu",
"src/kernels/zgemm/gemm_w4a4_launch_fp16_fp4.cu",
"src/kernels/zgemm/gemm_w4a4_launch_bf16_int4.cu",
"src/kernels/zgemm/gemm_w4a4_launch_bf16_fp4.cu",
"src/kernels/zgemm/gemm_w8a8.cu", "src/kernels/zgemm/gemm_w8a8.cu",
"src/kernels/zgemm/attention.cu",
"src/kernels/dwconv.cu", "src/kernels/dwconv.cu",
"src/kernels/gemm_batched.cu", "src/kernels/gemm_batched.cu",
"src/kernels/gemm_f16.cu", "src/kernels/gemm_f16.cu",
......
#include "FluxModel.h" #include "FluxModel.h"
#include "kernels/misc_kernels.h" #include "kernels/misc_kernels.h"
#include "kernels/gemm_batched.h" #include "kernels/gemm_batched.h"
#include "kernels/zgemm/zgemm.h"
#include "flash_api.h" #include "flash_api.h"
#include "activation.h" #include "activation.h"
...@@ -39,7 +40,7 @@ Tensor forward_fc(GEMM_W4A4 &fc, Tensor x) { ...@@ -39,7 +40,7 @@ Tensor forward_fc(GEMM_W4A4 &fc, Tensor x) {
AdaLayerNormZeroSingle::AdaLayerNormZeroSingle(int dim, Tensor::ScalarType dtype, Device device) : AdaLayerNormZeroSingle::AdaLayerNormZeroSingle(int dim, Tensor::ScalarType dtype, Device device) :
dim(dim), dim(dim),
linear(dim, 3 * dim, true, dtype, device), linear(dim, 3 * dim, true, dtype, device),
norm(dim, 1e-6, false, dtype, device) norm(dim, 1e-6, false, dtype, device)
{ {
registerChildren registerChildren
(linear, "linear") (linear, "linear")
...@@ -58,12 +59,12 @@ AdaLayerNormZeroSingle::Output AdaLayerNormZeroSingle::forward(Tensor x, Tensor ...@@ -58,12 +59,12 @@ AdaLayerNormZeroSingle::Output AdaLayerNormZeroSingle::forward(Tensor x, Tensor
debug("x", x); debug("x", x);
Tensor norm_x = norm.forward(x); Tensor norm_x = norm.forward(x);
debug("norm_x", norm_x); debug("norm_x", norm_x);
kernels::mul_add(norm_x, scale_msa, shift_msa); kernels::mul_add(norm_x, scale_msa, shift_msa);
return Output{norm_x, gate_msa}; return Output{norm_x, gate_msa};
} }
AdaLayerNormZero::AdaLayerNormZero(int dim, bool pre_only, Tensor::ScalarType dtype, Device device) : AdaLayerNormZero::AdaLayerNormZero(int dim, bool pre_only, Tensor::ScalarType dtype, Device device) :
dim(dim), pre_only(pre_only), dim(dim), pre_only(pre_only),
linear(dim, pre_only ? 2 * dim : 6 * dim, true, dtype, device), linear(dim, pre_only ? 2 * dim : 6 * dim, true, dtype, device),
norm(dim, 1e-6, false, dtype, device) norm(dim, 1e-6, false, dtype, device)
...@@ -90,7 +91,7 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) { ...@@ -90,7 +91,7 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
kernels::mul_add(norm_x, scale_msa, shift_msa); kernels::mul_add(norm_x, scale_msa, shift_msa);
debug("norm_x_scaled", norm_x); debug("norm_x_scaled", norm_x);
return Output{norm_x}; return Output{norm_x};
} else { } else {
auto &&[shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp] = kernels::split_mod<6>(emb); auto &&[shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp] = kernels::split_mod<6>(emb);
...@@ -107,7 +108,7 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) { ...@@ -107,7 +108,7 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
} }
Attention::Attention(int num_heads, int dim_head, Device device) : Attention::Attention(int num_heads, int dim_head, Device device) :
num_heads(num_heads), dim_head(dim_head), force_fp16(false) num_heads(num_heads), dim_head(dim_head), force_fp16(false)
{ {
headmask_type = Tensor::allocate({num_heads}, Tensor::INT32, Device::cpu()); headmask_type = Tensor::allocate({num_heads}, Tensor::INT32, Device::cpu());
...@@ -117,6 +118,33 @@ Attention::Attention(int num_heads, int dim_head, Device device) : ...@@ -117,6 +118,33 @@ Attention::Attention(int num_heads, int dim_head, Device device) :
headmask_type = headmask_type.copy(device); headmask_type = headmask_type.copy(device);
} }
Tensor Attention::forward(Tensor qkv) {
assert(qkv.ndims() == 3);
const Device device = qkv.device();
const int batch_size = qkv.shape[0];
const int num_tokens = qkv.shape[1];
assert(qkv.shape[2] == num_heads * dim_head * 3);
Tensor reshaped = qkv.view({batch_size, num_tokens, num_heads * 3, dim_head});
Tensor q = reshaped.slice(2, 0, num_heads);
Tensor k = reshaped.slice(2, num_heads, num_heads * 2);
Tensor v = reshaped.slice(2, num_heads * 2, num_heads * 3);
Tensor raw_attn_output = mha_fwd(q, k, v,
0.0f,
pow(q.shape[-1], (-0.5)),
false, -1, -1, false
).front();
assert(raw_attn_output.shape[0] == batch_size);
assert(raw_attn_output.shape[1] == num_tokens);
assert(raw_attn_output.shape[2] == num_heads);
assert(raw_attn_output.shape[3] == dim_head);
return raw_attn_output.view({batch_size * num_tokens, num_heads, dim_head});
}
Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) { Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
const bool cast_fp16 = this->force_fp16 && qkv.scalar_type() != Tensor::FP16; const bool cast_fp16 = this->force_fp16 && qkv.scalar_type() != Tensor::FP16;
...@@ -150,7 +178,7 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) { ...@@ -150,7 +178,7 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
gemm_batched_fp16(pool_q, pool_k, pool_s); gemm_batched_fp16(pool_q, pool_k, pool_s);
} }
} }
blockmask = kernels::topk(pool_score, pool_tokens * (1 - sparsityRatio)); blockmask = kernels::topk(pool_score, pool_tokens * (1 - sparsityRatio));
if (cu_seqlens_cpu.valid()) { if (cu_seqlens_cpu.valid()) {
...@@ -226,16 +254,16 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) { ...@@ -226,16 +254,16 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
false false
).front(); ).front();
Tensor raw_attn_output = mha_fwd(q, k, v, Tensor raw_attn_output = mha_fwd(q, k, v,
0.0f, 0.0f,
pow(q.shape[-1], (-0.5)), pow(q.shape[-1], (-0.5)),
false, -1, -1, false false, -1, -1, false
).front(); ).front();
Tensor raw_attn_output = mha_varlen_fwd( Tensor raw_attn_output = mha_varlen_fwd(
q, k, v, q, k, v,
cu_seqlens, cu_seqlens, cu_seqlens, cu_seqlens,
num_tokens_img + num_tokens_context, num_tokens_img + num_tokens_context, num_tokens_img + num_tokens_txt, num_tokens_img + num_tokens_txt,
0.0f, 0.0f,
pow(q.shape[-1], (-0.5)), pow(q.shape[-1], (-0.5)),
false, false, -1, -1, false false, false, -1, -1, false
...@@ -260,7 +288,7 @@ void Attention::setForceFP16(Module *module, bool value) { ...@@ -260,7 +288,7 @@ void Attention::setForceFP16(Module *module, bool value) {
} }
FluxSingleTransformerBlock::FluxSingleTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, int mlp_ratio, bool use_fp4, Tensor::ScalarType dtype, Device device) : FluxSingleTransformerBlock::FluxSingleTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, int mlp_ratio, bool use_fp4, Tensor::ScalarType dtype, Device device) :
dim(dim), dim(dim),
dim_head(attention_head_dim / num_attention_heads), dim_head(attention_head_dim / num_attention_heads),
num_heads(num_attention_heads), num_heads(num_attention_heads),
mlp_hidden_dim(dim * mlp_ratio), mlp_hidden_dim(dim * mlp_ratio),
...@@ -298,19 +326,50 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te ...@@ -298,19 +326,50 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
Tensor residual = hidden_states; Tensor residual = hidden_states;
Tensor qkv = Tensor::allocate({batch_size, num_tokens, dim * 3}, norm_hidden_states.scalar_type(), norm_hidden_states.device()); Tensor attn_output;
// qkv_proj.forward(norm_hidden_states, qkv, {});
// debug("qkv_raw", qkv);
debug("rotary_emb", rotary_emb); debug("rotary_emb", rotary_emb);
qkv_proj.forward(norm_hidden_states, qkv, {}, norm_q.weight, norm_k.weight, rotary_emb);
debug("qkv", qkv); if (attnImpl == AttentionImpl::FlashAttention2) {
// Tensor qkv = forward_fc(qkv_proj, norm_hidden_states); Tensor qkv = Tensor::allocate({batch_size, num_tokens, dim * 3}, norm_hidden_states.scalar_type(), norm_hidden_states.device());
// qkv_proj.forward(norm_hidden_states, qkv, {});
Tensor attn_output = attn.forward(qkv, {}, 0); // debug("qkv_raw", qkv);
attn_output = attn_output.reshape({batch_size, num_tokens, num_heads * dim_head});
qkv_proj.forward(norm_hidden_states, qkv, {}, norm_q.weight, norm_k.weight, rotary_emb);
debug("qkv", qkv);
// Tensor qkv = forward_fc(qkv_proj, norm_hidden_states);
// attn_output = attn.forward(qkv, {}, 0);
attn_output = attn.forward(qkv);
attn_output = attn_output.reshape({batch_size, num_tokens, num_heads * dim_head});
} else if (attnImpl == AttentionImpl::NunchakuFP16) {
assert(batch_size == 1);
const int num_tokens_pad = ceilDiv(num_tokens, 256) * 256;
Tensor q = Tensor::allocate({batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device());
Tensor k = Tensor::allocate({batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device());
Tensor v = Tensor::allocate({batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device());
qkv_proj.forward(norm_hidden_states, {}, {}, norm_q.weight, norm_k.weight, rotary_emb, q, k, v, num_tokens);
debug("packed_q", q);
debug("packed_k", k);
debug("packed_v", v);
Tensor o = Tensor::allocate({batch_size, num_tokens_pad, num_heads * dim_head}, norm_hidden_states.scalar_type(), norm_hidden_states.device());
kernels::attention_fp16(q, k, v, o, pow(dim_head, (-0.5)));
attn_output = o.slice(1, 0, num_tokens);
} else {
assert(false);
}
debug("raw_attn_output", attn_output); debug("raw_attn_output", attn_output);
attn_output = forward_fc(out_proj, attn_output); attn_output = forward_fc(out_proj, attn_output);
debug("attn_output", attn_output); debug("attn_output", attn_output);
...@@ -319,7 +378,7 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te ...@@ -319,7 +378,7 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
hidden_states = kernels::add(attn_output, ff_output); hidden_states = kernels::add(attn_output, ff_output);
debug("attn_ff_output", hidden_states); debug("attn_ff_output", hidden_states);
kernels::mul_add(hidden_states, gate, residual); kernels::mul_add(hidden_states, gate, residual);
nvtxRangePop(); nvtxRangePop();
...@@ -327,7 +386,7 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te ...@@ -327,7 +386,7 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
return hidden_states; return hidden_states;
} }
JointTransformerBlock::JointTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, bool context_pre_only, bool use_fp4, Tensor::ScalarType dtype, Device device) : JointTransformerBlock::JointTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, bool context_pre_only, bool use_fp4, Tensor::ScalarType dtype, Device device) :
dim(dim), dim(dim),
dim_head(attention_head_dim / num_attention_heads), dim_head(attention_head_dim / num_attention_heads),
num_heads(num_attention_heads), num_heads(num_attention_heads),
...@@ -384,13 +443,13 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, ...@@ -384,13 +443,13 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
int num_tokens_img = hidden_states.shape[1]; int num_tokens_img = hidden_states.shape[1];
int num_tokens_context = encoder_hidden_states.shape[1]; int num_tokens_txt = encoder_hidden_states.shape[1];
assert(hidden_states.shape[2] == dim); assert(hidden_states.shape[2] == dim);
assert(encoder_hidden_states.shape[2] == dim); assert(encoder_hidden_states.shape[2] == dim);
spdlog::debug("hidden_states={} encoder_hidden_states={} temb={}", hidden_states.shape.str(), encoder_hidden_states.shape.str(), temb.shape.str()); spdlog::debug("hidden_states={} encoder_hidden_states={} temb={}", hidden_states.shape.str(), encoder_hidden_states.shape.str(), temb.shape.str());
spdlog::debug("batch_size={} num_tokens_img={} num_tokens_context={}", batch_size, num_tokens_img, num_tokens_context); spdlog::debug("batch_size={} num_tokens_img={} num_tokens_txt={}", batch_size, num_tokens_img, num_tokens_txt);
auto norm1_output = norm1.forward(hidden_states, temb); auto norm1_output = norm1.forward(hidden_states, temb);
auto norm1_context_output = norm1_context.forward(encoder_hidden_states, temb); auto norm1_context_output = norm1_context.forward(encoder_hidden_states, temb);
...@@ -408,76 +467,141 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, ...@@ -408,76 +467,141 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
nvtxRangePop(); nvtxRangePop();
auto stream = getCurrentCUDAStream(); auto stream = getCurrentCUDAStream();
Tensor concat;
Tensor pool;
{
nvtxRangePushA("qkv_proj");
const bool blockSparse = sparsityRatio > 0; int num_tokens_img_pad = 0, num_tokens_txt_pad = 0;
Tensor raw_attn_output;
const int poolTokens = num_tokens_img / POOL_SIZE + num_tokens_context / POOL_SIZE; if (attnImpl == AttentionImpl::FlashAttention2) {
concat = Tensor::allocate({batch_size, num_tokens_img + num_tokens_context, dim * 3}, norm1_output.x.scalar_type(), norm1_output.x.device()); num_tokens_img_pad = num_tokens_img;
num_tokens_txt_pad = num_tokens_txt;
pool = blockSparse Tensor concat;
? Tensor::allocate({batch_size, poolTokens, dim * 3}, norm1_output.x.scalar_type(), norm1_output.x.device()) Tensor pool;
: Tensor{};
for (int i = 0; i < batch_size; i++) {
// img first
Tensor qkv = concat.slice(0, i, i + 1).slice(1, 0, num_tokens_img);
Tensor qkv_context = concat.slice(0, i, i + 1).slice(1, num_tokens_img, num_tokens_img + num_tokens_context);
Tensor pool_qkv = pool.valid() {
? pool.slice(0, i, i + 1).slice(1, 0, num_tokens_img / POOL_SIZE) nvtxRangePushA("qkv_proj");
: Tensor{};
Tensor pool_qkv_context = pool.valid() const bool blockSparse = sparsityRatio > 0;
? concat.slice(0, i, i + 1).slice(1, num_tokens_img / POOL_SIZE, num_tokens_img / POOL_SIZE + num_tokens_context / POOL_SIZE)
const int poolTokens = num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE;
concat = Tensor::allocate({batch_size, num_tokens_img + num_tokens_txt, dim * 3}, norm1_output.x.scalar_type(), norm1_output.x.device());
pool = blockSparse
? Tensor::allocate({batch_size, poolTokens, dim * 3}, norm1_output.x.scalar_type(), norm1_output.x.device())
: Tensor{}; : Tensor{};
// qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv); for (int i = 0; i < batch_size; i++) {
// debug("qkv_raw", qkv); // img first
Tensor qkv = concat.slice(0, i, i + 1).slice(1, 0, num_tokens_img);
Tensor qkv_context = concat.slice(0, i, i + 1).slice(1, num_tokens_img, num_tokens_img + num_tokens_txt);
Tensor pool_qkv = pool.valid()
? pool.slice(0, i, i + 1).slice(1, 0, num_tokens_img / POOL_SIZE)
: Tensor{};
Tensor pool_qkv_context = pool.valid()
? concat.slice(0, i, i + 1).slice(1, num_tokens_img / POOL_SIZE, num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE)
: Tensor{};
debug("rotary_emb", rotary_emb); // qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv);
// debug("qkv_raw", qkv);
qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv, pool_qkv, norm_q.weight, norm_k.weight, rotary_emb); debug("rotary_emb", rotary_emb);
debug("qkv", qkv);
// qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context); qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv, pool_qkv, norm_q.weight, norm_k.weight, rotary_emb);
// debug("qkv_context_raw", qkv_context); debug("qkv", qkv);
debug("rotary_emb_context", rotary_emb_context); // qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context);
// debug("qkv_context_raw", qkv_context);
qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context, pool_qkv_context, norm_added_q.weight, norm_added_k.weight, rotary_emb_context); debug("rotary_emb_context", rotary_emb_context);
debug("qkv_context", qkv_context);
qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context, pool_qkv_context, norm_added_q.weight, norm_added_k.weight, rotary_emb_context);
debug("qkv_context", qkv_context);
}
nvtxRangePop();
}
spdlog::debug("concat={}", concat.shape.str());
debug("concat", concat);
assert(concat.shape[2] == num_heads * dim_head * 3);
nvtxRangePushA("Attention");
if (pool.valid()) {
raw_attn_output = attn.forward(concat, pool, sparsityRatio);
} else {
raw_attn_output = attn.forward(concat);
} }
nvtxRangePop(); nvtxRangePop();
}
spdlog::debug("concat={}", concat.shape.str()); spdlog::debug("raw_attn_output={}", raw_attn_output.shape.str());
debug("concat", concat);
assert(concat.shape[2] == num_heads * dim_head * 3); raw_attn_output = raw_attn_output.view({batch_size, num_tokens_img + num_tokens_txt, num_heads, dim_head});
nvtxRangePushA("Attention"); } else if (attnImpl == AttentionImpl::NunchakuFP16) {
num_tokens_img_pad = ceilDiv(num_tokens_img, 256) * 256;
num_tokens_txt_pad = ceilDiv(num_tokens_txt, 256) * 256;
Tensor raw_attn_output = attn.forward(concat, pool, sparsityRatio); Tensor concat_q, concat_k, concat_v;
nvtxRangePop(); {
nvtxRangePushA("qkv_proj");
spdlog::debug("raw_attn_output={}", raw_attn_output.shape.str()); concat_q = Tensor::allocate({batch_size, num_heads, num_tokens_img_pad + num_tokens_txt_pad, dim_head}, Tensor::FP16, norm1_output.x.device());
concat_k = Tensor::empty_like(concat_q);
concat_v = Tensor::empty_like(concat_q);
raw_attn_output = raw_attn_output.view({batch_size, num_tokens_img + num_tokens_context, num_heads, dim_head}); for (int i = 0; i < batch_size; i++) {
debug("raw_attn_output", raw_attn_output); // img first
auto sliceImg = [&](Tensor x) {
return x.slice(0, i, i+1).slice(2, 0, num_tokens_img_pad);
};
auto sliceTxt = [&](Tensor x) {
return x.slice(0, i, i+1).slice(2, num_tokens_img_pad, num_tokens_img_pad + num_tokens_txt_pad);
};
qkv_proj.forward(
norm1_output.x.slice(0, i, i + 1), {}, {}, norm_q.weight, norm_k.weight, rotary_emb,
sliceImg(concat_q), sliceImg(concat_k), sliceImg(concat_v), num_tokens_img
);
qkv_proj_context.forward(
norm1_context_output.x.slice(0, i, i + 1), {}, {}, norm_added_q.weight, norm_added_k.weight, rotary_emb_context,
sliceTxt(concat_q), sliceTxt(concat_k), sliceTxt(concat_v), num_tokens_txt
);
}
debug("concat_q", concat_q);
debug("concat_k", concat_k);
debug("concat_v", concat_v);
nvtxRangePop();
}
raw_attn_output = Tensor::allocate({batch_size, num_tokens_img_pad + num_tokens_txt_pad, num_heads * dim_head}, norm1_output.x.scalar_type(), norm1_output.x.device());
nvtxRangePushA("Attention");
kernels::attention_fp16(concat_q, concat_k, concat_v, raw_attn_output, pow(dim_head, (-0.5)));
nvtxRangePop();
raw_attn_output = raw_attn_output.view({batch_size, num_tokens_img_pad + num_tokens_txt_pad, num_heads, dim_head});
} else {
assert(false);
}
debug("raw_attn_output", raw_attn_output);
{ {
nvtxRangePushA("o_proj"); nvtxRangePushA("o_proj");
auto &&[_, gate_msa, shift_mlp, scale_mlp, gate_mlp] = norm1_output; auto &&[_, gate_msa, shift_mlp, scale_mlp, gate_mlp] = norm1_output;
// raw_attn_output: [batch_size, num_tokens_img + num_tokens_context, num_heads * dim_head] // raw_attn_output: [batch_size, num_tokens_img + num_tokens_txt, num_heads * dim_head]
Tensor raw_attn_output_split; Tensor raw_attn_output_split;
if (batch_size == 1) { if (batch_size == 1) {
...@@ -485,16 +609,16 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, ...@@ -485,16 +609,16 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
} else { } else {
raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_img, num_heads * dim_head}, raw_attn_output.scalar_type(), raw_attn_output.device()); raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_img, num_heads * dim_head}, raw_attn_output.scalar_type(), raw_attn_output.device());
checkCUDA(cudaMemcpy2DAsync( checkCUDA(cudaMemcpy2DAsync(
raw_attn_output_split.data_ptr(), raw_attn_output_split.data_ptr(),
num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(), num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(),
raw_attn_output.data_ptr(), raw_attn_output.data_ptr(),
(num_tokens_img + num_tokens_context) * num_heads * dim_head * raw_attn_output.scalar_size(), (num_tokens_img_pad + num_tokens_txt_pad) * num_heads * dim_head * raw_attn_output.scalar_size(),
num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(), num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(),
batch_size, batch_size,
cudaMemcpyDeviceToDevice, cudaMemcpyDeviceToDevice,
stream)); stream));
} }
spdlog::debug("raw_attn_output_split={}", raw_attn_output_split.shape.str()); spdlog::debug("raw_attn_output_split={}", raw_attn_output_split.shape.str());
debug("img.raw_attn_output_split", raw_attn_output_split); debug("img.raw_attn_output_split", raw_attn_output_split);
...@@ -546,20 +670,20 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, ...@@ -546,20 +670,20 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
Tensor raw_attn_output_split; Tensor raw_attn_output_split;
if (batch_size == 1) { if (batch_size == 1) {
raw_attn_output_split = raw_attn_output.slice(1, num_tokens_img, num_tokens_img + num_tokens_context).reshape({batch_size, num_tokens_context, num_heads * dim_head}); raw_attn_output_split = raw_attn_output.slice(1, num_tokens_img_pad, num_tokens_img_pad + num_tokens_txt).reshape({batch_size, num_tokens_txt, num_heads * dim_head});
} else { } else {
raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_context, num_heads * dim_head}, raw_attn_output.scalar_type(), raw_attn_output.device()); raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_txt, num_heads * dim_head}, raw_attn_output.scalar_type(), raw_attn_output.device());
checkCUDA(cudaMemcpy2DAsync( checkCUDA(cudaMemcpy2DAsync(
raw_attn_output_split.data_ptr(), raw_attn_output_split.data_ptr(),
num_tokens_context * num_heads * dim_head * raw_attn_output_split.scalar_size(), num_tokens_txt * num_heads * dim_head * raw_attn_output_split.scalar_size(),
raw_attn_output.data_ptr<char>() + num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(), raw_attn_output.data_ptr<char>() + num_tokens_img_pad * num_heads * dim_head * raw_attn_output_split.scalar_size(),
(num_tokens_img + num_tokens_context) * num_heads * dim_head * raw_attn_output.scalar_size(), (num_tokens_img_pad + num_tokens_txt_pad) * num_heads * dim_head * raw_attn_output.scalar_size(),
num_tokens_context * num_heads * dim_head * raw_attn_output_split.scalar_size(), num_tokens_txt * num_heads * dim_head * raw_attn_output_split.scalar_size(),
batch_size, batch_size,
cudaMemcpyDeviceToDevice, cudaMemcpyDeviceToDevice,
stream)); stream));
} }
spdlog::debug("raw_attn_output_split={}", raw_attn_output_split.shape.str()); spdlog::debug("raw_attn_output_split={}", raw_attn_output_split.shape.str());
debug("context.raw_attn_output_split", raw_attn_output_split); debug("context.raw_attn_output_split", raw_attn_output_split);
...@@ -585,7 +709,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, ...@@ -585,7 +709,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
#else #else
auto norm_hidden_states = encoder_hidden_states; auto norm_hidden_states = encoder_hidden_states;
#endif #endif
// Tensor ff_output = mlp_context_fc2.forward(GELU::forward(mlp_context_fc1.forward(norm_hidden_states))); // Tensor ff_output = mlp_context_fc2.forward(GELU::forward(mlp_context_fc1.forward(norm_hidden_states)));
// Tensor ff_output = mlp_context_fc2.forward_quant(quant_static_fuse_gelu(mlp_context_fc1.forward(norm_hidden_states), 1.0)); // Tensor ff_output = mlp_context_fc2.forward_quant(quant_static_fuse_gelu(mlp_context_fc1.forward(norm_hidden_states), 1.0));
...@@ -607,7 +731,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, ...@@ -607,7 +731,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
return { hidden_states, encoder_hidden_states }; return { hidden_states, encoder_hidden_states };
} }
FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device) : offload(offload) { FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device) : dtype(dtype), offload(offload) {
for (int i = 0; i < 19; i++) { for (int i = 0; i < 19; i++) {
transformer_blocks.push_back(std::make_unique<JointTransformerBlock>(3072, 24, 3072, false, use_fp4, dtype, device)); transformer_blocks.push_back(std::make_unique<JointTransformerBlock>(3072, 24, 3072, false, use_fp4, dtype, device));
registerChildren(*transformer_blocks.back(), format("transformer_blocks.{}", i)); registerChildren(*transformer_blocks.back(), format("transformer_blocks.{}", i));
...@@ -626,7 +750,16 @@ FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Devic ...@@ -626,7 +750,16 @@ FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Devic
} }
} }
Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb_img, Tensor rotary_emb_context, Tensor rotary_emb_single) { Tensor FluxModel::forward(
Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor temb,
Tensor rotary_emb_img,
Tensor rotary_emb_context,
Tensor rotary_emb_single,
Tensor controlnet_block_samples,
Tensor controlnet_single_block_samples,
bool skip_first_layer) {
const int batch_size = hidden_states.shape[0]; const int batch_size = hidden_states.shape[0];
const Tensor::ScalarType dtype = hidden_states.dtype(); const Tensor::ScalarType dtype = hidden_states.dtype();
const Device device = hidden_states.device(); const Device device = hidden_states.device();
...@@ -639,9 +772,20 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te ...@@ -639,9 +772,20 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
Tensor concat; Tensor concat;
auto compute = [&](int layer) { auto compute = [&](int layer) {
if (skip_first_layer && size_t(layer) == 0) return;
if (size_t(layer) < transformer_blocks.size()) { if (size_t(layer) < transformer_blocks.size()) {
auto &block = transformer_blocks.at(layer); auto &block = transformer_blocks.at(layer);
std::tie(hidden_states, encoder_hidden_states) = block->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f); std::tie(hidden_states, encoder_hidden_states) = block->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
if (controlnet_block_samples.valid()) {
const int num_controlnet_block_samples = controlnet_block_samples.shape[0];
int interval_control = ceilDiv(transformer_blocks.size(), static_cast<size_t>(num_controlnet_block_samples));
int block_index = layer / interval_control;
// Xlabs ControlNet
// block_index = layer % num_controlnet_block_samples;
hidden_states = kernels::add(hidden_states, controlnet_block_samples[block_index]);
}
} else { } else {
if (size_t(layer) == transformer_blocks.size()) { if (size_t(layer) == transformer_blocks.size()) {
// txt first, same as diffusers // txt first, same as diffusers
...@@ -652,10 +796,23 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te ...@@ -652,10 +796,23 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
} }
hidden_states = concat; hidden_states = concat;
encoder_hidden_states = {}; encoder_hidden_states = {};
} }
auto &block = single_transformer_blocks.at(layer - transformer_blocks.size()); auto &block = single_transformer_blocks.at(layer - transformer_blocks.size());
hidden_states = block->forward(hidden_states, temb, rotary_emb_single); hidden_states = block->forward(hidden_states, temb, rotary_emb_single);
if (controlnet_single_block_samples.valid()) {
const int num_controlnet_single_block_samples = controlnet_single_block_samples.shape[0];
int interval_control = ceilDiv(single_transformer_blocks.size(), static_cast<size_t>(num_controlnet_single_block_samples));
int block_index = (layer - transformer_blocks.size()) / interval_control;
// Xlabs ControlNet
// block_index = layer % num_controlnet_single_block_samples
auto slice = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
slice = kernels::add(slice, controlnet_single_block_samples[block_index]);
hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
}
} }
}; };
auto load = [&](int layer) { auto load = [&](int layer) {
...@@ -681,4 +838,58 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te ...@@ -681,4 +838,58 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
helper.run(); helper.run();
return hidden_states; return hidden_states;
} }
\ No newline at end of file
std::tuple<Tensor, Tensor> FluxModel::forward_layer(
size_t layer,
Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor temb,
Tensor rotary_emb_img,
Tensor rotary_emb_context,
Tensor controlnet_block_samples,
Tensor controlnet_single_block_samples) {
std::tie(hidden_states, encoder_hidden_states) = transformer_blocks.at(layer)->forward(
hidden_states,
encoder_hidden_states,
temb,
rotary_emb_img,
rotary_emb_context, 0.0f);
const int txt_tokens = encoder_hidden_states.shape[1];
const int img_tokens = hidden_states.shape[1];
if (layer < transformer_blocks.size() && controlnet_block_samples.valid()) {
const int num_controlnet_block_samples = controlnet_block_samples.shape[0];
int interval_control = ceilDiv(transformer_blocks.size(), static_cast<size_t>(num_controlnet_block_samples));
int block_index = layer / interval_control;
// Xlabs ControlNet
// block_index = layer % num_controlnet_block_samples;
hidden_states = kernels::add(hidden_states, controlnet_block_samples[block_index]);
} else if (layer >= transformer_blocks.size() && controlnet_single_block_samples.valid()) {
const int num_controlnet_single_block_samples = controlnet_single_block_samples.shape[0];
int interval_control = ceilDiv(single_transformer_blocks.size(), static_cast<size_t>(num_controlnet_single_block_samples));
int block_index = (layer - transformer_blocks.size()) / interval_control;
// Xlabs ControlNet
// block_index = layer % num_controlnet_single_block_samples
auto slice = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
slice = kernels::add(slice, controlnet_single_block_samples[block_index]);
hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
}
return { hidden_states, encoder_hidden_states };
}
void FluxModel::setAttentionImpl(AttentionImpl impl) {
for (auto &&block : this->transformer_blocks) {
block->attnImpl = impl;
}
for (auto &&block : this->single_transformer_blocks) {
block->attnImpl = impl;
}
}
...@@ -6,6 +6,11 @@ ...@@ -6,6 +6,11 @@
#include "Linear.h" #include "Linear.h"
#include "layernorm.h" #include "layernorm.h"
enum class AttentionImpl {
FlashAttention2 = 0,
NunchakuFP16,
};
class AdaLayerNormZeroSingle : public Module { class AdaLayerNormZeroSingle : public Module {
public: public:
static constexpr bool USE_4BIT = true; static constexpr bool USE_4BIT = true;
...@@ -56,8 +61,9 @@ private: ...@@ -56,8 +61,9 @@ private:
class Attention : public Module { class Attention : public Module {
public: public:
static constexpr int POOL_SIZE = 128; static constexpr int POOL_SIZE = 128;
Attention(int num_heads, int dim_head, Device device); Attention(int num_heads, int dim_head, Device device);
Tensor forward(Tensor qkv);
Tensor forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio); Tensor forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio);
static void setForceFP16(Module *module, bool value); static void setForceFP16(Module *module, bool value);
...@@ -86,6 +92,8 @@ public: ...@@ -86,6 +92,8 @@ public:
const int num_heads; const int num_heads;
const int mlp_hidden_dim; const int mlp_hidden_dim;
AttentionImpl attnImpl = AttentionImpl::FlashAttention2;
private: private:
AdaLayerNormZeroSingle norm; AdaLayerNormZeroSingle norm;
GEMM mlp_fc1; GEMM mlp_fc1;
...@@ -110,6 +118,8 @@ public: ...@@ -110,6 +118,8 @@ public:
const int num_heads; const int num_heads;
const bool context_pre_only; const bool context_pre_only;
AttentionImpl attnImpl = AttentionImpl::FlashAttention2;
private: private:
AdaLayerNormZero norm1; AdaLayerNormZero norm1;
AdaLayerNormZero norm1_context; AdaLayerNormZero norm1_context;
...@@ -129,9 +139,30 @@ private: ...@@ -129,9 +139,30 @@ private:
class FluxModel : public Module { class FluxModel : public Module {
public: public:
FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device); FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb_img, Tensor rotary_emb_context, Tensor rotary_emb_single); Tensor forward(
Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor temb,
Tensor rotary_emb_img,
Tensor rotary_emb_context,
Tensor rotary_emb_single,
Tensor controlnet_block_samples,
Tensor controlnet_single_block_samples,
bool skip_first_layer = false);
std::tuple<Tensor, Tensor> forward_layer(
size_t layer,
Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor temb,
Tensor rotary_emb_img,
Tensor rotary_emb_context,
Tensor controlnet_block_samples,
Tensor controlnet_single_block_samples);
void setAttentionImpl(AttentionImpl impl);
public: public:
const Tensor::ScalarType dtype;
std::vector<std::unique_ptr<JointTransformerBlock>> transformer_blocks; std::vector<std::unique_ptr<JointTransformerBlock>> transformer_blocks;
std::vector<std::unique_ptr<FluxSingleTransformerBlock>> single_transformer_blocks; std::vector<std::unique_ptr<FluxSingleTransformerBlock>> single_transformer_blocks;
......
...@@ -52,13 +52,14 @@ void GEMV_AWQ::loadParam(std::string key, Tensor &dst, Tensor src) { ...@@ -52,13 +52,14 @@ void GEMV_AWQ::loadParam(std::string key, Tensor &dst, Tensor src) {
if (key == "lora_down" || key == "lora_up") { if (key == "lora_down" || key == "lora_up") {
assert(src.ndims() == 2); assert(src.ndims() == 2);
if (dst.shape.dataExtent != src.shape.dataExtent) { if (dst.shape.dataExtent != src.shape.dataExtent) {
dst = src.copy(this->device); dst = Tensor::allocate(src.shape.dataExtent, dst.scalar_type(), this->device);
Module::loadParam(key, dst, src);
if (key == "lora_down") { if (key == "lora_down") {
const int new_rank = dst.shape[0]; const int new_rank = dst.shape[0];
this->lora_rank = new_rank; this->lora_rank = new_rank;
} }
} else { } else {
dst.copy_(src); Module::loadParam(key, dst, src);
} }
} else { } else {
Module::loadParam(key, dst, src); Module::loadParam(key, dst, src);
...@@ -143,16 +144,18 @@ void GEMM_W4A4::loadParam(std::string key, Tensor &dst, Tensor src) { ...@@ -143,16 +144,18 @@ void GEMM_W4A4::loadParam(std::string key, Tensor &dst, Tensor src) {
if (key == "lora_down" || key == "lora_up") { if (key == "lora_down" || key == "lora_up") {
assert(src.ndims() == 2); assert(src.ndims() == 2);
if (dst.shape.dataExtent != src.shape.dataExtent) { if (dst.shape.dataExtent != src.shape.dataExtent) {
dst = src.copy(this->device); dst = Tensor::allocate(src.shape.dataExtent, dst.scalar_type(), this->device);
Module::loadParam(key, dst, src);
this->lora_rank = dst.shape[1]; this->lora_rank = dst.shape[1];
this->lora_scales.resize(ceilDiv(this->lora_rank, 16), 1.0f); this->lora_scales.resize(ceilDiv(this->lora_rank, 16), 1.0f);
} else { } else {
dst.copy_(src); Module::loadParam(key, dst, src);
} }
} else if (key == "wcscales") { } else if (key == "wcscales") {
assert(src.ndims() == 1); assert(src.ndims() == 1);
assert(src.shape[0] == out_features_pad); assert(src.shape[0] == out_features_pad);
dst = src.copy(this->device); dst = Tensor::allocate(src.shape.dataExtent, dst.scalar_type(), this->device);
Module::loadParam(key, dst, src);
} else if (key == "wtscale") { } else if (key == "wtscale") {
assert(src.numel() == 1); assert(src.numel() == 1);
if (src.dtype() == Tensor::BF16) { if (src.dtype() == Tensor::BF16) {
...@@ -160,7 +163,7 @@ void GEMM_W4A4::loadParam(std::string key, Tensor &dst, Tensor src) { ...@@ -160,7 +163,7 @@ void GEMM_W4A4::loadParam(std::string key, Tensor &dst, Tensor src) {
} else if (src.dtype() == Tensor::FP16) { } else if (src.dtype() == Tensor::FP16) {
*dst.data_ptr<float>() = float(*src.data_ptr<half>()); *dst.data_ptr<float>() = float(*src.data_ptr<half>());
} else if (src.dtype() == Tensor::FP32) { } else if (src.dtype() == Tensor::FP32) {
dst.copy_(src); Module::loadParam(key, dst, src);
} else { } else {
assert(false); assert(false);
} }
...@@ -181,7 +184,7 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward(Tensor x ...@@ -181,7 +184,7 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward(Tensor x
return forward_quant(quantize(x, false), fuse, nextGEMM); return forward_quant(quantize(x, false), fuse, nextGEMM);
} }
void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor norm_k, Tensor rotary_emb) { void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor norm_k, Tensor rotary_emb, Tensor out_q, Tensor out_k, Tensor out_v, int numTokens) {
QuantizedActivation qact = quantize(x, false); QuantizedActivation qact = quantize(x, false);
#if !NO_LORA_FUSION #if !NO_LORA_FUSION
...@@ -196,7 +199,8 @@ void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor ...@@ -196,7 +199,8 @@ void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor
kernels::gemm_w4a4( kernels::gemm_w4a4(
qact.act, qweight, out, {}, qact.ascales, wscales, {}, pool, qact.lora_act, this->lora_up, {}, {}, norm_q, norm_k, rotary_emb, this->bias, {}, {}, {}, qact.is_unsigned, this->lora_scales, false, qact.act, qweight, out, {}, qact.ascales, wscales, {}, pool, qact.lora_act, this->lora_up, {}, {}, norm_q, norm_k, rotary_emb, this->bias, {}, {}, {}, qact.is_unsigned, this->lora_scales, false,
use_fp4, *this->wtscale.data_ptr<float>(), wcscales.numel() > 0 ? wcscales: Tensor{} use_fp4, *this->wtscale.data_ptr<float>(), wcscales.numel() > 0 ? wcscales: Tensor{},
out_q, out_k, out_v, numTokens
); );
debug("gemm.out", out); debug("gemm.out", out);
...@@ -277,7 +281,8 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu ...@@ -277,7 +281,8 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
kernels::gemm_w4a4( kernels::gemm_w4a4(
qact.act, qweight, out, qout.act, qact.ascales, wscales, qout.ascales, {}, qact.lora_act, this->lora_up, next_lora, qout.lora_act, {}, {}, {}, this->bias, next_smooth, {}, {}, qact.is_unsigned, this->lora_scales, fuse == FuseOptions::SILU, qact.act, qweight, out, qout.act, qact.ascales, wscales, qout.ascales, {}, qact.lora_act, this->lora_up, next_lora, qout.lora_act, {}, {}, {}, this->bias, next_smooth, {}, {}, qact.is_unsigned, this->lora_scales, fuse == FuseOptions::SILU,
use_fp4, *this->wtscale.data_ptr<float>(), wcscales.numel() > 0 ? wcscales: Tensor{} use_fp4, *this->wtscale.data_ptr<float>(), wcscales.numel() > 0 ? wcscales: Tensor{},
{}, {}, {}, 0
); );
if (fuse == FuseOptions::EMPTY || fuse == FuseOptions::SILU) { if (fuse == FuseOptions::EMPTY || fuse == FuseOptions::SILU) {
...@@ -446,9 +451,9 @@ GEMM_W8A8::QuantizedActivation GEMM_W8A8::quantize(Tensor x, bool fuse_glu) { ...@@ -446,9 +451,9 @@ GEMM_W8A8::QuantizedActivation GEMM_W8A8::quantize(Tensor x, bool fuse_glu) {
} }
Tensor GEMM_W8A8::forward_quant(QuantizedActivation qact) { Tensor GEMM_W8A8::forward_quant(QuantizedActivation qact) {
auto oshape = qact.act.shape; auto shape = TensorShape(qact.act.shape.dataExtent);
oshape[-1] = out_features; shape[-1] = out_features;
Tensor out = Tensor::allocate(oshape, this->dtype, qact.act.device()); Tensor out = Tensor::allocate(shape, this->dtype, qact.act.device());
kernels::gemm_w8a8(qact.act, this->qweight, out, qact.ascales, this->wscales, this->bias); kernels::gemm_w8a8(qact.act, this->qweight, out, qact.ascales, this->wscales, this->bias);
debug("gemm.out", out); debug("gemm.out", out);
......
...@@ -69,7 +69,11 @@ public: ...@@ -69,7 +69,11 @@ public:
Tensor forward(Tensor x); Tensor forward(Tensor x);
Tensor forward_silu(Tensor x); Tensor forward_silu(Tensor x);
std::variant<Tensor, QuantizedActivation> forward(Tensor x, FuseOptions fuse, GEMM_W4A4 *nextGEMM = nullptr); std::variant<Tensor, QuantizedActivation> forward(Tensor x, FuseOptions fuse, GEMM_W4A4 *nextGEMM = nullptr);
void forward(Tensor x, Tensor out, Tensor pool = {}, Tensor norm_q = {}, Tensor norm_k = {}, Tensor rotary_emb = {}); void forward(
Tensor x, Tensor out,
Tensor pool = {}, Tensor norm_q = {}, Tensor norm_k = {}, Tensor rotary_emb = {},
Tensor out_q = {}, Tensor out_k = {}, Tensor out_v = {}, int numTokens = 0
);
std::variant<Tensor, QuantizedActivation> forward_quant(QuantizedActivation qact, FuseOptions fuse, GEMM_W4A4 *nextGEMM = nullptr); std::variant<Tensor, QuantizedActivation> forward_quant(QuantizedActivation qact, FuseOptions fuse, GEMM_W4A4 *nextGEMM = nullptr);
Tensor forward_quant(QuantizedActivation qact); Tensor forward_quant(QuantizedActivation qact);
......
#include "common.h"
#include "Module.h"
#include "kernels/misc_kernels.h"
void Module::copyWithCast(Tensor dst, Tensor src) {
assert(dst.is_contiguous());
assert(dst.device().type == Device::CUDA);
if (src.device().type == Device::CUDA && src.device().idx == dst.device().idx) {
nunchaku::kernels::cast(src, dst);
} else {
Tensor tmp;
tmp.buffer = dst.buffer;
tmp.shape = dst.shape;
tmp.scalarType = src.scalarType;
tmp.copy_(src);
nunchaku::kernels::cast(tmp, dst);
}
}
...@@ -131,10 +131,23 @@ public: ...@@ -131,10 +131,23 @@ public:
m->enabledLazyLoad = val; m->enabledLazyLoad = val;
}); });
} }
void setAutoCastFP16(bool val) {
traverse([val](Module *m) {
m->enabledAutoCastFP16 = val;
});
}
protected: protected:
virtual void loadParam(std::string key, Tensor &dst, Tensor src) { virtual void loadParam(std::string key, Tensor &dst, Tensor src) {
dst.copy_(src); static const std::set<Tensor::ScalarType> whitelist = {
Tensor::FP16,
Tensor::BF16,
};
if (enabledAutoCastFP16 && dst.scalar_type() != src.scalar_type() && whitelist.contains(dst.scalar_type()) && whitelist.contains(src.scalar_type())) {
copyWithCast(dst, src);
} else {
dst.copy_(src);
}
} }
struct ChildrenRegisterHelper { struct ChildrenRegisterHelper {
...@@ -174,7 +187,7 @@ protected: ...@@ -174,7 +187,7 @@ protected:
} }
void debug(std::string name, Tensor tensor) { void debug(std::string name, Tensor tensor) {
if (DebugContext::ctxs.empty()) { if (DebugContext::ctxs.empty() || !tensor.valid()) {
return; return;
} }
std::string prefix = getFullName(); std::string prefix = getFullName();
...@@ -187,6 +200,9 @@ protected: ...@@ -187,6 +200,9 @@ protected:
} }
} }
private:
void copyWithCast(Tensor dst, Tensor src);
public: public:
Module *parent = nullptr; Module *parent = nullptr;
std::string name = ""; std::string name = "";
...@@ -194,6 +210,7 @@ public: ...@@ -194,6 +210,7 @@ public:
std::map<std::string, Param> params; std::map<std::string, Param> params;
bool enabledLazyLoad = false; bool enabledLazyLoad = false;
bool enabledAutoCastFP16 = true;
}; };
struct LayerOffloadHelper { struct LayerOffloadHelper {
......
#include <iostream>
#include "SanaModel.h" #include "SanaModel.h"
#include "kernels/zgemm/zgemm.h" #include "kernels/zgemm/zgemm.h"
#include "flash_api.h" #include "flash_api.h"
...@@ -8,6 +10,7 @@ ...@@ -8,6 +10,7 @@
using spdlog::fmt_lib::format; using spdlog::fmt_lib::format;
using namespace nunchaku; using namespace nunchaku;
SanaLinearAttention::SanaLinearAttention(int dim, bool bias, bool pag, bool use_fp4, Tensor::ScalarType dtype, Device device) : SanaLinearAttention::SanaLinearAttention(int dim, bool bias, bool pag, bool use_fp4, Tensor::ScalarType dtype, Device device) :
dim(dim), dim(dim),
dim_pad(ceilDiv(dim, 128) * 128), dim_pad(ceilDiv(dim, 128) * 128),
...@@ -28,7 +31,7 @@ SanaLinearAttention::SanaLinearAttention(int dim, bool bias, bool pag, bool use_ ...@@ -28,7 +31,7 @@ SanaLinearAttention::SanaLinearAttention(int dim, bool bias, bool pag, bool use_
Tensor SanaLinearAttention::forward(Tensor x, Tensor out) { Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
constexpr int HEAD_DIM = 32; constexpr int HEAD_DIM = 32;
assert(x.ndims() == 3); assert(x.ndims() == 3);
const int batch_size = x.shape[0]; const int batch_size = x.shape[0];
const int num_tokens = x.shape[1]; const int num_tokens = x.shape[1];
...@@ -45,7 +48,7 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) { ...@@ -45,7 +48,7 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
x_pad.slice(0, i, i + 1).slice(1, 0, num_tokens).copy_(x.slice(0, i, i + 1)); x_pad.slice(0, i, i + 1).slice(1, 0, num_tokens).copy_(x.slice(0, i, i + 1));
} }
x = x_pad; x = x_pad;
} }
...@@ -55,18 +58,19 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) { ...@@ -55,18 +58,19 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
Tensor vk = Tensor::allocate({batch_size, num_heads, HEAD_DIM + 1, HEAD_DIM}, Tensor::FP32, x.device()); Tensor vk = Tensor::allocate({batch_size, num_heads, HEAD_DIM + 1, HEAD_DIM}, Tensor::FP32, x.device());
kernels::gemm_w4a4( kernels::gemm_w4a4(
qact.act, qact.act,
qkv_proj.qweight, qkv_proj.qweight,
{}, {},
{}, {},
qact.ascales, qact.ascales,
qkv_proj.wscales, qkv_proj.wscales,
{}, {}, qact.lora_act, qkv_proj.lora_up, {}, {}, {}, {}, {}, qkv_proj.bias, {}, {}, {}, qact.lora_act, qkv_proj.lora_up, {}, {}, {}, {}, {}, qkv_proj.bias, {},
vk, q, vk, q,
qact.is_unsigned, qkv_proj.lora_scales, false, qact.is_unsigned, qkv_proj.lora_scales, false,
qkv_proj.use_fp4, qkv_proj.use_fp4,
*qkv_proj.wtscale.data_ptr<float>(), *qkv_proj.wtscale.data_ptr<float>(),
qkv_proj.wcscales.numel() > 0 ? qkv_proj.wcscales : Tensor{} qkv_proj.wcscales.numel() > 0 ? qkv_proj.wcscales : Tensor{},
{}, {}, {}, 0
); );
debug("vk", vk); debug("vk", vk);
...@@ -118,12 +122,12 @@ Tensor SanaLinearAttention::forward_pag(Tensor x, bool cfg) { ...@@ -118,12 +122,12 @@ Tensor SanaLinearAttention::forward_pag(Tensor x, bool cfg) {
} }
this->forward(x_org, out_org); this->forward(x_org, out_org);
Tensor v_ptb = this->pag_to_v.value().forward(x_ptb); Tensor v_ptb = this->pag_to_v.value().forward(x_ptb);
this->out_proj.forward(v_ptb, out_ptb); this->out_proj.forward(v_ptb, out_ptb);
return out; return out;
} }
MultiHeadCrossAttention::MultiHeadCrossAttention(int num_heads, int head_dim, bool use_fp4, Tensor::ScalarType dtype, Device device) : MultiHeadCrossAttention::MultiHeadCrossAttention(int num_heads, int head_dim, bool use_fp4, Tensor::ScalarType dtype, Device device) :
num_heads(num_heads), head_dim(head_dim), num_heads(num_heads), head_dim(head_dim),
...@@ -143,7 +147,7 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens ...@@ -143,7 +147,7 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens
assert(cond.ndims() == 2); assert(cond.ndims() == 2);
assert(cu_seqlens_img.ndims() == 1); assert(cu_seqlens_img.ndims() == 1);
assert(cu_seqlens_txt.ndims() == 1); assert(cu_seqlens_txt.ndims() == 1);
const int batch_size = x.shape[0]; const int batch_size = x.shape[0];
const int num_tokens_img = x.shape[1]; const int num_tokens_img = x.shape[1];
const int num_tokens_txt = cond.shape[0]; const int num_tokens_txt = cond.shape[0];
...@@ -163,21 +167,21 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens ...@@ -163,21 +167,21 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens
num_tokens_img, num_tokens_txt, num_tokens_img, num_tokens_txt,
0.0f, 0.0f,
pow(q.shape[-1], (-0.5)), pow(q.shape[-1], (-0.5)),
false, false, false, false,
-1, -1, -1, -1,
false false
).front().view({batch_size, num_tokens_img, num_heads * head_dim}); ).front().view({batch_size, num_tokens_img, num_heads * head_dim});
// Tensor attn_output = mha_fwd(q, k, v, // Tensor attn_output = mha_fwd(q, k, v,
// 0.0f, // 0.0f,
// pow(q.shape[-1], (-0.5)), // pow(q.shape[-1], (-0.5)),
// false, -1, -1, false // false, -1, -1, false
// ).front().view({B, N, num_heads * head_dim}); // ).front().view({B, N, num_heads * head_dim});
return out_proj.forward(attn_output); return out_proj.forward(attn_output);
} }
SanaGLUMBConv::SanaGLUMBConv(int in_features, int hidden_features, bool use_fp4, Tensor::ScalarType dtype, Device device) : SanaGLUMBConv::SanaGLUMBConv(int in_features, int hidden_features, bool use_fp4, Tensor::ScalarType dtype, Device device) :
in_features(in_features), hidden_features(hidden_features), in_features(in_features), hidden_features(hidden_features),
inverted_conv(in_features, hidden_features * 2, true, use_fp4, dtype, device), inverted_conv(in_features, hidden_features * 2, true, use_fp4, dtype, device),
depth_conv(hidden_features * 2, true, dtype, device), depth_conv(hidden_features * 2, true, dtype, device),
...@@ -204,7 +208,7 @@ Tensor SanaGLUMBConv::forward(Tensor x, int H, int W) { ...@@ -204,7 +208,7 @@ Tensor SanaGLUMBConv::forward(Tensor x, int H, int W) {
return point_conv.forward_quant(qact); return point_conv.forward_quant(qact);
} }
SanaLinearTransformerBlock::SanaLinearTransformerBlock(int hidden_size, int intermediate_size, int num_cross_attention_heads, bool pag, bool use_fp4, Tensor::ScalarType dtype, Device device) : SanaLinearTransformerBlock::SanaLinearTransformerBlock(int hidden_size, int intermediate_size, int num_cross_attention_heads, bool pag, bool use_fp4, Tensor::ScalarType dtype, Device device) :
hidden_size(hidden_size), num_cross_attention_heads(num_cross_attention_heads), hidden_size(hidden_size), num_cross_attention_heads(num_cross_attention_heads),
attn(hidden_size, false, pag, use_fp4, dtype, device), attn(hidden_size, false, pag, use_fp4, dtype, device),
cross_attn(num_cross_attention_heads, hidden_size / num_cross_attention_heads, use_fp4, dtype, device), cross_attn(num_cross_attention_heads, hidden_size / num_cross_attention_heads, use_fp4, dtype, device),
...@@ -240,7 +244,7 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_ ...@@ -240,7 +244,7 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_
kernels::mul_add_batch(timestep, {}, false, 0, this->scale_shift_table, false); kernels::mul_add_batch(timestep, {}, false, 0, this->scale_shift_table, false);
debug("shifted_timestep", timestep); debug("shifted_timestep", timestep);
std::array<Tensor, 6> chunked; std::array<Tensor, 6> chunked;
for (int i = 0; i < 6; i++) { for (int i = 0; i < 6; i++) {
chunked[i] = timestep.slice(1, i, i + 1); chunked[i] = timestep.slice(1, i, i + 1);
...@@ -299,7 +303,7 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_ ...@@ -299,7 +303,7 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_
nvtxRangePop(); nvtxRangePop();
} }
nvtxRangePop(); nvtxRangePop();
debug("hidden_states_out", hidden_states); debug("hidden_states_out", hidden_states);
...@@ -307,7 +311,7 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_ ...@@ -307,7 +311,7 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_
return hidden_states; return hidden_states;
} }
SanaModel::SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device) : SanaModel::SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device) :
config(config) config(config)
{ {
const int inner_dim = config.num_attention_heads * config.attention_head_dim; const int inner_dim = config.num_attention_heads * config.attention_head_dim;
...@@ -324,8 +328,8 @@ SanaModel::SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device) ...@@ -324,8 +328,8 @@ SanaModel::SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device)
} }
} }
Tensor SanaModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor timestep, Tensor cu_seqlens_img, Tensor cu_seqlens_txt, int H, int W, bool pag, bool cfg) { Tensor SanaModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor timestep, Tensor cu_seqlens_img, Tensor cu_seqlens_txt, int H, int W, bool pag, bool cfg, bool skip_first_layer) {
for (int i = 0; i < config.num_layers; i++) { for (int i = (skip_first_layer ? 1 : 0); i < config.num_layers; i++) {
auto &&block = transformer_blocks[i]; auto &&block = transformer_blocks[i];
hidden_states = block->forward( hidden_states = block->forward(
hidden_states, encoder_hidden_states, timestep, cu_seqlens_img, cu_seqlens_txt, H, W, hidden_states, encoder_hidden_states, timestep, cu_seqlens_img, cu_seqlens_txt, H, W,
......
...@@ -89,7 +89,7 @@ struct SanaConfig { ...@@ -89,7 +89,7 @@ struct SanaConfig {
class SanaModel : public Module { class SanaModel : public Module {
public: public:
SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device); SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor timestep, Tensor cu_seqlens_img, Tensor cu_seqlens_txt, int H, int W, bool pag, bool cfg); Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor timestep, Tensor cu_seqlens_img, Tensor cu_seqlens_txt, int H, int W, bool pag, bool cfg, bool skip_first_layer);
public: public:
const SanaConfig config; const SanaConfig config;
......
...@@ -81,7 +81,8 @@ public: ...@@ -81,7 +81,8 @@ public:
BufferCUDA(size_t size) { BufferCUDA(size_t size) {
this->size = size; this->size = size;
this->device.type = Device::CUDA; this->device.type = Device::CUDA;
checkCUDA(cudaGetDevice(&this->device.idx)); // checkCUDA(cudaGetDevice(&this->device.idx));
this->device.idx = CUDADeviceContext::getDevice();
if (size == 0) { if (size == 0) {
this->ptr = nullptr; this->ptr = nullptr;
} }
...@@ -418,6 +419,7 @@ public: ...@@ -418,6 +419,7 @@ public:
result.buffer = std::make_shared<BufferMalloc>(shape.size() * scalarSize.at(scalarType)); result.buffer = std::make_shared<BufferMalloc>(shape.size() * scalarSize.at(scalarType));
} else if (device.type == Device::CUDA) { } else if (device.type == Device::CUDA) {
// TODO: cross device allocate // TODO: cross device allocate
CUDADeviceContext ctx(device.idx);
result.buffer = std::make_shared<BufferCUDA>(shape.size() * scalarSize.at(scalarType)); result.buffer = std::make_shared<BufferCUDA>(shape.size() * scalarSize.at(scalarType));
} else { } else {
assert(false); assert(false);
...@@ -429,6 +431,7 @@ public: ...@@ -429,6 +431,7 @@ public:
if (device.type == Device::CPU) { if (device.type == Device::CPU) {
memset(result.buffer->getPtr(), 0xCC, result.buffer->getSize()); memset(result.buffer->getPtr(), 0xCC, result.buffer->getSize());
} else if (device.type == Device::CUDA) { } else if (device.type == Device::CUDA) {
CUDADeviceContext ctx(device.idx);
checkCUDA(cudaMemsetAsync(result.buffer->getPtr(), 0xCC, result.buffer->getSize(), getCurrentCUDAStream())); checkCUDA(cudaMemsetAsync(result.buffer->getPtr(), 0xCC, result.buffer->getSize(), getCurrentCUDAStream()));
} }
} }
......
...@@ -107,16 +107,97 @@ struct CUDAEventWrapper { ...@@ -107,16 +107,97 @@ struct CUDAEventWrapper {
} }
}; };
/**
* 1. hold one when entered from external code (set `device` to -1 to avoid device change)
* 2. hold one when switching device
* 3. hold one with `disableCache` when calling external code that may change the device
*/
class CUDADeviceContext {
public:
CUDADeviceContext(int device = -1, bool disableCache = false) : disableCache(disableCache) {
if (cacheDisabled()) {
// no previous context => we might entered from external code, reset cache
// previous context is reset on => external code may be executed, reset
currentDeviceCache = -1;
}
ctxs.push(this);
lastDevice = getDevice();
if (device >= 0) {
setDevice(device);
}
if (disableCache) {
// we are about to call external code, reset cache
currentDeviceCache = -1;
}
}
CUDADeviceContext(const CUDADeviceContext &) = delete;
CUDADeviceContext(CUDADeviceContext &&) = delete;
~CUDADeviceContext() {
if (disableCache) {
// retured from external code, cache is not reliable, reset
currentDeviceCache = -1;
}
setDevice(lastDevice);
assert(ctxs.top() == this);
ctxs.pop();
if (cacheDisabled()) {
// ctxs.empty() => we are about to return to external code, reset cache
// otherwise => we are a nested context in a previous context with reset on, we might continue to execute external code, reset
currentDeviceCache = -1;
}
}
const bool disableCache;
int lastDevice;
public:
static int getDevice() {
int idx = -1;
if (cacheDisabled() || currentDeviceCache < 0) {
checkCUDA(cudaGetDevice(&idx));
} else {
idx = currentDeviceCache;
}
currentDeviceCache = cacheDisabled() ? -1 : idx;
return idx;
}
private:
static void setDevice(int idx) {
// TODO: deal with stream when switching device
assert(idx >= 0);
if (!cacheDisabled() && currentDeviceCache == idx) {
return;
}
checkCUDA(cudaSetDevice(idx));
currentDeviceCache = cacheDisabled() ? -1 : idx;
}
private:
static inline thread_local std::stack<CUDADeviceContext *> ctxs;
static inline thread_local int currentDeviceCache = -1;
static bool cacheDisabled() {
return ctxs.empty() || ctxs.top()->disableCache;
}
};
inline cudaDeviceProp *getCurrentDeviceProperties() { inline cudaDeviceProp *getCurrentDeviceProperties() {
static thread_local cudaDeviceProp prop; static thread_local std::map<int, cudaDeviceProp> props;
static thread_local bool propAvailable = false;
if (!propAvailable) { int deviceId = CUDADeviceContext::getDevice();
int device; if (!props.contains(deviceId)) {
checkCUDA(cudaGetDevice(&device)); cudaDeviceProp prop;
checkCUDA(cudaGetDeviceProperties(&prop, device)); checkCUDA(cudaGetDeviceProperties(&prop, deviceId));
propAvailable = true; props[deviceId] = prop;
} }
return &prop; return &props.at(deviceId);
} }
template<typename T> template<typename T>
......
...@@ -22,6 +22,7 @@ Tensor from_torch(at::Tensor input) { ...@@ -22,6 +22,7 @@ Tensor from_torch(at::Tensor input) {
} }
static const std::map<at::ScalarType, Tensor::ScalarType> mapType = { static const std::map<at::ScalarType, Tensor::ScalarType> mapType = {
{ at::ScalarType::Char, Tensor::INT8 },
{ at::ScalarType::Byte, Tensor::INT8 }, { at::ScalarType::Byte, Tensor::INT8 },
{ at::ScalarType::Int, Tensor::INT32 }, { at::ScalarType::Int, Tensor::INT32 },
{ at::ScalarType::Long, Tensor::INT64 }, { at::ScalarType::Long, Tensor::INT64 },
...@@ -36,7 +37,7 @@ Tensor from_torch(at::Tensor input) { ...@@ -36,7 +37,7 @@ Tensor from_torch(at::Tensor input) {
result.scalarType = mapType.at(input.scalar_type()); result.scalarType = mapType.at(input.scalar_type());
result.buffer = std::make_shared<BufferTorchTensor>(std::move(input)); result.buffer = std::make_shared<BufferTorchTensor>(std::move(input));
// Tensor::lockBuffer(result.buffer, getCurrentCUDAStream()); Tensor::lockBuffer(result.buffer, getCurrentCUDAStream());
return result; return result;
} }
......
...@@ -13,9 +13,9 @@ public: ...@@ -13,9 +13,9 @@ public:
this->device.type = this->tensor.is_cuda() ? Device::CUDA : Device::CPU; this->device.type = this->tensor.is_cuda() ? Device::CUDA : Device::CPU;
this->device.idx = this->tensor.get_device(); this->device.idx = this->tensor.get_device();
} }
virtual bool isAsyncBuffer() override { virtual bool isAsyncBuffer() override {
// TODO: figure out how torch manages memory // TODO: figure out how torch manages memory
return true; return this->device.type == Device::CUDA;
} }
private: private:
at::Tensor tensor; at::Tensor tensor;
...@@ -30,4 +30,22 @@ public: ...@@ -30,4 +30,22 @@ public:
}; };
Tensor from_torch(at::Tensor input); Tensor from_torch(at::Tensor input);
at::Tensor to_torch(Tensor input); at::Tensor to_torch(Tensor input);
\ No newline at end of file
class TensorsProviderTorch : public TensorsProvider {
public:
TensorsProviderTorch(std::map<std::string, at::Tensor> dict) : storage(std::move(dict)) {}
virtual bool contains(const std::string &key) const override {
return storage.contains(key);
}
virtual Tensor getTensor(const std::string &key) override {
if (!storage.contains(key)) {
return Tensor{};
}
return from_torch(storage.at(key));
}
private:
std::map<std::string, at::Tensor> storage;
};
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment