Unverified Commit ad8097b9 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

Release v0.2.0

Ready to release v0.2.0
parents 804a6d30 998192ca
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-fp4-flux.1-schnell", precision="fp4")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
image = pipeline(
"A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=4, guidance_scale=0
).images[0]
image.save("flux.1-schnell.png")
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-int4-flux.1-schnell", offload=True
) # set offload to False if you want to disable offloading
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16
)
pipeline.enable_sequential_cpu_offload() # remove this line if you want to disable the CPU offloading
image = pipeline(
"A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=4, guidance_scale=0
).images[0]
image.save("flux.1-schnell.png")
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-int4-flux.1-schnell", offload=True
) # set offload to False if you want to disable offloading
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
text_encoder_2=text_encoder_2,
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
pipeline.enable_sequential_cpu_offload() # remove this line if you want to disable the CPU offloading
image = pipeline(
"A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=4, guidance_scale=0
).images[0]
image.save("flux.1-schnell.png")
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-schnell")
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
text_encoder_2=text_encoder_2,
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
image = pipeline(
"A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=4, guidance_scale=0
).images[0]
image.save("flux.1-schnell.png")
import torch
from diffusers import SanaPipeline
from nunchaku import NunchakuSanaTransformer2DModel
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m")
pipe = SanaPipeline.from_pretrained(
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
transformer=transformer,
variant="bf16",
torch_dtype=torch.bfloat16,
).to("cuda")
pipe.vae.to(torch.bfloat16)
pipe.text_encoder.to(torch.bfloat16)
apply_cache_on_pipe(pipe, residual_diff_threshold=0.25)
# WarmUp
prompt = "A cute 🐼 eating 🎋, ink drawing style"
image = pipe(
prompt=prompt,
height=1024,
width=1024,
guidance_scale=4.5,
num_inference_steps=20,
generator=torch.Generator().manual_seed(42),
).images[0]
image.save("sana_1600m-int4.png")
......@@ -23,4 +23,4 @@ image = pipe(
generator=torch.Generator().manual_seed(42),
).images[0]
image.save("sana_1600m.png")
image.save("sana_1600m-int4.png")
......@@ -24,4 +24,4 @@ image = pipe(
pag_scale=2.0,
num_inference_steps=20,
).images[0]
image.save("sana_1600m_pag.png")
image.save("sana_1600m_pag-int4.png")
__version__ = "0.1.4"
__version__ = "0.2.0"
from diffusers import DiffusionPipeline
def apply_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
assert isinstance(pipe, DiffusionPipeline)
pipe_cls_name = pipe.__class__.__name__
if pipe_cls_name.startswith("Flux"):
from .flux import apply_cache_on_pipe as apply_cache_on_pipe_fn
elif pipe_cls_name.startswith("Sana"):
from .sana import apply_cache_on_pipe as apply_cache_on_pipe_fn
else:
raise ValueError(f"Unknown pipeline class name: {pipe_cls_name}")
return apply_cache_on_pipe_fn(pipe, *args, **kwargs)
import functools
import unittest
from diffusers import DiffusionPipeline, FluxTransformer2DModel
from torch import nn
from ...caching import utils
def apply_cache_on_transformer(transformer: FluxTransformer2DModel, *, residual_diff_threshold=0.12):
if getattr(transformer, "_is_cached", False):
return transformer
cached_transformer_blocks = nn.ModuleList(
[
utils.FluxCachedTransformerBlocks(
transformer=transformer,
residual_diff_threshold=residual_diff_threshold,
return_hidden_states_first=False,
)
]
)
dummy_single_transformer_blocks = nn.ModuleList()
original_forward = transformer.forward
@functools.wraps(original_forward)
def new_forward(self, *args, **kwargs):
with (
unittest.mock.patch.object(self, "transformer_blocks", cached_transformer_blocks),
unittest.mock.patch.object(self, "single_transformer_blocks", dummy_single_transformer_blocks),
):
return original_forward(*args, **kwargs)
transformer.forward = new_forward.__get__(transformer)
transformer._is_cached = True
return transformer
def apply_cache_on_pipe(pipe: DiffusionPipeline, *, shallow_patch: bool = False, **kwargs):
if not getattr(pipe, "_is_cached", False):
original_call = pipe.__class__.__call__
@functools.wraps(original_call)
def new_call(self, *args, **kwargs):
with utils.cache_context(utils.create_cache_context()):
return original_call(self, *args, **kwargs)
pipe.__class__.__call__ = new_call
pipe.__class__._is_cached = True
if not shallow_patch:
apply_cache_on_transformer(pipe.transformer, **kwargs)
return pipe
import functools
import unittest
import torch
from diffusers import DiffusionPipeline, SanaTransformer2DModel
from ...caching import utils
def apply_cache_on_transformer(transformer: SanaTransformer2DModel, *, residual_diff_threshold=0.12):
if getattr(transformer, "_is_cached", False):
return transformer
cached_transformer_blocks = torch.nn.ModuleList(
[
utils.SanaCachedTransformerBlocks(
transformer=transformer,
residual_diff_threshold=residual_diff_threshold,
)
]
)
original_forward = transformer.forward
@functools.wraps(original_forward)
def new_forward(self, *args, **kwargs):
with unittest.mock.patch.object(self, "transformer_blocks", cached_transformer_blocks):
return original_forward(*args, **kwargs)
transformer.forward = new_forward.__get__(transformer)
transformer._is_cached = True
return transformer
def apply_cache_on_pipe(pipe: DiffusionPipeline, *, shallow_patch: bool = False, **kwargs):
if not getattr(pipe, "_is_cached", False):
original_call = pipe.__class__.__call__
@functools.wraps(original_call)
def new_call(self, *args, **kwargs):
with utils.cache_context(utils.create_cache_context()):
return original_call(self, *args, **kwargs)
pipe.__class__.__call__ = new_call
pipe.__class__._is_cached = True
if not shallow_patch:
apply_cache_on_transformer(pipe.transformer, **kwargs)
return pipe
# This caching functionality is largely brought from https://github.com/chengzeyi/ParaAttention/src/para_attn/first_block_cache/
import contextlib
import dataclasses
from collections import defaultdict
from typing import DefaultDict, Dict, Optional
import torch
from torch import nn
@dataclasses.dataclass
class CacheContext:
buffers: Dict[str, torch.Tensor] = dataclasses.field(default_factory=dict)
incremental_name_counters: DefaultDict[str, int] = dataclasses.field(default_factory=lambda: defaultdict(int))
def get_incremental_name(self, name=None):
if name is None:
name = "default"
idx = self.incremental_name_counters[name]
self.incremental_name_counters[name] += 1
return f"{name}_{idx}"
def reset_incremental_name(self):
self.incremental_name_counters.clear()
# @torch.compiler.disable # This is a torchscript feature
def get_buffer(self, name=str):
return self.buffers.get(name)
def set_buffer(self, name, buffer):
self.buffers[name] = buffer
def clear_buffers(self):
self.buffers.clear()
@torch.compiler.disable
def get_buffer(name):
cache_context = get_current_cache_context()
assert cache_context is not None, "cache_context must be set before"
return cache_context.get_buffer(name)
@torch.compiler.disable
def set_buffer(name, buffer):
cache_context = get_current_cache_context()
assert cache_context is not None, "cache_context must be set before"
cache_context.set_buffer(name, buffer)
_current_cache_context = None
def create_cache_context():
return CacheContext()
def get_current_cache_context():
return _current_cache_context
@contextlib.contextmanager
def cache_context(cache_context):
global _current_cache_context
old_cache_context = _current_cache_context
_current_cache_context = cache_context
try:
yield
finally:
_current_cache_context = old_cache_context
@torch.compiler.disable
def are_two_tensors_similar(t1, t2, *, threshold, parallelized=False):
mean_diff = (t1 - t2).abs().mean()
mean_t1 = t1.abs().mean()
diff = mean_diff / mean_t1
return diff.item() < threshold
@torch.compiler.disable
def apply_prev_hidden_states_residual(
hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]:
hidden_states_residual = get_buffer("hidden_states_residual")
assert hidden_states_residual is not None, "hidden_states_residual must be set before"
hidden_states = hidden_states_residual + hidden_states
hidden_states = hidden_states.contiguous()
if encoder_hidden_states is not None:
encoder_hidden_states_residual = get_buffer("encoder_hidden_states_residual")
assert encoder_hidden_states_residual is not None, "encoder_hidden_states_residual must be set before"
encoder_hidden_states = encoder_hidden_states_residual + encoder_hidden_states
encoder_hidden_states = encoder_hidden_states.contiguous()
return hidden_states, encoder_hidden_states
@torch.compiler.disable
def get_can_use_cache(first_hidden_states_residual, threshold, parallelized=False):
prev_first_hidden_states_residual = get_buffer("first_hidden_states_residual")
can_use_cache = prev_first_hidden_states_residual is not None and are_two_tensors_similar(
prev_first_hidden_states_residual,
first_hidden_states_residual,
threshold=threshold,
parallelized=parallelized,
)
return can_use_cache
class SanaCachedTransformerBlocks(nn.Module):
def __init__(
self,
*,
transformer=None,
residual_diff_threshold,
verbose: bool = False,
):
super().__init__()
self.transformer = transformer
self.transformer_blocks = transformer.transformer_blocks
self.residual_diff_threshold = residual_diff_threshold
self.verbose = verbose
def forward(
self,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask=None,
timestep=None,
post_patch_height=None,
post_patch_width=None,
):
batch_size = hidden_states.shape[0]
if self.residual_diff_threshold <= 0.0 or batch_size > 2:
if batch_size > 2:
print("Batch size > 2 (for SANA CFG)" " currently not supported")
first_transformer_block = self.transformer_blocks[0]
hidden_states = first_transformer_block(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
height=post_patch_height,
width=post_patch_width,
skip_first_layer=False,
)
return hidden_states
original_hidden_states = hidden_states
first_transformer_block = self.transformer_blocks[0]
hidden_states = first_transformer_block.forward_layer_at(
0,
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
height=post_patch_height,
width=post_patch_width,
)
first_hidden_states_residual = hidden_states - original_hidden_states
del original_hidden_states
can_use_cache = get_can_use_cache(
first_hidden_states_residual,
threshold=self.residual_diff_threshold,
parallelized=self.transformer is not None and getattr(self.transformer, "_is_parallelized", False),
)
torch._dynamo.graph_break()
if can_use_cache:
del first_hidden_states_residual
if self.verbose:
print("Cache hit!!!")
hidden_states, _ = apply_prev_hidden_states_residual(hidden_states, None)
else:
if self.verbose:
print("Cache miss!!!")
set_buffer("first_hidden_states_residual", first_hidden_states_residual)
del first_hidden_states_residual
hidden_states, hidden_states_residual = self.call_remaining_transformer_blocks(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
post_patch_height=post_patch_height,
post_patch_width=post_patch_width,
)
set_buffer("hidden_states_residual", hidden_states_residual)
torch._dynamo.graph_break()
return hidden_states
def call_remaining_transformer_blocks(
self,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask=None,
timestep=None,
post_patch_height=None,
post_patch_width=None,
):
first_transformer_block = self.transformer_blocks[0]
original_hidden_states = hidden_states
hidden_states = first_transformer_block(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
height=post_patch_height,
width=post_patch_width,
skip_first_layer=True,
)
hidden_states_residual = hidden_states - original_hidden_states
return hidden_states, hidden_states_residual
class FluxCachedTransformerBlocks(nn.Module):
def __init__(
self,
*,
transformer=None,
residual_diff_threshold,
return_hidden_states_first=True,
return_hidden_states_only=False,
verbose: bool = False,
):
super().__init__()
self.transformer = transformer
self.transformer_blocks = transformer.transformer_blocks
self.single_transformer_blocks = transformer.single_transformer_blocks
self.residual_diff_threshold = residual_diff_threshold
self.return_hidden_states_first = return_hidden_states_first
self.return_hidden_states_only = return_hidden_states_only
self.verbose = verbose
def forward(self, hidden_states, encoder_hidden_states, *args, **kwargs):
batch_size = hidden_states.shape[0]
if self.residual_diff_threshold <= 0.0 or batch_size > 1:
if batch_size > 1:
print("Batch size > 1 currently not supported")
first_transformer_block = self.transformer_blocks[0]
encoder_hidden_states, hidden_states = first_transformer_block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, *args, **kwargs
)
return (
hidden_states
if self.return_hidden_states_only
else (
(hidden_states, encoder_hidden_states)
if self.return_hidden_states_first
else (encoder_hidden_states, hidden_states)
)
)
original_hidden_states = hidden_states
first_transformer_block = self.transformer_blocks[0]
encoder_hidden_states, hidden_states = first_transformer_block.forward_layer_at(
0, hidden_states, encoder_hidden_states, *args, **kwargs
)
first_hidden_states_residual = hidden_states - original_hidden_states
del original_hidden_states
can_use_cache = get_can_use_cache(
first_hidden_states_residual,
threshold=self.residual_diff_threshold,
parallelized=self.transformer is not None and getattr(self.transformer, "_is_parallelized", False),
)
torch._dynamo.graph_break()
if can_use_cache:
del first_hidden_states_residual
if self.verbose:
print("Cache hit!!!")
hidden_states, encoder_hidden_states = apply_prev_hidden_states_residual(
hidden_states, encoder_hidden_states
)
else:
if self.verbose:
print("Cache miss!!!")
set_buffer("first_hidden_states_residual", first_hidden_states_residual)
del first_hidden_states_residual
(
hidden_states,
encoder_hidden_states,
hidden_states_residual,
encoder_hidden_states_residual,
) = self.call_remaining_transformer_blocks(hidden_states, encoder_hidden_states, *args, **kwargs)
set_buffer("hidden_states_residual", hidden_states_residual)
set_buffer("encoder_hidden_states_residual", encoder_hidden_states_residual)
torch._dynamo.graph_break()
return (
hidden_states
if self.return_hidden_states_only
else (
(hidden_states, encoder_hidden_states)
if self.return_hidden_states_first
else (encoder_hidden_states, hidden_states)
)
)
def call_remaining_transformer_blocks(self, hidden_states, encoder_hidden_states, *args, **kwargs):
first_transformer_block = self.transformer_blocks[0]
original_hidden_states = hidden_states
original_encoder_hidden_states = encoder_hidden_states
encoder_hidden_states, hidden_states = first_transformer_block.forward(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
skip_first_layer=True,
*args,
**kwargs,
)
hidden_states = hidden_states.contiguous()
encoder_hidden_states = encoder_hidden_states.contiguous()
hidden_states_residual = hidden_states - original_hidden_states
encoder_hidden_states_residual = encoder_hidden_states - original_encoder_hidden_states
return hidden_states, encoder_hidden_states, hidden_states_residual, encoder_hidden_states_residual
......@@ -10,22 +10,37 @@
class QuantizedFluxModel : public ModuleWrapper<FluxModel> { // : public torch::CustomClassHolder {
public:
void init(bool use_fp4, bool offload, bool bf16, int8_t deviceId) {
spdlog::info("Initializing QuantizedFluxModel");
spdlog::info("Initializing QuantizedFluxModel on device {}", deviceId);
if (!bf16) {
spdlog::info("Use FP16 model");
}
if (offload) {
spdlog::info("Layer offloading enabled");
}
ModuleWrapper::init(deviceId);
CUDADeviceContext ctx(this->deviceId);
net = std::make_unique<FluxModel>(use_fp4, offload, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
}
bool isBF16() {
checkModel();
return net->dtype == Tensor::BF16;
}
torch::Tensor forward(
torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states,
torch::Tensor temb,
torch::Tensor rotary_emb_img,
torch::Tensor rotary_emb_context,
torch::Tensor rotary_emb_single)
torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states,
torch::Tensor temb,
torch::Tensor rotary_emb_img,
torch::Tensor rotary_emb_context,
torch::Tensor rotary_emb_single,
std::optional<torch::Tensor> controlnet_block_samples = std::nullopt,
std::optional<torch::Tensor> controlnet_single_block_samples = std::nullopt,
bool skip_first_layer = false)
{
checkModel();
CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedFluxModel forward");
......@@ -42,7 +57,10 @@ public:
from_torch(temb),
from_torch(rotary_emb_img),
from_torch(rotary_emb_context),
from_torch(rotary_emb_single)
from_torch(rotary_emb_single),
controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{},
controlnet_single_block_samples.has_value() ? from_torch(controlnet_single_block_samples.value().contiguous()) : Tensor{},
skip_first_layer
);
torch::Tensor output = to_torch(result);
......@@ -53,12 +71,16 @@ public:
std::tuple<torch::Tensor, torch::Tensor> forward_layer(
int64_t idx,
torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states,
torch::Tensor temb,
torch::Tensor rotary_emb_img,
torch::Tensor rotary_emb_context)
torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states,
torch::Tensor temb,
torch::Tensor rotary_emb_img,
torch::Tensor rotary_emb_context,
std::optional<torch::Tensor> controlnet_block_samples = std::nullopt,
std::optional<torch::Tensor> controlnet_single_block_samples = std::nullopt)
{
CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedFluxModel forward_layer {}", idx);
hidden_states = hidden_states.contiguous();
......@@ -67,17 +89,19 @@ public:
rotary_emb_img = rotary_emb_img.contiguous();
rotary_emb_context = rotary_emb_context.contiguous();
auto &&[result_img, result_txt] = net->transformer_blocks.at(idx)->forward(
auto &&[hidden_states_, encoder_hidden_states_] = net->forward_layer(
idx,
from_torch(hidden_states),
from_torch(encoder_hidden_states),
from_torch(temb),
from_torch(rotary_emb_img),
from_torch(rotary_emb_context),
0.0f
controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{},
controlnet_single_block_samples.has_value() ? from_torch(controlnet_single_block_samples.value().contiguous()) : Tensor{}
);
hidden_states = to_torch(result_img);
encoder_hidden_states = to_torch(result_txt);
hidden_states = to_torch(hidden_states_);
encoder_hidden_states = to_torch(encoder_hidden_states_);
Tensor::synchronizeDevice();
return { hidden_states, encoder_hidden_states };
......@@ -85,10 +109,12 @@ public:
torch::Tensor forward_single_layer(
int64_t idx,
torch::Tensor hidden_states,
torch::Tensor temb,
torch::Tensor hidden_states,
torch::Tensor temb,
torch::Tensor rotary_emb_single)
{
CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedFluxModel forward_single_layer {}", idx);
hidden_states = hidden_states.contiguous();
......@@ -115,6 +141,8 @@ public:
throw std::invalid_argument("skipRanks must be multiples of 16");
}
CUDADeviceContext ctx(deviceId);
spdlog::info("Set lora scale to {} (skip {} ranks)", scale, skipRanks);
net->traverse([&](Module *module) {
......@@ -131,8 +159,20 @@ public:
});
}
void forceFP16Attention(bool enable) {
Attention::setForceFP16(net.get(), enable);
void setAttentionImpl(std::string name) {
if (name.empty() || name == "default") {
name = "flashattn2";
}
spdlog::info("Set attention implementation to {}", name);
if (name == "flashattn2") {
net->setAttentionImpl(AttentionImpl::FlashAttention2);
} else if (name == "nunchaku-fp16") {
net->setAttentionImpl(AttentionImpl::NunchakuFP16);
} else {
throw std::invalid_argument(spdlog::fmt_lib::format("Invalid attention implementation {}", name));
}
}
};
\ No newline at end of file
......@@ -10,7 +10,7 @@ class QuantizedGEMM : public ModuleWrapper<GEMM_W4A4> {
public:
void init(int64_t in_features, int64_t out_features, bool bias, bool use_fp4, bool bf16, int8_t deviceId) {
spdlog::info("Initializing QuantizedGEMM");
size_t val = 0;
checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, 8192));
checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
......@@ -27,7 +27,7 @@ public:
x = x.contiguous();
Tensor result = net->forward(from_torch(x));
torch::Tensor output = to_torch(result);
Tensor::synchronizeDevice();
......@@ -48,7 +48,7 @@ public:
const int M = x.shape[0];
const int K = x.shape[1] * 2;
assert(x.dtype() == Tensor::INT8);
// activation: row major, [M / BLOCK_M, K / WARP_K, NUM_WARPS, WARP_M_TILES, WARP_SIZE] of packed_act_t (uint4)
......@@ -67,7 +67,7 @@ public:
const int offset = ((bm * (K / WARP_K) + bn) * NUM_WARPS + warpId) * WARP_M_TILES * WARP_SIZE * 4;
for (int i = 0; i < 16; i++) {
assert(offset + i < x.numel() / 4);
assert(static_cast<size_t>(offset + i) < x.numel() / 4);
uint32_t val = x.data_ptr<uint32_t>()[offset + i];
ss << "{";
for (int j = 0; j < 8; j++) {
......@@ -83,7 +83,7 @@ public:
}
}
}
ss << std::endl;
return ss.str();
}
......@@ -99,7 +99,7 @@ public:
from_torch(x),
fuse_glu
);
Tensor act = qout.act.copy(Device::cpu());
Tensor ascales = qout.ascales.copy(Device::cpu());
Tensor lora_act = qout.lora_act.copy(Device::cpu());
......@@ -110,4 +110,4 @@ public:
spdlog::debug("ascales = {}", dumpTensorBF16(ascales));
}
};
\ No newline at end of file
};
......@@ -9,7 +9,12 @@
template<typename M>
class ModuleWrapper {
public:
void init(int deviceId) {
this->deviceId = deviceId;
}
void reset() {
CUDADeviceContext ctx(this->deviceId);
debugContext.reset();
net.reset();
Tensor::synchronizeDevice();
......@@ -20,6 +25,7 @@ public:
void load(std::string path, bool partial = false) {
checkModel();
CUDADeviceContext ctx(this->deviceId);
spdlog::info("{} weights from {}", partial ? "Loading partial" : "Loading", path);
......@@ -30,6 +36,19 @@ public:
spdlog::info("Done.");
}
void loadDict(std::map<std::string, torch::Tensor> dict, bool partial = false) {
checkModel();
CUDADeviceContext ctx(this->deviceId);
spdlog::info("{} weights from pytorch", partial ? "Loading partial" : "Loading");
std::shared_ptr<TensorsProviderTorch> provider = std::make_shared<TensorsProviderTorch>(std::move(dict));
net->loadParams(*provider, partial);
Tensor::synchronizeDevice();
spdlog::info("Done.");
}
void startDebug() {
debugContext = std::make_unique<DebugContext>();
}
......@@ -38,6 +57,8 @@ public:
}
auto getDebugResults() {
CUDADeviceContext ctx(this->deviceId);
std::map<std::string, torch::Tensor> result;
if (debugContext) {
......@@ -59,4 +80,6 @@ protected:
protected:
std::unique_ptr<M> net;
std::unique_ptr<DebugContext> debugContext;
int deviceId = -1;
};
\ No newline at end of file
......@@ -32,7 +32,11 @@ namespace nunchaku::ops {
bool fuse_silu,
bool fp4,
float alpha,
std::optional<torch::Tensor> wcscales
std::optional<torch::Tensor> wcscales,
std::optional<torch::Tensor> out_q, // packed attention [B, H, M, D]
std::optional<torch::Tensor> out_k, // packed attention [B, H, M, D]
std::optional<torch::Tensor> out_v, // packed attention [B, H, M, D]
int attn_tokens
) {
spdlog::trace("running gemm_w4a4: ");
......@@ -70,11 +74,31 @@ namespace nunchaku::ops {
fuse_silu,
fp4,
alpha,
getTensor(wcscales)
getTensor(wcscales),
getTensor(out_q),
getTensor(out_k),
getTensor(out_v),
attn_tokens
);
// Tensor::synchronizeDevice();
}
void attention_fp16(
torch::Tensor q, // packed [Batch, Head, TokensQ, HEAD_DIM]
torch::Tensor k, // packed [Batch, Head, TokensKV, HEAD_DIM]
torch::Tensor v, // packed [Batch, Head, TokensKV, HEAD_DIM]
torch::Tensor o, // linear [Batch, TokensQ, Head * HEAD_DIM]
float scale
) {
nunchaku::kernels::attention_fp16(
from_torch(q),
from_torch(k),
from_torch(v),
from_torch(o),
scale
);
}
torch::Tensor gemv_awq(
torch::Tensor _in_feats,
torch::Tensor _kernel,
......@@ -122,6 +146,36 @@ namespace nunchaku::ops {
return output;
}
void test_rmsnorm_rope(
torch::Tensor input,
torch::Tensor output,
torch::Tensor norm_q,
torch::Tensor norm_k,
torch::Tensor rotary_emb)
{
nunchaku::kernels::test_rmsnorm_rope(
from_torch(input),
from_torch(output),
from_torch(norm_q),
from_torch(norm_k),
from_torch(rotary_emb)
);
}
void test_pack_qkv(
torch::Tensor input,
torch::Tensor out_q,
torch::Tensor out_k,
torch::Tensor out_v,
int numTokens)
{
nunchaku::kernels::test_pack_qkv(
from_torch(input),
from_torch(out_q),
from_torch(out_k),
from_torch(out_v),
numTokens
);
}
};
\ No newline at end of file
......@@ -18,18 +18,42 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("deviceId")
)
.def("reset", &QuantizedFluxModel::reset)
.def("load", &QuantizedFluxModel::load,
.def("load", &QuantizedFluxModel::load,
py::arg("path"),
py::arg("partial") = false
)
.def("forward", &QuantizedFluxModel::forward)
.def("forward_layer", &QuantizedFluxModel::forward_layer)
.def("loadDict", &QuantizedFluxModel::loadDict,
py::arg("dict"),
py::arg("partial") = false
)
.def("forward", &QuantizedFluxModel::forward,
py::arg("hidden_states"),
py::arg("encoder_hidden_states"),
py::arg("temb"),
py::arg("rotary_emb_img"),
py::arg("rotary_emb_context"),
py::arg("rotary_emb_single"),
py::arg("controlnet_block_samples") = py::none(),
py::arg("controlnet_single_block_samples") = py::none(),
py::arg("skip_first_layer") = false
)
.def("forward_layer", &QuantizedFluxModel::forward_layer,
py::arg("idx"),
py::arg("hidden_states"),
py::arg("encoder_hidden_states"),
py::arg("temb"),
py::arg("rotary_emb_img"),
py::arg("rotary_emb_context"),
py::arg("controlnet_block_samples") = py::none(),
py::arg("controlnet_single_block_samples") = py::none()
)
.def("forward_single_layer", &QuantizedFluxModel::forward_single_layer)
.def("startDebug", &QuantizedFluxModel::startDebug)
.def("stopDebug", &QuantizedFluxModel::stopDebug)
.def("getDebugResults", &QuantizedFluxModel::getDebugResults)
.def("setLoraScale", &QuantizedFluxModel::setLoraScale)
.def("forceFP16Attention", &QuantizedFluxModel::forceFP16Attention)
.def("setAttentionImpl", &QuantizedFluxModel::setAttentionImpl)
.def("isBF16", &QuantizedFluxModel::isBF16)
;
py::class_<QuantizedSanaModel>(m, "QuantizedSanaModel")
.def(py::init<>())
......@@ -41,10 +65,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("deviceId")
)
.def("reset", &QuantizedSanaModel::reset)
.def("load", &QuantizedSanaModel::load,
.def("load", &QuantizedSanaModel::load,
py::arg("path"),
py::arg("partial") = false
)
.def("loadDict", &QuantizedSanaModel::loadDict,
py::arg("dict"),
py::arg("partial") = false
)
.def("forward", &QuantizedSanaModel::forward)
.def("forward_layer", &QuantizedSanaModel::forward_layer)
.def("startDebug", &QuantizedSanaModel::startDebug)
......@@ -74,15 +102,22 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
;
m.def_submodule("ops")
.def("gemm_w4a4", nunchaku::ops::gemm_w4a4)
.def("attention_fp16", nunchaku::ops::attention_fp16)
.def("gemm_awq", nunchaku::ops::gemm_awq)
.def("gemv_awq", nunchaku::ops::gemv_awq)
.def("test_rmsnorm_rope", nunchaku::ops::test_rmsnorm_rope)
.def("test_pack_qkv", nunchaku::ops::test_pack_qkv)
;
m.def_submodule("utils")
.def("set_log_level", [](const std::string &level) {
spdlog::set_level(spdlog::level::from_str(level));
})
.def("set_cuda_stack_limit", nunchaku::utils::set_cuda_stack_limit)
.def("disable_memory_auto_release", nunchaku::utils::disable_memory_auto_release)
.def("trim_memory", nunchaku::utils::trim_memory)
.def("set_faster_i2f_mode", nunchaku::utils::set_faster_i2f_mode)
;
}
......@@ -9,7 +9,7 @@
class QuantizedSanaModel : public ModuleWrapper<SanaModel> {
public:
void init(pybind11::dict config, std::vector<int> pag_layers, bool use_fp4, bool bf16, int8_t deviceId) {
spdlog::info("Initializing QuantizedSanaModel");
spdlog::info("Initializing QuantizedSanaModel on device {}", deviceId);
SanaConfig cfg{
.num_layers = config["num_layers"].cast<int>(),
.num_attention_heads = config["num_attention_heads"].cast<int>(),
......@@ -19,21 +19,26 @@ public:
.pag_layers = pag_layers,
.use_fp4 = use_fp4,
};
ModuleWrapper::init(deviceId);
CUDADeviceContext ctx(this->deviceId);
net = std::make_unique<SanaModel>(cfg, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
}
torch::Tensor forward(
torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states,
torch::Tensor timestep,
torch::Tensor cu_seqlens_img,
torch::Tensor cu_seqlens_txt,
int H,
torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states,
torch::Tensor timestep,
torch::Tensor cu_seqlens_img,
torch::Tensor cu_seqlens_txt,
int H,
int W,
bool pag,
bool cfg)
bool pag,
bool cfg,
bool skip_first_layer = false)
{
checkModel();
CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedSanaModel forward");
......@@ -50,7 +55,8 @@ public:
from_torch(cu_seqlens_img),
from_torch(cu_seqlens_txt),
H, W,
pag, cfg
pag, cfg,
skip_first_layer
);
torch::Tensor output = to_torch(result);
......@@ -61,17 +67,18 @@ public:
torch::Tensor forward_layer(
int64_t idx,
torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states,
torch::Tensor timestep,
torch::Tensor cu_seqlens_img,
torch::Tensor cu_seqlens_txt,
int H,
torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states,
torch::Tensor timestep,
torch::Tensor cu_seqlens_img,
torch::Tensor cu_seqlens_txt,
int H,
int W,
bool pag,
bool cfg)
bool pag,
bool cfg)
{
checkModel();
CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedSanaModel forward_layer {}", idx);
......
......@@ -2,9 +2,17 @@
#include "common.h"
#include "Tensor.h"
#include "kernels/zgemm/zgemm.h"
namespace nunchaku::utils {
void set_cuda_stack_limit(int64_t newval) {
size_t val = 0;
checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, (size_t)newval));
checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
spdlog::debug("Stack={}", val);
}
void disable_memory_auto_release() {
int device;
checkCUDA(cudaGetDevice(&device));
......@@ -23,4 +31,9 @@ namespace nunchaku::utils {
checkCUDA(cudaMemPoolTrimTo(mempool, bytesToKeep));
}
void set_faster_i2f_mode(std::string mode) {
spdlog::info("Set fasteri2f mode to {}", mode);
kernels::set_faster_i2f_mode(mode);
}
};
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment