Commit b1b44398 authored by Samuel Tesfai's avatar Samuel Tesfai
Browse files

Fixing merges

parents 004e4e31 4b9c2e03
......@@ -23,6 +23,5 @@ image = pipe(
guidance_scale=5.0,
pag_scale=2.0,
num_inference_steps=20,
generator=torch.Generator().manual_seed(42),
).images[0]
image.save("sana_1600m_pag.png")
__version__ = "0.0.2beta6"
__version__ = "0.1.3"
......@@ -9,9 +9,9 @@
class QuantizedFluxModel : public ModuleWrapper<FluxModel> { // : public torch::CustomClassHolder {
public:
void init(bool bf16, int8_t deviceId) {
void init(bool use_fp4, bool bf16, int8_t deviceId) {
spdlog::info("Initializing QuantizedFluxModel");
net = std::make_unique<FluxModel>(bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
net = std::make_unique<FluxModel>(use_fp4, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
}
torch::Tensor forward(
......
......@@ -8,7 +8,7 @@
class QuantizedGEMM : public ModuleWrapper<GEMM_W4A4> {
public:
void init(int64_t in_features, int64_t out_features, bool bias, bool bf16, int8_t deviceId) {
void init(int64_t in_features, int64_t out_features, bool bias, bool use_fp4, bool bf16, int8_t deviceId) {
spdlog::info("Initializing QuantizedGEMM");
size_t val = 0;
......@@ -16,7 +16,7 @@ public:
checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
spdlog::debug("Stack={}", val);
net = std::make_unique<GEMM_W4A4>((int)in_features, (int)out_features, bias, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
net = std::make_unique<GEMM_W4A4>((int)in_features, (int)out_features, bias, use_fp4, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
}
torch::Tensor forward(torch::Tensor x) {
......
......@@ -29,7 +29,10 @@ namespace nunchaku::ops {
std::optional<torch::Tensor> out_linearattn,// linear [B, (M), N / 3]
bool act_unsigned,
std::vector<float> lora_scales,
bool fuse_silu
bool fuse_silu,
bool fp4,
float alpha,
std::optional<torch::Tensor> wcscales
) {
spdlog::trace("running gemm_w4a4: ");
......@@ -64,7 +67,10 @@ namespace nunchaku::ops {
getTensor(out_linearattn),
act_unsigned,
lora_scales,
fuse_silu
fuse_silu,
fp4,
alpha,
getTensor(wcscales)
);
Tensor::synchronizeDevice();
}
......
......@@ -14,6 +14,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::class_<QuantizedFluxModel>(m, "QuantizedFluxModel")
.def(py::init<>())
.def("init", &QuantizedFluxModel::init,
py::arg("use_fp4"),
py::arg("bf16"),
py::arg("deviceId")
)
......@@ -36,6 +37,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("init", &QuantizedSanaModel::init,
py::arg("config"),
py::arg("pag_layers"),
py::arg("use_fp4"),
py::arg("bf16"),
py::arg("deviceId")
)
......
......@@ -8,7 +8,7 @@
class QuantizedSanaModel : public ModuleWrapper<SanaModel> {
public:
void init(pybind11::dict config, std::vector<int> pag_layers, bool bf16, int8_t deviceId) {
void init(pybind11::dict config, std::vector<int> pag_layers, bool use_fp4, bool bf16, int8_t deviceId) {
spdlog::info("Initializing QuantizedSanaModel");
SanaConfig cfg{
.num_layers = config["num_layers"].cast<int>(),
......@@ -17,6 +17,7 @@ public:
.num_cross_attention_heads = config["num_cross_attention_heads"].cast<int>(),
.expand_ratio = config["mlp_ratio"].cast<double>(),
.pag_layers = pag_layers,
.use_fp4 = use_fp4,
};
net = std::make_unique<SanaModel>(cfg, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
}
......
# convert the comfyui lora to diffusers format
import argparse
import os
import torch
......@@ -8,7 +9,7 @@ from ...utils import load_state_dict_in_safetensors
def comfyui2diffusers(
input_lora: str | dict[str, torch.Tensor], output_path: str | None = None
input_lora: str | dict[str, torch.Tensor], output_path: str | None = None, min_rank: int | None = None
) -> dict[str, torch.Tensor]:
if isinstance(input_lora, str):
tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
......@@ -16,7 +17,7 @@ def comfyui2diffusers(
tensors = input_lora
new_tensors = {}
max_rank = 0
for k, v in tensors.items():
if "alpha" in k:
continue
......@@ -29,7 +30,10 @@ def comfyui2diffusers(
# Copy the tensor
new_k = new_k.replace("_img_attn_qkv", f".attn.to_{p}")
new_k = new_k.replace("_txt_attn_qkv", f".attn.add_{p}_proj")
new_tensors[new_k] = v.clone()
rank = v.shape[0]
alpha = tensors[k.replace("lora_down.weight", "alpha")]
new_tensors[new_k] = v.clone() * alpha / rank
max_rank = max(max_rank, rank)
else:
assert "lora_B" in new_k
assert v.shape[0] % 3 == 0
......@@ -58,7 +62,10 @@ def comfyui2diffusers(
new_k1 = new_k.replace("_linear1", ".proj_mlp")
else:
new_k1 = new_k.replace("_linear1", f".attn.to_{p}")
new_tensors[new_k1] = v.clone()
rank = v.shape[0]
alpha = tensors[k.replace("lora_down.weight", "alpha")]
new_tensors[new_k1] = v.clone() * alpha / rank
max_rank = max(max_rank, rank)
else:
if p == "i":
new_k1 = new_k.replace("_linear1", ".proj_mlp")
......@@ -70,10 +77,43 @@ def comfyui2diffusers(
else:
new_k = new_k.replace("_linear2", ".proj_out")
new_k = new_k.replace("_modulation_lin", ".norm.linear")
if "lora_down" in k:
rank = v.shape[0]
alpha = tensors[k.replace("lora_down.weight", "alpha")]
v = v * alpha / rank
max_rank = max(max_rank, rank)
new_tensors[new_k] = v
if min_rank is not None:
for k in new_tensors.keys():
v = new_tensors[k]
if "lora_A" in k:
rank = v.shape[0]
if rank < min_rank:
new_v = torch.zeros(min_rank, v.shape[1], dtype=v.dtype, device=v.device)
new_v[:rank] = v
new_tensors[k] = new_v
else:
assert "lora_B" in k
rank = v.shape[1]
if rank < min_rank:
new_v = torch.zeros(v.shape[0], min_rank, dtype=v.dtype, device=v.device)
new_v[:, :rank] = v
new_tensors[k] = new_v
if output_path is not None:
output_dir = os.path.dirname(os.path.abspath(output_path))
os.makedirs(output_dir, exist_ok=True)
save_file(new_tensors, output_path)
return new_tensors
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input-path", type=str, required=True, help="path to the comfyui lora safetensor file")
parser.add_argument(
"-o", "--output-path", type=str, required=True, help="path to the output diffusers safetensor file"
)
parser.add_argument("--min-rank", type=int, default=None, help="minimum rank for the LoRA weights")
args = parser.parse_args()
comfyui2diffusers(args.input_path, args.output_path, min_rank=args.min_rank)
......@@ -6,6 +6,7 @@ from safetensors.torch import save_file
from .comfyui_converter import comfyui2diffusers
from .diffusers_converter import convert_to_nunchaku_flux_lowrank_dict
from .utils import detect_format
from .xlab_converter import xlab2diffusers
from ...utils import filter_state_dict, load_state_dict_in_safetensors
......@@ -21,8 +22,8 @@ if __name__ == "__main__":
parser.add_argument(
"--lora-format",
type=str,
default="diffusers",
choices=["comfyui", "diffusers", "xlab"],
default="auto",
choices=["auto", "comfyui", "diffusers", "xlab"],
help="format of the LoRA weights",
)
parser.add_argument("--output-root", type=str, default="", help="root to the output safetensor file")
......@@ -37,8 +38,8 @@ if __name__ == "__main__":
args = parser.parse_args()
if not args.output_root:
# output to the parent directory of the quantized model safetensor file
args.output_root = os.path.dirname(args.quant_path)
# output to the parent directory of the lora safetensor file
args.output_root = os.path.dirname(args.lora_path)
if args.lora_name is None:
base_name = os.path.basename(args.lora_path)
lora_name = base_name.rsplit(".", 1)[0]
......@@ -53,6 +54,13 @@ if __name__ == "__main__":
orig_state_dict = load_state_dict_in_safetensors(args.quant_path)
lora_format = args.lora_format
if lora_format == "auto":
lora_format = detect_format(args.lora_path)
print(f"Detected LoRA format: {lora_format}")
if lora_format == "svdquant":
print("Already in SVDQuant format, no conversion needed.")
exit(0)
if lora_format == "diffusers":
extra_lora_dict = load_state_dict_in_safetensors(args.lora_path)
else:
......
# convert the diffusers lora to nunchaku format
"""Convert LoRA weights to Nunchaku format."""
import typing as tp
import torch
......@@ -215,8 +214,8 @@ def convert_to_nunchaku_transformer_block_lowrank_dict( # noqa: C901
update_state_dict(
converted,
{
"lora_down": lora[0],
"lora_up": reorder_adanorm_lora_up(lora[1], splits=3),
"lora_down": pad(lora[0], divisor=16, dim=0),
"lora_up": pad(reorder_adanorm_lora_up(lora[1], splits=3), divisor=16, dim=1),
},
prefix=converted_local_name,
)
......@@ -224,8 +223,8 @@ def convert_to_nunchaku_transformer_block_lowrank_dict( # noqa: C901
update_state_dict(
converted,
{
"lora_down": lora[0],
"lora_up": reorder_adanorm_lora_up(lora[1], splits=6),
"lora_down": pad(lora[0], divisor=16, dim=0),
"lora_up": pad(reorder_adanorm_lora_up(lora[1], splits=6), divisor=16, dim=1),
},
prefix=converted_local_name,
)
......@@ -263,6 +262,22 @@ def convert_to_nunchaku_flux_single_transformer_block_lowrank_dict(
extra_lora_dict.pop(f"{candidate_block_name}.proj_out.lora_A.weight")
extra_lora_dict.pop(f"{candidate_block_name}.proj_out.lora_B.weight")
for component in ["lora_A", "lora_B"]:
fc1_k = f"{candidate_block_name}.proj_mlp.{component}.weight"
fc2_k = f"{candidate_block_name}.proj_out.linears.1.{component}.weight"
fc1_v = extra_lora_dict[fc1_k]
fc2_v = extra_lora_dict[fc2_k]
dim = 0 if "lora_A" in fc1_k else 1
fc1_rank = fc1_v.shape[dim]
fc2_rank = fc2_v.shape[dim]
if fc1_rank != fc2_rank:
rank = max(fc1_rank, fc2_rank)
if fc1_rank < rank:
extra_lora_dict[fc1_k] = pad(fc1_v, divisor=rank, dim=dim)
if fc2_rank < rank:
extra_lora_dict[fc2_k] = pad(fc2_v, divisor=rank, dim=dim)
return convert_to_nunchaku_transformer_block_lowrank_dict(
orig_state_dict=orig_state_dict,
extra_lora_dict=extra_lora_dict,
......@@ -347,6 +362,28 @@ def convert_to_nunchaku_flux_lowrank_dict(
else:
extra_lora_dict = filter_state_dict(lora, filter_prefix="transformer.")
for k in extra_lora_dict.keys():
fc1_k = k
if "ff.net.0.proj" in k:
fc2_k = k.replace("ff.net.0.proj", "ff.net.2")
elif "ff_context.net.0.proj" in k:
fc2_k = k.replace("ff_context.net.0.proj", "ff_context.net.2")
else:
continue
assert fc2_k in extra_lora_dict
fc1_v = extra_lora_dict[fc1_k]
fc2_v = extra_lora_dict[fc2_k]
dim = 0 if "lora_A" in fc1_k else 1
fc1_rank = fc1_v.shape[dim]
fc2_rank = fc2_v.shape[dim]
if fc1_rank != fc2_rank:
rank = max(fc1_rank, fc2_rank)
if fc1_rank < rank:
extra_lora_dict[fc1_k] = pad(fc1_v, divisor=rank, dim=dim)
if fc2_rank < rank:
extra_lora_dict[fc2_k] = pad(fc2_v, divisor=rank, dim=dim)
block_names: set[str] = set()
for param_name in orig_state_dict.keys():
if param_name.startswith(("transformer_blocks.", "single_transformer_blocks.")):
......@@ -370,4 +407,5 @@ def convert_to_nunchaku_flux_lowrank_dict(
),
prefix=block_name,
)
return converted
import torch
from ...utils import load_state_dict_in_safetensors
def detect_format(lora: str | dict[str, torch.Tensor]) -> str:
if isinstance(lora, str):
tensors = load_state_dict_in_safetensors(lora, device="cpu")
else:
tensors = lora
for k in tensors.keys():
if "lora_unet_double_blocks_" in k or "lora_unet_single_blocks" in k:
return "comfyui"
elif "mlp_fc" in k or "mlp_context_fc1" in k:
return "svdquant"
elif "double_blocks." in k or "single_blocks." in k:
return "xlab"
elif "transformer." in k:
return "diffusers"
raise ValueError("Unknown format, please provide the format explicitly.")
......@@ -108,13 +108,12 @@ class EmbedND(nn.Module):
return emb.unsqueeze(1)
def load_quantized_module(path: str, device: str | torch.device = "cuda") -> QuantizedFluxModel:
def load_quantized_module(path: str, device: str | torch.device = "cuda", use_fp4: bool = False) -> QuantizedFluxModel:
device = torch.device(device)
assert device.type == "cuda"
m = QuantizedFluxModel()
cutils.disable_memory_auto_release()
m.init(True, 0 if device.index is None else device.index)
m.init(use_fp4, True, 0 if device.index is None else device.index)
m.load(path)
return m
......@@ -153,8 +152,10 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
@utils.validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs):
device = kwargs.get("device", "cuda")
precision = kwargs.get("precision", "int4")
assert precision in ["int4", "fp4"]
transformer, transformer_block_path = cls._build_model(pretrained_model_name_or_path, **kwargs)
m = load_quantized_module(transformer_block_path, device=device)
m = load_quantized_module(transformer_block_path, device=device, use_fp4=precision == "fp4")
transformer.inject_quantized_module(m, device)
return transformer
......
......@@ -124,9 +124,13 @@ class NunchakuSanaTransformer2DModel(SanaTransformer2DModel, NunchakuModelLoader
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs):
device = kwargs.get("device", "cuda")
pag_layers = kwargs.get("pag_layers", [])
precision = kwargs.get("precision", "int4")
assert precision in ["int4", "fp4"]
transformer, transformer_block_path = cls._build_model(pretrained_model_name_or_path, **kwargs)
transformer.config["num_layers"] = transformer.original_num_layers
m = load_quantized_module(transformer, transformer_block_path, device=device, pag_layers=pag_layers)
m = load_quantized_module(
transformer, transformer_block_path, device=device, pag_layers=pag_layers, use_fp4=precision == "fp4"
)
transformer.inject_quantized_module(m, device)
return transformer
......@@ -140,6 +144,7 @@ def load_quantized_module(
path: str,
device: str | torch.device = "cuda",
pag_layers: int | list[int] | None = None,
use_fp4: bool = False,
) -> QuantizedSanaModel:
if pag_layers is None:
pag_layers = []
......@@ -150,7 +155,7 @@ def load_quantized_module(
m = QuantizedSanaModel()
cutils.disable_memory_auto_release()
m.init(net.config, pag_layers, net.dtype == torch.bfloat16, 0 if device.index is None else device.index)
m.init(net.config, pag_layers, use_fp4, net.dtype == torch.bfloat16, 0 if device.index is None else device.index)
m.load(path)
return m
......
......@@ -4,7 +4,13 @@ from diffusers import FluxPipeline
from .models.transformer_flux import NunchakuFluxTransformer2dModel
if __name__ == "__main__":
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-schnell")
capability = torch.cuda.get_device_capability(0)
sm = f"{capability[0]}{capability[1]}"
precision = "fp4" if sm == "120" else "int4"
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/svdq-{precision}-flux.1-schnell", precision=precision
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
......
import os
import re
import subprocess
import sys
import setuptools
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import torch
from packaging import version as packaging_version
from torch.utils.cpp_extension import BuildExtension, CUDA_HOME, CUDAExtension
class CustomBuildExtension(BuildExtension):
def build_extensions(self):
......@@ -16,10 +22,49 @@ class CustomBuildExtension(BuildExtension):
ext.extra_compile_args["cxx"] += ext.extra_compile_args["gcc"]
super().build_extensions()
def get_sm_targets() -> list[str]:
nvcc_path = os.path.join(CUDA_HOME, "bin/nvcc") if CUDA_HOME else "nvcc"
try:
nvcc_output = subprocess.check_output([nvcc_path, "--version"]).decode()
match = re.search(r"release (\d+\.\d+), V(\d+\.\d+\.\d+)", nvcc_output)
if match:
nvcc_version = match.group(2)
else:
raise Exception("nvcc version not found")
print(f"Found nvcc version: {nvcc_version}")
except:
raise Exception("nvcc not found")
support_sm120 = packaging_version.parse(nvcc_version) >= packaging_version.parse("12.8")
install_mode = os.getenv("NUNCHAKU_INSTALL_MODE", "FAST")
if install_mode == "FAST":
ret = []
for i in range(torch.cuda.device_count()):
capability = torch.cuda.get_device_capability(i)
sm = f"{capability[0]}{capability[1]}"
if sm == "120" and support_sm120:
sm = "120a"
assert sm in ["80", "86", "89", "120a"], f"Unsupported SM {sm}"
if sm not in ret:
ret.append(sm)
else:
assert install_mode == "ALL"
ret = ["80", "86", "89"]
if support_sm120:
ret.append("120a")
return ret
if __name__ == "__main__":
fp = open("nunchaku/__version__.py", "r").read()
version = eval(fp.strip().split()[-1])
torch_version = torch.__version__.split("+")[0]
torch_major_minor_version = ".".join(torch_version.split(".")[:2])
version = version + "+torch" + torch_major_minor_version
ROOT_DIR = os.path.dirname(__file__)
INCLUDE_DIRS = [
......@@ -54,12 +99,6 @@ if __name__ == "__main__":
NVCC_FLAGS = [
"-DENABLE_BF16=1",
"-DBUILD_NUNCHAKU=1",
"-gencode",
"arch=compute_86,code=sm_86",
"-gencode",
"arch=compute_89,code=sm_89",
# "-gencode",
# "arch=compute_89,code=sm_120a",
"-g",
"-std=c++20",
"-UNDEBUG",
......@@ -74,13 +113,23 @@ if __name__ == "__main__":
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"--threads=2",
"--threads=3",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--generate-line-info",
"--ptxas-options=--allow-expensive-optimizations=true",
]
# https://github.com/NVIDIA/cutlass/pull/1479#issuecomment-2052300487
if os.getenv("NUNCHAKU_BUILD_WHEELS", "0") == "0":
NVCC_FLAGS.append("--generate-line-info")
sm_targets = get_sm_targets()
print(f"Detected SM targets: {sm_targets}", file=sys.stderr)
assert len(sm_targets) > 0, "No SM targets found"
for target in sm_targets:
NVCC_FLAGS += ["-gencode", f"arch=compute_{target},code=sm_{target}"]
NVCC_MSVC_FLAGS = ["-Xcompiler", "/Zc:__cplusplus"]
nunchaku_extension = CUDAExtension(
......
......@@ -259,19 +259,19 @@ void Attention::setForceFP16(Module *module, bool value) {
});
}
FluxSingleTransformerBlock::FluxSingleTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, int mlp_ratio, Tensor::ScalarType dtype, Device device) :
FluxSingleTransformerBlock::FluxSingleTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, int mlp_ratio, bool use_fp4, Tensor::ScalarType dtype, Device device) :
dim(dim),
dim_head(attention_head_dim / num_attention_heads),
num_heads(num_attention_heads),
mlp_hidden_dim(dim * mlp_ratio),
norm(dim, dtype, device),
mlp_fc1(dim, mlp_hidden_dim, true, dtype, device),
mlp_fc2(mlp_hidden_dim, dim, true, dtype, device),
qkv_proj(dim, dim * 3, true, dtype, device),
mlp_fc1(dim, mlp_hidden_dim, true, use_fp4, dtype, device),
mlp_fc2(mlp_hidden_dim, dim, true, use_fp4, dtype, device),
qkv_proj(dim, dim * 3, true, use_fp4, dtype, device),
norm_q(dim_head, 1e-6, false, dtype, device),
norm_k(dim_head, 1e-6, false, dtype, device),
attn(num_attention_heads, attention_head_dim / num_attention_heads, device),
out_proj(dim, dim, true, dtype, device)
out_proj(dim, dim, true, use_fp4, dtype, device)
{
registerChildren
(norm, "norm")
......@@ -327,28 +327,28 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
return hidden_states;
}
JointTransformerBlock::JointTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, bool context_pre_only, Tensor::ScalarType dtype, Device device) :
JointTransformerBlock::JointTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, bool context_pre_only, bool use_fp4, Tensor::ScalarType dtype, Device device) :
dim(dim),
dim_head(attention_head_dim / num_attention_heads),
num_heads(num_attention_heads),
context_pre_only(context_pre_only),
norm1(dim, false, dtype, device),
norm1_context(dim, context_pre_only, dtype, device),
qkv_proj(dim, dim * 3, true, dtype, device),
qkv_proj_context(dim, dim * 3, true, dtype, device),
qkv_proj(dim, dim * 3, true, use_fp4, dtype, device),
qkv_proj_context(dim, dim * 3, true, use_fp4, dtype, device),
norm_q(dim_head, 1e-6, false, dtype, device),
norm_k(dim_head, 1e-6, false, dtype, device),
norm_added_q(dim_head, 1e-6, false, dtype, device),
norm_added_k(dim_head, 1e-6, false, dtype, device),
attn(num_attention_heads, attention_head_dim / num_attention_heads, device),
out_proj(dim, dim, true, dtype, device),
out_proj_context(dim, dim, true, dtype, device),
out_proj(dim, dim, true, use_fp4, dtype, device),
out_proj_context(dim, dim, true, use_fp4, dtype, device),
norm2(dim, 1e-6, false, dtype, device),
norm2_context(dim, 1e-6, false, dtype, device),
mlp_fc1(dim, dim * 4, true, dtype, device),
mlp_fc2(dim * 4, dim, true, dtype, device),
mlp_context_fc1(dim, dim * 4, true, dtype, device),
mlp_context_fc2(dim * 4, dim, true, dtype, device)
mlp_fc1(dim, dim * 4, true, use_fp4, dtype, device),
mlp_fc2(dim * 4, dim, true, use_fp4, dtype, device),
mlp_context_fc1(dim, dim * 4, true, use_fp4, dtype, device),
mlp_context_fc2(dim * 4, dim, true, use_fp4, dtype, device)
{
registerChildren
(norm1, "norm1")
......@@ -607,13 +607,13 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
return { hidden_states, encoder_hidden_states };
}
FluxModel::FluxModel(Tensor::ScalarType dtype, Device device) {
FluxModel::FluxModel(bool use_fp4, Tensor::ScalarType dtype, Device device) {
for (int i = 0; i < 19; i++) {
transformer_blocks.push_back(std::make_unique<JointTransformerBlock>(3072, 24, 3072, false, dtype, device));
transformer_blocks.push_back(std::make_unique<JointTransformerBlock>(3072, 24, 3072, false, use_fp4, dtype, device));
registerChildren(*transformer_blocks.back(), format("transformer_blocks.{}", i));
}
for (int i = 0; i < 38; i++) {
single_transformer_blocks.push_back(std::make_unique<FluxSingleTransformerBlock>(3072, 24, 3072, 4, dtype, Device::cuda()));
single_transformer_blocks.push_back(std::make_unique<FluxSingleTransformerBlock>(3072, 24, 3072, 4, use_fp4, dtype, Device::cuda()));
registerChildren(*single_transformer_blocks.back(), format("single_transformer_blocks.{}", i));
}
}
......
......@@ -77,7 +77,7 @@ public:
static constexpr bool USE_4BIT = true;
using GEMM = std::conditional_t<USE_4BIT, GEMM_W4A4, GEMM_W8A8>;
FluxSingleTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, int mlp_ratio, Tensor::ScalarType dtype, Device device);
FluxSingleTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, int mlp_ratio, bool use_fp4, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor hidden_states, Tensor temb, Tensor rotary_emb);
public:
......@@ -101,7 +101,7 @@ public:
static constexpr bool USE_4BIT = true;
using GEMM = std::conditional_t<USE_4BIT, GEMM_W4A4, GEMM_W8A8>;
JointTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, bool context_pre_only, Tensor::ScalarType dtype, Device device);
JointTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, bool context_pre_only, bool use_fp4, Tensor::ScalarType dtype, Device device);
std::tuple<Tensor, Tensor> forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb, Tensor rotary_emb_context, float sparsityRatio);
public:
......@@ -128,7 +128,7 @@ private:
class FluxModel : public Module {
public:
FluxModel(Tensor::ScalarType dtype, Device device);
FluxModel(bool use_fp4, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb_img, Tensor rotary_emb_context, Tensor rotary_emb_single);
public:
......
......@@ -96,23 +96,33 @@ Tensor GEMV_AWQ::forward(Tensor x) {
#define NO_LORA_FUSION 0
GEMM_W4A4::GEMM_W4A4(int in_features, int out_features, bool bias, Tensor::ScalarType dtype, Device device) :
GEMM_W4A4::GEMM_W4A4(int in_features, int out_features, bool bias, bool use_fp4, Tensor::ScalarType dtype, Device device) :
in_features(in_features), out_features(out_features),
in_features_pad(ceilDiv(in_features, 128) * 128), out_features_pad(ceilDiv(out_features, 128) * 128),
use_fp4(use_fp4),
lora_rank(0), dtype(dtype)
{
this->qweight = Tensor::allocate({out_features_pad, in_features_pad / 2}, Tensor::INT8, device, true);
this->wscales = Tensor::allocate({in_features_pad / 64, out_features_pad}, dtype, device, true);
if (use_fp4) {
this->wscales = Tensor::allocate({in_features_pad / 16, out_features_pad}, Tensor::FP8_E4M3, device, true);
} else {
this->wscales = Tensor::allocate({in_features_pad / 64, out_features_pad}, dtype, device, true);
}
this->bias = bias ? Tensor::allocate({out_features_pad}, dtype, device, true) : Tensor{};
this->lora_down = Tensor::allocate({in_features_pad, lora_rank}, dtype, device, true);
this->lora_up = Tensor::allocate({out_features_pad, lora_rank}, dtype, device, true);
// TODO: smooth factor in FC1+FC2 fusion
// TODO: smooth factor in non-Lora fusion
this->smooth = Tensor::allocate({in_features_pad}, dtype, device, true);
// FIXME: reset wtscale and wcscales to default values when reloading the weights
this->wtscale = Tensor::allocate({1}, Tensor::FP32, Device::cpu(), true);
*this->wtscale.data_ptr<float>() = 1.0f;
this->wcscales = Tensor::allocate({0}, dtype, device, true);
registerParams
(qweight, "qweight")
(wscales, "wscales")
......@@ -120,6 +130,8 @@ GEMM_W4A4::GEMM_W4A4(int in_features, int out_features, bool bias, Tensor::Scala
(lora_down, "lora_down", ParamFlags::Optional)
(lora_up, "lora_up", ParamFlags::Optional)
(smooth, "smooth")
(wtscale, "wtscale", ParamFlags::Optional)
(wcscales, "wcscales", ParamFlags::Optional)
;
#if NO_LORA_FUSION
......@@ -137,6 +149,21 @@ void GEMM_W4A4::loadParam(std::string key, Tensor &dst, Tensor src) {
} else {
dst.copy_(src);
}
} else if (key == "wcscales") {
assert(src.ndims() == 1);
assert(src.shape[0] == out_features_pad);
dst = src.copy(this->qweight.device());
} else if (key == "wtscale") {
assert(src.numel() == 1);
if (src.dtype() == Tensor::BF16) {
*dst.data_ptr<float>() = float(*src.data_ptr<__nv_bfloat16>());
} else if (src.dtype() == Tensor::FP16) {
*dst.data_ptr<float>() = float(*src.data_ptr<half>());
} else if (src.dtype() == Tensor::FP32) {
dst.copy_(src);
} else {
assert(false);
}
} else {
Module::loadParam(key, dst, src);
}
......@@ -167,7 +194,10 @@ void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor
debug("gemm.nolora.out", out);
#endif
kernels::gemm_w4a4(qact.act, qweight, out, {}, qact.ascales, wscales, {}, pool, qact.lora_act, this->lora_up, {}, {}, norm_q, norm_k, rotary_emb, this->bias, {}, {}, {}, qact.is_unsigned, this->lora_scales, false);
kernels::gemm_w4a4(
qact.act, qweight, out, {}, qact.ascales, wscales, {}, pool, qact.lora_act, this->lora_up, {}, {}, norm_q, norm_k, rotary_emb, this->bias, {}, {}, {}, qact.is_unsigned, this->lora_scales, false,
use_fp4, *this->wtscale.data_ptr<float>(), wcscales.numel() > 0 ? wcscales: Tensor{}
);
debug("gemm.out", out);
#else
......@@ -215,9 +245,13 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
out = Tensor::allocate(shape, dtype, qweight.device());
} else {
qout.act = Tensor::allocate({M, out_features_pad / 2}, Tensor::INT8, qweight.device());
qout.ascales = Tensor::allocate({out_features_pad / 64, M}, dtype, qweight.device());
if (use_fp4) {
qout.ascales = Tensor::allocate({out_features_pad / 16, M}, Tensor::FP8_E4M3, qweight.device());
} else {
qout.ascales = Tensor::allocate({out_features_pad / 64, M}, dtype, qweight.device());
}
qout.lora_act = Tensor::allocate({M, lora_rank}, Tensor::FP32, qweight.device());
qout.is_unsigned = true;
qout.is_unsigned = !use_fp4;
qout.actShape = qact.actShape;
next_lora = nextGEMM->lora_down;
......@@ -241,7 +275,10 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
}
#endif
kernels::gemm_w4a4(qact.act, qweight, out, qout.act, qact.ascales, wscales, qout.ascales, {}, qact.lora_act, this->lora_up, next_lora, qout.lora_act, {}, {}, {}, this->bias, next_smooth, {}, {}, qact.is_unsigned, this->lora_scales, fuse == FuseOptions::SILU);
kernels::gemm_w4a4(
qact.act, qweight, out, qout.act, qact.ascales, wscales, qout.ascales, {}, qact.lora_act, this->lora_up, next_lora, qout.lora_act, {}, {}, {}, this->bias, next_smooth, {}, {}, qact.is_unsigned, this->lora_scales, fuse == FuseOptions::SILU,
use_fp4, *this->wtscale.data_ptr<float>(), wcscales.numel() > 0 ? wcscales: Tensor{}
);
if (fuse == FuseOptions::EMPTY || fuse == FuseOptions::SILU) {
debug("gemm.out", out);
......@@ -327,7 +364,11 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
QuantizedActivation qact;
qact.act = Tensor::allocate({M, in_features_pad / 2}, Tensor::INT8, qweight.device());
qact.ascales = Tensor::allocate({in_features_pad / 64, M}, dtype, qweight.device());
if (use_fp4) {
qact.ascales = Tensor::allocate({in_features_pad / 16, M}, Tensor::FP8_E4M3, qweight.device());
} else {
qact.ascales = Tensor::allocate({in_features_pad / 64, M}, dtype, qweight.device());
}
qact.lora_act = Tensor::allocate({M, lora_rank}, Tensor::FP32, qweight.device());
qact.is_unsigned = false;
qact.actShape = x.shape.dataExtent;
......@@ -336,7 +377,7 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
debug("quantize.x", x);
debug("quantize.smooth", this->smooth);
kernels::quantize_w4a4_act_fuse_lora(x, qact.act, qact.ascales, this->lora_down, qact.lora_act, this->smooth, fuse_glu);
kernels::quantize_w4a4_act_fuse_lora(x, qact.act, qact.ascales, this->lora_down, qact.lora_act, this->smooth, fuse_glu, use_fp4);
debug("quantize.qact", qact.act);
debug("quantize.ascales", qact.ascales);
......
......@@ -64,7 +64,7 @@ public:
};
public:
GEMM_W4A4(int in_features, int out_features, bool bias, Tensor::ScalarType dtype, Device device);
GEMM_W4A4(int in_features, int out_features, bool bias, bool use_fp4, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor x);
Tensor forward_silu(Tensor x);
std::variant<Tensor, QuantizedActivation> forward(Tensor x, FuseOptions fuse, GEMM_W4A4 *nextGEMM = nullptr);
......@@ -80,6 +80,7 @@ public:
const int out_features;
const int in_features_pad;
const int out_features_pad;
const bool use_fp4;
int lora_rank;
std::vector<float> lora_scales; // every 16 ranks share a scale
......@@ -99,6 +100,9 @@ public:
Tensor smooth;
Tensor wtscale;
Tensor wcscales;
cublasHandle_t handle;
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment