Commit b1b44398 authored by Samuel Tesfai's avatar Samuel Tesfai
Browse files

Fixing merges

parents 004e4e31 4b9c2e03
...@@ -23,6 +23,5 @@ image = pipe( ...@@ -23,6 +23,5 @@ image = pipe(
guidance_scale=5.0, guidance_scale=5.0,
pag_scale=2.0, pag_scale=2.0,
num_inference_steps=20, num_inference_steps=20,
generator=torch.Generator().manual_seed(42),
).images[0] ).images[0]
image.save("sana_1600m_pag.png") image.save("sana_1600m_pag.png")
__version__ = "0.0.2beta6" __version__ = "0.1.3"
...@@ -9,9 +9,9 @@ ...@@ -9,9 +9,9 @@
class QuantizedFluxModel : public ModuleWrapper<FluxModel> { // : public torch::CustomClassHolder { class QuantizedFluxModel : public ModuleWrapper<FluxModel> { // : public torch::CustomClassHolder {
public: public:
void init(bool bf16, int8_t deviceId) { void init(bool use_fp4, bool bf16, int8_t deviceId) {
spdlog::info("Initializing QuantizedFluxModel"); spdlog::info("Initializing QuantizedFluxModel");
net = std::make_unique<FluxModel>(bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId)); net = std::make_unique<FluxModel>(use_fp4, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
} }
torch::Tensor forward( torch::Tensor forward(
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
class QuantizedGEMM : public ModuleWrapper<GEMM_W4A4> { class QuantizedGEMM : public ModuleWrapper<GEMM_W4A4> {
public: public:
void init(int64_t in_features, int64_t out_features, bool bias, bool bf16, int8_t deviceId) { void init(int64_t in_features, int64_t out_features, bool bias, bool use_fp4, bool bf16, int8_t deviceId) {
spdlog::info("Initializing QuantizedGEMM"); spdlog::info("Initializing QuantizedGEMM");
size_t val = 0; size_t val = 0;
...@@ -16,7 +16,7 @@ public: ...@@ -16,7 +16,7 @@ public:
checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize)); checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
spdlog::debug("Stack={}", val); spdlog::debug("Stack={}", val);
net = std::make_unique<GEMM_W4A4>((int)in_features, (int)out_features, bias, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId)); net = std::make_unique<GEMM_W4A4>((int)in_features, (int)out_features, bias, use_fp4, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
} }
torch::Tensor forward(torch::Tensor x) { torch::Tensor forward(torch::Tensor x) {
......
...@@ -29,7 +29,10 @@ namespace nunchaku::ops { ...@@ -29,7 +29,10 @@ namespace nunchaku::ops {
std::optional<torch::Tensor> out_linearattn,// linear [B, (M), N / 3] std::optional<torch::Tensor> out_linearattn,// linear [B, (M), N / 3]
bool act_unsigned, bool act_unsigned,
std::vector<float> lora_scales, std::vector<float> lora_scales,
bool fuse_silu bool fuse_silu,
bool fp4,
float alpha,
std::optional<torch::Tensor> wcscales
) { ) {
spdlog::trace("running gemm_w4a4: "); spdlog::trace("running gemm_w4a4: ");
...@@ -64,7 +67,10 @@ namespace nunchaku::ops { ...@@ -64,7 +67,10 @@ namespace nunchaku::ops {
getTensor(out_linearattn), getTensor(out_linearattn),
act_unsigned, act_unsigned,
lora_scales, lora_scales,
fuse_silu fuse_silu,
fp4,
alpha,
getTensor(wcscales)
); );
Tensor::synchronizeDevice(); Tensor::synchronizeDevice();
} }
......
...@@ -14,6 +14,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -14,6 +14,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::class_<QuantizedFluxModel>(m, "QuantizedFluxModel") py::class_<QuantizedFluxModel>(m, "QuantizedFluxModel")
.def(py::init<>()) .def(py::init<>())
.def("init", &QuantizedFluxModel::init, .def("init", &QuantizedFluxModel::init,
py::arg("use_fp4"),
py::arg("bf16"), py::arg("bf16"),
py::arg("deviceId") py::arg("deviceId")
) )
...@@ -36,6 +37,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -36,6 +37,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("init", &QuantizedSanaModel::init, .def("init", &QuantizedSanaModel::init,
py::arg("config"), py::arg("config"),
py::arg("pag_layers"), py::arg("pag_layers"),
py::arg("use_fp4"),
py::arg("bf16"), py::arg("bf16"),
py::arg("deviceId") py::arg("deviceId")
) )
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
class QuantizedSanaModel : public ModuleWrapper<SanaModel> { class QuantizedSanaModel : public ModuleWrapper<SanaModel> {
public: public:
void init(pybind11::dict config, std::vector<int> pag_layers, bool bf16, int8_t deviceId) { void init(pybind11::dict config, std::vector<int> pag_layers, bool use_fp4, bool bf16, int8_t deviceId) {
spdlog::info("Initializing QuantizedSanaModel"); spdlog::info("Initializing QuantizedSanaModel");
SanaConfig cfg{ SanaConfig cfg{
.num_layers = config["num_layers"].cast<int>(), .num_layers = config["num_layers"].cast<int>(),
...@@ -17,6 +17,7 @@ public: ...@@ -17,6 +17,7 @@ public:
.num_cross_attention_heads = config["num_cross_attention_heads"].cast<int>(), .num_cross_attention_heads = config["num_cross_attention_heads"].cast<int>(),
.expand_ratio = config["mlp_ratio"].cast<double>(), .expand_ratio = config["mlp_ratio"].cast<double>(),
.pag_layers = pag_layers, .pag_layers = pag_layers,
.use_fp4 = use_fp4,
}; };
net = std::make_unique<SanaModel>(cfg, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId)); net = std::make_unique<SanaModel>(cfg, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
} }
......
# convert the comfyui lora to diffusers format # convert the comfyui lora to diffusers format
import argparse
import os import os
import torch import torch
...@@ -8,7 +9,7 @@ from ...utils import load_state_dict_in_safetensors ...@@ -8,7 +9,7 @@ from ...utils import load_state_dict_in_safetensors
def comfyui2diffusers( def comfyui2diffusers(
input_lora: str | dict[str, torch.Tensor], output_path: str | None = None input_lora: str | dict[str, torch.Tensor], output_path: str | None = None, min_rank: int | None = None
) -> dict[str, torch.Tensor]: ) -> dict[str, torch.Tensor]:
if isinstance(input_lora, str): if isinstance(input_lora, str):
tensors = load_state_dict_in_safetensors(input_lora, device="cpu") tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
...@@ -16,7 +17,7 @@ def comfyui2diffusers( ...@@ -16,7 +17,7 @@ def comfyui2diffusers(
tensors = input_lora tensors = input_lora
new_tensors = {} new_tensors = {}
max_rank = 0
for k, v in tensors.items(): for k, v in tensors.items():
if "alpha" in k: if "alpha" in k:
continue continue
...@@ -29,7 +30,10 @@ def comfyui2diffusers( ...@@ -29,7 +30,10 @@ def comfyui2diffusers(
# Copy the tensor # Copy the tensor
new_k = new_k.replace("_img_attn_qkv", f".attn.to_{p}") new_k = new_k.replace("_img_attn_qkv", f".attn.to_{p}")
new_k = new_k.replace("_txt_attn_qkv", f".attn.add_{p}_proj") new_k = new_k.replace("_txt_attn_qkv", f".attn.add_{p}_proj")
new_tensors[new_k] = v.clone() rank = v.shape[0]
alpha = tensors[k.replace("lora_down.weight", "alpha")]
new_tensors[new_k] = v.clone() * alpha / rank
max_rank = max(max_rank, rank)
else: else:
assert "lora_B" in new_k assert "lora_B" in new_k
assert v.shape[0] % 3 == 0 assert v.shape[0] % 3 == 0
...@@ -58,7 +62,10 @@ def comfyui2diffusers( ...@@ -58,7 +62,10 @@ def comfyui2diffusers(
new_k1 = new_k.replace("_linear1", ".proj_mlp") new_k1 = new_k.replace("_linear1", ".proj_mlp")
else: else:
new_k1 = new_k.replace("_linear1", f".attn.to_{p}") new_k1 = new_k.replace("_linear1", f".attn.to_{p}")
new_tensors[new_k1] = v.clone() rank = v.shape[0]
alpha = tensors[k.replace("lora_down.weight", "alpha")]
new_tensors[new_k1] = v.clone() * alpha / rank
max_rank = max(max_rank, rank)
else: else:
if p == "i": if p == "i":
new_k1 = new_k.replace("_linear1", ".proj_mlp") new_k1 = new_k.replace("_linear1", ".proj_mlp")
...@@ -70,10 +77,43 @@ def comfyui2diffusers( ...@@ -70,10 +77,43 @@ def comfyui2diffusers(
else: else:
new_k = new_k.replace("_linear2", ".proj_out") new_k = new_k.replace("_linear2", ".proj_out")
new_k = new_k.replace("_modulation_lin", ".norm.linear") new_k = new_k.replace("_modulation_lin", ".norm.linear")
if "lora_down" in k:
rank = v.shape[0]
alpha = tensors[k.replace("lora_down.weight", "alpha")]
v = v * alpha / rank
max_rank = max(max_rank, rank)
new_tensors[new_k] = v new_tensors[new_k] = v
if min_rank is not None:
for k in new_tensors.keys():
v = new_tensors[k]
if "lora_A" in k:
rank = v.shape[0]
if rank < min_rank:
new_v = torch.zeros(min_rank, v.shape[1], dtype=v.dtype, device=v.device)
new_v[:rank] = v
new_tensors[k] = new_v
else:
assert "lora_B" in k
rank = v.shape[1]
if rank < min_rank:
new_v = torch.zeros(v.shape[0], min_rank, dtype=v.dtype, device=v.device)
new_v[:, :rank] = v
new_tensors[k] = new_v
if output_path is not None: if output_path is not None:
output_dir = os.path.dirname(os.path.abspath(output_path)) output_dir = os.path.dirname(os.path.abspath(output_path))
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
save_file(new_tensors, output_path) save_file(new_tensors, output_path)
return new_tensors return new_tensors
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input-path", type=str, required=True, help="path to the comfyui lora safetensor file")
parser.add_argument(
"-o", "--output-path", type=str, required=True, help="path to the output diffusers safetensor file"
)
parser.add_argument("--min-rank", type=int, default=None, help="minimum rank for the LoRA weights")
args = parser.parse_args()
comfyui2diffusers(args.input_path, args.output_path, min_rank=args.min_rank)
...@@ -6,6 +6,7 @@ from safetensors.torch import save_file ...@@ -6,6 +6,7 @@ from safetensors.torch import save_file
from .comfyui_converter import comfyui2diffusers from .comfyui_converter import comfyui2diffusers
from .diffusers_converter import convert_to_nunchaku_flux_lowrank_dict from .diffusers_converter import convert_to_nunchaku_flux_lowrank_dict
from .utils import detect_format
from .xlab_converter import xlab2diffusers from .xlab_converter import xlab2diffusers
from ...utils import filter_state_dict, load_state_dict_in_safetensors from ...utils import filter_state_dict, load_state_dict_in_safetensors
...@@ -21,8 +22,8 @@ if __name__ == "__main__": ...@@ -21,8 +22,8 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--lora-format", "--lora-format",
type=str, type=str,
default="diffusers", default="auto",
choices=["comfyui", "diffusers", "xlab"], choices=["auto", "comfyui", "diffusers", "xlab"],
help="format of the LoRA weights", help="format of the LoRA weights",
) )
parser.add_argument("--output-root", type=str, default="", help="root to the output safetensor file") parser.add_argument("--output-root", type=str, default="", help="root to the output safetensor file")
...@@ -37,8 +38,8 @@ if __name__ == "__main__": ...@@ -37,8 +38,8 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
if not args.output_root: if not args.output_root:
# output to the parent directory of the quantized model safetensor file # output to the parent directory of the lora safetensor file
args.output_root = os.path.dirname(args.quant_path) args.output_root = os.path.dirname(args.lora_path)
if args.lora_name is None: if args.lora_name is None:
base_name = os.path.basename(args.lora_path) base_name = os.path.basename(args.lora_path)
lora_name = base_name.rsplit(".", 1)[0] lora_name = base_name.rsplit(".", 1)[0]
...@@ -53,6 +54,13 @@ if __name__ == "__main__": ...@@ -53,6 +54,13 @@ if __name__ == "__main__":
orig_state_dict = load_state_dict_in_safetensors(args.quant_path) orig_state_dict = load_state_dict_in_safetensors(args.quant_path)
lora_format = args.lora_format lora_format = args.lora_format
if lora_format == "auto":
lora_format = detect_format(args.lora_path)
print(f"Detected LoRA format: {lora_format}")
if lora_format == "svdquant":
print("Already in SVDQuant format, no conversion needed.")
exit(0)
if lora_format == "diffusers": if lora_format == "diffusers":
extra_lora_dict = load_state_dict_in_safetensors(args.lora_path) extra_lora_dict = load_state_dict_in_safetensors(args.lora_path)
else: else:
......
# convert the diffusers lora to nunchaku format # convert the diffusers lora to nunchaku format
"""Convert LoRA weights to Nunchaku format.""" """Convert LoRA weights to Nunchaku format."""
import typing as tp import typing as tp
import torch import torch
...@@ -215,8 +214,8 @@ def convert_to_nunchaku_transformer_block_lowrank_dict( # noqa: C901 ...@@ -215,8 +214,8 @@ def convert_to_nunchaku_transformer_block_lowrank_dict( # noqa: C901
update_state_dict( update_state_dict(
converted, converted,
{ {
"lora_down": lora[0], "lora_down": pad(lora[0], divisor=16, dim=0),
"lora_up": reorder_adanorm_lora_up(lora[1], splits=3), "lora_up": pad(reorder_adanorm_lora_up(lora[1], splits=3), divisor=16, dim=1),
}, },
prefix=converted_local_name, prefix=converted_local_name,
) )
...@@ -224,8 +223,8 @@ def convert_to_nunchaku_transformer_block_lowrank_dict( # noqa: C901 ...@@ -224,8 +223,8 @@ def convert_to_nunchaku_transformer_block_lowrank_dict( # noqa: C901
update_state_dict( update_state_dict(
converted, converted,
{ {
"lora_down": lora[0], "lora_down": pad(lora[0], divisor=16, dim=0),
"lora_up": reorder_adanorm_lora_up(lora[1], splits=6), "lora_up": pad(reorder_adanorm_lora_up(lora[1], splits=6), divisor=16, dim=1),
}, },
prefix=converted_local_name, prefix=converted_local_name,
) )
...@@ -263,6 +262,22 @@ def convert_to_nunchaku_flux_single_transformer_block_lowrank_dict( ...@@ -263,6 +262,22 @@ def convert_to_nunchaku_flux_single_transformer_block_lowrank_dict(
extra_lora_dict.pop(f"{candidate_block_name}.proj_out.lora_A.weight") extra_lora_dict.pop(f"{candidate_block_name}.proj_out.lora_A.weight")
extra_lora_dict.pop(f"{candidate_block_name}.proj_out.lora_B.weight") extra_lora_dict.pop(f"{candidate_block_name}.proj_out.lora_B.weight")
for component in ["lora_A", "lora_B"]:
fc1_k = f"{candidate_block_name}.proj_mlp.{component}.weight"
fc2_k = f"{candidate_block_name}.proj_out.linears.1.{component}.weight"
fc1_v = extra_lora_dict[fc1_k]
fc2_v = extra_lora_dict[fc2_k]
dim = 0 if "lora_A" in fc1_k else 1
fc1_rank = fc1_v.shape[dim]
fc2_rank = fc2_v.shape[dim]
if fc1_rank != fc2_rank:
rank = max(fc1_rank, fc2_rank)
if fc1_rank < rank:
extra_lora_dict[fc1_k] = pad(fc1_v, divisor=rank, dim=dim)
if fc2_rank < rank:
extra_lora_dict[fc2_k] = pad(fc2_v, divisor=rank, dim=dim)
return convert_to_nunchaku_transformer_block_lowrank_dict( return convert_to_nunchaku_transformer_block_lowrank_dict(
orig_state_dict=orig_state_dict, orig_state_dict=orig_state_dict,
extra_lora_dict=extra_lora_dict, extra_lora_dict=extra_lora_dict,
...@@ -347,6 +362,28 @@ def convert_to_nunchaku_flux_lowrank_dict( ...@@ -347,6 +362,28 @@ def convert_to_nunchaku_flux_lowrank_dict(
else: else:
extra_lora_dict = filter_state_dict(lora, filter_prefix="transformer.") extra_lora_dict = filter_state_dict(lora, filter_prefix="transformer.")
for k in extra_lora_dict.keys():
fc1_k = k
if "ff.net.0.proj" in k:
fc2_k = k.replace("ff.net.0.proj", "ff.net.2")
elif "ff_context.net.0.proj" in k:
fc2_k = k.replace("ff_context.net.0.proj", "ff_context.net.2")
else:
continue
assert fc2_k in extra_lora_dict
fc1_v = extra_lora_dict[fc1_k]
fc2_v = extra_lora_dict[fc2_k]
dim = 0 if "lora_A" in fc1_k else 1
fc1_rank = fc1_v.shape[dim]
fc2_rank = fc2_v.shape[dim]
if fc1_rank != fc2_rank:
rank = max(fc1_rank, fc2_rank)
if fc1_rank < rank:
extra_lora_dict[fc1_k] = pad(fc1_v, divisor=rank, dim=dim)
if fc2_rank < rank:
extra_lora_dict[fc2_k] = pad(fc2_v, divisor=rank, dim=dim)
block_names: set[str] = set() block_names: set[str] = set()
for param_name in orig_state_dict.keys(): for param_name in orig_state_dict.keys():
if param_name.startswith(("transformer_blocks.", "single_transformer_blocks.")): if param_name.startswith(("transformer_blocks.", "single_transformer_blocks.")):
...@@ -370,4 +407,5 @@ def convert_to_nunchaku_flux_lowrank_dict( ...@@ -370,4 +407,5 @@ def convert_to_nunchaku_flux_lowrank_dict(
), ),
prefix=block_name, prefix=block_name,
) )
return converted return converted
import torch
from ...utils import load_state_dict_in_safetensors
def detect_format(lora: str | dict[str, torch.Tensor]) -> str:
if isinstance(lora, str):
tensors = load_state_dict_in_safetensors(lora, device="cpu")
else:
tensors = lora
for k in tensors.keys():
if "lora_unet_double_blocks_" in k or "lora_unet_single_blocks" in k:
return "comfyui"
elif "mlp_fc" in k or "mlp_context_fc1" in k:
return "svdquant"
elif "double_blocks." in k or "single_blocks." in k:
return "xlab"
elif "transformer." in k:
return "diffusers"
raise ValueError("Unknown format, please provide the format explicitly.")
...@@ -108,13 +108,12 @@ class EmbedND(nn.Module): ...@@ -108,13 +108,12 @@ class EmbedND(nn.Module):
return emb.unsqueeze(1) return emb.unsqueeze(1)
def load_quantized_module(path: str, device: str | torch.device = "cuda") -> QuantizedFluxModel: def load_quantized_module(path: str, device: str | torch.device = "cuda", use_fp4: bool = False) -> QuantizedFluxModel:
device = torch.device(device) device = torch.device(device)
assert device.type == "cuda" assert device.type == "cuda"
m = QuantizedFluxModel() m = QuantizedFluxModel()
cutils.disable_memory_auto_release() cutils.disable_memory_auto_release()
m.init(True, 0 if device.index is None else device.index) m.init(use_fp4, True, 0 if device.index is None else device.index)
m.load(path) m.load(path)
return m return m
...@@ -153,8 +152,10 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader ...@@ -153,8 +152,10 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
@utils.validate_hf_hub_args @utils.validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs):
device = kwargs.get("device", "cuda") device = kwargs.get("device", "cuda")
precision = kwargs.get("precision", "int4")
assert precision in ["int4", "fp4"]
transformer, transformer_block_path = cls._build_model(pretrained_model_name_or_path, **kwargs) transformer, transformer_block_path = cls._build_model(pretrained_model_name_or_path, **kwargs)
m = load_quantized_module(transformer_block_path, device=device) m = load_quantized_module(transformer_block_path, device=device, use_fp4=precision == "fp4")
transformer.inject_quantized_module(m, device) transformer.inject_quantized_module(m, device)
return transformer return transformer
......
...@@ -124,9 +124,13 @@ class NunchakuSanaTransformer2DModel(SanaTransformer2DModel, NunchakuModelLoader ...@@ -124,9 +124,13 @@ class NunchakuSanaTransformer2DModel(SanaTransformer2DModel, NunchakuModelLoader
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs):
device = kwargs.get("device", "cuda") device = kwargs.get("device", "cuda")
pag_layers = kwargs.get("pag_layers", []) pag_layers = kwargs.get("pag_layers", [])
precision = kwargs.get("precision", "int4")
assert precision in ["int4", "fp4"]
transformer, transformer_block_path = cls._build_model(pretrained_model_name_or_path, **kwargs) transformer, transformer_block_path = cls._build_model(pretrained_model_name_or_path, **kwargs)
transformer.config["num_layers"] = transformer.original_num_layers transformer.config["num_layers"] = transformer.original_num_layers
m = load_quantized_module(transformer, transformer_block_path, device=device, pag_layers=pag_layers) m = load_quantized_module(
transformer, transformer_block_path, device=device, pag_layers=pag_layers, use_fp4=precision == "fp4"
)
transformer.inject_quantized_module(m, device) transformer.inject_quantized_module(m, device)
return transformer return transformer
...@@ -140,6 +144,7 @@ def load_quantized_module( ...@@ -140,6 +144,7 @@ def load_quantized_module(
path: str, path: str,
device: str | torch.device = "cuda", device: str | torch.device = "cuda",
pag_layers: int | list[int] | None = None, pag_layers: int | list[int] | None = None,
use_fp4: bool = False,
) -> QuantizedSanaModel: ) -> QuantizedSanaModel:
if pag_layers is None: if pag_layers is None:
pag_layers = [] pag_layers = []
...@@ -150,7 +155,7 @@ def load_quantized_module( ...@@ -150,7 +155,7 @@ def load_quantized_module(
m = QuantizedSanaModel() m = QuantizedSanaModel()
cutils.disable_memory_auto_release() cutils.disable_memory_auto_release()
m.init(net.config, pag_layers, net.dtype == torch.bfloat16, 0 if device.index is None else device.index) m.init(net.config, pag_layers, use_fp4, net.dtype == torch.bfloat16, 0 if device.index is None else device.index)
m.load(path) m.load(path)
return m return m
......
...@@ -4,7 +4,13 @@ from diffusers import FluxPipeline ...@@ -4,7 +4,13 @@ from diffusers import FluxPipeline
from .models.transformer_flux import NunchakuFluxTransformer2dModel from .models.transformer_flux import NunchakuFluxTransformer2dModel
if __name__ == "__main__": if __name__ == "__main__":
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-schnell") capability = torch.cuda.get_device_capability(0)
sm = f"{capability[0]}{capability[1]}"
precision = "fp4" if sm == "120" else "int4"
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/svdq-{precision}-flux.1-schnell", precision=precision
)
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda") ).to("cuda")
......
import os import os
import re
import subprocess
import sys
import setuptools import setuptools
from torch.utils.cpp_extension import BuildExtension, CUDAExtension import torch
from packaging import version as packaging_version
from torch.utils.cpp_extension import BuildExtension, CUDA_HOME, CUDAExtension
class CustomBuildExtension(BuildExtension): class CustomBuildExtension(BuildExtension):
def build_extensions(self): def build_extensions(self):
...@@ -16,10 +22,49 @@ class CustomBuildExtension(BuildExtension): ...@@ -16,10 +22,49 @@ class CustomBuildExtension(BuildExtension):
ext.extra_compile_args["cxx"] += ext.extra_compile_args["gcc"] ext.extra_compile_args["cxx"] += ext.extra_compile_args["gcc"]
super().build_extensions() super().build_extensions()
def get_sm_targets() -> list[str]:
nvcc_path = os.path.join(CUDA_HOME, "bin/nvcc") if CUDA_HOME else "nvcc"
try:
nvcc_output = subprocess.check_output([nvcc_path, "--version"]).decode()
match = re.search(r"release (\d+\.\d+), V(\d+\.\d+\.\d+)", nvcc_output)
if match:
nvcc_version = match.group(2)
else:
raise Exception("nvcc version not found")
print(f"Found nvcc version: {nvcc_version}")
except:
raise Exception("nvcc not found")
support_sm120 = packaging_version.parse(nvcc_version) >= packaging_version.parse("12.8")
install_mode = os.getenv("NUNCHAKU_INSTALL_MODE", "FAST")
if install_mode == "FAST":
ret = []
for i in range(torch.cuda.device_count()):
capability = torch.cuda.get_device_capability(i)
sm = f"{capability[0]}{capability[1]}"
if sm == "120" and support_sm120:
sm = "120a"
assert sm in ["80", "86", "89", "120a"], f"Unsupported SM {sm}"
if sm not in ret:
ret.append(sm)
else:
assert install_mode == "ALL"
ret = ["80", "86", "89"]
if support_sm120:
ret.append("120a")
return ret
if __name__ == "__main__": if __name__ == "__main__":
fp = open("nunchaku/__version__.py", "r").read() fp = open("nunchaku/__version__.py", "r").read()
version = eval(fp.strip().split()[-1]) version = eval(fp.strip().split()[-1])
torch_version = torch.__version__.split("+")[0]
torch_major_minor_version = ".".join(torch_version.split(".")[:2])
version = version + "+torch" + torch_major_minor_version
ROOT_DIR = os.path.dirname(__file__) ROOT_DIR = os.path.dirname(__file__)
INCLUDE_DIRS = [ INCLUDE_DIRS = [
...@@ -54,12 +99,6 @@ if __name__ == "__main__": ...@@ -54,12 +99,6 @@ if __name__ == "__main__":
NVCC_FLAGS = [ NVCC_FLAGS = [
"-DENABLE_BF16=1", "-DENABLE_BF16=1",
"-DBUILD_NUNCHAKU=1", "-DBUILD_NUNCHAKU=1",
"-gencode",
"arch=compute_86,code=sm_86",
"-gencode",
"arch=compute_89,code=sm_89",
# "-gencode",
# "arch=compute_89,code=sm_120a",
"-g", "-g",
"-std=c++20", "-std=c++20",
"-UNDEBUG", "-UNDEBUG",
...@@ -74,13 +113,23 @@ if __name__ == "__main__": ...@@ -74,13 +113,23 @@ if __name__ == "__main__":
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__", "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__", "-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__", "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"--threads=2", "--threads=3",
"--expt-relaxed-constexpr", "--expt-relaxed-constexpr",
"--expt-extended-lambda", "--expt-extended-lambda",
"--generate-line-info",
"--ptxas-options=--allow-expensive-optimizations=true", "--ptxas-options=--allow-expensive-optimizations=true",
] ]
# https://github.com/NVIDIA/cutlass/pull/1479#issuecomment-2052300487
if os.getenv("NUNCHAKU_BUILD_WHEELS", "0") == "0":
NVCC_FLAGS.append("--generate-line-info")
sm_targets = get_sm_targets()
print(f"Detected SM targets: {sm_targets}", file=sys.stderr)
assert len(sm_targets) > 0, "No SM targets found"
for target in sm_targets:
NVCC_FLAGS += ["-gencode", f"arch=compute_{target},code=sm_{target}"]
NVCC_MSVC_FLAGS = ["-Xcompiler", "/Zc:__cplusplus"] NVCC_MSVC_FLAGS = ["-Xcompiler", "/Zc:__cplusplus"]
nunchaku_extension = CUDAExtension( nunchaku_extension = CUDAExtension(
......
...@@ -259,19 +259,19 @@ void Attention::setForceFP16(Module *module, bool value) { ...@@ -259,19 +259,19 @@ void Attention::setForceFP16(Module *module, bool value) {
}); });
} }
FluxSingleTransformerBlock::FluxSingleTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, int mlp_ratio, Tensor::ScalarType dtype, Device device) : FluxSingleTransformerBlock::FluxSingleTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, int mlp_ratio, bool use_fp4, Tensor::ScalarType dtype, Device device) :
dim(dim), dim(dim),
dim_head(attention_head_dim / num_attention_heads), dim_head(attention_head_dim / num_attention_heads),
num_heads(num_attention_heads), num_heads(num_attention_heads),
mlp_hidden_dim(dim * mlp_ratio), mlp_hidden_dim(dim * mlp_ratio),
norm(dim, dtype, device), norm(dim, dtype, device),
mlp_fc1(dim, mlp_hidden_dim, true, dtype, device), mlp_fc1(dim, mlp_hidden_dim, true, use_fp4, dtype, device),
mlp_fc2(mlp_hidden_dim, dim, true, dtype, device), mlp_fc2(mlp_hidden_dim, dim, true, use_fp4, dtype, device),
qkv_proj(dim, dim * 3, true, dtype, device), qkv_proj(dim, dim * 3, true, use_fp4, dtype, device),
norm_q(dim_head, 1e-6, false, dtype, device), norm_q(dim_head, 1e-6, false, dtype, device),
norm_k(dim_head, 1e-6, false, dtype, device), norm_k(dim_head, 1e-6, false, dtype, device),
attn(num_attention_heads, attention_head_dim / num_attention_heads, device), attn(num_attention_heads, attention_head_dim / num_attention_heads, device),
out_proj(dim, dim, true, dtype, device) out_proj(dim, dim, true, use_fp4, dtype, device)
{ {
registerChildren registerChildren
(norm, "norm") (norm, "norm")
...@@ -327,28 +327,28 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te ...@@ -327,28 +327,28 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
return hidden_states; return hidden_states;
} }
JointTransformerBlock::JointTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, bool context_pre_only, Tensor::ScalarType dtype, Device device) : JointTransformerBlock::JointTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, bool context_pre_only, bool use_fp4, Tensor::ScalarType dtype, Device device) :
dim(dim), dim(dim),
dim_head(attention_head_dim / num_attention_heads), dim_head(attention_head_dim / num_attention_heads),
num_heads(num_attention_heads), num_heads(num_attention_heads),
context_pre_only(context_pre_only), context_pre_only(context_pre_only),
norm1(dim, false, dtype, device), norm1(dim, false, dtype, device),
norm1_context(dim, context_pre_only, dtype, device), norm1_context(dim, context_pre_only, dtype, device),
qkv_proj(dim, dim * 3, true, dtype, device), qkv_proj(dim, dim * 3, true, use_fp4, dtype, device),
qkv_proj_context(dim, dim * 3, true, dtype, device), qkv_proj_context(dim, dim * 3, true, use_fp4, dtype, device),
norm_q(dim_head, 1e-6, false, dtype, device), norm_q(dim_head, 1e-6, false, dtype, device),
norm_k(dim_head, 1e-6, false, dtype, device), norm_k(dim_head, 1e-6, false, dtype, device),
norm_added_q(dim_head, 1e-6, false, dtype, device), norm_added_q(dim_head, 1e-6, false, dtype, device),
norm_added_k(dim_head, 1e-6, false, dtype, device), norm_added_k(dim_head, 1e-6, false, dtype, device),
attn(num_attention_heads, attention_head_dim / num_attention_heads, device), attn(num_attention_heads, attention_head_dim / num_attention_heads, device),
out_proj(dim, dim, true, dtype, device), out_proj(dim, dim, true, use_fp4, dtype, device),
out_proj_context(dim, dim, true, dtype, device), out_proj_context(dim, dim, true, use_fp4, dtype, device),
norm2(dim, 1e-6, false, dtype, device), norm2(dim, 1e-6, false, dtype, device),
norm2_context(dim, 1e-6, false, dtype, device), norm2_context(dim, 1e-6, false, dtype, device),
mlp_fc1(dim, dim * 4, true, dtype, device), mlp_fc1(dim, dim * 4, true, use_fp4, dtype, device),
mlp_fc2(dim * 4, dim, true, dtype, device), mlp_fc2(dim * 4, dim, true, use_fp4, dtype, device),
mlp_context_fc1(dim, dim * 4, true, dtype, device), mlp_context_fc1(dim, dim * 4, true, use_fp4, dtype, device),
mlp_context_fc2(dim * 4, dim, true, dtype, device) mlp_context_fc2(dim * 4, dim, true, use_fp4, dtype, device)
{ {
registerChildren registerChildren
(norm1, "norm1") (norm1, "norm1")
...@@ -607,13 +607,13 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, ...@@ -607,13 +607,13 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
return { hidden_states, encoder_hidden_states }; return { hidden_states, encoder_hidden_states };
} }
FluxModel::FluxModel(Tensor::ScalarType dtype, Device device) { FluxModel::FluxModel(bool use_fp4, Tensor::ScalarType dtype, Device device) {
for (int i = 0; i < 19; i++) { for (int i = 0; i < 19; i++) {
transformer_blocks.push_back(std::make_unique<JointTransformerBlock>(3072, 24, 3072, false, dtype, device)); transformer_blocks.push_back(std::make_unique<JointTransformerBlock>(3072, 24, 3072, false, use_fp4, dtype, device));
registerChildren(*transformer_blocks.back(), format("transformer_blocks.{}", i)); registerChildren(*transformer_blocks.back(), format("transformer_blocks.{}", i));
} }
for (int i = 0; i < 38; i++) { for (int i = 0; i < 38; i++) {
single_transformer_blocks.push_back(std::make_unique<FluxSingleTransformerBlock>(3072, 24, 3072, 4, dtype, Device::cuda())); single_transformer_blocks.push_back(std::make_unique<FluxSingleTransformerBlock>(3072, 24, 3072, 4, use_fp4, dtype, Device::cuda()));
registerChildren(*single_transformer_blocks.back(), format("single_transformer_blocks.{}", i)); registerChildren(*single_transformer_blocks.back(), format("single_transformer_blocks.{}", i));
} }
} }
......
...@@ -77,7 +77,7 @@ public: ...@@ -77,7 +77,7 @@ public:
static constexpr bool USE_4BIT = true; static constexpr bool USE_4BIT = true;
using GEMM = std::conditional_t<USE_4BIT, GEMM_W4A4, GEMM_W8A8>; using GEMM = std::conditional_t<USE_4BIT, GEMM_W4A4, GEMM_W8A8>;
FluxSingleTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, int mlp_ratio, Tensor::ScalarType dtype, Device device); FluxSingleTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, int mlp_ratio, bool use_fp4, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor hidden_states, Tensor temb, Tensor rotary_emb); Tensor forward(Tensor hidden_states, Tensor temb, Tensor rotary_emb);
public: public:
...@@ -101,7 +101,7 @@ public: ...@@ -101,7 +101,7 @@ public:
static constexpr bool USE_4BIT = true; static constexpr bool USE_4BIT = true;
using GEMM = std::conditional_t<USE_4BIT, GEMM_W4A4, GEMM_W8A8>; using GEMM = std::conditional_t<USE_4BIT, GEMM_W4A4, GEMM_W8A8>;
JointTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, bool context_pre_only, Tensor::ScalarType dtype, Device device); JointTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, bool context_pre_only, bool use_fp4, Tensor::ScalarType dtype, Device device);
std::tuple<Tensor, Tensor> forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb, Tensor rotary_emb_context, float sparsityRatio); std::tuple<Tensor, Tensor> forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb, Tensor rotary_emb_context, float sparsityRatio);
public: public:
...@@ -128,7 +128,7 @@ private: ...@@ -128,7 +128,7 @@ private:
class FluxModel : public Module { class FluxModel : public Module {
public: public:
FluxModel(Tensor::ScalarType dtype, Device device); FluxModel(bool use_fp4, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb_img, Tensor rotary_emb_context, Tensor rotary_emb_single); Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb_img, Tensor rotary_emb_context, Tensor rotary_emb_single);
public: public:
......
...@@ -96,23 +96,33 @@ Tensor GEMV_AWQ::forward(Tensor x) { ...@@ -96,23 +96,33 @@ Tensor GEMV_AWQ::forward(Tensor x) {
#define NO_LORA_FUSION 0 #define NO_LORA_FUSION 0
GEMM_W4A4::GEMM_W4A4(int in_features, int out_features, bool bias, Tensor::ScalarType dtype, Device device) : GEMM_W4A4::GEMM_W4A4(int in_features, int out_features, bool bias, bool use_fp4, Tensor::ScalarType dtype, Device device) :
in_features(in_features), out_features(out_features), in_features(in_features), out_features(out_features),
in_features_pad(ceilDiv(in_features, 128) * 128), out_features_pad(ceilDiv(out_features, 128) * 128), in_features_pad(ceilDiv(in_features, 128) * 128), out_features_pad(ceilDiv(out_features, 128) * 128),
use_fp4(use_fp4),
lora_rank(0), dtype(dtype) lora_rank(0), dtype(dtype)
{ {
this->qweight = Tensor::allocate({out_features_pad, in_features_pad / 2}, Tensor::INT8, device, true); this->qweight = Tensor::allocate({out_features_pad, in_features_pad / 2}, Tensor::INT8, device, true);
this->wscales = Tensor::allocate({in_features_pad / 64, out_features_pad}, dtype, device, true); if (use_fp4) {
this->wscales = Tensor::allocate({in_features_pad / 16, out_features_pad}, Tensor::FP8_E4M3, device, true);
} else {
this->wscales = Tensor::allocate({in_features_pad / 64, out_features_pad}, dtype, device, true);
}
this->bias = bias ? Tensor::allocate({out_features_pad}, dtype, device, true) : Tensor{}; this->bias = bias ? Tensor::allocate({out_features_pad}, dtype, device, true) : Tensor{};
this->lora_down = Tensor::allocate({in_features_pad, lora_rank}, dtype, device, true); this->lora_down = Tensor::allocate({in_features_pad, lora_rank}, dtype, device, true);
this->lora_up = Tensor::allocate({out_features_pad, lora_rank}, dtype, device, true); this->lora_up = Tensor::allocate({out_features_pad, lora_rank}, dtype, device, true);
// TODO: smooth factor in FC1+FC2 fusion
// TODO: smooth factor in non-Lora fusion // TODO: smooth factor in non-Lora fusion
this->smooth = Tensor::allocate({in_features_pad}, dtype, device, true); this->smooth = Tensor::allocate({in_features_pad}, dtype, device, true);
// FIXME: reset wtscale and wcscales to default values when reloading the weights
this->wtscale = Tensor::allocate({1}, Tensor::FP32, Device::cpu(), true);
*this->wtscale.data_ptr<float>() = 1.0f;
this->wcscales = Tensor::allocate({0}, dtype, device, true);
registerParams registerParams
(qweight, "qweight") (qweight, "qweight")
(wscales, "wscales") (wscales, "wscales")
...@@ -120,6 +130,8 @@ GEMM_W4A4::GEMM_W4A4(int in_features, int out_features, bool bias, Tensor::Scala ...@@ -120,6 +130,8 @@ GEMM_W4A4::GEMM_W4A4(int in_features, int out_features, bool bias, Tensor::Scala
(lora_down, "lora_down", ParamFlags::Optional) (lora_down, "lora_down", ParamFlags::Optional)
(lora_up, "lora_up", ParamFlags::Optional) (lora_up, "lora_up", ParamFlags::Optional)
(smooth, "smooth") (smooth, "smooth")
(wtscale, "wtscale", ParamFlags::Optional)
(wcscales, "wcscales", ParamFlags::Optional)
; ;
#if NO_LORA_FUSION #if NO_LORA_FUSION
...@@ -137,6 +149,21 @@ void GEMM_W4A4::loadParam(std::string key, Tensor &dst, Tensor src) { ...@@ -137,6 +149,21 @@ void GEMM_W4A4::loadParam(std::string key, Tensor &dst, Tensor src) {
} else { } else {
dst.copy_(src); dst.copy_(src);
} }
} else if (key == "wcscales") {
assert(src.ndims() == 1);
assert(src.shape[0] == out_features_pad);
dst = src.copy(this->qweight.device());
} else if (key == "wtscale") {
assert(src.numel() == 1);
if (src.dtype() == Tensor::BF16) {
*dst.data_ptr<float>() = float(*src.data_ptr<__nv_bfloat16>());
} else if (src.dtype() == Tensor::FP16) {
*dst.data_ptr<float>() = float(*src.data_ptr<half>());
} else if (src.dtype() == Tensor::FP32) {
dst.copy_(src);
} else {
assert(false);
}
} else { } else {
Module::loadParam(key, dst, src); Module::loadParam(key, dst, src);
} }
...@@ -167,7 +194,10 @@ void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor ...@@ -167,7 +194,10 @@ void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor
debug("gemm.nolora.out", out); debug("gemm.nolora.out", out);
#endif #endif
kernels::gemm_w4a4(qact.act, qweight, out, {}, qact.ascales, wscales, {}, pool, qact.lora_act, this->lora_up, {}, {}, norm_q, norm_k, rotary_emb, this->bias, {}, {}, {}, qact.is_unsigned, this->lora_scales, false); kernels::gemm_w4a4(
qact.act, qweight, out, {}, qact.ascales, wscales, {}, pool, qact.lora_act, this->lora_up, {}, {}, norm_q, norm_k, rotary_emb, this->bias, {}, {}, {}, qact.is_unsigned, this->lora_scales, false,
use_fp4, *this->wtscale.data_ptr<float>(), wcscales.numel() > 0 ? wcscales: Tensor{}
);
debug("gemm.out", out); debug("gemm.out", out);
#else #else
...@@ -215,9 +245,13 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu ...@@ -215,9 +245,13 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
out = Tensor::allocate(shape, dtype, qweight.device()); out = Tensor::allocate(shape, dtype, qweight.device());
} else { } else {
qout.act = Tensor::allocate({M, out_features_pad / 2}, Tensor::INT8, qweight.device()); qout.act = Tensor::allocate({M, out_features_pad / 2}, Tensor::INT8, qweight.device());
qout.ascales = Tensor::allocate({out_features_pad / 64, M}, dtype, qweight.device()); if (use_fp4) {
qout.ascales = Tensor::allocate({out_features_pad / 16, M}, Tensor::FP8_E4M3, qweight.device());
} else {
qout.ascales = Tensor::allocate({out_features_pad / 64, M}, dtype, qweight.device());
}
qout.lora_act = Tensor::allocate({M, lora_rank}, Tensor::FP32, qweight.device()); qout.lora_act = Tensor::allocate({M, lora_rank}, Tensor::FP32, qweight.device());
qout.is_unsigned = true; qout.is_unsigned = !use_fp4;
qout.actShape = qact.actShape; qout.actShape = qact.actShape;
next_lora = nextGEMM->lora_down; next_lora = nextGEMM->lora_down;
...@@ -241,7 +275,10 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu ...@@ -241,7 +275,10 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
} }
#endif #endif
kernels::gemm_w4a4(qact.act, qweight, out, qout.act, qact.ascales, wscales, qout.ascales, {}, qact.lora_act, this->lora_up, next_lora, qout.lora_act, {}, {}, {}, this->bias, next_smooth, {}, {}, qact.is_unsigned, this->lora_scales, fuse == FuseOptions::SILU); kernels::gemm_w4a4(
qact.act, qweight, out, qout.act, qact.ascales, wscales, qout.ascales, {}, qact.lora_act, this->lora_up, next_lora, qout.lora_act, {}, {}, {}, this->bias, next_smooth, {}, {}, qact.is_unsigned, this->lora_scales, fuse == FuseOptions::SILU,
use_fp4, *this->wtscale.data_ptr<float>(), wcscales.numel() > 0 ? wcscales: Tensor{}
);
if (fuse == FuseOptions::EMPTY || fuse == FuseOptions::SILU) { if (fuse == FuseOptions::EMPTY || fuse == FuseOptions::SILU) {
debug("gemm.out", out); debug("gemm.out", out);
...@@ -327,7 +364,11 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) { ...@@ -327,7 +364,11 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
QuantizedActivation qact; QuantizedActivation qact;
qact.act = Tensor::allocate({M, in_features_pad / 2}, Tensor::INT8, qweight.device()); qact.act = Tensor::allocate({M, in_features_pad / 2}, Tensor::INT8, qweight.device());
qact.ascales = Tensor::allocate({in_features_pad / 64, M}, dtype, qweight.device()); if (use_fp4) {
qact.ascales = Tensor::allocate({in_features_pad / 16, M}, Tensor::FP8_E4M3, qweight.device());
} else {
qact.ascales = Tensor::allocate({in_features_pad / 64, M}, dtype, qweight.device());
}
qact.lora_act = Tensor::allocate({M, lora_rank}, Tensor::FP32, qweight.device()); qact.lora_act = Tensor::allocate({M, lora_rank}, Tensor::FP32, qweight.device());
qact.is_unsigned = false; qact.is_unsigned = false;
qact.actShape = x.shape.dataExtent; qact.actShape = x.shape.dataExtent;
...@@ -336,7 +377,7 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) { ...@@ -336,7 +377,7 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
debug("quantize.x", x); debug("quantize.x", x);
debug("quantize.smooth", this->smooth); debug("quantize.smooth", this->smooth);
kernels::quantize_w4a4_act_fuse_lora(x, qact.act, qact.ascales, this->lora_down, qact.lora_act, this->smooth, fuse_glu); kernels::quantize_w4a4_act_fuse_lora(x, qact.act, qact.ascales, this->lora_down, qact.lora_act, this->smooth, fuse_glu, use_fp4);
debug("quantize.qact", qact.act); debug("quantize.qact", qact.act);
debug("quantize.ascales", qact.ascales); debug("quantize.ascales", qact.ascales);
......
...@@ -64,7 +64,7 @@ public: ...@@ -64,7 +64,7 @@ public:
}; };
public: public:
GEMM_W4A4(int in_features, int out_features, bool bias, Tensor::ScalarType dtype, Device device); GEMM_W4A4(int in_features, int out_features, bool bias, bool use_fp4, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor x); Tensor forward(Tensor x);
Tensor forward_silu(Tensor x); Tensor forward_silu(Tensor x);
std::variant<Tensor, QuantizedActivation> forward(Tensor x, FuseOptions fuse, GEMM_W4A4 *nextGEMM = nullptr); std::variant<Tensor, QuantizedActivation> forward(Tensor x, FuseOptions fuse, GEMM_W4A4 *nextGEMM = nullptr);
...@@ -80,6 +80,7 @@ public: ...@@ -80,6 +80,7 @@ public:
const int out_features; const int out_features;
const int in_features_pad; const int in_features_pad;
const int out_features_pad; const int out_features_pad;
const bool use_fp4;
int lora_rank; int lora_rank;
std::vector<float> lora_scales; // every 16 ranks share a scale std::vector<float> lora_scales; // every 16 ranks share a scale
...@@ -99,6 +100,9 @@ public: ...@@ -99,6 +100,9 @@ public:
Tensor smooth; Tensor smooth;
Tensor wtscale;
Tensor wcscales;
cublasHandle_t handle; cublasHandle_t handle;
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment