Unverified Commit 37a27712 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

Merge pull request #340 from mit-han-lab/dev

feat: support PuLID, Double FBCache and TeaCache; better linter
parents c1d6fc84 760ab022
...@@ -39,4 +39,4 @@ docker build -f docker/Dockerfile --no-cache \ ...@@ -39,4 +39,4 @@ docker build -f docker/Dockerfile --no-cache \
-t lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION} . -t lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION} .
docker push lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION} docker push lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION}
docker rmi lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION} docker rmi lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION}
\ No newline at end of file
...@@ -27,4 +27,4 @@ docker build -f docker/Dockerfile.torch27 --no-cache \ ...@@ -27,4 +27,4 @@ docker build -f docker/Dockerfile.torch27 --no-cache \
-t lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION} . -t lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION} .
docker push lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION} docker push lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION}
docker rmi lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION} docker rmi lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION}
\ No newline at end of file
...@@ -39,4 +39,4 @@ docker build -f docker/Dockerfile.torch28 --no-cache \ ...@@ -39,4 +39,4 @@ docker build -f docker/Dockerfile.torch28 --no-cache \
-t lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION} . -t lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION} .
docker push lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION} docker push lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION}
docker rmi lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION} docker rmi lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION}
\ No newline at end of file
...@@ -35,4 +35,4 @@ docker run --rm \ ...@@ -35,4 +35,4 @@ docker run --rm \
export NUNCHAKU_BUILD_WHEELS=1 && \ export NUNCHAKU_BUILD_WHEELS=1 && \
export MAX_JOBS=${MAX_JOBS} && \ export MAX_JOBS=${MAX_JOBS} && \
${PYTHON_ROOT_PATH}/bin/python -m build --wheel --no-isolation ${PYTHON_ROOT_PATH}/bin/python -m build --wheel --no-isolation
" "
\ No newline at end of file
...@@ -33,4 +33,4 @@ docker run --rm \ ...@@ -33,4 +33,4 @@ docker run --rm \
export NUNCHAKU_BUILD_WHEELS=1 && \ export NUNCHAKU_BUILD_WHEELS=1 && \
export MAX_JOBS=${MAX_JOBS} && \ export MAX_JOBS=${MAX_JOBS} && \
${PYTHON_ROOT_PATH}/bin/python -m build --wheel --no-isolation ${PYTHON_ROOT_PATH}/bin/python -m build --wheel --no-isolation
" "
\ No newline at end of file
...@@ -33,4 +33,4 @@ docker run --rm \ ...@@ -33,4 +33,4 @@ docker run --rm \
export NUNCHAKU_BUILD_WHEELS=1 && \ export NUNCHAKU_BUILD_WHEELS=1 && \
export MAX_JOBS=${MAX_JOBS} && \ export MAX_JOBS=${MAX_JOBS} && \
${PYTHON_ROOT_PATH}/bin/python -m build --wheel --no-isolation ${PYTHON_ROOT_PATH}/bin/python -m build --wheel --no-isolation
" "
\ No newline at end of file
...@@ -4,4 +4,4 @@ set -ex ...@@ -4,4 +4,4 @@ set -ex
docker run --rm \ docker run --rm \
-v "$(pwd)":/nunchaku \ -v "$(pwd)":/nunchaku \
pytorch/manylinux-builder:cuda12.4 \ pytorch/manylinux-builder:cuda12.4 \
bash -c "cd /nunchaku && rm -rf *" bash -c "cd /nunchaku && rm -rf *"
\ No newline at end of file
...@@ -6,7 +6,7 @@ import sys ...@@ -6,7 +6,7 @@ import sys
import setuptools import setuptools
import torch import torch
from packaging import version as packaging_version from packaging import version as packaging_version
from torch.utils.cpp_extension import BuildExtension, CUDA_HOME, CUDAExtension from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension
class CustomBuildExtension(BuildExtension): class CustomBuildExtension(BuildExtension):
......
...@@ -4,20 +4,18 @@ ...@@ -4,20 +4,18 @@
#include "kernels/zgemm/zgemm.h" #include "kernels/zgemm/zgemm.h"
#include "flash_api.h" #include "flash_api.h"
#include "activation.h" #include "activation.h"
#include <nvtx3/nvToolsExt.h> #include <nvtx3/nvToolsExt.h>
#include <pybind11/functional.h>
#include <iostream> #include <iostream>
using spdlog::fmt_lib::format; using spdlog::fmt_lib::format;
using namespace nunchaku; using namespace nunchaku;
Tensor forward_mlp(GEMM_W4A4 &fc1, GEMM_W4A4 &fc2, Tensor norm_hidden_states) { Tensor forward_mlp(GEMM_W4A4 &fc1, GEMM_W4A4 &fc2, Tensor norm_hidden_states) {
Tensor ff_output = fc2.forward_quant( Tensor ff_output = fc2.forward_quant(std::get<GEMM_W4A4::QuantizedActivation>(
std::get<GEMM_W4A4::QuantizedActivation>(fc1.forward(norm_hidden_states, GEMM_W4A4::FuseOptions::GELU_QUANT, &fc2)) fc1.forward(norm_hidden_states, GEMM_W4A4::FuseOptions::GELU_QUANT, &fc2)));
);
return ff_output; return ff_output;
} }
...@@ -26,7 +24,6 @@ Tensor forward_mlp(GEMM_W4A4 &fc1, GEMM_W4A4 &fc2, Tensor norm_hidden_states) { ...@@ -26,7 +24,6 @@ Tensor forward_mlp(GEMM_W4A4 &fc1, GEMM_W4A4 &fc2, Tensor norm_hidden_states) {
// return ff_output; // return ff_output;
// } // }
Tensor forward_fc(GEMM_W4A4 &fc, Tensor x) { Tensor forward_fc(GEMM_W4A4 &fc, Tensor x) {
return fc.forward(x); return fc.forward(x);
// return std::get<Tensor>(fc.forward(x)); // return std::get<Tensor>(fc.forward(x));
...@@ -36,16 +33,9 @@ Tensor forward_fc(GEMM_W4A4 &fc, Tensor x) { ...@@ -36,16 +33,9 @@ Tensor forward_fc(GEMM_W4A4 &fc, Tensor x) {
// return fc.forward(x); // return fc.forward(x);
// } // }
AdaLayerNormZeroSingle::AdaLayerNormZeroSingle(int dim, Tensor::ScalarType dtype, Device device)
AdaLayerNormZeroSingle::AdaLayerNormZeroSingle(int dim, Tensor::ScalarType dtype, Device device) : : dim(dim), linear(dim, 3 * dim, true, dtype, device), norm(dim, 1e-6, false, dtype, device) {
dim(dim), registerChildren(linear, "linear")(norm, "norm");
linear(dim, 3 * dim, true, dtype, device),
norm(dim, 1e-6, false, dtype, device)
{
registerChildren
(linear, "linear")
(norm, "norm")
;
} }
AdaLayerNormZeroSingle::Output AdaLayerNormZeroSingle::forward(Tensor x, Tensor emb) { AdaLayerNormZeroSingle::Output AdaLayerNormZeroSingle::forward(Tensor x, Tensor emb) {
...@@ -65,15 +55,10 @@ AdaLayerNormZeroSingle::Output AdaLayerNormZeroSingle::forward(Tensor x, Tensor ...@@ -65,15 +55,10 @@ AdaLayerNormZeroSingle::Output AdaLayerNormZeroSingle::forward(Tensor x, Tensor
return Output{norm_x, gate_msa}; return Output{norm_x, gate_msa};
} }
AdaLayerNormZero::AdaLayerNormZero(int dim, bool pre_only, Tensor::ScalarType dtype, Device device) : AdaLayerNormZero::AdaLayerNormZero(int dim, bool pre_only, Tensor::ScalarType dtype, Device device)
dim(dim), pre_only(pre_only), : dim(dim), pre_only(pre_only), linear(dim, pre_only ? 2 * dim : 6 * dim, true, dtype, device),
linear(dim, pre_only ? 2 * dim : 6 * dim, true, dtype, device), norm(dim, 1e-6, false, dtype, device) {
norm(dim, 1e-6, false, dtype, device) registerChildren(linear, "linear")(norm, "norm");
{
registerChildren
(linear, "linear")
(norm, "norm")
;
} }
AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) { AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
...@@ -110,10 +95,8 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) { ...@@ -110,10 +95,8 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
} }
} }
Attention::Attention(int num_heads, int dim_head, Device device)
Attention::Attention(int num_heads, int dim_head, Device device) : : num_heads(num_heads), dim_head(dim_head), force_fp16(false) {
num_heads(num_heads), dim_head(dim_head), force_fp16(false)
{
headmask_type = Tensor::allocate({num_heads}, Tensor::INT32, Device::cpu()); headmask_type = Tensor::allocate({num_heads}, Tensor::INT32, Device::cpu());
for (int i = 0; i < num_heads; i++) { for (int i = 0; i < num_heads; i++) {
headmask_type.data_ptr<int32_t>()[i] = i + 1; headmask_type.data_ptr<int32_t>()[i] = i + 1;
...@@ -124,27 +107,23 @@ Attention::Attention(int num_heads, int dim_head, Device device) : ...@@ -124,27 +107,23 @@ Attention::Attention(int num_heads, int dim_head, Device device) :
Tensor Attention::forward(Tensor qkv) { Tensor Attention::forward(Tensor qkv) {
assert(qkv.ndims() == 3); assert(qkv.ndims() == 3);
const Device device = qkv.device(); const Device device = qkv.device();
const int batch_size = qkv.shape[0]; const int batch_size = qkv.shape[0];
const int num_tokens = qkv.shape[1]; const int num_tokens = qkv.shape[1];
assert(qkv.shape[2] == num_heads * dim_head * 3); assert(qkv.shape[2] == num_heads * dim_head * 3);
Tensor reshaped = qkv.view({batch_size, num_tokens, num_heads * 3, dim_head}); Tensor reshaped = qkv.view({batch_size, num_tokens, num_heads * 3, dim_head});
Tensor q = reshaped.slice(2, 0, num_heads); Tensor q = reshaped.slice(2, 0, num_heads);
Tensor k = reshaped.slice(2, num_heads, num_heads * 2); Tensor k = reshaped.slice(2, num_heads, num_heads * 2);
Tensor v = reshaped.slice(2, num_heads * 2, num_heads * 3); Tensor v = reshaped.slice(2, num_heads * 2, num_heads * 3);
Tensor raw_attn_output = mha_fwd(q, k, v, Tensor raw_attn_output = mha_fwd(q, k, v, 0.0f, pow(q.shape[-1], (-0.5)), false, -1, -1, false).front();
0.0f,
pow(q.shape[-1], (-0.5)),
false, -1, -1, false
).front();
assert(raw_attn_output.shape[0] == batch_size); assert(raw_attn_output.shape[0] == batch_size);
assert(raw_attn_output.shape[1] == num_tokens); assert(raw_attn_output.shape[1] == num_tokens);
assert(raw_attn_output.shape[2] == num_heads); assert(raw_attn_output.shape[2] == num_heads);
assert(raw_attn_output.shape[3] == dim_head); assert(raw_attn_output.shape[3] == dim_head);
return raw_attn_output.view({batch_size * num_tokens, num_heads, dim_head}); return raw_attn_output.view({batch_size * num_tokens, num_heads, dim_head});
} }
...@@ -153,13 +132,13 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) { ...@@ -153,13 +132,13 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
assert(qkv.ndims() == 3); assert(qkv.ndims() == 3);
const Device device = qkv.device(); const Device device = qkv.device();
const int batch_size = qkv.shape[0]; const int batch_size = qkv.shape[0];
const int num_tokens = qkv.shape[1]; const int num_tokens = qkv.shape[1];
assert(qkv.shape[2] == num_heads * dim_head * 3); assert(qkv.shape[2] == num_heads * dim_head * 3);
constexpr int POOL_SIZE = 128; constexpr int POOL_SIZE = 128;
const int pool_tokens = ceilDiv(num_tokens, POOL_SIZE); const int pool_tokens = ceilDiv(num_tokens, POOL_SIZE);
Tensor blockmask; Tensor blockmask;
...@@ -173,11 +152,11 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) { ...@@ -173,11 +152,11 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
if (pool_qkv.valid() && sparsityRatio > 0) { if (pool_qkv.valid() && sparsityRatio > 0) {
pool_qkv = pool_qkv.view({batch_size, pool_tokens, 3, num_heads, dim_head}); pool_qkv = pool_qkv.view({batch_size, pool_tokens, 3, num_heads, dim_head});
pool_qkv = pool_qkv.transpose(1, 2).transpose(2, 3); // [batch_size, 3, num_heads, poolTokens, dim_head] pool_qkv = pool_qkv.transpose(1, 2).transpose(2, 3); // [batch_size, 3, num_heads, poolTokens, dim_head]
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
Tensor pool_q = pool_qkv.slice(0, i, i+1).slice(1, 0, 1); Tensor pool_q = pool_qkv.slice(0, i, i + 1).slice(1, 0, 1);
Tensor pool_k = pool_qkv.slice(0, i, i+1).slice(1, 1, 2); Tensor pool_k = pool_qkv.slice(0, i, i + 1).slice(1, 1, 2);
Tensor pool_s = pool_score.slice(0, i, i+1); Tensor pool_s = pool_score.slice(0, i, i + 1);
gemm_batched_fp16(pool_q, pool_k, pool_s); gemm_batched_fp16(pool_q, pool_k, pool_s);
} }
} }
...@@ -197,7 +176,7 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) { ...@@ -197,7 +176,7 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
} }
} }
if (!cu_seqlens_cpu.valid()) { if (!cu_seqlens_cpu.valid()) {
cu_seqlens_cpu = Tensor::allocate({batch_size + 1}, Tensor::INT32, Device::cpu()); cu_seqlens_cpu = Tensor::allocate({batch_size + 1}, Tensor::INT32, Device::cpu());
cu_seqlens_cpu.data_ptr<int32_t>()[0] = 0; cu_seqlens_cpu.data_ptr<int32_t>()[0] = 0;
for (int i = 1; i <= batch_size; i++) { for (int i = 1; i <= batch_size; i++) {
cu_seqlens_cpu.data_ptr<int32_t>()[i] = cu_seqlens_cpu.data_ptr<int32_t>()[i - 1] + num_tokens; cu_seqlens_cpu.data_ptr<int32_t>()[i] = cu_seqlens_cpu.data_ptr<int32_t>()[i - 1] + num_tokens;
...@@ -215,25 +194,32 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) { ...@@ -215,25 +194,32 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
Tensor cu_seqlens = cu_seqlens_cpu.copy(device); Tensor cu_seqlens = cu_seqlens_cpu.copy(device);
Tensor reshaped = qkv.view({batch_size * num_tokens, num_heads * 3, dim_head}); Tensor reshaped = qkv.view({batch_size * num_tokens, num_heads * 3, dim_head});
Tensor q = reshaped.slice(1, 0, num_heads); Tensor q = reshaped.slice(1, 0, num_heads);
Tensor k = reshaped.slice(1, num_heads, num_heads * 2); Tensor k = reshaped.slice(1, num_heads, num_heads * 2);
Tensor v = reshaped.slice(1, num_heads * 2, num_heads * 3); Tensor v = reshaped.slice(1, num_heads * 2, num_heads * 3);
spdlog::debug("q,k,v={}", q.shape.str()); spdlog::debug("q,k,v={}", q.shape.str());
Tensor raw_attn_output = mha_fwd_block( Tensor raw_attn_output = mha_fwd_block(q,
q, k, v, k,
cu_seqlens, cu_seqlens, v,
POOL_SIZE, POOL_SIZE, cu_seqlens,
headmask_type, cu_seqlens,
{}, POOL_SIZE,
blockmask, POOL_SIZE,
num_tokens, headmask_type,
num_tokens, {},
0.0f, blockmask,
pow(q.shape[-1], (-0.5)), num_tokens,
false, false, false, -1, -1 num_tokens,
).front(); 0.0f,
pow(q.shape[-1], (-0.5)),
false,
false,
false,
-1,
-1)
.front();
debug("raw_attn_output", raw_attn_output); debug("raw_attn_output", raw_attn_output);
...@@ -290,30 +276,22 @@ void Attention::setForceFP16(Module *module, bool value) { ...@@ -290,30 +276,22 @@ void Attention::setForceFP16(Module *module, bool value) {
}); });
} }
FluxSingleTransformerBlock::FluxSingleTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, int mlp_ratio, bool use_fp4, Tensor::ScalarType dtype, Device device) : FluxSingleTransformerBlock::FluxSingleTransformerBlock(int dim,
dim(dim), int num_attention_heads,
dim_head(attention_head_dim / num_attention_heads), int attention_head_dim,
num_heads(num_attention_heads), int mlp_ratio,
mlp_hidden_dim(dim * mlp_ratio), bool use_fp4,
norm(dim, dtype, device), Tensor::ScalarType dtype,
mlp_fc1(dim, mlp_hidden_dim, true, use_fp4, dtype, device), Device device)
mlp_fc2(mlp_hidden_dim, dim, true, use_fp4, dtype, device), : dim(dim), dim_head(attention_head_dim / num_attention_heads), num_heads(num_attention_heads),
qkv_proj(dim, dim * 3, true, use_fp4, dtype, device), mlp_hidden_dim(dim * mlp_ratio), norm(dim, dtype, device),
norm_q(dim_head, 1e-6, false, dtype, device), mlp_fc1(dim, mlp_hidden_dim, true, use_fp4, dtype, device),
norm_k(dim_head, 1e-6, false, dtype, device), mlp_fc2(mlp_hidden_dim, dim, true, use_fp4, dtype, device), qkv_proj(dim, dim * 3, true, use_fp4, dtype, device),
attn(num_attention_heads, attention_head_dim / num_attention_heads, device), norm_q(dim_head, 1e-6, false, dtype, device), norm_k(dim_head, 1e-6, false, dtype, device),
out_proj(dim, dim, true, use_fp4, dtype, device) attn(num_attention_heads, attention_head_dim / num_attention_heads, device),
{ out_proj(dim, dim, true, use_fp4, dtype, device) {
registerChildren registerChildren(norm, "norm")(mlp_fc1, "mlp_fc1")(mlp_fc2, "mlp_fc2")(qkv_proj, "qkv_proj")(norm_q, "norm_q")(
(norm, "norm") norm_k, "norm_k")(attn, "attn")(out_proj, "out_proj");
(mlp_fc1, "mlp_fc1")
(mlp_fc2, "mlp_fc2")
(qkv_proj, "qkv_proj")
(norm_q, "norm_q")
(norm_k, "norm_k")
(attn, "attn")
(out_proj, "out_proj")
;
} }
Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Tensor rotary_emb) { Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Tensor rotary_emb) {
...@@ -334,12 +312,18 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te ...@@ -334,12 +312,18 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
debug("rotary_emb", rotary_emb); debug("rotary_emb", rotary_emb);
if (attnImpl == AttentionImpl::FlashAttention2) { if (attnImpl == AttentionImpl::FlashAttention2) {
Tensor qkv = Tensor::allocate({batch_size, num_tokens, dim * 3}, norm_hidden_states.scalar_type(), norm_hidden_states.device()); Tensor qkv = Tensor::allocate(
{batch_size, num_tokens, dim * 3}, norm_hidden_states.scalar_type(), norm_hidden_states.device());
// qkv_proj.forward(norm_hidden_states, qkv, {}); // qkv_proj.forward(norm_hidden_states, qkv, {});
// debug("qkv_raw", qkv); // debug("qkv_raw", qkv);
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
qkv_proj.forward(norm_hidden_states.slice(0, i, i+1), qkv.slice(0, i, i+1), {}, norm_q.weight, norm_k.weight, rotary_emb); qkv_proj.forward(norm_hidden_states.slice(0, i, i + 1),
qkv.slice(0, i, i + 1),
{},
norm_q.weight,
norm_k.weight,
rotary_emb);
} }
debug("qkv", qkv); debug("qkv", qkv);
// Tensor qkv = forward_fc(qkv_proj, norm_hidden_states); // Tensor qkv = forward_fc(qkv_proj, norm_hidden_states);
...@@ -352,24 +336,33 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te ...@@ -352,24 +336,33 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
const int num_tokens_pad = ceilDiv(num_tokens, 256) * 256; const int num_tokens_pad = ceilDiv(num_tokens, 256) * 256;
Tensor q = Tensor::allocate({batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device()); Tensor q = Tensor::allocate(
Tensor k = Tensor::allocate({batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device()); {batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device());
Tensor v = Tensor::allocate({batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device()); Tensor k = Tensor::allocate(
{batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device());
Tensor v = Tensor::allocate(
{batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device());
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
qkv_proj.forward( qkv_proj.forward(norm_hidden_states.slice(0, i, i + 1),
norm_hidden_states.slice(0, i, i+1), {}, {}, norm_q.weight, norm_k.weight, rotary_emb, {},
q.slice(0, i, i+1), {},
k.slice(0, i, i+1), norm_q.weight,
v.slice(0, i, i+1), norm_k.weight,
num_tokens); rotary_emb,
q.slice(0, i, i + 1),
k.slice(0, i, i + 1),
v.slice(0, i, i + 1),
num_tokens);
} }
debug("packed_q", q); debug("packed_q", q);
debug("packed_k", k); debug("packed_k", k);
debug("packed_v", v); debug("packed_v", v);
Tensor o = Tensor::allocate({batch_size, num_tokens_pad, num_heads * dim_head}, norm_hidden_states.scalar_type(), norm_hidden_states.device()); Tensor o = Tensor::allocate({batch_size, num_tokens_pad, num_heads * dim_head},
norm_hidden_states.scalar_type(),
norm_hidden_states.device());
kernels::attention_fp16(q, k, v, o, pow(dim_head, (-0.5))); kernels::attention_fp16(q, k, v, o, pow(dim_head, (-0.5)));
...@@ -377,16 +370,14 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te ...@@ -377,16 +370,14 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
attn_output = o.slice(1, 0, num_tokens); attn_output = o.slice(1, 0, num_tokens);
} else { } else {
attn_output = Tensor::allocate({batch_size, num_tokens, num_heads * dim_head}, o.scalar_type(), o.device()); attn_output = Tensor::allocate({batch_size, num_tokens, num_heads * dim_head}, o.scalar_type(), o.device());
checkCUDA(cudaMemcpy2DAsync( checkCUDA(cudaMemcpy2DAsync(attn_output.data_ptr(),
attn_output.data_ptr(), attn_output.stride(0) * attn_output.scalar_size(),
attn_output.stride(0) * attn_output.scalar_size(), o.data_ptr(),
o.data_ptr(), o.stride(0) * o.scalar_size(),
o.stride(0) * o.scalar_size(), attn_output.stride(0) * attn_output.scalar_size(),
attn_output.stride(0) * attn_output.scalar_size(), batch_size,
batch_size, cudaMemcpyDeviceToDevice,
cudaMemcpyDeviceToDevice, getCurrentCUDAStream()));
getCurrentCUDAStream()
));
} }
} else { } else {
assert(false); assert(false);
...@@ -394,8 +385,6 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te ...@@ -394,8 +385,6 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
debug("raw_attn_output", attn_output); debug("raw_attn_output", attn_output);
attn_output = forward_fc(out_proj, attn_output); attn_output = forward_fc(out_proj, attn_output);
debug("attn_output", attn_output); debug("attn_output", attn_output);
...@@ -413,54 +402,40 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te ...@@ -413,54 +402,40 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
return hidden_states; return hidden_states;
} }
JointTransformerBlock::JointTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, bool context_pre_only, bool use_fp4, Tensor::ScalarType dtype, Device device) : JointTransformerBlock::JointTransformerBlock(int dim,
dim(dim), int num_attention_heads,
dim_head(attention_head_dim / num_attention_heads), int attention_head_dim,
num_heads(num_attention_heads), bool context_pre_only,
context_pre_only(context_pre_only), bool use_fp4,
norm1(dim, false, dtype, device), Tensor::ScalarType dtype,
norm1_context(dim, context_pre_only, dtype, device), Device device)
qkv_proj(dim, dim * 3, true, use_fp4, dtype, device), : dim(dim), dim_head(attention_head_dim / num_attention_heads), num_heads(num_attention_heads),
qkv_proj_context(dim, dim * 3, true, use_fp4, dtype, device), context_pre_only(context_pre_only), norm1(dim, false, dtype, device),
norm_q(dim_head, 1e-6, false, dtype, device), norm1_context(dim, context_pre_only, dtype, device), qkv_proj(dim, dim * 3, true, use_fp4, dtype, device),
norm_k(dim_head, 1e-6, false, dtype, device), qkv_proj_context(dim, dim * 3, true, use_fp4, dtype, device), norm_q(dim_head, 1e-6, false, dtype, device),
norm_added_q(dim_head, 1e-6, false, dtype, device), norm_k(dim_head, 1e-6, false, dtype, device), norm_added_q(dim_head, 1e-6, false, dtype, device),
norm_added_k(dim_head, 1e-6, false, dtype, device), norm_added_k(dim_head, 1e-6, false, dtype, device),
attn(num_attention_heads, attention_head_dim / num_attention_heads, device), attn(num_attention_heads, attention_head_dim / num_attention_heads, device),
out_proj(dim, dim, true, use_fp4, dtype, device), out_proj(dim, dim, true, use_fp4, dtype, device), out_proj_context(dim, dim, true, use_fp4, dtype, device),
out_proj_context(dim, dim, true, use_fp4, dtype, device), norm2(dim, 1e-6, false, dtype, device), norm2_context(dim, 1e-6, false, dtype, device),
norm2(dim, 1e-6, false, dtype, device), mlp_fc1(dim, dim * 4, true, use_fp4, dtype, device), mlp_fc2(dim * 4, dim, true, use_fp4, dtype, device),
norm2_context(dim, 1e-6, false, dtype, device), mlp_context_fc1(dim, dim * 4, true, use_fp4, dtype, device),
mlp_fc1(dim, dim * 4, true, use_fp4, dtype, device), mlp_context_fc2(dim * 4, dim, true, use_fp4, dtype, device) {
mlp_fc2(dim * 4, dim, true, use_fp4, dtype, device), registerChildren(norm1, "norm1")(norm1_context, "norm1_context")(qkv_proj, "qkv_proj")(qkv_proj_context,
mlp_context_fc1(dim, dim * 4, true, use_fp4, dtype, device), "qkv_proj_context")(
mlp_context_fc2(dim * 4, dim, true, use_fp4, dtype, device) norm_q, "norm_q")(norm_k, "norm_k")(norm_added_q, "norm_added_q")(norm_added_k, "norm_added_k")(attn, "attn")(
{ out_proj, "out_proj")(out_proj_context, "out_proj_context")(norm2, "norm2")(norm2_context, "norm2_context")(
registerChildren mlp_fc1, "mlp_fc1")(mlp_fc2, "mlp_fc2")(mlp_context_fc1, "mlp_context_fc1")(mlp_context_fc2, "mlp_context_fc2");
(norm1, "norm1")
(norm1_context, "norm1_context")
(qkv_proj, "qkv_proj")
(qkv_proj_context, "qkv_proj_context")
(norm_q, "norm_q")
(norm_k, "norm_k")
(norm_added_q, "norm_added_q")
(norm_added_k, "norm_added_k")
(attn, "attn")
(out_proj, "out_proj")
(out_proj_context, "out_proj_context")
(norm2, "norm2")
(norm2_context, "norm2_context")
(mlp_fc1, "mlp_fc1")
(mlp_fc2, "mlp_fc2")
(mlp_context_fc1, "mlp_context_fc1")
(mlp_context_fc2, "mlp_context_fc2")
;
} }
// hidden_states: [Batch, Width * Height, dim] // hidden_states: [Batch, Width * Height, dim]
// encoder_hidden_states: [Batch, Token, dim] // encoder_hidden_states: [Batch, Token, dim]
std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb, Tensor rotary_emb_context, float sparsityRatio) { std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor temb,
Tensor rotary_emb,
Tensor rotary_emb_context,
float sparsityRatio) {
int batch_size = hidden_states.shape[0]; int batch_size = hidden_states.shape[0];
assert(encoder_hidden_states.shape[0] == batch_size); assert(encoder_hidden_states.shape[0] == batch_size);
...@@ -468,17 +443,19 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, ...@@ -468,17 +443,19 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
nvtxRangePushA("AdaNorm"); nvtxRangePushA("AdaNorm");
int num_tokens_img = hidden_states.shape[1]; int num_tokens_img = hidden_states.shape[1];
int num_tokens_txt = encoder_hidden_states.shape[1]; int num_tokens_txt = encoder_hidden_states.shape[1];
assert(hidden_states.shape[2] == dim); assert(hidden_states.shape[2] == dim);
assert(encoder_hidden_states.shape[2] == dim); assert(encoder_hidden_states.shape[2] == dim);
spdlog::debug("hidden_states={} encoder_hidden_states={} temb={}", hidden_states.shape.str(), encoder_hidden_states.shape.str(), temb.shape.str()); spdlog::debug("hidden_states={} encoder_hidden_states={} temb={}",
hidden_states.shape.str(),
encoder_hidden_states.shape.str(),
temb.shape.str());
spdlog::debug("batch_size={} num_tokens_img={} num_tokens_txt={}", batch_size, num_tokens_img, num_tokens_txt); spdlog::debug("batch_size={} num_tokens_img={} num_tokens_txt={}", batch_size, num_tokens_img, num_tokens_txt);
auto norm1_output = norm1.forward(hidden_states, temb); auto norm1_output = norm1.forward(hidden_states, temb);
auto norm1_context_output = norm1_context.forward(encoder_hidden_states, temb); auto norm1_context_output = norm1_context.forward(encoder_hidden_states, temb);
#if 0 #if 0
...@@ -511,30 +488,37 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, ...@@ -511,30 +488,37 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
const bool blockSparse = sparsityRatio > 0; const bool blockSparse = sparsityRatio > 0;
const int poolTokens = num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE; const int poolTokens = num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE;
concat = Tensor::allocate({batch_size, num_tokens_img + num_tokens_txt, dim * 3}, norm1_output.x.scalar_type(), norm1_output.x.device()); concat = Tensor::allocate({batch_size, num_tokens_img + num_tokens_txt, dim * 3},
norm1_output.x.scalar_type(),
norm1_output.x.device());
pool = blockSparse pool = blockSparse ? Tensor::allocate({batch_size, poolTokens, dim * 3},
? Tensor::allocate({batch_size, poolTokens, dim * 3}, norm1_output.x.scalar_type(), norm1_output.x.device()) norm1_output.x.scalar_type(),
: Tensor{}; norm1_output.x.device())
: Tensor{};
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
// img first // img first
Tensor qkv = concat.slice(0, i, i + 1).slice(1, 0, num_tokens_img); Tensor qkv = concat.slice(0, i, i + 1).slice(1, 0, num_tokens_img);
Tensor qkv_context = concat.slice(0, i, i + 1).slice(1, num_tokens_img, num_tokens_img + num_tokens_txt); Tensor qkv_context =
concat.slice(0, i, i + 1).slice(1, num_tokens_img, num_tokens_img + num_tokens_txt);
Tensor pool_qkv = pool.valid() Tensor pool_qkv =
? pool.slice(0, i, i + 1).slice(1, 0, num_tokens_img / POOL_SIZE) pool.valid() ? pool.slice(0, i, i + 1).slice(1, 0, num_tokens_img / POOL_SIZE) : Tensor{};
: Tensor{};
Tensor pool_qkv_context = pool.valid() Tensor pool_qkv_context = pool.valid()
? concat.slice(0, i, i + 1).slice(1, num_tokens_img / POOL_SIZE, num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE) ? pool.slice(0, i, i + 1)
: Tensor{}; .slice(1,
num_tokens_img / POOL_SIZE,
num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE)
: Tensor{};
// qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv); // qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv);
// debug("qkv_raw", qkv); // debug("qkv_raw", qkv);
debug("rotary_emb", rotary_emb); debug("rotary_emb", rotary_emb);
qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv, pool_qkv, norm_q.weight, norm_k.weight, rotary_emb); qkv_proj.forward(
norm1_output.x.slice(0, i, i + 1), qkv, pool_qkv, norm_q.weight, norm_k.weight, rotary_emb);
debug("qkv", qkv); debug("qkv", qkv);
// qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context); // qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context);
...@@ -542,7 +526,12 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, ...@@ -542,7 +526,12 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
debug("rotary_emb_context", rotary_emb_context); debug("rotary_emb_context", rotary_emb_context);
qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context, pool_qkv_context, norm_added_q.weight, norm_added_k.weight, rotary_emb_context); qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1),
qkv_context,
pool_qkv_context,
norm_added_q.weight,
norm_added_k.weight,
rotary_emb_context);
debug("qkv_context", qkv_context); debug("qkv_context", qkv_context);
} }
...@@ -577,28 +566,40 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, ...@@ -577,28 +566,40 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
{ {
nvtxRangePushA("qkv_proj"); nvtxRangePushA("qkv_proj");
concat_q = Tensor::allocate({batch_size, num_heads, num_tokens_img_pad + num_tokens_txt_pad, dim_head}, Tensor::FP16, norm1_output.x.device()); concat_q = Tensor::allocate({batch_size, num_heads, num_tokens_img_pad + num_tokens_txt_pad, dim_head},
Tensor::FP16,
norm1_output.x.device());
concat_k = Tensor::empty_like(concat_q); concat_k = Tensor::empty_like(concat_q);
concat_v = Tensor::empty_like(concat_q); concat_v = Tensor::empty_like(concat_q);
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
// img first // img first
auto sliceImg = [&](Tensor x) { auto sliceImg = [&](Tensor x) { return x.slice(0, i, i + 1).slice(2, 0, num_tokens_img_pad); };
return x.slice(0, i, i+1).slice(2, 0, num_tokens_img_pad);
};
auto sliceTxt = [&](Tensor x) { auto sliceTxt = [&](Tensor x) {
return x.slice(0, i, i+1).slice(2, num_tokens_img_pad, num_tokens_img_pad + num_tokens_txt_pad); return x.slice(0, i, i + 1).slice(2, num_tokens_img_pad, num_tokens_img_pad + num_tokens_txt_pad);
}; };
qkv_proj.forward( qkv_proj.forward(norm1_output.x.slice(0, i, i + 1),
norm1_output.x.slice(0, i, i + 1), {}, {}, norm_q.weight, norm_k.weight, rotary_emb, {},
sliceImg(concat_q), sliceImg(concat_k), sliceImg(concat_v), num_tokens_img {},
); norm_q.weight,
norm_k.weight,
qkv_proj_context.forward( rotary_emb,
norm1_context_output.x.slice(0, i, i + 1), {}, {}, norm_added_q.weight, norm_added_k.weight, rotary_emb_context, sliceImg(concat_q),
sliceTxt(concat_q), sliceTxt(concat_k), sliceTxt(concat_v), num_tokens_txt sliceImg(concat_k),
); sliceImg(concat_v),
num_tokens_img);
qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1),
{},
{},
norm_added_q.weight,
norm_added_k.weight,
rotary_emb_context,
sliceTxt(concat_q),
sliceTxt(concat_k),
sliceTxt(concat_v),
num_tokens_txt);
} }
debug("concat_q", concat_q); debug("concat_q", concat_q);
...@@ -608,7 +609,9 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, ...@@ -608,7 +609,9 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
nvtxRangePop(); nvtxRangePop();
} }
raw_attn_output = Tensor::allocate({batch_size, num_tokens_img_pad + num_tokens_txt_pad, num_heads * dim_head}, norm1_output.x.scalar_type(), norm1_output.x.device()); raw_attn_output = Tensor::allocate({batch_size, num_tokens_img_pad + num_tokens_txt_pad, num_heads * dim_head},
norm1_output.x.scalar_type(),
norm1_output.x.device());
nvtxRangePushA("Attention"); nvtxRangePushA("Attention");
...@@ -616,7 +619,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, ...@@ -616,7 +619,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
nvtxRangePop(); nvtxRangePop();
raw_attn_output = raw_attn_output.view({batch_size, num_tokens_img_pad + num_tokens_txt_pad, num_heads, dim_head}); raw_attn_output =
raw_attn_output.view({batch_size, num_tokens_img_pad + num_tokens_txt_pad, num_heads, dim_head});
} else { } else {
assert(false); assert(false);
} }
...@@ -632,25 +636,28 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, ...@@ -632,25 +636,28 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
Tensor raw_attn_output_split; Tensor raw_attn_output_split;
if (batch_size == 1) { if (batch_size == 1) {
raw_attn_output_split = raw_attn_output.slice(1, 0, num_tokens_img).reshape({batch_size, num_tokens_img, num_heads * dim_head}); raw_attn_output_split =
raw_attn_output.slice(1, 0, num_tokens_img).reshape({batch_size, num_tokens_img, num_heads * dim_head});
} else { } else {
raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_img, num_heads * dim_head}, raw_attn_output.scalar_type(), raw_attn_output.device()); raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_img, num_heads * dim_head},
checkCUDA(cudaMemcpy2DAsync( raw_attn_output.scalar_type(),
raw_attn_output_split.data_ptr(), raw_attn_output.device());
num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(), checkCUDA(cudaMemcpy2DAsync(raw_attn_output_split.data_ptr(),
raw_attn_output.data_ptr(), num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(),
(num_tokens_img_pad + num_tokens_txt_pad) * num_heads * dim_head * raw_attn_output.scalar_size(), raw_attn_output.data_ptr(),
num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(), (num_tokens_img_pad + num_tokens_txt_pad) * num_heads * dim_head *
batch_size, raw_attn_output.scalar_size(),
cudaMemcpyDeviceToDevice, num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(),
stream)); batch_size,
cudaMemcpyDeviceToDevice,
stream));
} }
spdlog::debug("raw_attn_output_split={}", raw_attn_output_split.shape.str()); spdlog::debug("raw_attn_output_split={}", raw_attn_output_split.shape.str());
debug("img.raw_attn_output_split", raw_attn_output_split); debug("img.raw_attn_output_split", raw_attn_output_split);
Tensor attn_output = forward_fc(out_proj, raw_attn_output_split); // std::get<Tensor>(out_proj.forward(raw_attn_output_split)); Tensor attn_output =
forward_fc(out_proj, raw_attn_output_split); // std::get<Tensor>(out_proj.forward(raw_attn_output_split));
debug("img.attn_output", attn_output); debug("img.attn_output", attn_output);
#if 1 #if 1
...@@ -690,7 +697,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, ...@@ -690,7 +697,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
} }
if (context_pre_only) { if (context_pre_only) {
return { hidden_states, encoder_hidden_states }; return {hidden_states, encoder_hidden_states};
} }
{ {
...@@ -700,25 +707,30 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, ...@@ -700,25 +707,30 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
Tensor raw_attn_output_split; Tensor raw_attn_output_split;
if (batch_size == 1) { if (batch_size == 1) {
raw_attn_output_split = raw_attn_output.slice(1, num_tokens_img_pad, num_tokens_img_pad + num_tokens_txt).reshape({batch_size, num_tokens_txt, num_heads * dim_head}); raw_attn_output_split = raw_attn_output.slice(1, num_tokens_img_pad, num_tokens_img_pad + num_tokens_txt)
.reshape({batch_size, num_tokens_txt, num_heads * dim_head});
} else { } else {
raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_txt, num_heads * dim_head}, raw_attn_output.scalar_type(), raw_attn_output.device()); raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_txt, num_heads * dim_head},
checkCUDA(cudaMemcpy2DAsync( raw_attn_output.scalar_type(),
raw_attn_output_split.data_ptr(), raw_attn_output.device());
num_tokens_txt * num_heads * dim_head * raw_attn_output_split.scalar_size(), checkCUDA(cudaMemcpy2DAsync(raw_attn_output_split.data_ptr(),
raw_attn_output.data_ptr<char>() + num_tokens_img_pad * num_heads * dim_head * raw_attn_output_split.scalar_size(), num_tokens_txt * num_heads * dim_head * raw_attn_output_split.scalar_size(),
(num_tokens_img_pad + num_tokens_txt_pad) * num_heads * dim_head * raw_attn_output.scalar_size(), raw_attn_output.data_ptr<char>() + num_tokens_img_pad * num_heads * dim_head *
num_tokens_txt * num_heads * dim_head * raw_attn_output_split.scalar_size(), raw_attn_output_split.scalar_size(),
batch_size, (num_tokens_img_pad + num_tokens_txt_pad) * num_heads * dim_head *
cudaMemcpyDeviceToDevice, raw_attn_output.scalar_size(),
stream)); num_tokens_txt * num_heads * dim_head * raw_attn_output_split.scalar_size(),
batch_size,
cudaMemcpyDeviceToDevice,
stream));
} }
spdlog::debug("raw_attn_output_split={}", raw_attn_output_split.shape.str()); spdlog::debug("raw_attn_output_split={}", raw_attn_output_split.shape.str());
debug("context.raw_attn_output_split", raw_attn_output_split); debug("context.raw_attn_output_split", raw_attn_output_split);
Tensor attn_output = forward_fc(out_proj_context, raw_attn_output_split); // std::get<Tensor>(out_proj_context.forward(raw_attn_output_split)); Tensor attn_output =
forward_fc(out_proj_context,
raw_attn_output_split); // std::get<Tensor>(out_proj_context.forward(raw_attn_output_split));
debug("context.attn_output", attn_output); debug("context.attn_output", attn_output);
#if 1 #if 1
...@@ -742,9 +754,9 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, ...@@ -742,9 +754,9 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
auto norm_hidden_states = encoder_hidden_states; auto norm_hidden_states = encoder_hidden_states;
#endif #endif
// Tensor ff_output = mlp_context_fc2.forward(GELU::forward(mlp_context_fc1.forward(norm_hidden_states))); // Tensor ff_output = mlp_context_fc2.forward(GELU::forward(mlp_context_fc1.forward(norm_hidden_states)));
// Tensor ff_output = mlp_context_fc2.forward_quant(quant_static_fuse_gelu(mlp_context_fc1.forward(norm_hidden_states), 1.0)); // Tensor ff_output =
// mlp_context_fc2.forward_quant(quant_static_fuse_gelu(mlp_context_fc1.forward(norm_hidden_states), 1.0));
debug("context.ff_input", norm_hidden_states); debug("context.ff_input", norm_hidden_states);
Tensor ff_output = forward_mlp(mlp_context_fc1, mlp_context_fc2, norm_hidden_states); Tensor ff_output = forward_mlp(mlp_context_fc1, mlp_context_fc2, norm_hidden_states);
debug("context.ff_output", ff_output); debug("context.ff_output", ff_output);
...@@ -761,12 +773,14 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, ...@@ -761,12 +773,14 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
nvtxRangePop(); nvtxRangePop();
return { hidden_states, encoder_hidden_states }; return {hidden_states, encoder_hidden_states};
} }
FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device) : dtype(dtype), offload(offload) { FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device)
: dtype(dtype), offload(offload) {
for (int i = 0; i < 19; i++) { for (int i = 0; i < 19; i++) {
transformer_blocks.push_back(std::make_unique<JointTransformerBlock>(3072, 24, 3072, false, use_fp4, dtype, device)); transformer_blocks.push_back(
std::make_unique<JointTransformerBlock>(3072, 24, 3072, false, use_fp4, dtype, device));
registerChildren(*transformer_blocks.back(), format("transformer_blocks.{}", i)); registerChildren(*transformer_blocks.back(), format("transformer_blocks.{}", i));
if (offload && i > 0) { // don't offload first block if (offload && i > 0) { // don't offload first block
transformer_blocks.back()->setLazyLoad(true); transformer_blocks.back()->setLazyLoad(true);
...@@ -774,7 +788,8 @@ FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Devic ...@@ -774,7 +788,8 @@ FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Devic
} }
} }
for (int i = 0; i < 38; i++) { for (int i = 0; i < 38; i++) {
single_transformer_blocks.push_back(std::make_unique<FluxSingleTransformerBlock>(3072, 24, 3072, 4, use_fp4, dtype, device)); single_transformer_blocks.push_back(
std::make_unique<FluxSingleTransformerBlock>(3072, 24, 3072, 4, use_fp4, dtype, device));
registerChildren(*single_transformer_blocks.back(), format("single_transformer_blocks.{}", i)); registerChildren(*single_transformer_blocks.back(), format("single_transformer_blocks.{}", i));
if (offload) { if (offload) {
single_transformer_blocks.back()->setLazyLoad(true); single_transformer_blocks.back()->setLazyLoad(true);
...@@ -783,19 +798,18 @@ FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Devic ...@@ -783,19 +798,18 @@ FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Devic
} }
} }
Tensor FluxModel::forward( Tensor FluxModel::forward(Tensor hidden_states,
Tensor hidden_states, Tensor encoder_hidden_states,
Tensor encoder_hidden_states, Tensor temb,
Tensor temb, Tensor rotary_emb_img,
Tensor rotary_emb_img, Tensor rotary_emb_context,
Tensor rotary_emb_context, Tensor rotary_emb_single,
Tensor rotary_emb_single, Tensor controlnet_block_samples,
Tensor controlnet_block_samples, Tensor controlnet_single_block_samples,
Tensor controlnet_single_block_samples, bool skip_first_layer) {
bool skip_first_layer) { const int batch_size = hidden_states.shape[0];
const int batch_size = hidden_states.shape[0];
const Tensor::ScalarType dtype = hidden_states.dtype(); const Tensor::ScalarType dtype = hidden_states.dtype();
const Device device = hidden_states.device(); const Device device = hidden_states.device();
const int txt_tokens = encoder_hidden_states.shape[1]; const int txt_tokens = encoder_hidden_states.shape[1];
const int img_tokens = hidden_states.shape[1]; const int img_tokens = hidden_states.shape[1];
...@@ -805,45 +819,68 @@ Tensor FluxModel::forward( ...@@ -805,45 +819,68 @@ Tensor FluxModel::forward(
Tensor concat; Tensor concat;
auto compute = [&](int layer) { auto compute = [&](int layer) {
if (skip_first_layer && size_t(layer) == 0) return; if (skip_first_layer && size_t(layer) == 0)
return;
if (size_t(layer) < transformer_blocks.size()) { if (size_t(layer) < transformer_blocks.size()) {
auto &block = transformer_blocks.at(layer); auto &block = transformer_blocks.at(layer);
std::tie(hidden_states, encoder_hidden_states) = block->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f); std::tie(hidden_states, encoder_hidden_states) =
block->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
if (controlnet_block_samples.valid()) { if (controlnet_block_samples.valid()) {
const int num_controlnet_block_samples = controlnet_block_samples.shape[0]; const int num_controlnet_block_samples = controlnet_block_samples.shape[0];
int interval_control = ceilDiv(transformer_blocks.size(), static_cast<size_t>(num_controlnet_block_samples)); int interval_control =
ceilDiv(transformer_blocks.size(), static_cast<size_t>(num_controlnet_block_samples));
int block_index = layer / interval_control; int block_index = layer / interval_control;
// Xlabs ControlNet // Xlabs ControlNet
// block_index = layer % num_controlnet_block_samples; // block_index = layer % num_controlnet_block_samples;
hidden_states = kernels::add(hidden_states, controlnet_block_samples[block_index]); hidden_states = kernels::add(hidden_states, controlnet_block_samples[block_index]);
} }
if (residual_callback && layer % 2 == 0) {
Tensor cpu_input = hidden_states.copy(Device::cpu());
pybind11::gil_scoped_acquire gil;
Tensor cpu_output = residual_callback(cpu_input);
Tensor residual = cpu_output.copy(Device::cuda());
hidden_states = kernels::add(hidden_states, residual);
}
} else { } else {
if (size_t(layer) == transformer_blocks.size()) { if (size_t(layer) == transformer_blocks.size()) {
// txt first, same as diffusers // txt first, same as diffusers
concat = Tensor::allocate({batch_size, txt_tokens + img_tokens, 3072}, dtype, device); concat = Tensor::allocate({batch_size, txt_tokens + img_tokens, 3072}, dtype, device);
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
concat.slice(0, i, i + 1).slice(1, 0, txt_tokens).copy_(encoder_hidden_states.slice(0, i, i + 1)); concat.slice(0, i, i + 1).slice(1, 0, txt_tokens).copy_(encoder_hidden_states.slice(0, i, i + 1));
concat.slice(0, i, i + 1).slice(1, txt_tokens, txt_tokens + img_tokens).copy_(hidden_states.slice(0, i, i + 1)); concat.slice(0, i, i + 1)
.slice(1, txt_tokens, txt_tokens + img_tokens)
.copy_(hidden_states.slice(0, i, i + 1));
} }
hidden_states = concat; hidden_states = concat;
encoder_hidden_states = {}; encoder_hidden_states = {};
} }
auto &block = single_transformer_blocks.at(layer - transformer_blocks.size()); auto &block = single_transformer_blocks.at(layer - transformer_blocks.size());
hidden_states = block->forward(hidden_states, temb, rotary_emb_single); hidden_states = block->forward(hidden_states, temb, rotary_emb_single);
if (controlnet_single_block_samples.valid()) { if (controlnet_single_block_samples.valid()) {
const int num_controlnet_single_block_samples = controlnet_single_block_samples.shape[0]; const int num_controlnet_single_block_samples = controlnet_single_block_samples.shape[0];
int interval_control = ceilDiv(single_transformer_blocks.size(), static_cast<size_t>(num_controlnet_single_block_samples)); int interval_control =
ceilDiv(single_transformer_blocks.size(), static_cast<size_t>(num_controlnet_single_block_samples));
int block_index = (layer - transformer_blocks.size()) / interval_control; int block_index = (layer - transformer_blocks.size()) / interval_control;
// Xlabs ControlNet // Xlabs ControlNet
// block_index = layer % num_controlnet_single_block_samples // block_index = layer % num_controlnet_single_block_samples
auto slice = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens); auto slice = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
slice = kernels::add(slice, controlnet_single_block_samples[block_index]); slice = kernels::add(slice, controlnet_single_block_samples[block_index]);
hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
}
size_t local_layer_idx = layer - transformer_blocks.size();
if (residual_callback && local_layer_idx % 4 == 0) {
Tensor callback_input = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
Tensor cpu_input = callback_input.copy(Device::cpu());
pybind11::gil_scoped_acquire gil;
Tensor cpu_output = residual_callback(cpu_input);
Tensor residual = cpu_output.copy(Device::cuda());
auto slice = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
slice = kernels::add(slice, residual);
hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice); hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
} }
} }
...@@ -873,31 +910,22 @@ Tensor FluxModel::forward( ...@@ -873,31 +910,22 @@ Tensor FluxModel::forward(
return hidden_states; return hidden_states;
} }
std::tuple<Tensor, Tensor> FluxModel::forward_layer( std::tuple<Tensor, Tensor> FluxModel::forward_layer(size_t layer,
size_t layer, Tensor hidden_states,
Tensor hidden_states, Tensor encoder_hidden_states,
Tensor encoder_hidden_states, Tensor temb,
Tensor temb, Tensor rotary_emb_img,
Tensor rotary_emb_img, Tensor rotary_emb_context,
Tensor rotary_emb_context, Tensor controlnet_block_samples,
Tensor controlnet_block_samples, Tensor controlnet_single_block_samples) {
Tensor controlnet_single_block_samples) {
if (layer < transformer_blocks.size()) {
if (layer < transformer_blocks.size()){
std::tie(hidden_states, encoder_hidden_states) = transformer_blocks.at(layer)->forward( std::tie(hidden_states, encoder_hidden_states) = transformer_blocks.at(layer)->forward(
hidden_states, hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
encoder_hidden_states, } else {
temb, std::tie(hidden_states, encoder_hidden_states) =
rotary_emb_img, transformer_blocks.at(layer - transformer_blocks.size())
rotary_emb_context, 0.0f); ->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
}
else {
std::tie(hidden_states, encoder_hidden_states) = transformer_blocks.at(layer - transformer_blocks.size())->forward(
hidden_states,
encoder_hidden_states,
temb,
rotary_emb_img,
rotary_emb_context, 0.0f);
} }
const int txt_tokens = encoder_hidden_states.shape[1]; const int txt_tokens = encoder_hidden_states.shape[1];
...@@ -907,7 +935,7 @@ std::tuple<Tensor, Tensor> FluxModel::forward_layer( ...@@ -907,7 +935,7 @@ std::tuple<Tensor, Tensor> FluxModel::forward_layer(
const int num_controlnet_block_samples = controlnet_block_samples.shape[0]; const int num_controlnet_block_samples = controlnet_block_samples.shape[0];
int interval_control = ceilDiv(transformer_blocks.size(), static_cast<size_t>(num_controlnet_block_samples)); int interval_control = ceilDiv(transformer_blocks.size(), static_cast<size_t>(num_controlnet_block_samples));
int block_index = layer / interval_control; int block_index = layer / interval_control;
// Xlabs ControlNet // Xlabs ControlNet
// block_index = layer % num_controlnet_block_samples; // block_index = layer % num_controlnet_block_samples;
...@@ -915,17 +943,18 @@ std::tuple<Tensor, Tensor> FluxModel::forward_layer( ...@@ -915,17 +943,18 @@ std::tuple<Tensor, Tensor> FluxModel::forward_layer(
} else if (layer >= transformer_blocks.size() && controlnet_single_block_samples.valid()) { } else if (layer >= transformer_blocks.size() && controlnet_single_block_samples.valid()) {
const int num_controlnet_single_block_samples = controlnet_single_block_samples.shape[0]; const int num_controlnet_single_block_samples = controlnet_single_block_samples.shape[0];
int interval_control = ceilDiv(single_transformer_blocks.size(), static_cast<size_t>(num_controlnet_single_block_samples)); int interval_control =
ceilDiv(single_transformer_blocks.size(), static_cast<size_t>(num_controlnet_single_block_samples));
int block_index = (layer - transformer_blocks.size()) / interval_control; int block_index = (layer - transformer_blocks.size()) / interval_control;
// Xlabs ControlNet // Xlabs ControlNet
// block_index = layer % num_controlnet_single_block_samples // block_index = layer % num_controlnet_single_block_samples
auto slice = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens); auto slice = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
slice = kernels::add(slice, controlnet_single_block_samples[block_index]); slice = kernels::add(slice, controlnet_single_block_samples[block_index]);
hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice); hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
} }
return { hidden_states, encoder_hidden_states }; return {hidden_states, encoder_hidden_states};
} }
void FluxModel::setAttentionImpl(AttentionImpl impl) { void FluxModel::setAttentionImpl(AttentionImpl impl) {
...@@ -936,3 +965,6 @@ void FluxModel::setAttentionImpl(AttentionImpl impl) { ...@@ -936,3 +965,6 @@ void FluxModel::setAttentionImpl(AttentionImpl impl) {
block->attnImpl = impl; block->attnImpl = impl;
} }
} }
void FluxModel::set_residual_callback(std::function<Tensor(const Tensor &)> cb) {
residual_callback = std::move(cb);
}
...@@ -5,6 +5,10 @@ ...@@ -5,6 +5,10 @@
#include "Module.h" #include "Module.h"
#include "Linear.h" #include "Linear.h"
#include "layernorm.h" #include "layernorm.h"
#include <pybind11/functional.h>
namespace pybind11 {
class function;
}
enum class AttentionImpl { enum class AttentionImpl {
FlashAttention2 = 0, FlashAttention2 = 0,
...@@ -14,7 +18,7 @@ enum class AttentionImpl { ...@@ -14,7 +18,7 @@ enum class AttentionImpl {
class AdaLayerNormZeroSingle : public Module { class AdaLayerNormZeroSingle : public Module {
public: public:
static constexpr bool USE_4BIT = true; static constexpr bool USE_4BIT = true;
using GEMM = std::conditional_t<USE_4BIT, GEMV_AWQ, GEMM_W8A8>; using GEMM = std::conditional_t<USE_4BIT, GEMV_AWQ, GEMM_W8A8>;
struct Output { struct Output {
Tensor x; Tensor x;
...@@ -36,7 +40,7 @@ private: ...@@ -36,7 +40,7 @@ private:
class AdaLayerNormZero : public Module { class AdaLayerNormZero : public Module {
public: public:
static constexpr bool USE_4BIT = true; static constexpr bool USE_4BIT = true;
using GEMM = std::conditional_t<USE_4BIT, GEMV_AWQ, GEMM_W8A8>; using GEMM = std::conditional_t<USE_4BIT, GEMV_AWQ, GEMM_W8A8>;
struct Output { struct Output {
Tensor x; Tensor x;
...@@ -45,6 +49,7 @@ public: ...@@ -45,6 +49,7 @@ public:
Tensor scale_mlp; Tensor scale_mlp;
Tensor gate_mlp; Tensor gate_mlp;
}; };
public: public:
AdaLayerNormZero(int dim, bool pre_only, Tensor::ScalarType dtype, Device device); AdaLayerNormZero(int dim, bool pre_only, Tensor::ScalarType dtype, Device device);
Output forward(Tensor x, Tensor emb); Output forward(Tensor x, Tensor emb);
...@@ -81,9 +86,15 @@ private: ...@@ -81,9 +86,15 @@ private:
class FluxSingleTransformerBlock : public Module { class FluxSingleTransformerBlock : public Module {
public: public:
static constexpr bool USE_4BIT = true; static constexpr bool USE_4BIT = true;
using GEMM = std::conditional_t<USE_4BIT, GEMM_W4A4, GEMM_W8A8>; using GEMM = std::conditional_t<USE_4BIT, GEMM_W4A4, GEMM_W8A8>;
FluxSingleTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, int mlp_ratio, bool use_fp4, Tensor::ScalarType dtype, Device device); FluxSingleTransformerBlock(int dim,
int num_attention_heads,
int attention_head_dim,
int mlp_ratio,
bool use_fp4,
Tensor::ScalarType dtype,
Device device);
Tensor forward(Tensor hidden_states, Tensor temb, Tensor rotary_emb); Tensor forward(Tensor hidden_states, Tensor temb, Tensor rotary_emb);
public: public:
...@@ -107,21 +118,32 @@ private: ...@@ -107,21 +118,32 @@ private:
class JointTransformerBlock : public Module { class JointTransformerBlock : public Module {
public: public:
static constexpr bool USE_4BIT = true; static constexpr bool USE_4BIT = true;
using GEMM = std::conditional_t<USE_4BIT, GEMM_W4A4, GEMM_W8A8>; using GEMM = std::conditional_t<USE_4BIT, GEMM_W4A4, GEMM_W8A8>;
JointTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, bool context_pre_only, bool use_fp4, Tensor::ScalarType dtype, Device device); JointTransformerBlock(int dim,
std::tuple<Tensor, Tensor> forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb, Tensor rotary_emb_context, float sparsityRatio); int num_attention_heads,
int attention_head_dim,
bool context_pre_only,
bool use_fp4,
Tensor::ScalarType dtype,
Device device);
std::tuple<Tensor, Tensor> forward(Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor temb,
Tensor rotary_emb,
Tensor rotary_emb_context,
float sparsityRatio);
public: public:
const int dim; const int dim;
const int dim_head; const int dim_head;
const int num_heads; const int num_heads;
const bool context_pre_only; const bool context_pre_only;
AdaLayerNormZero norm1;
AttentionImpl attnImpl = AttentionImpl::FlashAttention2; AttentionImpl attnImpl = AttentionImpl::FlashAttention2;
private: private:
AdaLayerNormZero norm1;
AdaLayerNormZero norm1_context; AdaLayerNormZero norm1_context;
GEMM qkv_proj; GEMM qkv_proj;
GEMM qkv_proj_context; GEMM qkv_proj_context;
...@@ -139,33 +161,35 @@ private: ...@@ -139,33 +161,35 @@ private:
class FluxModel : public Module { class FluxModel : public Module {
public: public:
FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device); FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device);
Tensor forward( Tensor forward(Tensor hidden_states,
Tensor hidden_states, Tensor encoder_hidden_states,
Tensor encoder_hidden_states, Tensor temb,
Tensor temb, Tensor rotary_emb_img,
Tensor rotary_emb_img, Tensor rotary_emb_context,
Tensor rotary_emb_context, Tensor rotary_emb_single,
Tensor rotary_emb_single, Tensor controlnet_block_samples,
Tensor controlnet_block_samples, Tensor controlnet_single_block_samples,
Tensor controlnet_single_block_samples, bool skip_first_layer = false);
bool skip_first_layer = false); std::tuple<Tensor, Tensor> forward_layer(size_t layer,
std::tuple<Tensor, Tensor> forward_layer( Tensor hidden_states,
size_t layer, Tensor encoder_hidden_states,
Tensor hidden_states, Tensor temb,
Tensor encoder_hidden_states, Tensor rotary_emb_img,
Tensor temb, Tensor rotary_emb_context,
Tensor rotary_emb_img, Tensor controlnet_block_samples,
Tensor rotary_emb_context, Tensor controlnet_single_block_samples);
Tensor controlnet_block_samples,
Tensor controlnet_single_block_samples);
void setAttentionImpl(AttentionImpl impl); void setAttentionImpl(AttentionImpl impl);
void set_residual_callback(std::function<Tensor(const Tensor &)> cb);
public: public:
const Tensor::ScalarType dtype; const Tensor::ScalarType dtype;
std::vector<std::unique_ptr<JointTransformerBlock>> transformer_blocks; std::vector<std::unique_ptr<JointTransformerBlock>> transformer_blocks;
std::vector<std::unique_ptr<FluxSingleTransformerBlock>> single_transformer_blocks; std::vector<std::unique_ptr<FluxSingleTransformerBlock>> single_transformer_blocks;
std::function<Tensor(const Tensor &)> residual_callback;
private: private:
bool offload; bool offload;
}; };
\ No newline at end of file
...@@ -9,16 +9,12 @@ ...@@ -9,16 +9,12 @@
using namespace nunchaku; using namespace nunchaku;
GEMM_F16::GEMM_F16(int in_features, int out_features, bool use_bias, Tensor::ScalarType dtype, Device device) : GEMM_F16::GEMM_F16(int in_features, int out_features, bool use_bias, Tensor::ScalarType dtype, Device device)
in_features(in_features), out_features(out_features) : in_features(in_features), out_features(out_features) {
{
this->weight = Tensor::allocate({out_features, in_features}, dtype, device); this->weight = Tensor::allocate({out_features, in_features}, dtype, device);
this->bias = use_bias ? Tensor::allocate({out_features}, dtype, device) : Tensor{}; this->bias = use_bias ? Tensor::allocate({out_features}, dtype, device) : Tensor{};
registerParams registerParams(weight, "weight", ParamFlags::LazyLoad)(bias, "bias");
(weight, "weight", ParamFlags::LazyLoad)
(bias, "bias")
;
} }
Tensor GEMM_F16::forward(Tensor x) { Tensor GEMM_F16::forward(Tensor x) {
...@@ -26,26 +22,20 @@ Tensor GEMM_F16::forward(Tensor x) { ...@@ -26,26 +22,20 @@ Tensor GEMM_F16::forward(Tensor x) {
return out; return out;
} }
GEMV_AWQ::GEMV_AWQ(int in_features, int out_features, bool use_bias, Tensor::ScalarType dtype, Device device) : GEMV_AWQ::GEMV_AWQ(int in_features, int out_features, bool use_bias, Tensor::ScalarType dtype, Device device)
in_features(in_features), out_features(out_features), group_size(64), lora_rank(0), lora_scale(1.0f), device(device) : in_features(in_features), out_features(out_features), group_size(64), lora_rank(0), lora_scale(1.0f),
{ device(device) {
this->qweight = Tensor::allocate({out_features / 4, ceilDiv(in_features, 8) * 4}, Tensor::INT32, device); this->qweight = Tensor::allocate({out_features / 4, ceilDiv(in_features, 8) * 4}, Tensor::INT32, device);
this->wscales = Tensor::allocate({ceilDiv(in_features, group_size), out_features}, dtype, device); this->wscales = Tensor::allocate({ceilDiv(in_features, group_size), out_features}, dtype, device);
this->wzeros = Tensor::allocate({ceilDiv(in_features, group_size), out_features}, dtype, device); this->wzeros = Tensor::allocate({ceilDiv(in_features, group_size), out_features}, dtype, device);
this->bias = use_bias ? Tensor::allocate({out_features}, dtype, device) : Tensor{}; this->bias = use_bias ? Tensor::allocate({out_features}, dtype, device) : Tensor{};
// !!! lora layout is different from w4a4 !!! // !!! lora layout is different from w4a4 !!!
this->lora_down = Tensor::allocate({lora_rank, in_features}, dtype, device, true); this->lora_down = Tensor::allocate({lora_rank, in_features}, dtype, device, true);
this->lora_up = Tensor::allocate({out_features, lora_rank}, dtype, device, true); this->lora_up = Tensor::allocate({out_features, lora_rank}, dtype, device, true);
registerParams registerParams(qweight, "qweight", ParamFlags::LazyLoad)(wscales, "wscales")(wzeros, "wzeros")(bias, "bias")(
(qweight, "qweight", ParamFlags::LazyLoad) lora_down, "lora_down", ParamFlags::Optional)(lora_up, "lora_up", ParamFlags::Optional);
(wscales, "wscales")
(wzeros, "wzeros")
(bias, "bias")
(lora_down, "lora_down", ParamFlags::Optional)
(lora_up, "lora_up", ParamFlags::Optional)
;
} }
void GEMV_AWQ::loadParam(std::string key, Tensor &dst, Tensor src) { void GEMV_AWQ::loadParam(std::string key, Tensor &dst, Tensor src) {
...@@ -56,7 +46,7 @@ void GEMV_AWQ::loadParam(std::string key, Tensor &dst, Tensor src) { ...@@ -56,7 +46,7 @@ void GEMV_AWQ::loadParam(std::string key, Tensor &dst, Tensor src) {
Module::loadParam(key, dst, src); Module::loadParam(key, dst, src);
if (key == "lora_down") { if (key == "lora_down") {
const int new_rank = dst.shape[0]; const int new_rank = dst.shape[0];
this->lora_rank = new_rank; this->lora_rank = new_rank;
} }
} else { } else {
Module::loadParam(key, dst, src); Module::loadParam(key, dst, src);
...@@ -70,7 +60,7 @@ Tensor GEMV_AWQ::forward(Tensor x) { ...@@ -70,7 +60,7 @@ Tensor GEMV_AWQ::forward(Tensor x) {
debug("x", x); debug("x", x);
const int M = (int)x.numel() / x.shape[-1]; const int M = (int)x.numel() / x.shape[-1];
Tensor out = gemv_awq(x, this->qweight, this->wscales, this->wzeros, M, out_features, in_features, group_size); Tensor out = gemv_awq(x, this->qweight, this->wscales, this->wzeros, M, out_features, in_features, group_size);
if (bias.valid()) { if (bias.valid()) {
// TODO: batch // TODO: batch
// assert(out.numel() == bias.numel()); // assert(out.numel() == bias.numel());
...@@ -91,19 +81,16 @@ Tensor GEMV_AWQ::forward(Tensor x) { ...@@ -91,19 +81,16 @@ Tensor GEMV_AWQ::forward(Tensor x) {
} }
debug("out", out); debug("out", out);
return out; return out;
} }
#define NO_LORA_FUSION 0 #define NO_LORA_FUSION 0
GEMM_W4A4::GEMM_W4A4(int in_features, int out_features, bool bias, bool use_fp4, Tensor::ScalarType dtype, Device device) : GEMM_W4A4::GEMM_W4A4(
in_features(in_features), out_features(out_features), int in_features, int out_features, bool bias, bool use_fp4, Tensor::ScalarType dtype, Device device)
in_features_pad(ceilDiv(in_features, 128) * 128), out_features_pad(ceilDiv(out_features, 128) * 128), : in_features(in_features), out_features(out_features), in_features_pad(ceilDiv(in_features, 128) * 128),
use_fp4(use_fp4), out_features_pad(ceilDiv(out_features, 128) * 128), use_fp4(use_fp4), lora_rank(0), dtype(dtype), device(device) {
lora_rank(0), dtype(dtype), device(device)
{
this->qweight = Tensor::allocate({out_features_pad, in_features_pad / 2}, Tensor::INT8, device, true); this->qweight = Tensor::allocate({out_features_pad, in_features_pad / 2}, Tensor::INT8, device, true);
if (use_fp4) { if (use_fp4) {
this->wscales = Tensor::allocate({in_features_pad / 16, out_features_pad}, Tensor::FP8_E4M3, device, true); this->wscales = Tensor::allocate({in_features_pad / 16, out_features_pad}, Tensor::FP8_E4M3, device, true);
...@@ -114,27 +101,20 @@ GEMM_W4A4::GEMM_W4A4(int in_features, int out_features, bool bias, bool use_fp4, ...@@ -114,27 +101,20 @@ GEMM_W4A4::GEMM_W4A4(int in_features, int out_features, bool bias, bool use_fp4,
this->bias = bias ? Tensor::allocate({out_features_pad}, dtype, device, true) : Tensor{}; this->bias = bias ? Tensor::allocate({out_features_pad}, dtype, device, true) : Tensor{};
this->lora_down = Tensor::allocate({in_features_pad, lora_rank}, dtype, device, true); this->lora_down = Tensor::allocate({in_features_pad, lora_rank}, dtype, device, true);
this->lora_up = Tensor::allocate({out_features_pad, lora_rank}, dtype, device, true); this->lora_up = Tensor::allocate({out_features_pad, lora_rank}, dtype, device, true);
// TODO: smooth factor in non-Lora fusion // TODO: smooth factor in non-Lora fusion
this->smooth = Tensor::allocate({in_features_pad}, dtype, device, true); this->smooth = Tensor::allocate({in_features_pad}, dtype, device, true);
// FIXME: reset wtscale and wcscales to default values when reloading the weights // FIXME: reset wtscale and wcscales to default values when reloading the weights
this->wtscale = Tensor::allocate({1}, Tensor::FP32, Device::cpu(), true); this->wtscale = Tensor::allocate({1}, Tensor::FP32, Device::cpu(), true);
*this->wtscale.data_ptr<float>() = 1.0f; *this->wtscale.data_ptr<float>() = 1.0f;
this->wcscales = Tensor::allocate({0}, dtype, device, true); this->wcscales = Tensor::allocate({0}, dtype, device, true);
registerParams registerParams(qweight, "qweight", ParamFlags::LazyLoad)(wscales, "wscales")(this->bias, "bias")(
(qweight, "qweight", ParamFlags::LazyLoad) lora_down, "lora_down", ParamFlags::Optional)(lora_up, "lora_up", ParamFlags::Optional)(smooth, "smooth")(
(wscales, "wscales") wtscale, "wtscale", ParamFlags::Optional)(wcscales, "wcscales", ParamFlags::Optional);
(this->bias, "bias")
(lora_down, "lora_down", ParamFlags::Optional)
(lora_up, "lora_up", ParamFlags::Optional)
(smooth, "smooth")
(wtscale, "wtscale", ParamFlags::Optional)
(wcscales, "wcscales", ParamFlags::Optional)
;
#if NO_LORA_FUSION #if NO_LORA_FUSION
checkCUBLAS(cublasCreate(&handle)); checkCUBLAS(cublasCreate(&handle));
...@@ -181,11 +161,21 @@ Tensor GEMM_W4A4::forward_silu(Tensor x) { ...@@ -181,11 +161,21 @@ Tensor GEMM_W4A4::forward_silu(Tensor x) {
return std::get<Tensor>(this->forward(x, FuseOptions::SILU, nullptr)); return std::get<Tensor>(this->forward(x, FuseOptions::SILU, nullptr));
} }
std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward(Tensor x, FuseOptions fuse, GEMM_W4A4 *nextGEMM) { std::variant<Tensor, GEMM_W4A4::QuantizedActivation>
GEMM_W4A4::forward(Tensor x, FuseOptions fuse, GEMM_W4A4 *nextGEMM) {
return forward_quant(quantize(x, false), fuse, nextGEMM); return forward_quant(quantize(x, false), fuse, nextGEMM);
} }
void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor norm_k, Tensor rotary_emb, Tensor out_q, Tensor out_k, Tensor out_v, int numTokens) { void GEMM_W4A4::forward(Tensor x,
Tensor out,
Tensor pool,
Tensor norm_q,
Tensor norm_k,
Tensor rotary_emb,
Tensor out_q,
Tensor out_k,
Tensor out_v,
int numTokens) {
QuantizedActivation qact = quantize(x, false); QuantizedActivation qact = quantize(x, false);
#if !NO_LORA_FUSION #if !NO_LORA_FUSION
...@@ -198,42 +188,87 @@ void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor ...@@ -198,42 +188,87 @@ void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor
debug("gemm.nolora.out", out); debug("gemm.nolora.out", out);
#endif #endif
kernels::gemm_w4a4( kernels::gemm_w4a4(qact.act,
qact.act, qweight, out, {}, qact.ascales, wscales, {}, pool, qact.lora_act, this->lora_up, {}, {}, norm_q, norm_k, rotary_emb, this->bias, {}, {}, {}, qact.is_unsigned, this->lora_scales, false, qweight,
use_fp4, *this->wtscale.data_ptr<float>(), wcscales.numel() > 0 ? wcscales: Tensor{}, out,
out_q, out_k, out_v, numTokens {},
); qact.ascales,
wscales,
{},
pool,
qact.lora_act,
this->lora_up,
{},
{},
norm_q,
norm_k,
rotary_emb,
this->bias,
{},
{},
{},
qact.is_unsigned,
this->lora_scales,
false,
use_fp4,
*this->wtscale.data_ptr<float>(),
wcscales.numel() > 0 ? wcscales : Tensor{},
out_q,
out_k,
out_v,
numTokens);
debug("gemm.out", out); debug("gemm.out", out);
#else #else
const int M = (int)qact.act.numel() / qact.act.shape[-1]; const int M = (int)qact.act.numel() / qact.act.shape[-1];
kernels::gemm_w4a4(qact.act, qweight, out, {}, qact.ascales, wscales, {}, pool, {}, {}, {}, {}, norm_q, norm_k, rotary_emb, this->bias, {}, qact.is_unsigned, this->lora_scales); kernels::gemm_w4a4(qact.act,
qweight,
out,
{},
qact.ascales,
wscales,
{},
pool,
{},
{},
{},
{},
norm_q,
norm_k,
rotary_emb,
this->bias,
{},
qact.is_unsigned,
this->lora_scales);
nvtxRangePushA("LoraUp"); nvtxRangePushA("LoraUp");
static const half one = 1.0; static const half one = 1.0;
static const half zero = 0.0; static const half zero = 0.0;
// lora_up: [M, R] * [OC, R] => [M, OC] // lora_up: [M, R] * [OC, R] => [M, OC]
// cublas view: [OC, R] * [M, R]^T // cublas view: [OC, R] * [M, R]^T
checkCUBLAS(cublasHgemm( checkCUBLAS(cublasHgemm(handle,
handle, CUBLAS_OP_T,
CUBLAS_OP_T, CUBLAS_OP_N, CUBLAS_OP_N,
this->out_features, M, this->lora_rank, this->out_features,
&one, M,
this->lora_up.data_ptr<half>(), this->lora_rank,
this->lora_rank, &one,
qact.lora_act.data_ptr<half>(), this->lora_up.data_ptr<half>(),
this->lora_rank, this->lora_rank,
&one, qact.lora_act.data_ptr<half>(),
out.data_ptr<half>(), this->lora_rank,
this->out_features)); &one,
out.data_ptr<half>(),
this->out_features));
nvtxRangePop(); nvtxRangePop();
#endif #endif
} }
std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(QuantizedActivation qact, FuseOptions fuse, GEMM_W4A4 *nextGEMM) { std::variant<Tensor, GEMM_W4A4::QuantizedActivation>
GEMM_W4A4::forward_quant(QuantizedActivation qact, FuseOptions fuse, GEMM_W4A4 *nextGEMM) {
Tensor out; Tensor out;
QuantizedActivation qout; QuantizedActivation qout;
...@@ -246,8 +281,8 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu ...@@ -246,8 +281,8 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
// auto shape = TensorShape(qact.act.shape.dataExtent); // auto shape = TensorShape(qact.act.shape.dataExtent);
// shape[-1] = out_features; // shape[-1] = out_features;
auto shape = TensorShape(qact.actShape.dataExtent); auto shape = TensorShape(qact.actShape.dataExtent);
shape[-1] = out_features; shape[-1] = out_features;
out = Tensor::allocate(shape, dtype, device); out = Tensor::allocate(shape, dtype, device);
} else { } else {
qout.act = Tensor::allocate({M, out_features_pad / 2}, Tensor::INT8, device); qout.act = Tensor::allocate({M, out_features_pad / 2}, Tensor::INT8, device);
if (use_fp4) { if (use_fp4) {
...@@ -255,11 +290,11 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu ...@@ -255,11 +290,11 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
} else { } else {
qout.ascales = Tensor::allocate({out_features_pad / 64, M}, dtype, device); qout.ascales = Tensor::allocate({out_features_pad / 64, M}, dtype, device);
} }
qout.lora_act = Tensor::allocate({M, lora_rank}, Tensor::FP32, device); qout.lora_act = Tensor::allocate({M, lora_rank}, Tensor::FP32, device);
qout.is_unsigned = !use_fp4; qout.is_unsigned = !use_fp4;
qout.actShape = qact.actShape; qout.actShape = qact.actShape;
next_lora = nextGEMM->lora_down; next_lora = nextGEMM->lora_down;
next_smooth = nextGEMM->smooth; next_smooth = nextGEMM->smooth;
} }
...@@ -280,11 +315,35 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu ...@@ -280,11 +315,35 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
} }
#endif #endif
kernels::gemm_w4a4( kernels::gemm_w4a4(qact.act,
qact.act, qweight, out, qout.act, qact.ascales, wscales, qout.ascales, {}, qact.lora_act, this->lora_up, next_lora, qout.lora_act, {}, {}, {}, this->bias, next_smooth, {}, {}, qact.is_unsigned, this->lora_scales, fuse == FuseOptions::SILU, qweight,
use_fp4, *this->wtscale.data_ptr<float>(), wcscales.numel() > 0 ? wcscales: Tensor{}, out,
{}, {}, {}, 0 qout.act,
); qact.ascales,
wscales,
qout.ascales,
{},
qact.lora_act,
this->lora_up,
next_lora,
qout.lora_act,
{},
{},
{},
this->bias,
next_smooth,
{},
{},
qact.is_unsigned,
this->lora_scales,
fuse == FuseOptions::SILU,
use_fp4,
*this->wtscale.data_ptr<float>(),
wcscales.numel() > 0 ? wcscales : Tensor{},
{},
{},
{},
0);
if (fuse == FuseOptions::EMPTY || fuse == FuseOptions::SILU) { if (fuse == FuseOptions::EMPTY || fuse == FuseOptions::SILU) {
debug("gemm.out", out); debug("gemm.out", out);
...@@ -294,36 +353,55 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu ...@@ -294,36 +353,55 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
debug("gemm.lora_act_out", qout.lora_act); debug("gemm.lora_act_out", qout.lora_act);
} }
#else #else
if (!out.valid()) { if (!out.valid()) {
auto shape = TensorShape(qact.act.shape.dataExtent); auto shape = TensorShape(qact.act.shape.dataExtent);
shape[-1] = out_features; shape[-1] = out_features;
out = Tensor::allocate(shape, Tensor::FP16, qweight.device()); out = Tensor::allocate(shape, Tensor::FP16, qweight.device());
} }
kernels::gemm_w4a4(qact.act, qweight, out, qout.act, qact.ascales, wscales, qout.ascales, {}, {}, {}, {}, {}, {}, {}, {}, this->bias, next_smooth, qact.is_unsigned, this->lora_scales); kernels::gemm_w4a4(qact.act,
qweight,
out,
qout.act,
qact.ascales,
wscales,
qout.ascales,
{},
{},
{},
{},
{},
{},
{},
{},
this->bias,
next_smooth,
qact.is_unsigned,
this->lora_scales);
nvtxRangePushA("LoraUp"); nvtxRangePushA("LoraUp");
static const half one = 1.0; static const half one = 1.0;
static const half zero = 0.0; static const half zero = 0.0;
// lora_up: [M, R] * [OC, R]^T => [M, OC] // lora_up: [M, R] * [OC, R]^T => [M, OC]
// cublas view: [R, OC]^T * [R, M] => [OC, M] // cublas view: [R, OC]^T * [R, M] => [OC, M]
// lora_up layout wrong? // lora_up layout wrong?
checkCUBLAS(cublasHgemm( checkCUBLAS(cublasHgemm(handle,
handle, CUBLAS_OP_T,
CUBLAS_OP_T, CUBLAS_OP_N, CUBLAS_OP_N,
this->out_features, M, this->lora_rank, this->out_features,
&one, M,
this->lora_up.data_ptr<half>(), this->lora_rank,
this->lora_rank, &one,
qact.lora_act.data_ptr<half>(), this->lora_up.data_ptr<half>(),
this->lora_rank, this->lora_rank,
&one, qact.lora_act.data_ptr<half>(),
out.data_ptr<half>(), this->lora_rank,
this->out_features)); &one,
out.data_ptr<half>(),
this->out_features));
nvtxRangePop(); nvtxRangePop();
...@@ -332,18 +410,20 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu ...@@ -332,18 +410,20 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
// IC is for next lora (OC of this layer) // IC is for next lora (OC of this layer)
// lora_down: [M, IC] * [IC, R] => [M, R] // lora_down: [M, IC] * [IC, R] => [M, R]
// cublas view: [R, IC] * [IC, M] => [R, M] // cublas view: [R, IC] * [IC, M] => [R, M]
checkCUBLAS(cublasHgemm( checkCUBLAS(cublasHgemm(handle,
handle, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N, CUBLAS_OP_N,
this->lora_rank, M, this->out_features, this->lora_rank,
&one, M,
next_lora.data_ptr<half>(), this->out_features,
this->lora_rank, &one,
out.data_ptr<half>(), next_lora.data_ptr<half>(),
this->out_features, this->lora_rank,
&zero, out.data_ptr<half>(),
qout.lora_act.data_ptr<half>(), this->out_features,
this->lora_rank)); &zero,
qout.lora_act.data_ptr<half>(),
this->lora_rank));
out = {}; out = {};
...@@ -363,7 +443,7 @@ Tensor GEMM_W4A4::forward_quant(QuantizedActivation qact) { ...@@ -363,7 +443,7 @@ Tensor GEMM_W4A4::forward_quant(QuantizedActivation qact) {
GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) { GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
const int actualM = x.numel() / x.shape[-1]; const int actualM = x.numel() / x.shape[-1];
const int M = ceilDiv(actualM, 256) * 256; const int M = ceilDiv(actualM, 256) * 256;
// auto shape = TensorShape(x.shape.dataExtent); // auto shape = TensorShape(x.shape.dataExtent);
// shape[-1] = in_features / 2; // shape[-1] = in_features / 2;
...@@ -375,39 +455,42 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) { ...@@ -375,39 +455,42 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
} else { } else {
qact.ascales = Tensor::allocate({in_features_pad / 64, M}, dtype, device); qact.ascales = Tensor::allocate({in_features_pad / 64, M}, dtype, device);
} }
qact.lora_act = Tensor::allocate({M, lora_rank}, Tensor::FP32, device); qact.lora_act = Tensor::allocate({M, lora_rank}, Tensor::FP32, device);
qact.is_unsigned = false; qact.is_unsigned = false;
qact.actShape = x.shape.dataExtent; qact.actShape = x.shape.dataExtent;
#if !NO_LORA_FUSION #if !NO_LORA_FUSION
debug("quantize.x", x); debug("quantize.x", x);
debug("quantize.smooth", this->smooth); debug("quantize.smooth", this->smooth);
kernels::quantize_w4a4_act_fuse_lora(x, qact.act, qact.ascales, this->lora_down, qact.lora_act, this->smooth, fuse_glu, use_fp4); kernels::quantize_w4a4_act_fuse_lora(
x, qact.act, qact.ascales, this->lora_down, qact.lora_act, this->smooth, fuse_glu, use_fp4);
debug("quantize.qact", qact.act); debug("quantize.qact", qact.act);
debug("quantize.ascales", qact.ascales); debug("quantize.ascales", qact.ascales);
debug("quantize.lora_act", qact.lora_act); debug("quantize.lora_act", qact.lora_act);
#else #else
static const half one = 1.0; static const half one = 1.0;
static const half zero = 0.0; static const half zero = 0.0;
nvtxRangePushA("LoraDown"); nvtxRangePushA("LoraDown");
// lora_down: [M, IC] * [IC, R] => [M, R] // lora_down: [M, IC] * [IC, R] => [M, R]
// cublas view: [R, IC] * [IC, M] // cublas view: [R, IC] * [IC, M]
checkCUBLAS(cublasHgemm( checkCUBLAS(cublasHgemm(handle,
handle, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N, CUBLAS_OP_N,
this->lora_rank, M, this->in_features, this->lora_rank,
&one, M,
lora_down.data_ptr<half>(), this->in_features,
this->lora_rank, &one,
x.data_ptr<half>(), lora_down.data_ptr<half>(),
this->in_features, this->lora_rank,
&zero, x.data_ptr<half>(),
qact.lora_act.data_ptr<half>(), this->in_features,
this->lora_rank)); &zero,
qact.lora_act.data_ptr<half>(),
this->lora_rank));
nvtxRangePop(); nvtxRangePop();
...@@ -418,18 +501,13 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) { ...@@ -418,18 +501,13 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
return qact; return qact;
} }
GEMM_W8A8::GEMM_W8A8(int in_features, int out_features, bool bias, Tensor::ScalarType dtype, Device device) : GEMM_W8A8::GEMM_W8A8(int in_features, int out_features, bool bias, Tensor::ScalarType dtype, Device device)
in_features(in_features), out_features(out_features), dtype(dtype) : in_features(in_features), out_features(out_features), dtype(dtype) {
{
this->qweight = Tensor::allocate({out_features, in_features}, Tensor::INT8, device); this->qweight = Tensor::allocate({out_features, in_features}, Tensor::INT8, device);
this->wscales = Tensor::allocate({out_features}, dtype, device); this->wscales = Tensor::allocate({out_features}, dtype, device);
this->bias = bias ? Tensor::allocate({out_features}, dtype, device, true) : Tensor{}; this->bias = bias ? Tensor::allocate({out_features}, dtype, device, true) : Tensor{};
registerParams registerParams(qweight, "qweight", ParamFlags::LazyLoad)(wscales, "wscales")(this->bias, "bias");
(qweight, "qweight", ParamFlags::LazyLoad)
(wscales, "wscales")
(this->bias, "bias")
;
} }
GEMM_W8A8::QuantizedActivation GEMM_W8A8::quantize(Tensor x, bool fuse_glu) { GEMM_W8A8::QuantizedActivation GEMM_W8A8::quantize(Tensor x, bool fuse_glu) {
...@@ -438,7 +516,7 @@ GEMM_W8A8::QuantizedActivation GEMM_W8A8::quantize(Tensor x, bool fuse_glu) { ...@@ -438,7 +516,7 @@ GEMM_W8A8::QuantizedActivation GEMM_W8A8::quantize(Tensor x, bool fuse_glu) {
if (fuse_glu) { if (fuse_glu) {
qshape[-1] /= 2; qshape[-1] /= 2;
} }
qact.act = Tensor::allocate(qshape, Tensor::INT8, x.device()); qact.act = Tensor::allocate(qshape, Tensor::INT8, x.device());
qact.ascales = Tensor::allocate({(int)x.numel() / x.shape[-1]}, this->dtype, x.device()); qact.ascales = Tensor::allocate({(int)x.numel() / x.shape[-1]}, this->dtype, x.device());
debug("quantize.x", x); debug("quantize.x", x);
...@@ -453,7 +531,7 @@ GEMM_W8A8::QuantizedActivation GEMM_W8A8::quantize(Tensor x, bool fuse_glu) { ...@@ -453,7 +531,7 @@ GEMM_W8A8::QuantizedActivation GEMM_W8A8::quantize(Tensor x, bool fuse_glu) {
Tensor GEMM_W8A8::forward_quant(QuantizedActivation qact) { Tensor GEMM_W8A8::forward_quant(QuantizedActivation qact) {
auto shape = TensorShape(qact.act.shape.dataExtent); auto shape = TensorShape(qact.act.shape.dataExtent);
shape[-1] = out_features; shape[-1] = out_features;
Tensor out = Tensor::allocate(shape, this->dtype, qact.act.device()); Tensor out = Tensor::allocate(shape, this->dtype, qact.act.device());
kernels::gemm_w8a8(qact.act, this->qweight, out, qact.ascales, this->wscales, this->bias); kernels::gemm_w8a8(qact.act, this->qweight, out, qact.ascales, this->wscales, this->bias);
...@@ -461,18 +539,13 @@ Tensor GEMM_W8A8::forward_quant(QuantizedActivation qact) { ...@@ -461,18 +539,13 @@ Tensor GEMM_W8A8::forward_quant(QuantizedActivation qact) {
return out; return out;
} }
DWCONV::DWCONV(int in_features, bool use_bias, Tensor::ScalarType dtype, Device device) : DWCONV::DWCONV(int in_features, bool use_bias, Tensor::ScalarType dtype, Device device) : in_features(in_features) {
in_features(in_features)
{
this->weight = Tensor::allocate({in_features, 3, 3, 1}, dtype, device); this->weight = Tensor::allocate({in_features, 3, 3, 1}, dtype, device);
this->bias = use_bias ? Tensor::allocate({in_features}, dtype, device) : Tensor{}; this->bias = use_bias ? Tensor::allocate({in_features}, dtype, device) : Tensor{};
registerParams registerParams(this->weight, "weight")(this->bias, "bias");
(this->weight, "weight")
(this->bias, "bias")
;
} }
Tensor DWCONV::forward(Tensor x) { Tensor DWCONV::forward(Tensor x) {
return dwconv_f16(x, this->weight, {}, this->bias); return dwconv_f16(x, this->weight, {}, this->bias);
} }
\ No newline at end of file
...@@ -37,6 +37,7 @@ public: ...@@ -37,6 +37,7 @@ public:
float lora_scale; float lora_scale;
const Device device; const Device device;
public: public:
Tensor qweight; Tensor qweight;
Tensor wscales; Tensor wscales;
...@@ -69,12 +70,18 @@ public: ...@@ -69,12 +70,18 @@ public:
Tensor forward(Tensor x); Tensor forward(Tensor x);
Tensor forward_silu(Tensor x); Tensor forward_silu(Tensor x);
std::variant<Tensor, QuantizedActivation> forward(Tensor x, FuseOptions fuse, GEMM_W4A4 *nextGEMM = nullptr); std::variant<Tensor, QuantizedActivation> forward(Tensor x, FuseOptions fuse, GEMM_W4A4 *nextGEMM = nullptr);
void forward( void forward(Tensor x,
Tensor x, Tensor out, Tensor out,
Tensor pool = {}, Tensor norm_q = {}, Tensor norm_k = {}, Tensor rotary_emb = {}, Tensor pool = {},
Tensor out_q = {}, Tensor out_k = {}, Tensor out_v = {}, int numTokens = 0 Tensor norm_q = {},
); Tensor norm_k = {},
std::variant<Tensor, QuantizedActivation> forward_quant(QuantizedActivation qact, FuseOptions fuse, GEMM_W4A4 *nextGEMM = nullptr); Tensor rotary_emb = {},
Tensor out_q = {},
Tensor out_k = {},
Tensor out_v = {},
int numTokens = 0);
std::variant<Tensor, QuantizedActivation>
forward_quant(QuantizedActivation qact, FuseOptions fuse, GEMM_W4A4 *nextGEMM = nullptr);
Tensor forward_quant(QuantizedActivation qact); Tensor forward_quant(QuantizedActivation qact);
public: public:
...@@ -86,7 +93,7 @@ public: ...@@ -86,7 +93,7 @@ public:
const int in_features_pad; const int in_features_pad;
const int out_features_pad; const int out_features_pad;
const bool use_fp4; const bool use_fp4;
int lora_rank; int lora_rank;
std::vector<float> lora_scales; // every 16 ranks share a scale std::vector<float> lora_scales; // every 16 ranks share a scale
...@@ -118,13 +125,16 @@ public: ...@@ -118,13 +125,16 @@ public:
Tensor act; Tensor act;
Tensor ascales; Tensor ascales;
}; };
public: public:
GEMM_W8A8(int in_features, int out_features, bool bias, Tensor::ScalarType dtype, Device device); GEMM_W8A8(int in_features, int out_features, bool bias, Tensor::ScalarType dtype, Device device);
public: public:
QuantizedActivation quantize(Tensor x, bool fuse_glu); QuantizedActivation quantize(Tensor x, bool fuse_glu);
Tensor forward_quant(QuantizedActivation qact); Tensor forward_quant(QuantizedActivation qact);
Tensor forward(Tensor x) { return forward_quant(quantize(x, false)); } Tensor forward(Tensor x) {
return forward_quant(quantize(x, false));
}
public: public:
const int in_features; const int in_features;
...@@ -149,4 +159,4 @@ public: ...@@ -149,4 +159,4 @@ public:
public: public:
Tensor weight; Tensor weight;
Tensor bias; Tensor bias;
}; };
\ No newline at end of file
...@@ -10,8 +10,8 @@ void Module::copyWithCast(Tensor dst, Tensor src) { ...@@ -10,8 +10,8 @@ void Module::copyWithCast(Tensor dst, Tensor src) {
nunchaku::kernels::cast(src, dst); nunchaku::kernels::cast(src, dst);
} else { } else {
Tensor tmp; Tensor tmp;
tmp.buffer = dst.buffer; tmp.buffer = dst.buffer;
tmp.shape = dst.shape; tmp.shape = dst.shape;
tmp.scalarType = src.scalarType; tmp.scalarType = src.scalarType;
tmp.copy_(src); tmp.copy_(src);
nunchaku::kernels::cast(tmp, dst); nunchaku::kernels::cast(tmp, dst);
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
class Module { class Module {
protected: protected:
enum class ParamFlags : int { enum class ParamFlags : int {
None = 0, None = 0,
Optional = 1, Optional = 1,
LazyLoad = 2, LazyLoad = 2,
}; };
...@@ -19,7 +19,7 @@ protected: ...@@ -19,7 +19,7 @@ protected:
Tensor src; Tensor src;
}; };
struct Param { struct Param {
Tensor *tensor = nullptr; Tensor *tensor = nullptr;
ParamFlags flags = ParamFlags::None; ParamFlags flags = ParamFlags::None;
TensorLazyLoadInfo lazyInfo; TensorLazyLoadInfo lazyInfo;
...@@ -50,7 +50,7 @@ public: ...@@ -50,7 +50,7 @@ public:
std::string getPrefix() const { std::string getPrefix() const {
std::string fullName = getFullName(); std::string fullName = getFullName();
std::string prefix = fullName.empty() ? "" : fullName + "."; std::string prefix = fullName.empty() ? "" : fullName + ".";
return prefix; return prefix;
} }
...@@ -80,7 +80,7 @@ public: ...@@ -80,7 +80,7 @@ public:
continue; continue;
} }
// keep loading params if param is not released // keep loading params if param is not released
} }
this->loadParam(key, *param.tensor, src); this->loadParam(key, *param.tensor, src);
// tensor->copy_(src); // tensor->copy_(src);
} }
...@@ -99,8 +99,8 @@ public: ...@@ -99,8 +99,8 @@ public:
} }
TensorLazyLoadInfo &lazy = param.lazyInfo; TensorLazyLoadInfo &lazy = param.lazyInfo;
Tensor &dst = *param.tensor; Tensor &dst = *param.tensor;
Tensor src = lazy.src; Tensor src = lazy.src;
if (dst.valid()) { if (dst.valid()) {
continue; continue;
...@@ -108,7 +108,8 @@ public: ...@@ -108,7 +108,8 @@ public:
dst = Tensor::allocate(lazy.shape, lazy.type, lazy.device); dst = Tensor::allocate(lazy.shape, lazy.type, lazy.device);
if (!src.valid() && !checkFlag(param.flags, ParamFlags::Optional)) { if (!src.valid() && !checkFlag(param.flags, ParamFlags::Optional)) {
throw std::runtime_error(spdlog::fmt_lib::format("Lazy load: Tensor {} has no src", m->getPrefix() + key)); throw std::runtime_error(
spdlog::fmt_lib::format("Lazy load: Tensor {} has no src", m->getPrefix() + key));
} }
m->loadParam(key, dst, src); m->loadParam(key, dst, src);
} }
...@@ -127,14 +128,10 @@ public: ...@@ -127,14 +128,10 @@ public:
}); });
} }
void setLazyLoad(bool val) { void setLazyLoad(bool val) {
traverse([val](Module *m) { traverse([val](Module *m) { m->enabledLazyLoad = val; });
m->enabledLazyLoad = val;
});
} }
void setAutoCastFP16(bool val) { void setAutoCastFP16(bool val) {
traverse([val](Module *m) { traverse([val](Module *m) { m->enabledAutoCastFP16 = val; });
m->enabledAutoCastFP16 = val;
});
} }
protected: protected:
...@@ -143,7 +140,8 @@ protected: ...@@ -143,7 +140,8 @@ protected:
Tensor::FP16, Tensor::FP16,
Tensor::BF16, Tensor::BF16,
}; };
if (enabledAutoCastFP16 && dst.scalar_type() != src.scalar_type() && whitelist.contains(dst.scalar_type()) && whitelist.contains(src.scalar_type())) { if (enabledAutoCastFP16 && dst.scalar_type() != src.scalar_type() && whitelist.contains(dst.scalar_type()) &&
whitelist.contains(src.scalar_type())) {
copyWithCast(dst, src); copyWithCast(dst, src);
} else { } else {
dst.copy_(src); dst.copy_(src);
...@@ -159,7 +157,7 @@ protected: ...@@ -159,7 +157,7 @@ protected:
}; };
ChildrenRegisterHelper registerChildren(Module &module, std::string name) { ChildrenRegisterHelper registerChildren(Module &module, std::string name) {
module.parent = this; module.parent = this;
module.name = name; module.name = name;
children.push_back(&module); children.push_back(&module);
return ChildrenRegisterHelper(*this); return ChildrenRegisterHelper(*this);
} }
...@@ -174,13 +172,13 @@ protected: ...@@ -174,13 +172,13 @@ protected:
ParamsRegisterHelper registerParams(Tensor &param, std::string name, ParamFlags flags = ParamFlags::None) { ParamsRegisterHelper registerParams(Tensor &param, std::string name, ParamFlags flags = ParamFlags::None) {
if (param.valid()) { if (param.valid()) {
params[name].tensor = &param; params[name].tensor = &param;
params[name].flags = flags; params[name].flags = flags;
if (checkFlag(flags, ParamFlags::LazyLoad) && param.valid()) { if (checkFlag(flags, ParamFlags::LazyLoad) && param.valid()) {
TensorLazyLoadInfo &lazy = params[name].lazyInfo; TensorLazyLoadInfo &lazy = params[name].lazyInfo;
lazy.shape = param.shape; lazy.shape = param.shape;
lazy.type = param.dtype(); lazy.type = param.dtype();
lazy.device = param.device(); lazy.device = param.device();
} }
} }
return ParamsRegisterHelper(*this); return ParamsRegisterHelper(*this);
...@@ -204,12 +202,12 @@ private: ...@@ -204,12 +202,12 @@ private:
void copyWithCast(Tensor dst, Tensor src); void copyWithCast(Tensor dst, Tensor src);
public: public:
Module *parent = nullptr; Module *parent = nullptr;
std::string name = ""; std::string name = "";
std::vector<Module *> children; std::vector<Module *> children;
std::map<std::string, Param> params; std::map<std::string, Param> params;
bool enabledLazyLoad = false; bool enabledLazyLoad = false;
bool enabledAutoCastFP16 = true; bool enabledAutoCastFP16 = true;
}; };
...@@ -226,12 +224,11 @@ struct LayerOffloadHelper { ...@@ -226,12 +224,11 @@ struct LayerOffloadHelper {
std::unique_ptr<CUDAEventWrapper> eventComputeDone; std::unique_ptr<CUDAEventWrapper> eventComputeDone;
std::unique_ptr<CUDAEventWrapper> eventLoadDone; std::unique_ptr<CUDAEventWrapper> eventLoadDone;
LayerOffloadHelper(bool offload, int numLayers, func_t funcCompute, func_t funcLoad, func_t funcUnload) LayerOffloadHelper(bool offload, int numLayers, func_t funcCompute, func_t funcLoad, func_t funcUnload)
: offload(offload), numLayers(numLayers), funcCompute(funcCompute), funcLoad(funcLoad), funcUnload(funcUnload) : offload(offload), numLayers(numLayers), funcCompute(funcCompute), funcLoad(funcLoad), funcUnload(funcUnload) {
{
if (offload) { if (offload) {
streamCompute = std::make_unique<CUDAStreamWrapper>(); streamCompute = std::make_unique<CUDAStreamWrapper>();
streamLoad = std::make_unique<CUDAStreamWrapper>(); streamLoad = std::make_unique<CUDAStreamWrapper>();
needWorkaround = checkWorkaround(); needWorkaround = checkWorkaround();
if (needWorkaround) { if (needWorkaround) {
...@@ -280,7 +277,7 @@ private: ...@@ -280,7 +277,7 @@ private:
} }
eventComputeDone = std::move(nextComputeDone); eventComputeDone = std::move(nextComputeDone);
eventLoadDone = std::move(nextLoadDone); eventLoadDone = std::move(nextLoadDone);
workaroundSynchronize(); workaroundSynchronize();
} }
...@@ -304,12 +301,12 @@ private: ...@@ -304,12 +301,12 @@ private:
return false; return false;
} }
} }
#ifdef _WIN32 #ifdef _WIN32
return true; return true;
#else #else
return false; return false;
#endif #endif
} }
void workaroundFlush() { void workaroundFlush() {
if (!needWorkaround) { if (!needWorkaround) {
...@@ -323,4 +320,4 @@ private: ...@@ -323,4 +320,4 @@ private:
} }
checkCUDA(cudaEventSynchronize(eventComputeDone->event)); checkCUDA(cudaEventSynchronize(eventComputeDone->event));
} }
}; };
\ No newline at end of file
...@@ -10,18 +10,11 @@ ...@@ -10,18 +10,11 @@
using spdlog::fmt_lib::format; using spdlog::fmt_lib::format;
using namespace nunchaku; using namespace nunchaku;
SanaLinearAttention::SanaLinearAttention(
SanaLinearAttention::SanaLinearAttention(int dim, bool bias, bool pag, bool use_fp4, Tensor::ScalarType dtype, Device device) : int dim, bool bias, bool pag, bool use_fp4, Tensor::ScalarType dtype, Device device)
dim(dim), : dim(dim), dim_pad(ceilDiv(dim, 128) * 128), qkv_proj(dim, dim_pad * 3, bias, use_fp4, dtype, device),
dim_pad(ceilDiv(dim, 128) * 128), out_proj(dim_pad, dim, bias, use_fp4, dtype, device), pag_to_v(std::nullopt) {
qkv_proj(dim, dim_pad * 3, bias, use_fp4, dtype, device), registerChildren(qkv_proj, "qkv_proj")(out_proj, "out_proj");
out_proj(dim_pad, dim, bias, use_fp4, dtype, device),
pag_to_v(std::nullopt)
{
registerChildren
(qkv_proj, "qkv_proj")
(out_proj, "out_proj")
;
if (pag) { if (pag) {
pag_to_v.emplace(dim, dim_pad, bias, use_fp4, dtype, device); pag_to_v.emplace(dim, dim_pad, bias, use_fp4, dtype, device);
...@@ -33,8 +26,8 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) { ...@@ -33,8 +26,8 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
constexpr int HEAD_DIM = 32; constexpr int HEAD_DIM = 32;
assert(x.ndims() == 3); assert(x.ndims() == 3);
const int batch_size = x.shape[0]; const int batch_size = x.shape[0];
const int num_tokens = x.shape[1]; const int num_tokens = x.shape[1];
const int num_tokens_pad = ceilDiv(num_tokens, 256) * 256; const int num_tokens_pad = ceilDiv(num_tokens, 256) * 256;
assert(x.shape[2] == dim); assert(x.shape[2] == dim);
...@@ -54,24 +47,38 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) { ...@@ -54,24 +47,38 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
auto qact = qkv_proj.quantize(x, false); auto qact = qkv_proj.quantize(x, false);
Tensor q = Tensor::allocate({batch_size, num_tokens_pad, dim_pad}, x.dtype(), x.device()); Tensor q = Tensor::allocate({batch_size, num_tokens_pad, dim_pad}, x.dtype(), x.device());
Tensor vk = Tensor::allocate({batch_size, num_heads, HEAD_DIM + 1, HEAD_DIM}, Tensor::FP32, x.device()); Tensor vk = Tensor::allocate({batch_size, num_heads, HEAD_DIM + 1, HEAD_DIM}, Tensor::FP32, x.device());
kernels::gemm_w4a4( kernels::gemm_w4a4(qact.act,
qact.act, qkv_proj.qweight,
qkv_proj.qweight, {},
{}, {},
{}, qact.ascales,
qact.ascales, qkv_proj.wscales,
qkv_proj.wscales, {},
{}, {}, qact.lora_act, qkv_proj.lora_up, {}, {}, {}, {}, {}, qkv_proj.bias, {}, {},
vk, q, qact.lora_act,
qact.is_unsigned, qkv_proj.lora_scales, false, qkv_proj.lora_up,
qkv_proj.use_fp4, {},
*qkv_proj.wtscale.data_ptr<float>(), {},
qkv_proj.wcscales.numel() > 0 ? qkv_proj.wcscales : Tensor{}, {},
{}, {}, {}, 0 {},
); {},
qkv_proj.bias,
{},
vk,
q,
qact.is_unsigned,
qkv_proj.lora_scales,
false,
qkv_proj.use_fp4,
*qkv_proj.wtscale.data_ptr<float>(),
qkv_proj.wcscales.numel() > 0 ? qkv_proj.wcscales : Tensor{},
{},
{},
{},
0);
debug("vk", vk); debug("vk", vk);
debug("q", q); debug("q", q);
...@@ -88,7 +95,6 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) { ...@@ -88,7 +95,6 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
q = q_unpad; q = q_unpad;
} }
// kernels::gemm_w8a8_fuse_litela(qact.act, qkv.qweight, q, vk, qact.ascales, qkv.wscales); // kernels::gemm_w8a8_fuse_litela(qact.act, qkv.qweight, q, vk, qact.ascales, qkv.wscales);
// return out_proj.forward(q); // return out_proj.forward(q);
...@@ -109,14 +115,14 @@ Tensor SanaLinearAttention::forward_pag(Tensor x, bool cfg) { ...@@ -109,14 +115,14 @@ Tensor SanaLinearAttention::forward_pag(Tensor x, bool cfg) {
if (cfg) { if (cfg) {
assert(batch_size % 3 == 0); assert(batch_size % 3 == 0);
x_org = x.slice(0, 0, batch_size * 2 / 3); x_org = x.slice(0, 0, batch_size * 2 / 3);
x_ptb = x.slice(0, batch_size * 2 / 3, batch_size); x_ptb = x.slice(0, batch_size * 2 / 3, batch_size);
out_org = out.slice(0, 0, batch_size * 2 / 3); out_org = out.slice(0, 0, batch_size * 2 / 3);
out_ptb = out.slice(0, batch_size * 2 / 3, batch_size); out_ptb = out.slice(0, batch_size * 2 / 3, batch_size);
} else { } else {
assert(batch_size % 2 == 0); assert(batch_size % 2 == 0);
x_org = x.slice(0, 0, batch_size / 2); x_org = x.slice(0, 0, batch_size / 2);
x_ptb = x.slice(0, batch_size / 2, batch_size); x_ptb = x.slice(0, batch_size / 2, batch_size);
out_org = out.slice(0, 0, batch_size / 2); out_org = out.slice(0, 0, batch_size / 2);
out_ptb = out.slice(0, batch_size / 2, batch_size); out_ptb = out.slice(0, batch_size / 2, batch_size);
} }
...@@ -129,17 +135,13 @@ Tensor SanaLinearAttention::forward_pag(Tensor x, bool cfg) { ...@@ -129,17 +135,13 @@ Tensor SanaLinearAttention::forward_pag(Tensor x, bool cfg) {
return out; return out;
} }
MultiHeadCrossAttention::MultiHeadCrossAttention(int num_heads, int head_dim, bool use_fp4, Tensor::ScalarType dtype, Device device) : MultiHeadCrossAttention::MultiHeadCrossAttention(
num_heads(num_heads), head_dim(head_dim), int num_heads, int head_dim, bool use_fp4, Tensor::ScalarType dtype, Device device)
q_linear(num_heads * head_dim, num_heads * head_dim, true, use_fp4, dtype, device), : num_heads(num_heads), head_dim(head_dim),
kv_linear(num_heads * head_dim, num_heads * head_dim * 2, true, dtype, device), q_linear(num_heads * head_dim, num_heads * head_dim, true, use_fp4, dtype, device),
out_proj(num_heads * head_dim, num_heads * head_dim, true, use_fp4, dtype, device) kv_linear(num_heads * head_dim, num_heads * head_dim * 2, true, dtype, device),
{ out_proj(num_heads * head_dim, num_heads * head_dim, true, use_fp4, dtype, device) {
registerChildren registerChildren(q_linear, "q_linear")(kv_linear, "kv_linear")(out_proj, "out_proj");
(q_linear, "q_linear")
(kv_linear, "kv_linear")
(out_proj, "out_proj")
;
} }
Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens_img, Tensor cu_seqlens_txt) { Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens_img, Tensor cu_seqlens_txt) {
...@@ -155,22 +157,28 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens ...@@ -155,22 +157,28 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens
assert(cu_seqlens_img.shape[0] == batch_size + 1); assert(cu_seqlens_img.shape[0] == batch_size + 1);
assert(cu_seqlens_txt.shape[0] == batch_size + 1); assert(cu_seqlens_txt.shape[0] == batch_size + 1);
Tensor q = q_linear.forward(x).view({batch_size * num_tokens_img, num_heads, head_dim}); Tensor q = q_linear.forward(x).view({batch_size * num_tokens_img, num_heads, head_dim});
Tensor kv = kv_linear.forward(cond).view({num_tokens_txt, num_heads * 2, head_dim}); Tensor kv = kv_linear.forward(cond).view({num_tokens_txt, num_heads * 2, head_dim});
Tensor k = kv.slice(1, 0, num_heads); Tensor k = kv.slice(1, 0, num_heads);
Tensor v = kv.slice(1, num_heads, num_heads * 2); Tensor v = kv.slice(1, num_heads, num_heads * 2);
Tensor attn_output = mha_varlen_fwd( Tensor attn_output = mha_varlen_fwd(q,
q, k, v, k,
cu_seqlens_img, cu_seqlens_txt, v,
num_tokens_img, num_tokens_txt, cu_seqlens_img,
0.0f, cu_seqlens_txt,
pow(q.shape[-1], (-0.5)), num_tokens_img,
false, false, num_tokens_txt,
-1, -1, 0.0f,
false pow(q.shape[-1], (-0.5)),
).front().view({batch_size, num_tokens_img, num_heads * head_dim}); false,
false,
-1,
-1,
false)
.front()
.view({batch_size, num_tokens_img, num_heads * head_dim});
// Tensor attn_output = mha_fwd(q, k, v, // Tensor attn_output = mha_fwd(q, k, v,
// 0.0f, // 0.0f,
...@@ -181,17 +189,13 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens ...@@ -181,17 +189,13 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens
return out_proj.forward(attn_output); return out_proj.forward(attn_output);
} }
SanaGLUMBConv::SanaGLUMBConv(int in_features, int hidden_features, bool use_fp4, Tensor::ScalarType dtype, Device device) : SanaGLUMBConv::SanaGLUMBConv(
in_features(in_features), hidden_features(hidden_features), int in_features, int hidden_features, bool use_fp4, Tensor::ScalarType dtype, Device device)
inverted_conv(in_features, hidden_features * 2, true, use_fp4, dtype, device), : in_features(in_features), hidden_features(hidden_features),
depth_conv(hidden_features * 2, true, dtype, device), inverted_conv(in_features, hidden_features * 2, true, use_fp4, dtype, device),
point_conv(hidden_features, in_features, false, use_fp4, dtype, device) depth_conv(hidden_features * 2, true, dtype, device),
{ point_conv(hidden_features, in_features, false, use_fp4, dtype, device) {
registerChildren registerChildren(inverted_conv, "inverted_conv")(depth_conv, "depth_conv")(point_conv, "point_conv");
(inverted_conv, "inverted_conv")
(depth_conv, "depth_conv")
(point_conv, "point_conv")
;
} }
Tensor SanaGLUMBConv::forward(Tensor x, int H, int W) { Tensor SanaGLUMBConv::forward(Tensor x, int H, int W) {
...@@ -203,33 +207,39 @@ Tensor SanaGLUMBConv::forward(Tensor x, int H, int W) { ...@@ -203,33 +207,39 @@ Tensor SanaGLUMBConv::forward(Tensor x, int H, int W) {
debug("inverted_conv_output", x); debug("inverted_conv_output", x);
x = depth_conv.forward(x); x = depth_conv.forward(x);
debug("depth_conv_output", x); debug("depth_conv_output", x);
x = x.view({x.shape[0], H * W, x.shape[-1]}); x = x.view({x.shape[0], H * W, x.shape[-1]});
auto qact = point_conv.quantize(x, true); auto qact = point_conv.quantize(x, true);
return point_conv.forward_quant(qact); return point_conv.forward_quant(qact);
} }
SanaLinearTransformerBlock::SanaLinearTransformerBlock(int hidden_size, int intermediate_size, int num_cross_attention_heads, bool pag, bool use_fp4, Tensor::ScalarType dtype, Device device) : SanaLinearTransformerBlock::SanaLinearTransformerBlock(int hidden_size,
hidden_size(hidden_size), num_cross_attention_heads(num_cross_attention_heads), int intermediate_size,
attn(hidden_size, false, pag, use_fp4, dtype, device), int num_cross_attention_heads,
cross_attn(num_cross_attention_heads, hidden_size / num_cross_attention_heads, use_fp4, dtype, device), bool pag,
ff(hidden_size, intermediate_size, use_fp4, dtype, device), bool use_fp4,
norm1(hidden_size, 1e-6, false, dtype, device), Tensor::ScalarType dtype,
norm2(hidden_size, 1e-6, false, dtype, device) Device device)
{ : hidden_size(hidden_size), num_cross_attention_heads(num_cross_attention_heads),
attn(hidden_size, false, pag, use_fp4, dtype, device),
cross_attn(num_cross_attention_heads, hidden_size / num_cross_attention_heads, use_fp4, dtype, device),
ff(hidden_size, intermediate_size, use_fp4, dtype, device), norm1(hidden_size, 1e-6, false, dtype, device),
norm2(hidden_size, 1e-6, false, dtype, device) {
this->scale_shift_table = Tensor::allocate({6, hidden_size}, dtype, device); this->scale_shift_table = Tensor::allocate({6, hidden_size}, dtype, device);
registerChildren registerChildren(attn, "attn")(cross_attn, "cross_attn")(ff, "ff");
(attn, "attn")
(cross_attn, "cross_attn")
(ff, "ff")
;
registerParams registerParams(this->scale_shift_table, "scale_shift_table");
(this->scale_shift_table, "scale_shift_table")
;
} }
Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor timestep, Tensor cu_seqlens_img, Tensor cu_seqlens_txt, int H, int W, bool pag, bool cfg) { Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor timestep,
Tensor cu_seqlens_img,
Tensor cu_seqlens_txt,
int H,
int W,
bool pag,
bool cfg) {
nvtxRangePushA("SanaLinearTransformerBlock"); nvtxRangePushA("SanaLinearTransformerBlock");
...@@ -257,7 +267,7 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_ ...@@ -257,7 +267,7 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_
{ {
nvtxRangePushA("LinearAttention"); nvtxRangePushA("LinearAttention");
Tensor residual = hidden_states; Tensor residual = hidden_states;
Tensor norm_hidden_states = norm1.forward(hidden_states); Tensor norm_hidden_states = norm1.forward(hidden_states);
kernels::mul_add_batch(norm_hidden_states, scale_msa, true, 1, shift_msa, true); kernels::mul_add_batch(norm_hidden_states, scale_msa, true, 1, shift_msa, true);
debug("norm_hidden_states_la", norm_hidden_states); debug("norm_hidden_states_la", norm_hidden_states);
...@@ -311,9 +321,7 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_ ...@@ -311,9 +321,7 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_
return hidden_states; return hidden_states;
} }
SanaModel::SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device) : SanaModel::SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device) : config(config) {
config(config)
{
const int inner_dim = config.num_attention_heads * config.attention_head_dim; const int inner_dim = config.num_attention_heads * config.attention_head_dim;
for (int i = 0; i < config.num_layers; i++) { for (int i = 0; i < config.num_layers; i++) {
transformer_blocks.push_back(std::make_unique<SanaLinearTransformerBlock>( transformer_blocks.push_back(std::make_unique<SanaLinearTransformerBlock>(
...@@ -322,20 +330,34 @@ SanaModel::SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device) ...@@ -322,20 +330,34 @@ SanaModel::SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device)
config.num_cross_attention_heads, config.num_cross_attention_heads,
std::find(config.pag_layers.begin(), config.pag_layers.end(), i) != config.pag_layers.end(), std::find(config.pag_layers.begin(), config.pag_layers.end(), i) != config.pag_layers.end(),
config.use_fp4, config.use_fp4,
dtype, device dtype,
)); device));
registerChildren(*transformer_blocks.back(), format("transformer_blocks.{}", i)); registerChildren(*transformer_blocks.back(), format("transformer_blocks.{}", i));
} }
} }
Tensor SanaModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor timestep, Tensor cu_seqlens_img, Tensor cu_seqlens_txt, int H, int W, bool pag, bool cfg, bool skip_first_layer) { Tensor SanaModel::forward(Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor timestep,
Tensor cu_seqlens_img,
Tensor cu_seqlens_txt,
int H,
int W,
bool pag,
bool cfg,
bool skip_first_layer) {
for (int i = (skip_first_layer ? 1 : 0); i < config.num_layers; i++) { for (int i = (skip_first_layer ? 1 : 0); i < config.num_layers; i++) {
auto &&block = transformer_blocks[i]; auto &&block = transformer_blocks[i];
hidden_states = block->forward( hidden_states = block->forward(hidden_states,
hidden_states, encoder_hidden_states, timestep, cu_seqlens_img, cu_seqlens_txt, H, W, encoder_hidden_states,
pag && std::find(config.pag_layers.begin(), config.pag_layers.end(), i) != config.pag_layers.end(), timestep,
cfg cu_seqlens_img,
); cu_seqlens_txt,
H,
W,
pag && std::find(config.pag_layers.begin(), config.pag_layers.end(), i) !=
config.pag_layers.end(),
cfg);
} }
return hidden_states; return hidden_states;
} }
...@@ -35,7 +35,7 @@ public: ...@@ -35,7 +35,7 @@ public:
private: private:
GEMM_W4A4 q_linear; GEMM_W4A4 q_linear;
GEMM_F16 kv_linear; GEMM_F16 kv_linear;
GEMM_W4A4 out_proj; GEMM_W4A4 out_proj;
}; };
...@@ -57,9 +57,23 @@ private: ...@@ -57,9 +57,23 @@ private:
class SanaLinearTransformerBlock : public Module { class SanaLinearTransformerBlock : public Module {
public: public:
SanaLinearTransformerBlock(int hidden_size, int intermediate_size, int num_cross_attention_heads, bool pag, bool use_fp4, Tensor::ScalarType dtype, Device device); SanaLinearTransformerBlock(int hidden_size,
int intermediate_size,
Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor timestep, Tensor cu_seqlens_img, Tensor cu_seqlens_txt, int H, int W, bool pag, bool cfg); int num_cross_attention_heads,
bool pag,
bool use_fp4,
Tensor::ScalarType dtype,
Device device);
Tensor forward(Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor timestep,
Tensor cu_seqlens_img,
Tensor cu_seqlens_txt,
int H,
int W,
bool pag,
bool cfg);
public: public:
const int hidden_size; const int hidden_size;
...@@ -89,11 +103,20 @@ struct SanaConfig { ...@@ -89,11 +103,20 @@ struct SanaConfig {
class SanaModel : public Module { class SanaModel : public Module {
public: public:
SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device); SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor timestep, Tensor cu_seqlens_img, Tensor cu_seqlens_txt, int H, int W, bool pag, bool cfg, bool skip_first_layer); Tensor forward(Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor timestep,
Tensor cu_seqlens_img,
Tensor cu_seqlens_txt,
int H,
int W,
bool pag,
bool cfg,
bool skip_first_layer);
public: public:
const SanaConfig config; const SanaConfig config;
public: public:
std::vector<std::unique_ptr<SanaLinearTransformerBlock>> transformer_blocks; std::vector<std::unique_ptr<SanaLinearTransformerBlock>> transformer_blocks;
}; };
\ No newline at end of file
...@@ -3,14 +3,13 @@ ...@@ -3,14 +3,13 @@
#include <nlohmann/json.hpp> #include <nlohmann/json.hpp>
#include <mio/mmap.hpp> #include <mio/mmap.hpp>
using json = nlohmann::json; using json = nlohmann::json;
using spdlog::fmt_lib::format; using spdlog::fmt_lib::format;
class SafeTensors::MMapImpl { class SafeTensors::MMapImpl {
public: public:
virtual ~MMapImpl() {} virtual ~MMapImpl() {}
virtual size_t size() = 0; virtual size_t size() = 0;
virtual const char *data() = 0; virtual const char *data() = 0;
}; };
...@@ -55,7 +54,7 @@ private: ...@@ -55,7 +54,7 @@ private:
std::unique_ptr<Buffer> buffer; std::unique_ptr<Buffer> buffer;
}; };
#ifdef __linux__ #ifdef __linux__
#include <unistd.h> #include <unistd.h>
#include <fcntl.h> #include <fcntl.h>
...@@ -97,7 +96,7 @@ private: ...@@ -97,7 +96,7 @@ private:
void *ptr; void *ptr;
}; };
#else #else
class SafeTensors::MMapImplPrivate : public SafeTensors::MMapImpl { class SafeTensors::MMapImplPrivate : public SafeTensors::MMapImpl {
public: public:
...@@ -117,33 +116,34 @@ public: ...@@ -117,33 +116,34 @@ public:
SafeTensors::SafeTensors(const std::string &filename) { SafeTensors::SafeTensors(const std::string &filename) {
this->hostRegistered = false; this->hostRegistered = false;
this->memoryPinned = false; this->memoryPinned = false;
auto methodPrivate = [&]() { auto methodPrivate = [&]() {
this->mapped = std::make_unique<MMapImplPrivate>(filename); this->mapped = std::make_unique<MMapImplPrivate>(filename);
checkCUDA(cudaHostRegister(const_cast<char *>(this->mapped->data()), this->mapped->size(), cudaHostRegisterPortable)); checkCUDA(
cudaHostRegister(const_cast<char *>(this->mapped->data()), this->mapped->size(), cudaHostRegisterPortable));
this->hostRegistered = true; this->hostRegistered = true;
this->memoryPinned = true; this->memoryPinned = true;
}; };
auto methodMio = [&]() { auto methodMio = [&]() {
this->mapped = std::make_unique<MMapImplMio>(filename); this->mapped = std::make_unique<MMapImplMio>(filename);
checkCUDA(cudaHostRegister(const_cast<char *>(this->mapped->data()), this->mapped->size(), cudaHostRegisterPortable | cudaHostRegisterReadOnly)); checkCUDA(cudaHostRegister(const_cast<char *>(this->mapped->data()),
this->mapped->size(),
cudaHostRegisterPortable | cudaHostRegisterReadOnly));
this->hostRegistered = true; this->hostRegistered = true;
this->memoryPinned = true; this->memoryPinned = true;
}; };
auto methodRead = [&]() { auto methodRead = [&]() {
this->mapped = std::make_unique<MMapImplRead>(filename, true); this->mapped = std::make_unique<MMapImplRead>(filename, true);
this->memoryPinned = true; this->memoryPinned = true;
}; };
auto methodReadNopin = [&]() { auto methodReadNopin = [&]() { this->mapped = std::make_unique<MMapImplRead>(filename, false); };
this->mapped = std::make_unique<MMapImplRead>(filename, false);
};
const std::map<std::string, std::function<void()>> methods = { const std::map<std::string, std::function<void()>> methods = {
{ "PRIVATE", methodPrivate }, {"PRIVATE", methodPrivate},
{ "MIO", methodMio }, {"MIO", methodMio},
{ "READ", methodRead }, {"READ", methodRead},
{ "READNOPIN", methodReadNopin }, {"READNOPIN", methodReadNopin},
}; };
auto tryMethod = [&](std::string name) { auto tryMethod = [&](std::string name) {
...@@ -168,7 +168,6 @@ SafeTensors::SafeTensors(const std::string &filename) { ...@@ -168,7 +168,6 @@ SafeTensors::SafeTensors(const std::string &filename) {
#else #else
tryMethod("MIO") || tryMethod("READ") || tryMethod("READNOPIN"); tryMethod("MIO") || tryMethod("READ") || tryMethod("READNOPIN");
#endif #endif
} }
if (!this->mapped) { if (!this->mapped) {
...@@ -192,19 +191,20 @@ SafeTensors::~SafeTensors() { ...@@ -192,19 +191,20 @@ SafeTensors::~SafeTensors() {
void SafeTensors::parseHeader() { void SafeTensors::parseHeader() {
static const std::unordered_map<std::string, Tensor::ScalarType> mapDType = { static const std::unordered_map<std::string, Tensor::ScalarType> mapDType = {
{ "BF16", Tensor::BF16 }, {"BF16", Tensor::BF16},
{ "F16", Tensor::FP16 }, {"F16", Tensor::FP16},
{ "F32", Tensor::FP32 }, {"F32", Tensor::FP32},
{ "I8", Tensor::INT8 }, {"I8", Tensor::INT8},
{ "I32", Tensor::INT32 }, {"I32", Tensor::INT32},
{ "I64", Tensor::INT64 }, {"I64", Tensor::INT64},
{ "F8_E4M3", Tensor::FP8_E4M3 }, {"F8_E4M3", Tensor::FP8_E4M3},
{ "F8_E5M2", Tensor::FP8_E5M2 }, {"F8_E5M2", Tensor::FP8_E5M2},
}; };
auto check = [](bool cond, std::source_location location = std::source_location::current()) { auto check = [](bool cond, std::source_location location = std::source_location::current()) {
if (!cond) { if (!cond) {
throw std::runtime_error(format("Safetensors check failed at {}:{}", location.file_name(), location.line())); throw std::runtime_error(
format("Safetensors check failed at {}:{}", location.file_name(), location.line()));
} }
}; };
...@@ -222,8 +222,9 @@ void SafeTensors::parseHeader() { ...@@ -222,8 +222,9 @@ void SafeTensors::parseHeader() {
continue; continue;
} }
auto dtype = mapDType.at(info["dtype"].get<std::string>());; auto dtype = mapDType.at(info["dtype"].get<std::string>());
auto shape = info["shape"].get<std::vector<int>>(); ;
auto shape = info["shape"].get<std::vector<int>>();
auto data_offsets = info["data_offsets"].get<std::vector<uint64_t>>(); auto data_offsets = info["data_offsets"].get<std::vector<uint64_t>>();
check(data_offsets.size() == 2); check(data_offsets.size() == 2);
...@@ -235,8 +236,8 @@ void SafeTensors::parseHeader() { ...@@ -235,8 +236,8 @@ void SafeTensors::parseHeader() {
} }
TensorInfo tinfo; TensorInfo tinfo;
tinfo.type = dtype; tinfo.type = dtype;
tinfo.shape = TensorShape(shape); tinfo.shape = TensorShape(shape);
tinfo.length = data_offsets[1] - data_offsets[0]; tinfo.length = data_offsets[1] - data_offsets[0];
tinfo.offset = 8 + sizeHeader + data_offsets[0]; tinfo.offset = 8 + sizeHeader + data_offsets[0];
...@@ -258,15 +259,15 @@ Tensor SafeTensors::getTensor(const std::string &key) { ...@@ -258,15 +259,15 @@ Tensor SafeTensors::getTensor(const std::string &key) {
std::shared_ptr<BufferMMap> buffer = info.buffer.lock(); std::shared_ptr<BufferMMap> buffer = info.buffer.lock();
if (!buffer) { if (!buffer) {
buffer = std::make_shared<BufferMMap>(const_cast<char *>(this->mapped->data() + info.offset), info.length, shared_from_this()); buffer = std::make_shared<BufferMMap>(
const_cast<char *>(this->mapped->data() + info.offset), info.length, shared_from_this());
info.buffer = buffer; info.buffer = buffer;
} }
Tensor result; Tensor result;
result.shape = info.shape; result.shape = info.shape;
result.scalarType = info.type; result.scalarType = info.type;
result.buffer = buffer; result.buffer = buffer;
return result; return result;
} }
...@@ -6,15 +6,15 @@ ...@@ -6,15 +6,15 @@
class BufferMMap : public Buffer { class BufferMMap : public Buffer {
public: public:
BufferMMap(void *ptr, size_t size, std::shared_ptr<void> parent) : parent(parent) { BufferMMap(void *ptr, size_t size, std::shared_ptr<void> parent) : parent(parent) {
this->size = size; this->size = size;
this->device.type = Device::CPU; this->device.type = Device::CPU;
this->ptr = ptr; this->ptr = ptr;
// auto ret = cudaHostRegister(ptr, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly); // auto ret = cudaHostRegister(ptr, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
// if (ret == cudaSuccess) { // if (ret == cudaSuccess) {
// this->registered = true; // this->registered = true;
// } else { // } else {
// log(std::format("cudaHostRegister failed at {:p} (size={}): {}", ptr, size, cudaGetErrorString(cudaGetLastError()))); // log(std::format("cudaHostRegister failed at {:p} (size={}): {}", ptr, size,
// this->registered = false; // cudaGetErrorString(cudaGetLastError()))); this->registered = false;
// } // }
} }
virtual ~BufferMMap() { virtual ~BufferMMap() {
...@@ -22,6 +22,7 @@ public: ...@@ -22,6 +22,7 @@ public:
// checkCUDA(cudaHostUnregister(ptr)); // checkCUDA(cudaHostUnregister(ptr));
// } // }
} }
public: public:
std::shared_ptr<void> parent; std::shared_ptr<void> parent;
// bool registered; // bool registered;
...@@ -32,7 +33,7 @@ public: ...@@ -32,7 +33,7 @@ public:
SafeTensors(const std::string &filename); SafeTensors(const std::string &filename);
~SafeTensors(); ~SafeTensors();
virtual bool contains(const std::string &key) const override { virtual bool contains(const std::string &key) const override {
return tensors.contains(key); return tensors.contains(key);
} }
virtual Tensor getTensor(const std::string &key) override; virtual Tensor getTensor(const std::string &key) override;
...@@ -57,4 +58,4 @@ private: ...@@ -57,4 +58,4 @@ private:
std::unique_ptr<MMapImpl> mapped; std::unique_ptr<MMapImpl> mapped;
bool hostRegistered, memoryPinned; bool hostRegistered, memoryPinned;
}; };
\ No newline at end of file
...@@ -3,13 +3,10 @@ ...@@ -3,13 +3,10 @@
#include "common.h" #include "common.h"
struct Device { struct Device {
enum Type { enum Type { INVALID_DEVICE_TYPE = 0, CPU, CUDA };
INVALID_DEVICE_TYPE = 0,
CPU, CUDA
};
Type type = INVALID_DEVICE_TYPE; Type type = INVALID_DEVICE_TYPE;
int idx = 0; int idx = 0;
static constexpr Device cpu(int idx = 0) { static constexpr Device cpu(int idx = 0) {
return Device{CPU, idx}; return Device{CPU, idx};
...@@ -23,21 +20,29 @@ struct Device { ...@@ -23,21 +20,29 @@ struct Device {
class Buffer : public std::enable_shared_from_this<Buffer> { class Buffer : public std::enable_shared_from_this<Buffer> {
public: public:
virtual ~Buffer() {} virtual ~Buffer() {}
void *getPtr() { return ptr; } void *getPtr() {
return ptr;
}
template<typename T> template<typename T>
T *getPtr() { return reinterpret_cast<T *>(ptr); } T *getPtr() {
return reinterpret_cast<T *>(ptr);
}
size_t getSize() { return size; } size_t getSize() {
Device getDevice() { return device; } return size;
}
Device getDevice() {
return device;
}
virtual bool isAsyncBuffer() { virtual bool isAsyncBuffer() {
return false; return false;
} }
protected: protected:
template <typename Derived> template<typename Derived>
std::shared_ptr<Derived> shared_from_base() { std::shared_ptr<Derived> shared_from_base() {
return std::static_pointer_cast<Derived>(shared_from_this()); return std::static_pointer_cast<Derived>(shared_from_this());
} }
...@@ -55,9 +60,9 @@ protected: ...@@ -55,9 +60,9 @@ protected:
class BufferMalloc : public Buffer { class BufferMalloc : public Buffer {
public: public:
BufferMalloc(size_t size) { BufferMalloc(size_t size) {
this->size = size; this->size = size;
this->device.type = Device::CPU; this->device.type = Device::CPU;
this->ptr = malloc(size); this->ptr = malloc(size);
} }
virtual ~BufferMalloc() { virtual ~BufferMalloc() {
free(this->ptr); free(this->ptr);
...@@ -67,7 +72,7 @@ public: ...@@ -67,7 +72,7 @@ public:
class BufferHost : public Buffer { class BufferHost : public Buffer {
public: public:
BufferHost(size_t size) { BufferHost(size_t size) {
this->size = size; this->size = size;
this->device.type = Device::CPU; this->device.type = Device::CPU;
checkCUDA(cudaHostAlloc(&this->ptr, size, cudaHostAllocPortable)); checkCUDA(cudaHostAlloc(&this->ptr, size, cudaHostAllocPortable));
} }
...@@ -79,7 +84,7 @@ public: ...@@ -79,7 +84,7 @@ public:
class BufferCUDA : public Buffer { class BufferCUDA : public Buffer {
public: public:
BufferCUDA(size_t size) { BufferCUDA(size_t size) {
this->size = size; this->size = size;
this->device.type = Device::CUDA; this->device.type = Device::CUDA;
// checkCUDA(cudaGetDevice(&this->device.idx)); // checkCUDA(cudaGetDevice(&this->device.idx));
this->device.idx = CUDADeviceContext::getDevice(); this->device.idx = CUDADeviceContext::getDevice();
...@@ -96,7 +101,7 @@ public: ...@@ -96,7 +101,7 @@ public:
} }
checkCUDA(cudaFreeAsync(this->ptr, getCurrentCUDAStream())); checkCUDA(cudaFreeAsync(this->ptr, getCurrentCUDAStream()));
} }
virtual bool isAsyncBuffer() override { virtual bool isAsyncBuffer() override {
return true; return true;
} }
}; };
...@@ -104,7 +109,7 @@ public: ...@@ -104,7 +109,7 @@ public:
class BufferCUDASync : public Buffer { class BufferCUDASync : public Buffer {
public: public:
BufferCUDASync(size_t size) { BufferCUDASync(size_t size) {
this->size = size; this->size = size;
this->device.type = Device::CUDA; this->device.type = Device::CUDA;
checkCUDA(cudaGetDevice(&this->device.idx)); checkCUDA(cudaGetDevice(&this->device.idx));
checkCUDA(cudaMalloc(&this->ptr, size)); checkCUDA(cudaMalloc(&this->ptr, size));
...@@ -118,8 +123,8 @@ class BufferView : public Buffer { ...@@ -118,8 +123,8 @@ class BufferView : public Buffer {
public: public:
BufferView(std::shared_ptr<Buffer> reference, size_t offset, size_t size) : reference(reference) { BufferView(std::shared_ptr<Buffer> reference, size_t offset, size_t size) : reference(reference) {
assert(offset + size <= reference->getSize()); assert(offset + size <= reference->getSize());
this->ptr = (void *)((std::uint8_t *)reference->getPtr() + offset); this->ptr = (void *)((std::uint8_t *)reference->getPtr() + offset);
this->size = size; this->size = size;
this->device = reference->getDevice(); this->device = reference->getDevice();
} }
...@@ -213,23 +218,31 @@ struct TensorShape { ...@@ -213,23 +218,31 @@ struct TensorShape {
} }
}; };
class Tensor { class Tensor {
public: public:
enum ScalarType { enum ScalarType {
INVALID_SCALAR_TYPE, INVALID_SCALAR_TYPE,
INT8, INT16, INT32, INT64, INT8,
FP16, FP32, BF16, INT16,
FP8_E4M3, FP8_E5M2, INT32,
INT64,
FP16,
FP32,
BF16,
FP8_E4M3,
FP8_E5M2,
}; };
struct TensorOptions { struct TensorOptions {
Device device_; Device device_;
ScalarType dtype_; ScalarType dtype_;
Device device() const { return device_; } Device device() const {
ScalarType dtype() const { return dtype_; } return device_;
}
ScalarType dtype() const {
return dtype_;
}
TensorOptions device(Device dev) const { TensorOptions device(Device dev) const {
TensorOptions result(*this); TensorOptions result(*this);
...@@ -244,56 +257,95 @@ public: ...@@ -244,56 +257,95 @@ public:
}; };
static const std::map<ScalarType, size_t> scalarSize; static const std::map<ScalarType, size_t> scalarSize;
public: public:
TensorShape shape; TensorShape shape;
ScalarType scalarType; ScalarType scalarType;
std::shared_ptr<Buffer> buffer; std::shared_ptr<Buffer> buffer;
public: public:
bool valid() const { return shape.dataExtent.size() > 0; } bool valid() const {
int size(int dim) const { return shape[dim]; } return shape.dataExtent.size() > 0;
bool is_contiguous() const { return shape.is_contiguous(); } }
std::vector<int> sizes() const { return shape.dataExtent; } int size(int dim) const {
return shape[dim];
}
bool is_contiguous() const {
return shape.is_contiguous();
}
std::vector<int> sizes() const {
return shape.dataExtent;
}
bool is_cuda() const { return device().type == Device::CUDA; } bool is_cuda() const {
return device().type == Device::CUDA;
}
TensorOptions options() const { return TensorOptions{device(), dtype()}; } TensorOptions options() const {
int get_device() const { return device().idx; } return TensorOptions{device(), dtype()};
}
int get_device() const {
return device().idx;
}
template<typename T> template<typename T>
T *data_ptr() { return reinterpret_cast<T*>(data_ptr()); } T *data_ptr() {
return reinterpret_cast<T *>(data_ptr());
}
template<typename T> template<typename T>
const T *data_ptr() const { return reinterpret_cast<const T*>(data_ptr()); } const T *data_ptr() const {
return reinterpret_cast<const T *>(data_ptr());
const void *data_ptr() const { return buffer->getPtr<char>() + shape.offset * scalar_size(); } }
void *data_ptr() { return buffer->getPtr<char>() + shape.offset * scalar_size(); }
Device device() const { return buffer->getDevice(); } const void *data_ptr() const {
return buffer->getPtr<char>() + shape.offset * scalar_size();
}
void *data_ptr() {
return buffer->getPtr<char>() + shape.offset * scalar_size();
}
ScalarType scalar_type() const { return scalarType; } Device device() const {
ScalarType dtype() const { return scalar_type(); } return buffer->getDevice();
}
ScalarType scalar_type() const {
return scalarType;
}
ScalarType dtype() const {
return scalar_type();
}
size_t stride(int dim) const { return shape.stride(dim); } size_t stride(int dim) const {
return shape.stride(dim);
}
size_t numel() const { return shape.size(); } size_t numel() const {
size_t ndims() const { return shape.ndims(); } return shape.size();
}
size_t ndims() const {
return shape.ndims();
}
size_t dim() const { return ndims(); } size_t dim() const {
return ndims();
}
size_t scalar_size() const { return scalarSize.at(scalarType); } size_t scalar_size() const {
return scalarSize.at(scalarType);
}
Tensor operator[](int idx) const { Tensor operator[](int idx) const {
assert(ndims() > 1); assert(ndims() > 1);
Tensor result; Tensor result;
result.shape = std::vector<int>(this->shape.dataExtent.begin() + 1, this->shape.dataExtent.end()); result.shape = std::vector<int>(this->shape.dataExtent.begin() + 1, this->shape.dataExtent.end());
size_t size = stride(0) * scalar_size(); size_t size = stride(0) * scalar_size();
result.buffer = std::make_shared<BufferView>(this->buffer, idx * size, size); result.buffer = std::make_shared<BufferView>(this->buffer, idx * size, size);
result.scalarType = this->scalarType; result.scalarType = this->scalarType;
return result; return result;
} }
template<typename T> template<typename T>
const T & at(const std::vector<int> &idx) const { const T &at(const std::vector<int> &idx) const {
assert(ndims() == idx.size()); assert(ndims() == idx.size());
int64_t offset = 0; int64_t offset = 0;
for (size_t i = 0; i < ndims(); i++) { for (size_t i = 0; i < ndims(); i++) {
...@@ -304,17 +356,17 @@ public: ...@@ -304,17 +356,17 @@ public:
} }
template<typename T> template<typename T>
T & at(const std::vector<int> &idx) { T &at(const std::vector<int> &idx) {
return const_cast<T &>(const_cast<const Tensor *>(this)->at<T>(idx)); return const_cast<T &>(const_cast<const Tensor *>(this)->at<T>(idx));
} }
Tensor slice(int dim, int from, int to) const { Tensor slice(int dim, int from, int to) const {
assert(from <= to); assert(from <= to);
Tensor result; Tensor result;
result.buffer = this->buffer; result.buffer = this->buffer;
result.scalarType = this->scalarType; result.scalarType = this->scalarType;
result.shape = TensorShape(this->shape.dataExtent); result.shape = TensorShape(this->shape.dataExtent);
result.shape[dim] = to - from; result.shape[dim] = to - from;
result.shape.dataStride.resize(result.shape.ndims()); result.shape.dataStride.resize(result.shape.ndims());
for (int i = 0; i < result.shape.ndims(); i++) { for (int i = 0; i < result.shape.ndims(); i++) {
...@@ -326,7 +378,7 @@ public: ...@@ -326,7 +378,7 @@ public:
} }
Tensor transpose(int dim1, int dim2) const { Tensor transpose(int dim1, int dim2) const {
Tensor result; Tensor result;
result.buffer = this->buffer; result.buffer = this->buffer;
result.scalarType = this->scalarType; result.scalarType = this->scalarType;
result.shape = TensorShape(this->shape.dataExtent); result.shape = TensorShape(this->shape.dataExtent);
...@@ -346,9 +398,9 @@ public: ...@@ -346,9 +398,9 @@ public:
assert(shape.size() == this->shape.size()); assert(shape.size() == this->shape.size());
assert(this->is_contiguous()); assert(this->is_contiguous());
Tensor result; Tensor result;
result.buffer = this->buffer; result.buffer = this->buffer;
result.scalarType = this->scalarType; result.scalarType = this->scalarType;
result.shape = shape; result.shape = shape;
result.shape.offset = this->shape.offset; result.shape.offset = this->shape.offset;
return result; return result;
} }
...@@ -363,7 +415,8 @@ public: ...@@ -363,7 +415,8 @@ public:
Tensor &zero_() { Tensor &zero_() {
assert(this->is_contiguous()); assert(this->is_contiguous());
checkCUDA(cudaMemsetAsync(data_ptr<char>() + shape.offset * scalar_size(), 0, shape.size() * scalar_size(), getCurrentCUDAStream())); checkCUDA(cudaMemsetAsync(
data_ptr<char>() + shape.offset * scalar_size(), 0, shape.size() * scalar_size(), getCurrentCUDAStream()));
return *this; return *this;
} }
Tensor &copy_(Tensor other) { Tensor &copy_(Tensor other) {
...@@ -380,23 +433,17 @@ public: ...@@ -380,23 +433,17 @@ public:
} }
if (this->device().type == Device::CPU && other.device().type == Device::CPU) { if (this->device().type == Device::CPU && other.device().type == Device::CPU) {
memcpy( memcpy(data_ptr<char>(), other.data_ptr<char>(), shape.size() * scalar_size());
data_ptr<char>(),
other.data_ptr<char>(),
shape.size() * scalar_size()
);
return *this; return *this;
} }
lockBuffer(this->buffer, getCurrentCUDAStream()); lockBuffer(this->buffer, getCurrentCUDAStream());
lockBuffer(other.buffer, getCurrentCUDAStream()); lockBuffer(other.buffer, getCurrentCUDAStream());
checkCUDA(cudaMemcpyAsync( checkCUDA(cudaMemcpyAsync(data_ptr<char>(),
data_ptr<char>(), other.data_ptr<char>(),
other.data_ptr<char>(), shape.size() * scalar_size(),
shape.size() * scalar_size(), getCopyKind(this->device(), other.device()),
getCopyKind(this->device(), other.device()), getCurrentCUDAStream()));
getCurrentCUDAStream()
));
return *this; return *this;
} }
...@@ -425,14 +472,15 @@ public: ...@@ -425,14 +472,15 @@ public:
assert(false); assert(false);
} }
result.scalarType = scalarType; result.scalarType = scalarType;
result.shape = shape; result.shape = shape;
if (fill) { if (fill) {
if (device.type == Device::CPU) { if (device.type == Device::CPU) {
memset(result.buffer->getPtr(), 0xCC, result.buffer->getSize()); memset(result.buffer->getPtr(), 0xCC, result.buffer->getSize());
} else if (device.type == Device::CUDA) { } else if (device.type == Device::CUDA) {
CUDADeviceContext ctx(device.idx); CUDADeviceContext ctx(device.idx);
checkCUDA(cudaMemsetAsync(result.buffer->getPtr(), 0xCC, result.buffer->getSize(), getCurrentCUDAStream())); checkCUDA(
cudaMemsetAsync(result.buffer->getPtr(), 0xCC, result.buffer->getSize(), getCurrentCUDAStream()));
} }
} }
...@@ -450,11 +498,12 @@ public: ...@@ -450,11 +498,12 @@ public:
checkCUDA(cudaMemsetAsync(result.buffer->getPtr(), 1, result.buffer->getSize(), getCurrentCUDAStream())); checkCUDA(cudaMemsetAsync(result.buffer->getPtr(), 1, result.buffer->getSize(), getCurrentCUDAStream()));
return result; return result;
} }
static Tensor allocate_view(TensorShape shape, ScalarType scalarType, std::shared_ptr<Buffer> buffer, size_t offset = 0) { static Tensor
allocate_view(TensorShape shape, ScalarType scalarType, std::shared_ptr<Buffer> buffer, size_t offset = 0) {
Tensor result; Tensor result;
result.buffer = std::make_shared<BufferView>(buffer, offset, shape.size() * scalarSize.at(scalarType)); result.buffer = std::make_shared<BufferView>(buffer, offset, shape.size() * scalarSize.at(scalarType));
result.scalarType = scalarType; result.scalarType = scalarType;
result.shape = shape; result.shape = shape;
return result; return result;
} }
...@@ -468,13 +517,16 @@ public: ...@@ -468,13 +517,16 @@ public:
// lockBuffer(this->buffer, getCurrentCUDAStream()); // lockBuffer(this->buffer, getCurrentCUDAStream());
// lockBuffer(result.buffer, getCurrentCUDAStream()); // lockBuffer(result.buffer, getCurrentCUDAStream());
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(), cudaMemcpyDefault, getCurrentCUDAStream())); // checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(), cudaMemcpyDefault,
// if (this->device().type == Device::CPU && device.type == Device::CUDA) { // getCurrentCUDAStream())); if (this->device().type == Device::CPU && device.type == Device::CUDA) {
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(), cudaMemcpyHostToDevice, getCurrentCUDAStream())); // checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(),
// cudaMemcpyHostToDevice, getCurrentCUDAStream()));
// } else if (this->device().type == Device::CUDA && device.type == Device::CPU) { // } else if (this->device().type == Device::CUDA && device.type == Device::CPU) {
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(), cudaMemcpyDeviceToHost, getCurrentCUDAStream())); // checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(),
// cudaMemcpyDeviceToHost, getCurrentCUDAStream()));
// } else { // } else {
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(), cudaMemcpyDefault, getCurrentCUDAStream())); // checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(),
// cudaMemcpyDefault, getCurrentCUDAStream()));
// } // }
return result; return result;
} }
...@@ -516,9 +568,10 @@ private: ...@@ -516,9 +568,10 @@ private:
// } // }
static inline std::map<cudaStream_t, std::set<std::shared_ptr<Buffer>>> lockedBuffers; static inline std::map<cudaStream_t, std::set<std::shared_ptr<Buffer>>> lockedBuffers;
public: public:
// before launching an async operation, make sure to lock the buffer in case the buffer is freed before GPU completes // before launching an async operation, make sure to lock the buffer in case the buffer is freed before GPU
// completes
static void lockBuffer(std::shared_ptr<Buffer> buffer, cudaStream_t stream) { static void lockBuffer(std::shared_ptr<Buffer> buffer, cudaStream_t stream) {
if (!buffer->isAsyncBuffer()) { if (!buffer->isAsyncBuffer()) {
lockedBuffers[stream].insert(buffer); lockedBuffers[stream].insert(buffer);
...@@ -558,5 +611,5 @@ inline const std::map<Tensor::ScalarType, size_t> Tensor::scalarSize = { ...@@ -558,5 +611,5 @@ inline const std::map<Tensor::ScalarType, size_t> Tensor::scalarSize = {
struct TensorsProvider { struct TensorsProvider {
virtual ~TensorsProvider() {} virtual ~TensorsProvider() {}
virtual bool contains(const std::string &key) const = 0; virtual bool contains(const std::string &key) const = 0;
virtual Tensor getTensor(const std::string &key) = 0; virtual Tensor getTensor(const std::string &key) = 0;
}; };
\ No newline at end of file
...@@ -22,13 +22,15 @@ Tensor GELU::forward(Tensor x) { ...@@ -22,13 +22,15 @@ Tensor GELU::forward(Tensor x) {
// return out; // return out;
// } // }
// Tensor SiluAndMulQuant::forward_with_act_sum(Tensor x, Tensor quantized_mlp_act_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer) { // Tensor SiluAndMulQuant::forward_with_act_sum(Tensor x, Tensor quantized_mlp_act_buffer, Tensor
// quantized_scale_buffer, Tensor quantized_sum_buffer) {
// Tensor out = SiluAndMul::forward(x); // Tensor out = SiluAndMul::forward(x);
// invoke_quant_fuse_sum(quantized_mlp_act_buffer, out, quantized_sum_buffer, quantized_scale_buffer); // invoke_quant_fuse_sum(quantized_mlp_act_buffer, out, quantized_sum_buffer, quantized_scale_buffer);
// return out; // return out;
// } // }
// Tensor SiluAndMulQuant::forward_wo_act_sum(Tensor x, Tensor quantized_mlp_act_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer) { // Tensor SiluAndMulQuant::forward_wo_act_sum(Tensor x, Tensor quantized_mlp_act_buffer, Tensor quantized_scale_buffer,
// Tensor quantized_sum_buffer) {
// Tensor out = SiluAndMul::forward(x); // Tensor out = SiluAndMul::forward(x);
// invoke_quant(quantized_mlp_act_buffer, out, quantized_scale_buffer, {}); // invoke_quant(quantized_mlp_act_buffer, out, quantized_scale_buffer, {});
// return out; // return out;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment