Unverified Commit f060b8da authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

[Major] Release v0.1.4

Support 4-bit text encoder and per-layer CPU offloading, reducing FLUX's minimum memory requirement to just 4 GiB while maintaining a 2–3× speedup. Fix various issues related to resolution, LoRA, pin memory, and runtime stability. Check out the release notes for full details!
parents f549dfc6 873a35be
from .text_encoders.t5_encoder import NunchakuT5EncoderModel
from .transformers import NunchakuFluxTransformer2dModel, NunchakuSanaTransformer2DModel
# -*- coding: utf-8 -*-
"""TinyChat Quantized Linear Module"""
import torch
import torch.nn as nn
from .tinychat_utils import ceil_num_groups, convert_to_tinychat_w4x16y16_linear_weight
from ..._C.ops import gemm_awq, gemv_awq
__all__ = ["W4Linear"]
class W4Linear(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = False,
group_size: int = 128,
dtype: torch.dtype = torch.float16,
device: str | torch.device = "cuda",
):
super().__init__()
assert dtype in (torch.float16, torch.bfloat16), f"Unsupported dtype: {dtype}"
self.in_features = in_features
self.out_features = out_features
self.group_size = group_size if group_size != -1 else in_features
assert self.in_features % self.group_size == 0
assert out_features % (32 // self.weight_bits) == 0
self.ceil_num_groups = ceil_num_groups(
in_features=self.in_features,
group_size=self.group_size,
weight_bits=self.weight_bits,
)
assert out_features % (self.interleave) == 0
self.register_buffer(
"qweight",
torch.zeros(
(
self.out_features // self.interleave,
self.in_features // (16 // self.weight_bits) * self.interleave,
),
dtype=torch.int16,
device=device,
),
)
self.register_buffer(
"scales",
torch.zeros((self.ceil_num_groups, self.out_features), dtype=dtype, device=device),
)
self.register_buffer(
"scaled_zeros",
torch.zeros((self.ceil_num_groups, self.out_features), dtype=dtype, device=device),
)
if bias:
self.register_buffer("bias", torch.zeros((out_features), dtype=dtype, device=device))
else:
self.bias = None
@property
def weight_bits(self) -> int:
return 4
@property
def interleave(self) -> int:
return 4
@torch.no_grad()
def forward(self, x):
if x.numel() / x.shape[-1] < 8:
out = gemv_awq(
x,
self.qweight,
self.scales,
self.scaled_zeros,
x.numel() // x.shape[-1],
self.out_features,
self.in_features,
self.group_size,
)
else:
out = gemm_awq(x, self.qweight, self.scales, self.scaled_zeros)
out = out + self.bias if self.bias is not None else out
return out
@staticmethod
def from_linear(
linear: nn.Linear,
group_size: int,
init_only: bool = False,
weight: torch.Tensor | None = None,
scale: torch.Tensor | None = None,
zero: torch.Tensor | None = None,
zero_pre_scaled: bool = False,
) -> "W4Linear":
"""Convert a linear layer to a TinyChat 4-bit weight-only quantized linear layer.
Args:
linear (`nn.Linear`):
linear layer to be converted.
group_size (`int`):
quantization group size.
init_only (`bool`, *optional*, defaults to `False`):
whether to only initialize the quantized linear layer.
weight (`torch.Tensor`, *optional*, defaults to `None`):
weight tensor for the quantized linear layer.
scale (`torch.Tensor`, *optional*, defaults to `None`):
scale tensor for the quantized linear layer.
zero (`torch.Tensor`, *optional*, defaults to `None`):
zero point tensor for the quantized linear layer.
zero_pre_scaled (`bool`, *optional*, defaults to `False`):
whether zero point tensor is pre-scaled.
Returns:
`W4Linear`:
quantized linear layer.
"""
assert isinstance(linear, nn.Linear)
weight = linear.weight.data if weight is None else weight.data
dtype, device = weight.dtype, weight.device
oc, ic = linear.out_features, linear.in_features
_linear = W4Linear(
in_features=ic,
out_features=oc,
bias=linear.bias is not None,
group_size=group_size,
dtype=dtype,
device=device,
)
if init_only:
return _linear
if linear.bias is not None:
_linear.bias.data.copy_(linear.bias.data)
if scale is None:
assert zero is None, "scale and zero point tensors should be provided together."
group_size = ic if group_size <= 0 else group_size
assert group_size <= ic, "group size should be less than or equal to input channel size."
assert ic % group_size == 0, "input channel size should be divisible by group size."
ng, gs = ic // group_size, group_size
weight = weight.to(dtype=torch.float32).view(oc, 1, ng, gs)
vmin, vmax = weight.amin(dim=-1, keepdim=True), weight.amax(dim=-1, keepdim=True)
scale = (vmax - vmin).div_(15)
scale[scale == 0] = 1.0
if zero_pre_scaled:
zero = vmin.neg_().div_(scale).round_().clamp_(0, 15)
weight = weight.div_(scale).add_(zero).round_().clamp_(0, 15).sub_(zero).mul_(scale)
else:
zero = vmin.neg_().clamp_min(0)
weight = weight.add_(zero).div_(scale).round_().clamp_(0, 15).mul_(scale).sub_(zero)
weight = weight.to(dtype=dtype).view(oc, ic)
scale = scale.to(dtype=dtype)
zero = zero.to(dtype=dtype)
weight, scale, zero = convert_to_tinychat_w4x16y16_linear_weight(
weight=weight,
scale=scale,
zero=zero,
group_size=group_size,
zero_pre_scaled=zero_pre_scaled,
)
_linear.qweight.data.copy_(weight)
_linear.scales.data.copy_(scale)
_linear.scaled_zeros.data.copy_(zero)
return _linear
def extra_repr(self) -> str:
return "in_features={}, out_features={}, bias={}, weight_bits={}, group_size={}".format(
self.in_features,
self.out_features,
self.bias is not None,
self.weight_bits,
self.group_size,
)
import os
import torch
from deepcompressor.backend.tinychat.linear import W4Linear
from huggingface_hub import constants, hf_hub_download
from safetensors.torch import load_file
from torch import nn
from transformers import PretrainedConfig, T5EncoderModel
from .linear import W4Linear
def quantize_t5_encoder(
t5_encoder: nn.Module,
......
# -*- coding: utf-8 -*-
"""TinyChat backend utilities."""
import torch
__all__ = ["ceil_num_groups", "convert_to_tinychat_w4x16y16_linear_weight"]
def ceil_divide(x: int, divisor: int) -> int:
"""Ceiling division.
Args:
x (`int`):
dividend.
divisor (`int`):
divisor.
Returns:
`int`:
ceiling division result.
"""
return (x + divisor - 1) // divisor
def ceil_num_groups(in_features: int, group_size: int, weight_bits: int = 4) -> int:
"""Calculate the ceiling number of quantization groups.
Args:
in_features (`int`):
input channel size.
group_size (`int`):
quantization group size.
weight_bits (`int`, *optional*, defaults to `4`):
quantized weight bits.
Returns:
`int`:
ceiling number of quantization groups.
"""
assert in_features % group_size == 0, "input channel size should be divisible by group size."
num_groups = in_features // group_size
assert weight_bits in (4, 2, 1), "weight bits should be 4, 2, or 1."
pack_size = 32 // weight_bits # one INT32 contains `pack_size` elements of weights
num_packs = ceil_divide(num_groups, pack_size)
if group_size >= 128:
num_packs_factor = 1
elif group_size == 64:
num_packs_factor = 2
elif group_size == 32:
num_packs_factor = 4
else:
raise NotImplementedError
# make sure num_packs is a multiple of num_packs_factor
num_packs = ceil_divide(num_packs, num_packs_factor) * num_packs_factor
num_groups = num_packs * pack_size
return num_groups
def pack_w4(weight: torch.Tensor) -> torch.Tensor:
assert weight.dtype == torch.int32, f"quantized weight should be torch.int32, but got {weight.dtype}."
oc, ic = weight.shape
assert ic % 32 == 0, "input channel size should be divisible by 32."
# [0, 1, ..., 31] -> [0, 8, 16, 24, 1, 9, 17, 25, ..., 7, 15, 23, 31]
weight = weight.view(-1, 4, 8)
weight = weight[:, 0] | (weight[:, 1] << 4) | (weight[:, 2] << 8) | (weight[:, 3] << 12)
weight = weight.view(oc // 4, 4, ic // 64, 16).permute(0, 2, 1, 3).reshape(oc // 4, ic)
return weight.to(torch.int16)
def convert_to_tinychat_w4x16y16_linear_weight(
weight: torch.Tensor,
scale: torch.Tensor,
zero: torch.Tensor,
group_size: int = -1,
zero_pre_scaled: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Convert a weight tensor to TinyChat W4-X16-Y16 linear weight format.
Args:
weight (`torch.Tensor`):
weight tensor to be converted.
scale (`torch.Tensor`):
scale tensor for the weight tensor.
zero (`torch.Tensor`):
zero point tensor for the weight tensor.
group_size (`int`, *optional*, defaults to `-1`):
quantization group size.
zero_pre_scaled (`bool`, *optional*, defaults to `False`):
whether zero point tensor is pre-scaled.
Returns:
`tuple[torch.Tensor, torch.Tensor, torch.Tensor]`:
packed quantized weight tensor, scale tensor, and zero point tensor.
"""
dtype, device = weight.dtype, weight.device
assert dtype in (torch.float16, torch.bfloat16), "currently tinychat only supports fp16 and bf16."
assert scale is not None, "scale tensor is required for quantization."
assert zero is not None, "zero point tensor is required for quantization."
weight = weight.to(dtype=torch.float32)
scale = scale.to(dtype=torch.float32, device=device)
zero = zero.to(dtype=torch.float32, device=device)
if zero_pre_scaled:
zero = zero * scale
oc, ic = weight.shape
group_size = ic if group_size <= 0 else group_size
assert group_size <= ic, "group size should be less than or equal to input channel size."
assert ic % group_size == 0, "input channel size should be divisible by group size."
ng = ic // group_size
if scale.numel() == 1:
scale = scale.view(1, 1).expand(oc, ng)
scale = scale.reshape(oc, ng).contiguous().view(oc, ng, 1)
if zero.numel() == 1:
zero = zero.view(1, 1).expand(oc, ng)
zero = zero.reshape(oc, ng).contiguous().view(oc, ng, 1)
weight = weight.view(oc, ng, -1).add_(zero).div_(scale).round_().view(oc, ic)
assert weight.min() >= 0 and weight.max() <= 15, "quantized weight should be in [0, 15]."
_weight = pack_w4(weight.to(torch.int32))
_ng = ceil_num_groups(ic, group_size, weight_bits=4)
_scale = torch.zeros((_ng, oc), dtype=dtype, device=device)
_zero = torch.zeros((_ng, oc), dtype=dtype, device=device)
_scale[:ng] = scale.view(oc, ng).t().to(dtype=dtype)
_zero[:ng] = zero.view(oc, ng).t().to(dtype=dtype).neg_()
return _weight, _scale, _zero
\ No newline at end of file
from .transformer_flux import NunchakuFluxTransformer2dModel
from .transformer_sana import NunchakuSanaTransformer2DModel
......@@ -8,9 +8,10 @@ from huggingface_hub import utils
from packaging.version import Version
from torch import nn
from nunchaku.utils import fetch_or_download
from .utils import NunchakuModelLoaderMixin, pad_tensor
from .._C import QuantizedFluxModel, utils as cutils
from ..utils import fetch_or_download
from ..._C import QuantizedFluxModel, utils as cutils
from ...utils import load_state_dict_in_safetensors
SVD_RANK = 32
......@@ -88,8 +89,6 @@ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
else:
out = out.view(batch_size, -1, dim // 2, 1, 1)
# stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
# out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
return out.float()
......@@ -108,12 +107,14 @@ class EmbedND(nn.Module):
return emb.unsqueeze(1)
def load_quantized_module(path: str, device: str | torch.device = "cuda", use_fp4: bool = False) -> QuantizedFluxModel:
def load_quantized_module(
path: str, device: str | torch.device = "cuda", use_fp4: bool = False, offload: bool = False
) -> QuantizedFluxModel:
device = torch.device(device)
assert device.type == "cuda"
m = QuantizedFluxModel()
cutils.disable_memory_auto_release()
m.init(use_fp4, True, 0 if device.index is None else device.index)
m.init(use_fp4, offload, True, 0 if device.index is None else device.index)
m.load(path)
return m
......@@ -147,19 +148,49 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
guidance_embeds=guidance_embeds,
axes_dims_rope=axes_dims_rope,
)
self.unquantized_loras = {}
self.unquantized_state_dict = None
@classmethod
@utils.validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs):
device = kwargs.get("device", "cuda")
precision = kwargs.get("precision", "int4")
offload = kwargs.get("offload", False)
assert precision in ["int4", "fp4"]
transformer, transformer_block_path = cls._build_model(pretrained_model_name_or_path, **kwargs)
m = load_quantized_module(transformer_block_path, device=device, use_fp4=precision == "fp4")
m = load_quantized_module(transformer_block_path, device=device, use_fp4=precision == "fp4", offload=offload)
transformer.inject_quantized_module(m, device)
return transformer
def update_unquantized_lora_params(self, strength: float = 1):
new_state_dict = {}
for k in self.unquantized_state_dict.keys():
v = self.unquantized_state_dict[k]
if k.replace(".weight", ".lora_B.weight") in self.unquantized_loras:
new_state_dict[k] = v + strength * (
self.unquantized_loras[k.replace(".weight", ".lora_B.weight")]
@ self.unquantized_loras[k.replace(".weight", ".lora_A.weight")]
)
else:
new_state_dict[k] = v
self.load_state_dict(new_state_dict, strict=True)
def update_lora_params(self, path: str):
state_dict = load_state_dict_in_safetensors(path)
unquantized_loras = {}
for k in state_dict.keys():
if "transformer_blocks" not in k:
unquantized_loras[k] = state_dict[k]
self.unquantized_loras = unquantized_loras
if len(unquantized_loras) > 0:
if self.unquantized_state_dict is None:
unquantized_state_dict = self.state_dict()
self.unquantized_state_dict = {k: v.cpu() for k, v in unquantized_state_dict.items()}
self.update_unquantized_lora_params(1)
path = fetch_or_download(path)
block = self.transformer_blocks[0]
assert isinstance(block, NunchakuFluxTransformerBlocks)
......@@ -169,6 +200,8 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
block = self.transformer_blocks[0]
assert isinstance(block, NunchakuFluxTransformerBlocks)
block.m.setLoraScale(SVD_RANK, strength)
if len(self.unquantized_loras) > 0:
self.update_unquantized_lora_params(strength)
def inject_quantized_module(self, m: QuantizedFluxModel, device: str | torch.device = "cuda"):
self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=[16, 56, 56])
......
......@@ -9,7 +9,7 @@ from huggingface_hub import utils
from torch import nn
from .utils import NunchakuModelLoaderMixin
from .._C import QuantizedSanaModel, utils as cutils
from ..._C import QuantizedSanaModel, utils as cutils
SVD_RANK = 32
......
import torch
from diffusers import FluxPipeline
from .models.transformer_flux import NunchakuFluxTransformer2dModel
from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel
if __name__ == "__main__":
capability = torch.cuda.get_device_capability(0)
......
......@@ -5,11 +5,11 @@ import torch
from huggingface_hub import hf_hub_download
def fetch_or_download(path: str) -> str:
def fetch_or_download(path: str, repo_type: str = "model") -> str:
if not os.path.exists(path):
hf_repo_id = os.path.dirname(path)
filename = os.path.basename(path)
path = hf_hub_download(repo_id=hf_repo_id, filename=filename)
path = hf_hub_download(repo_id=hf_repo_id, filename=filename, repo_type=repo_type)
return path
......
......@@ -21,4 +21,4 @@ dependencies = [
"protobuf",
"huggingface_hub",
]
requires-python = ">=3.11, <3.13"
requires-python = ">=3.10, <3.13"
#!/bin/bash
# Modified from https://github.com/sgl-project/sglang/blob/main/sgl-kernel/build.sh
set -ex
PYTHON_VERSION=$1
TORCH_VERSION=$2
CUDA_VERSION=$3
MAX_JOBS=${4:-} # optional
PYTHON_ROOT_PATH=/opt/python/cp${PYTHON_VERSION//.}-cp${PYTHON_VERSION//.}
docker run --rm \
-v "$(pwd)":/nunchaku \
pytorch/manylinux-builder:cuda${CUDA_VERSION} \
bash -c "
cd /nunchaku && \
rm -rf build && \
yum install -y devtoolset-11 && \
source scl_source enable devtoolset-11 && \
gcc --version && g++ --version && \
${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir torch==${TORCH_VERSION} numpy --index-url https://download.pytorch.org/whl/cu${CUDA_VERSION//.} && \
${PYTHON_ROOT_PATH}/bin/pip install build ninja wheel setuptools && \
export NUNCHAKU_INSTALL_MODE=ALL && \
export NUNCHAKU_BUILD_WHEELS=1 && \
export MAX_JOBS=${MAX_JOBS} && \
${PYTHON_ROOT_PATH}/bin/python -m build --wheel --no-isolation
"
\ No newline at end of file
#!/bin/bash
set -ex
#docker run --rm \
# -v "$(pwd)":/nunchaku \
# pytorch/manylinux-builder:cuda12.4 \
# bash -c "cd /nunchaku && rm -r *"
docker run --rm -it \
-v "$(pwd)":/nunchaku \
pytorch/manylinux-builder:cuda12.4 \
bash
\ No newline at end of file
......@@ -77,7 +77,7 @@ if __name__ == "__main__":
"third_party/Block-Sparse-Attention/csrc/block_sparse_attn",
]
INCLUDE_DIRS = [ROOT_DIR + "/" + dir for dir in INCLUDE_DIRS]
INCLUDE_DIRS = [os.path.join(ROOT_DIR, dir) for dir in INCLUDE_DIRS]
DEBUG = False
......@@ -93,8 +93,13 @@ if __name__ == "__main__":
else:
return []
sm_targets = get_sm_targets()
print(f"Detected SM targets: {sm_targets}", file=sys.stderr)
assert len(sm_targets) > 0, "No SM targets found"
GCC_FLAGS = ["-DENABLE_BF16=1", "-DBUILD_NUNCHAKU=1", "-fvisibility=hidden", "-g", "-std=c++20", "-UNDEBUG", "-Og"]
MSVC_FLAGS = ["/DENABLE_BF16=1", "/DBUILD_NUNCHAKU=1", "/std:c++20", "/UNDEBUG", "/Zc:__cplusplus"]
MSVC_FLAGS = ["/DENABLE_BF16=1", "/DBUILD_NUNCHAKU=1", "/std:c++20", "/UNDEBUG", "/Zc:__cplusplus", "/FS"]
NVCC_FLAGS = [
"-DENABLE_BF16=1",
"-DBUILD_NUNCHAKU=1",
......@@ -112,7 +117,7 @@ if __name__ == "__main__":
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"--threads=3",
f"--threads={len(sm_targets)}",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--ptxas-options=--allow-expensive-optimizations=true",
......@@ -121,15 +126,10 @@ if __name__ == "__main__":
if os.getenv("NUNCHAKU_BUILD_WHEELS", "0") == "0":
NVCC_FLAGS.append("--generate-line-info")
sm_targets = get_sm_targets()
print(f"Detected SM targets: {sm_targets}", file=sys.stderr)
assert len(sm_targets) > 0, "No SM targets found"
for target in sm_targets:
NVCC_FLAGS += ["-gencode", f"arch=compute_{target},code=sm_{target}"]
NVCC_MSVC_FLAGS = ["-Xcompiler", "/Zc:__cplusplus"]
NVCC_MSVC_FLAGS = ["-Xcompiler", "/Zc:__cplusplus", "-Xcompiler", "/FS"]
nunchaku_extension = CUDAExtension(
name="nunchaku._C",
......@@ -164,6 +164,7 @@ if __name__ == "__main__":
"src/kernels/dwconv.cu",
"src/kernels/gemm_batched.cu",
"src/kernels/gemm_f16.cu",
"src/kernels/awq/gemm_awq.cu",
"src/kernels/awq/gemv_awq.cu",
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/flash_api.cpp"),
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/flash_api_adapter.cpp"),
......
......@@ -607,14 +607,22 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
return { hidden_states, encoder_hidden_states };
}
FluxModel::FluxModel(bool use_fp4, Tensor::ScalarType dtype, Device device) {
FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device) : offload(offload) {
for (int i = 0; i < 19; i++) {
transformer_blocks.push_back(std::make_unique<JointTransformerBlock>(3072, 24, 3072, false, use_fp4, dtype, device));
registerChildren(*transformer_blocks.back(), format("transformer_blocks.{}", i));
if (offload && i > 0) { // don't offload first block
transformer_blocks.back()->setLazyLoad(true);
transformer_blocks.back()->releaseLazyParams();
}
}
for (int i = 0; i < 38; i++) {
single_transformer_blocks.push_back(std::make_unique<FluxSingleTransformerBlock>(3072, 24, 3072, 4, use_fp4, dtype, Device::cuda()));
registerChildren(*single_transformer_blocks.back(), format("single_transformer_blocks.{}", i));
if (offload) {
single_transformer_blocks.back()->setLazyLoad(true);
single_transformer_blocks.back()->releaseLazyParams();
}
}
}
......@@ -626,22 +634,51 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
const int txt_tokens = encoder_hidden_states.shape[1];
const int img_tokens = hidden_states.shape[1];
for (auto &&block : transformer_blocks) {
std::tie(hidden_states, encoder_hidden_states) = block->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
}
const int numLayers = transformer_blocks.size() + single_transformer_blocks.size();
// txt first, same as diffusers
Tensor concat = Tensor::allocate({batch_size, txt_tokens + img_tokens, 3072}, dtype, device);
for (int i = 0; i < batch_size; i++) {
concat.slice(0, i, i + 1).slice(1, 0, txt_tokens).copy_(encoder_hidden_states);
concat.slice(0, i, i + 1).slice(1, txt_tokens, txt_tokens + img_tokens).copy_(hidden_states);
}
hidden_states = concat;
encoder_hidden_states = {};
Tensor concat;
for (auto &&block : single_transformer_blocks) {
hidden_states = block->forward(hidden_states, temb, rotary_emb_single);
}
auto compute = [&](int layer) {
if (size_t(layer) < transformer_blocks.size()) {
auto &block = transformer_blocks.at(layer);
std::tie(hidden_states, encoder_hidden_states) = block->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
} else {
if (size_t(layer) == transformer_blocks.size()) {
// txt first, same as diffusers
concat = Tensor::allocate({batch_size, txt_tokens + img_tokens, 3072}, dtype, device);
for (int i = 0; i < batch_size; i++) {
concat.slice(0, i, i + 1).slice(1, 0, txt_tokens).copy_(encoder_hidden_states);
concat.slice(0, i, i + 1).slice(1, txt_tokens, txt_tokens + img_tokens).copy_(hidden_states);
}
hidden_states = concat;
encoder_hidden_states = {};
}
auto &block = single_transformer_blocks.at(layer - transformer_blocks.size());
hidden_states = block->forward(hidden_states, temb, rotary_emb_single);
}
};
auto load = [&](int layer) {
if (size_t(layer) < transformer_blocks.size()) {
auto &block = transformer_blocks.at(layer);
block->loadLazyParams();
} else {
auto &block = single_transformer_blocks.at(layer - transformer_blocks.size());
block->loadLazyParams();
}
};
auto unload = [&](int layer) {
if (size_t(layer) < transformer_blocks.size()) {
auto &block = transformer_blocks.at(layer);
block->releaseLazyParams();
} else {
auto &block = single_transformer_blocks.at(layer - transformer_blocks.size());
block->releaseLazyParams();
}
};
LayerOffloadHelper helper(this->offload, numLayers, compute, load, unload);
helper.run();
return hidden_states;
}
\ No newline at end of file
......@@ -128,10 +128,13 @@ private:
class FluxModel : public Module {
public:
FluxModel(bool use_fp4, Tensor::ScalarType dtype, Device device);
FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb_img, Tensor rotary_emb_context, Tensor rotary_emb_single);
public:
std::vector<std::unique_ptr<JointTransformerBlock>> transformer_blocks;
std::vector<std::unique_ptr<FluxSingleTransformerBlock>> single_transformer_blocks;
private:
bool offload;
};
\ No newline at end of file
......@@ -16,7 +16,7 @@ GEMM_F16::GEMM_F16(int in_features, int out_features, bool use_bias, Tensor::Sca
this->bias = use_bias ? Tensor::allocate({out_features}, dtype, device) : Tensor{};
registerParams
(weight, "weight")
(weight, "weight", ParamFlags::LazyLoad)
(bias, "bias")
;
}
......@@ -27,7 +27,7 @@ Tensor GEMM_F16::forward(Tensor x) {
}
GEMV_AWQ::GEMV_AWQ(int in_features, int out_features, bool use_bias, Tensor::ScalarType dtype, Device device) :
in_features(in_features), out_features(out_features), group_size(64), lora_rank(0), lora_scale(1.0f)
in_features(in_features), out_features(out_features), group_size(64), lora_rank(0), lora_scale(1.0f), device(device)
{
this->qweight = Tensor::allocate({out_features / 4, ceilDiv(in_features, 8) * 4}, Tensor::INT32, device);
this->wscales = Tensor::allocate({ceilDiv(in_features, group_size), out_features}, dtype, device);
......@@ -39,7 +39,7 @@ GEMV_AWQ::GEMV_AWQ(int in_features, int out_features, bool use_bias, Tensor::Sca
this->lora_up = Tensor::allocate({out_features, lora_rank}, dtype, device, true);
registerParams
(qweight, "qweight")
(qweight, "qweight", ParamFlags::LazyLoad)
(wscales, "wscales")
(wzeros, "wzeros")
(bias, "bias")
......@@ -52,7 +52,7 @@ void GEMV_AWQ::loadParam(std::string key, Tensor &dst, Tensor src) {
if (key == "lora_down" || key == "lora_up") {
assert(src.ndims() == 2);
if (dst.shape.dataExtent != src.shape.dataExtent) {
dst = src.copy(this->qweight.device());
dst = src.copy(this->device);
if (key == "lora_down") {
const int new_rank = dst.shape[0];
this->lora_rank = new_rank;
......@@ -100,7 +100,7 @@ GEMM_W4A4::GEMM_W4A4(int in_features, int out_features, bool bias, bool use_fp4,
in_features(in_features), out_features(out_features),
in_features_pad(ceilDiv(in_features, 128) * 128), out_features_pad(ceilDiv(out_features, 128) * 128),
use_fp4(use_fp4),
lora_rank(0), dtype(dtype)
lora_rank(0), dtype(dtype), device(device)
{
this->qweight = Tensor::allocate({out_features_pad, in_features_pad / 2}, Tensor::INT8, device, true);
if (use_fp4) {
......@@ -124,7 +124,7 @@ GEMM_W4A4::GEMM_W4A4(int in_features, int out_features, bool bias, bool use_fp4,
this->wcscales = Tensor::allocate({0}, dtype, device, true);
registerParams
(qweight, "qweight")
(qweight, "qweight", ParamFlags::LazyLoad)
(wscales, "wscales")
(this->bias, "bias")
(lora_down, "lora_down", ParamFlags::Optional)
......@@ -143,7 +143,7 @@ void GEMM_W4A4::loadParam(std::string key, Tensor &dst, Tensor src) {
if (key == "lora_down" || key == "lora_up") {
assert(src.ndims() == 2);
if (dst.shape.dataExtent != src.shape.dataExtent) {
dst = src.copy(this->qweight.device());
dst = src.copy(this->device);
this->lora_rank = dst.shape[1];
this->lora_scales.resize(ceilDiv(this->lora_rank, 16), 1.0f);
} else {
......@@ -152,7 +152,7 @@ void GEMM_W4A4::loadParam(std::string key, Tensor &dst, Tensor src) {
} else if (key == "wcscales") {
assert(src.ndims() == 1);
assert(src.shape[0] == out_features_pad);
dst = src.copy(this->qweight.device());
dst = src.copy(this->device);
} else if (key == "wtscale") {
assert(src.numel() == 1);
if (src.dtype() == Tensor::BF16) {
......@@ -242,15 +242,15 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
// shape[-1] = out_features;
auto shape = TensorShape(qact.actShape.dataExtent);
shape[-1] = out_features;
out = Tensor::allocate(shape, dtype, qweight.device());
out = Tensor::allocate(shape, dtype, device);
} else {
qout.act = Tensor::allocate({M, out_features_pad / 2}, Tensor::INT8, qweight.device());
qout.act = Tensor::allocate({M, out_features_pad / 2}, Tensor::INT8, device);
if (use_fp4) {
qout.ascales = Tensor::allocate({out_features_pad / 16, M}, Tensor::FP8_E4M3, qweight.device());
qout.ascales = Tensor::allocate({out_features_pad / 16, M}, Tensor::FP8_E4M3, device);
} else {
qout.ascales = Tensor::allocate({out_features_pad / 64, M}, dtype, qweight.device());
qout.ascales = Tensor::allocate({out_features_pad / 64, M}, dtype, device);
}
qout.lora_act = Tensor::allocate({M, lora_rank}, Tensor::FP32, qweight.device());
qout.lora_act = Tensor::allocate({M, lora_rank}, Tensor::FP32, device);
qout.is_unsigned = !use_fp4;
qout.actShape = qact.actShape;
......@@ -363,13 +363,13 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
// shape[-1] = in_features / 2;
QuantizedActivation qact;
qact.act = Tensor::allocate({M, in_features_pad / 2}, Tensor::INT8, qweight.device());
qact.act = Tensor::allocate({M, in_features_pad / 2}, Tensor::INT8, device);
if (use_fp4) {
qact.ascales = Tensor::allocate({in_features_pad / 16, M}, Tensor::FP8_E4M3, qweight.device());
qact.ascales = Tensor::allocate({in_features_pad / 16, M}, Tensor::FP8_E4M3, device);
} else {
qact.ascales = Tensor::allocate({in_features_pad / 64, M}, dtype, qweight.device());
qact.ascales = Tensor::allocate({in_features_pad / 64, M}, dtype, device);
}
qact.lora_act = Tensor::allocate({M, lora_rank}, Tensor::FP32, qweight.device());
qact.lora_act = Tensor::allocate({M, lora_rank}, Tensor::FP32, device);
qact.is_unsigned = false;
qact.actShape = x.shape.dataExtent;
......@@ -420,7 +420,7 @@ GEMM_W8A8::GEMM_W8A8(int in_features, int out_features, bool bias, Tensor::Scala
this->bias = bias ? Tensor::allocate({out_features}, dtype, device, true) : Tensor{};
registerParams
(qweight, "qweight")
(qweight, "qweight", ParamFlags::LazyLoad)
(wscales, "wscales")
(this->bias, "bias")
;
......
......@@ -36,6 +36,7 @@ public:
int lora_rank;
float lora_scale;
const Device device;
public:
Tensor qweight;
Tensor wscales;
......@@ -86,6 +87,7 @@ public:
std::vector<float> lora_scales; // every 16 ranks share a scale
const Tensor::ScalarType dtype;
const Device device;
protected:
virtual void loadParam(std::string key, Tensor &dst, Tensor src) override;
......
......@@ -9,10 +9,20 @@ protected:
enum class ParamFlags : int {
None = 0,
Optional = 1,
LazyLoad = 2,
};
struct TensorLazyLoadInfo {
TensorShape shape;
Tensor::ScalarType type;
Device device;
Tensor src;
};
struct Param {
Tensor *tensor;
ParamFlags flags;
Tensor *tensor = nullptr;
ParamFlags flags = ParamFlags::None;
TensorLazyLoadInfo lazyInfo;
};
friend inline ParamFlags operator|(ParamFlags lhs, ParamFlags rhs) {
......@@ -21,6 +31,9 @@ protected:
friend inline ParamFlags operator&(ParamFlags lhs, ParamFlags rhs) {
return static_cast<ParamFlags>(static_cast<int>(lhs) & static_cast<int>(rhs));
}
static bool checkFlag(ParamFlags flags, ParamFlags target) {
return int(flags & target);
}
public:
std::string getFullName() const {
......@@ -35,6 +48,12 @@ public:
}
}
std::string getPrefix() const {
std::string fullName = getFullName();
std::string prefix = fullName.empty() ? "" : fullName + ".";
return prefix;
}
void traverse(std::function<void(Module *)> func) {
func(this);
for (Module *c : this->children) {
......@@ -46,8 +65,7 @@ public:
for (Module *c : children) {
c->loadParams(provider, partial);
}
std::string fullName = getFullName();
std::string prefix = fullName.empty() ? "" : fullName + ".";
std::string prefix = getPrefix();
for (auto &&[key, param] : params) {
Tensor src = provider.getTensor(prefix + key);
if (!src.valid()) {
......@@ -56,6 +74,13 @@ public:
}
throw std::runtime_error(spdlog::fmt_lib::format("Tensor {} not found", prefix + key));
}
if (enabledLazyLoad && checkFlag(param.flags, ParamFlags::LazyLoad)) {
param.lazyInfo.src = src;
if (!param.tensor->valid()) {
continue;
}
// keep loading params if param is not released
}
this->loadParam(key, *param.tensor, src);
// tensor->copy_(src);
}
......@@ -66,7 +91,46 @@ public:
this->name = std::move(name);
}
void loadLazyParams() {
traverse([](Module *m) {
for (auto &&[key, param] : m->params) {
if (!checkFlag(param.flags, ParamFlags::LazyLoad)) {
continue;
}
TensorLazyLoadInfo &lazy = param.lazyInfo;
Tensor &dst = *param.tensor;
Tensor src = lazy.src;
if (dst.valid()) {
continue;
}
dst = Tensor::allocate(lazy.shape, lazy.type, lazy.device);
if (!src.valid() && !checkFlag(param.flags, ParamFlags::Optional)) {
throw std::runtime_error(spdlog::fmt_lib::format("Lazy load: Tensor {} has no src", m->getPrefix() + key));
}
m->loadParam(key, dst, src);
}
});
}
void releaseLazyParams() {
traverse([](Module *m) {
if (!m->enabledLazyLoad) {
return;
}
for (auto &&[key, param] : m->params) {
if (checkFlag(param.flags, ParamFlags::LazyLoad)) {
*param.tensor = Tensor{};
}
}
});
}
void setLazyLoad(bool val) {
traverse([val](Module *m) {
m->enabledLazyLoad = val;
});
}
protected:
virtual void loadParam(std::string key, Tensor &dst, Tensor src) {
......@@ -98,6 +162,13 @@ protected:
if (param.valid()) {
params[name].tensor = &param;
params[name].flags = flags;
if (checkFlag(flags, ParamFlags::LazyLoad) && param.valid()) {
TensorLazyLoadInfo &lazy = params[name].lazyInfo;
lazy.shape = param.shape;
lazy.type = param.dtype();
lazy.device = param.device();
}
}
return ParamsRegisterHelper(*this);
}
......@@ -121,4 +192,78 @@ public:
std::string name = "";
std::vector<Module *> children;
std::map<std::string, Param> params;
bool enabledLazyLoad = false;
};
struct LayerOffloadHelper {
using func_t = std::function<void(int)>;
const bool offload;
const int numLayers;
func_t funcCompute, funcLoad, funcUnload;
std::unique_ptr<CUDAStreamWrapper> streamCompute;
std::unique_ptr<CUDAStreamWrapper> streamLoad;
std::unique_ptr<CUDAEventWrapper> eventComputeDone;
std::unique_ptr<CUDAEventWrapper> eventLoadDone;
LayerOffloadHelper(bool offload, int numLayers, func_t funcCompute, func_t funcLoad, func_t funcUnload)
: offload(offload), numLayers(numLayers), funcCompute(funcCompute), funcLoad(funcLoad), funcUnload(funcUnload)
{
if (offload) {
streamCompute = std::make_unique<CUDAStreamWrapper>();
streamLoad = std::make_unique<CUDAStreamWrapper>();
}
}
void run() {
for (int i = 0; i < numLayers; i++) {
run(i);
}
waitEvent(eventComputeDone.get());
funcUnload(numLayers - 1);
}
private:
void run(int layer) {
if (!offload) {
funcCompute(layer);
} else {
std::unique_ptr<CUDAEventWrapper> nextComputeDone, nextLoadDone;
// issue compute kernels first so that we could still overlap compute and memcpy if memory is not pinned
{
CUDAStreamContext ctx(streamCompute->stream);
waitEvent(eventLoadDone.get());
funcCompute(layer);
nextComputeDone = std::make_unique<CUDAEventWrapper>();
checkCUDA(cudaEventRecord(nextComputeDone->event, getCurrentCUDAStream()));
}
{
CUDAStreamContext ctx(streamLoad->stream);
waitEvent(eventComputeDone.get());
if (layer - 1 > 0) {
funcUnload(layer - 1);
}
if (layer + 1 < numLayers) {
funcLoad(layer + 1);
}
nextLoadDone = std::make_unique<CUDAEventWrapper>();
checkCUDA(cudaEventRecord(nextLoadDone->event, getCurrentCUDAStream()));
}
eventComputeDone = std::move(nextComputeDone);
eventLoadDone = std::move(nextLoadDone);
}
}
static void waitEvent(CUDAEventWrapper *event) {
if (!event) {
return;
}
checkCUDA(cudaStreamWaitEvent(getCurrentCUDAStream(), event->event));
}
};
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment