Unverified Commit ad8097b9 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

Release v0.2.0

Ready to release v0.2.0
parents 804a6d30 998192ca
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-fp4-flux.1-schnell", precision="fp4")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
image = pipeline(
"A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=4, guidance_scale=0
).images[0]
image.save("flux.1-schnell.png")
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-int4-flux.1-schnell", offload=True
) # set offload to False if you want to disable offloading
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16
)
pipeline.enable_sequential_cpu_offload() # remove this line if you want to disable the CPU offloading
image = pipeline(
"A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=4, guidance_scale=0
).images[0]
image.save("flux.1-schnell.png")
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-int4-flux.1-schnell", offload=True
) # set offload to False if you want to disable offloading
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
text_encoder_2=text_encoder_2,
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
pipeline.enable_sequential_cpu_offload() # remove this line if you want to disable the CPU offloading
image = pipeline(
"A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=4, guidance_scale=0
).images[0]
image.save("flux.1-schnell.png")
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-schnell")
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
text_encoder_2=text_encoder_2,
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
image = pipeline(
"A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=4, guidance_scale=0
).images[0]
image.save("flux.1-schnell.png")
import torch
from diffusers import SanaPipeline
from nunchaku import NunchakuSanaTransformer2DModel
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m")
pipe = SanaPipeline.from_pretrained(
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
transformer=transformer,
variant="bf16",
torch_dtype=torch.bfloat16,
).to("cuda")
pipe.vae.to(torch.bfloat16)
pipe.text_encoder.to(torch.bfloat16)
apply_cache_on_pipe(pipe, residual_diff_threshold=0.25)
# WarmUp
prompt = "A cute 🐼 eating 🎋, ink drawing style"
image = pipe(
prompt=prompt,
height=1024,
width=1024,
guidance_scale=4.5,
num_inference_steps=20,
generator=torch.Generator().manual_seed(42),
).images[0]
image.save("sana_1600m-int4.png")
...@@ -23,4 +23,4 @@ image = pipe( ...@@ -23,4 +23,4 @@ image = pipe(
generator=torch.Generator().manual_seed(42), generator=torch.Generator().manual_seed(42),
).images[0] ).images[0]
image.save("sana_1600m.png") image.save("sana_1600m-int4.png")
...@@ -24,4 +24,4 @@ image = pipe( ...@@ -24,4 +24,4 @@ image = pipe(
pag_scale=2.0, pag_scale=2.0,
num_inference_steps=20, num_inference_steps=20,
).images[0] ).images[0]
image.save("sana_1600m_pag.png") image.save("sana_1600m_pag-int4.png")
__version__ = "0.1.4" __version__ = "0.2.0"
from diffusers import DiffusionPipeline
def apply_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
assert isinstance(pipe, DiffusionPipeline)
pipe_cls_name = pipe.__class__.__name__
if pipe_cls_name.startswith("Flux"):
from .flux import apply_cache_on_pipe as apply_cache_on_pipe_fn
elif pipe_cls_name.startswith("Sana"):
from .sana import apply_cache_on_pipe as apply_cache_on_pipe_fn
else:
raise ValueError(f"Unknown pipeline class name: {pipe_cls_name}")
return apply_cache_on_pipe_fn(pipe, *args, **kwargs)
import functools
import unittest
from diffusers import DiffusionPipeline, FluxTransformer2DModel
from torch import nn
from ...caching import utils
def apply_cache_on_transformer(transformer: FluxTransformer2DModel, *, residual_diff_threshold=0.12):
if getattr(transformer, "_is_cached", False):
return transformer
cached_transformer_blocks = nn.ModuleList(
[
utils.FluxCachedTransformerBlocks(
transformer=transformer,
residual_diff_threshold=residual_diff_threshold,
return_hidden_states_first=False,
)
]
)
dummy_single_transformer_blocks = nn.ModuleList()
original_forward = transformer.forward
@functools.wraps(original_forward)
def new_forward(self, *args, **kwargs):
with (
unittest.mock.patch.object(self, "transformer_blocks", cached_transformer_blocks),
unittest.mock.patch.object(self, "single_transformer_blocks", dummy_single_transformer_blocks),
):
return original_forward(*args, **kwargs)
transformer.forward = new_forward.__get__(transformer)
transformer._is_cached = True
return transformer
def apply_cache_on_pipe(pipe: DiffusionPipeline, *, shallow_patch: bool = False, **kwargs):
if not getattr(pipe, "_is_cached", False):
original_call = pipe.__class__.__call__
@functools.wraps(original_call)
def new_call(self, *args, **kwargs):
with utils.cache_context(utils.create_cache_context()):
return original_call(self, *args, **kwargs)
pipe.__class__.__call__ = new_call
pipe.__class__._is_cached = True
if not shallow_patch:
apply_cache_on_transformer(pipe.transformer, **kwargs)
return pipe
import functools
import unittest
import torch
from diffusers import DiffusionPipeline, SanaTransformer2DModel
from ...caching import utils
def apply_cache_on_transformer(transformer: SanaTransformer2DModel, *, residual_diff_threshold=0.12):
if getattr(transformer, "_is_cached", False):
return transformer
cached_transformer_blocks = torch.nn.ModuleList(
[
utils.SanaCachedTransformerBlocks(
transformer=transformer,
residual_diff_threshold=residual_diff_threshold,
)
]
)
original_forward = transformer.forward
@functools.wraps(original_forward)
def new_forward(self, *args, **kwargs):
with unittest.mock.patch.object(self, "transformer_blocks", cached_transformer_blocks):
return original_forward(*args, **kwargs)
transformer.forward = new_forward.__get__(transformer)
transformer._is_cached = True
return transformer
def apply_cache_on_pipe(pipe: DiffusionPipeline, *, shallow_patch: bool = False, **kwargs):
if not getattr(pipe, "_is_cached", False):
original_call = pipe.__class__.__call__
@functools.wraps(original_call)
def new_call(self, *args, **kwargs):
with utils.cache_context(utils.create_cache_context()):
return original_call(self, *args, **kwargs)
pipe.__class__.__call__ = new_call
pipe.__class__._is_cached = True
if not shallow_patch:
apply_cache_on_transformer(pipe.transformer, **kwargs)
return pipe
# This caching functionality is largely brought from https://github.com/chengzeyi/ParaAttention/src/para_attn/first_block_cache/
import contextlib
import dataclasses
from collections import defaultdict
from typing import DefaultDict, Dict, Optional
import torch
from torch import nn
@dataclasses.dataclass
class CacheContext:
buffers: Dict[str, torch.Tensor] = dataclasses.field(default_factory=dict)
incremental_name_counters: DefaultDict[str, int] = dataclasses.field(default_factory=lambda: defaultdict(int))
def get_incremental_name(self, name=None):
if name is None:
name = "default"
idx = self.incremental_name_counters[name]
self.incremental_name_counters[name] += 1
return f"{name}_{idx}"
def reset_incremental_name(self):
self.incremental_name_counters.clear()
# @torch.compiler.disable # This is a torchscript feature
def get_buffer(self, name=str):
return self.buffers.get(name)
def set_buffer(self, name, buffer):
self.buffers[name] = buffer
def clear_buffers(self):
self.buffers.clear()
@torch.compiler.disable
def get_buffer(name):
cache_context = get_current_cache_context()
assert cache_context is not None, "cache_context must be set before"
return cache_context.get_buffer(name)
@torch.compiler.disable
def set_buffer(name, buffer):
cache_context = get_current_cache_context()
assert cache_context is not None, "cache_context must be set before"
cache_context.set_buffer(name, buffer)
_current_cache_context = None
def create_cache_context():
return CacheContext()
def get_current_cache_context():
return _current_cache_context
@contextlib.contextmanager
def cache_context(cache_context):
global _current_cache_context
old_cache_context = _current_cache_context
_current_cache_context = cache_context
try:
yield
finally:
_current_cache_context = old_cache_context
@torch.compiler.disable
def are_two_tensors_similar(t1, t2, *, threshold, parallelized=False):
mean_diff = (t1 - t2).abs().mean()
mean_t1 = t1.abs().mean()
diff = mean_diff / mean_t1
return diff.item() < threshold
@torch.compiler.disable
def apply_prev_hidden_states_residual(
hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]:
hidden_states_residual = get_buffer("hidden_states_residual")
assert hidden_states_residual is not None, "hidden_states_residual must be set before"
hidden_states = hidden_states_residual + hidden_states
hidden_states = hidden_states.contiguous()
if encoder_hidden_states is not None:
encoder_hidden_states_residual = get_buffer("encoder_hidden_states_residual")
assert encoder_hidden_states_residual is not None, "encoder_hidden_states_residual must be set before"
encoder_hidden_states = encoder_hidden_states_residual + encoder_hidden_states
encoder_hidden_states = encoder_hidden_states.contiguous()
return hidden_states, encoder_hidden_states
@torch.compiler.disable
def get_can_use_cache(first_hidden_states_residual, threshold, parallelized=False):
prev_first_hidden_states_residual = get_buffer("first_hidden_states_residual")
can_use_cache = prev_first_hidden_states_residual is not None and are_two_tensors_similar(
prev_first_hidden_states_residual,
first_hidden_states_residual,
threshold=threshold,
parallelized=parallelized,
)
return can_use_cache
class SanaCachedTransformerBlocks(nn.Module):
def __init__(
self,
*,
transformer=None,
residual_diff_threshold,
verbose: bool = False,
):
super().__init__()
self.transformer = transformer
self.transformer_blocks = transformer.transformer_blocks
self.residual_diff_threshold = residual_diff_threshold
self.verbose = verbose
def forward(
self,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask=None,
timestep=None,
post_patch_height=None,
post_patch_width=None,
):
batch_size = hidden_states.shape[0]
if self.residual_diff_threshold <= 0.0 or batch_size > 2:
if batch_size > 2:
print("Batch size > 2 (for SANA CFG)" " currently not supported")
first_transformer_block = self.transformer_blocks[0]
hidden_states = first_transformer_block(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
height=post_patch_height,
width=post_patch_width,
skip_first_layer=False,
)
return hidden_states
original_hidden_states = hidden_states
first_transformer_block = self.transformer_blocks[0]
hidden_states = first_transformer_block.forward_layer_at(
0,
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
height=post_patch_height,
width=post_patch_width,
)
first_hidden_states_residual = hidden_states - original_hidden_states
del original_hidden_states
can_use_cache = get_can_use_cache(
first_hidden_states_residual,
threshold=self.residual_diff_threshold,
parallelized=self.transformer is not None and getattr(self.transformer, "_is_parallelized", False),
)
torch._dynamo.graph_break()
if can_use_cache:
del first_hidden_states_residual
if self.verbose:
print("Cache hit!!!")
hidden_states, _ = apply_prev_hidden_states_residual(hidden_states, None)
else:
if self.verbose:
print("Cache miss!!!")
set_buffer("first_hidden_states_residual", first_hidden_states_residual)
del first_hidden_states_residual
hidden_states, hidden_states_residual = self.call_remaining_transformer_blocks(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
post_patch_height=post_patch_height,
post_patch_width=post_patch_width,
)
set_buffer("hidden_states_residual", hidden_states_residual)
torch._dynamo.graph_break()
return hidden_states
def call_remaining_transformer_blocks(
self,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask=None,
timestep=None,
post_patch_height=None,
post_patch_width=None,
):
first_transformer_block = self.transformer_blocks[0]
original_hidden_states = hidden_states
hidden_states = first_transformer_block(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
height=post_patch_height,
width=post_patch_width,
skip_first_layer=True,
)
hidden_states_residual = hidden_states - original_hidden_states
return hidden_states, hidden_states_residual
class FluxCachedTransformerBlocks(nn.Module):
def __init__(
self,
*,
transformer=None,
residual_diff_threshold,
return_hidden_states_first=True,
return_hidden_states_only=False,
verbose: bool = False,
):
super().__init__()
self.transformer = transformer
self.transformer_blocks = transformer.transformer_blocks
self.single_transformer_blocks = transformer.single_transformer_blocks
self.residual_diff_threshold = residual_diff_threshold
self.return_hidden_states_first = return_hidden_states_first
self.return_hidden_states_only = return_hidden_states_only
self.verbose = verbose
def forward(self, hidden_states, encoder_hidden_states, *args, **kwargs):
batch_size = hidden_states.shape[0]
if self.residual_diff_threshold <= 0.0 or batch_size > 1:
if batch_size > 1:
print("Batch size > 1 currently not supported")
first_transformer_block = self.transformer_blocks[0]
encoder_hidden_states, hidden_states = first_transformer_block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, *args, **kwargs
)
return (
hidden_states
if self.return_hidden_states_only
else (
(hidden_states, encoder_hidden_states)
if self.return_hidden_states_first
else (encoder_hidden_states, hidden_states)
)
)
original_hidden_states = hidden_states
first_transformer_block = self.transformer_blocks[0]
encoder_hidden_states, hidden_states = first_transformer_block.forward_layer_at(
0, hidden_states, encoder_hidden_states, *args, **kwargs
)
first_hidden_states_residual = hidden_states - original_hidden_states
del original_hidden_states
can_use_cache = get_can_use_cache(
first_hidden_states_residual,
threshold=self.residual_diff_threshold,
parallelized=self.transformer is not None and getattr(self.transformer, "_is_parallelized", False),
)
torch._dynamo.graph_break()
if can_use_cache:
del first_hidden_states_residual
if self.verbose:
print("Cache hit!!!")
hidden_states, encoder_hidden_states = apply_prev_hidden_states_residual(
hidden_states, encoder_hidden_states
)
else:
if self.verbose:
print("Cache miss!!!")
set_buffer("first_hidden_states_residual", first_hidden_states_residual)
del first_hidden_states_residual
(
hidden_states,
encoder_hidden_states,
hidden_states_residual,
encoder_hidden_states_residual,
) = self.call_remaining_transformer_blocks(hidden_states, encoder_hidden_states, *args, **kwargs)
set_buffer("hidden_states_residual", hidden_states_residual)
set_buffer("encoder_hidden_states_residual", encoder_hidden_states_residual)
torch._dynamo.graph_break()
return (
hidden_states
if self.return_hidden_states_only
else (
(hidden_states, encoder_hidden_states)
if self.return_hidden_states_first
else (encoder_hidden_states, hidden_states)
)
)
def call_remaining_transformer_blocks(self, hidden_states, encoder_hidden_states, *args, **kwargs):
first_transformer_block = self.transformer_blocks[0]
original_hidden_states = hidden_states
original_encoder_hidden_states = encoder_hidden_states
encoder_hidden_states, hidden_states = first_transformer_block.forward(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
skip_first_layer=True,
*args,
**kwargs,
)
hidden_states = hidden_states.contiguous()
encoder_hidden_states = encoder_hidden_states.contiguous()
hidden_states_residual = hidden_states - original_hidden_states
encoder_hidden_states_residual = encoder_hidden_states - original_encoder_hidden_states
return hidden_states, encoder_hidden_states, hidden_states_residual, encoder_hidden_states_residual
...@@ -10,22 +10,37 @@ ...@@ -10,22 +10,37 @@
class QuantizedFluxModel : public ModuleWrapper<FluxModel> { // : public torch::CustomClassHolder { class QuantizedFluxModel : public ModuleWrapper<FluxModel> { // : public torch::CustomClassHolder {
public: public:
void init(bool use_fp4, bool offload, bool bf16, int8_t deviceId) { void init(bool use_fp4, bool offload, bool bf16, int8_t deviceId) {
spdlog::info("Initializing QuantizedFluxModel"); spdlog::info("Initializing QuantizedFluxModel on device {}", deviceId);
if (!bf16) {
spdlog::info("Use FP16 model");
}
if (offload) { if (offload) {
spdlog::info("Layer offloading enabled"); spdlog::info("Layer offloading enabled");
} }
ModuleWrapper::init(deviceId);
CUDADeviceContext ctx(this->deviceId);
net = std::make_unique<FluxModel>(use_fp4, offload, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId)); net = std::make_unique<FluxModel>(use_fp4, offload, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
} }
bool isBF16() {
checkModel();
return net->dtype == Tensor::BF16;
}
torch::Tensor forward( torch::Tensor forward(
torch::Tensor hidden_states, torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states, torch::Tensor encoder_hidden_states,
torch::Tensor temb, torch::Tensor temb,
torch::Tensor rotary_emb_img, torch::Tensor rotary_emb_img,
torch::Tensor rotary_emb_context, torch::Tensor rotary_emb_context,
torch::Tensor rotary_emb_single) torch::Tensor rotary_emb_single,
std::optional<torch::Tensor> controlnet_block_samples = std::nullopt,
std::optional<torch::Tensor> controlnet_single_block_samples = std::nullopt,
bool skip_first_layer = false)
{ {
checkModel(); checkModel();
CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedFluxModel forward"); spdlog::debug("QuantizedFluxModel forward");
...@@ -42,7 +57,10 @@ public: ...@@ -42,7 +57,10 @@ public:
from_torch(temb), from_torch(temb),
from_torch(rotary_emb_img), from_torch(rotary_emb_img),
from_torch(rotary_emb_context), from_torch(rotary_emb_context),
from_torch(rotary_emb_single) from_torch(rotary_emb_single),
controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{},
controlnet_single_block_samples.has_value() ? from_torch(controlnet_single_block_samples.value().contiguous()) : Tensor{},
skip_first_layer
); );
torch::Tensor output = to_torch(result); torch::Tensor output = to_torch(result);
...@@ -53,12 +71,16 @@ public: ...@@ -53,12 +71,16 @@ public:
std::tuple<torch::Tensor, torch::Tensor> forward_layer( std::tuple<torch::Tensor, torch::Tensor> forward_layer(
int64_t idx, int64_t idx,
torch::Tensor hidden_states, torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states, torch::Tensor encoder_hidden_states,
torch::Tensor temb, torch::Tensor temb,
torch::Tensor rotary_emb_img, torch::Tensor rotary_emb_img,
torch::Tensor rotary_emb_context) torch::Tensor rotary_emb_context,
std::optional<torch::Tensor> controlnet_block_samples = std::nullopt,
std::optional<torch::Tensor> controlnet_single_block_samples = std::nullopt)
{ {
CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedFluxModel forward_layer {}", idx); spdlog::debug("QuantizedFluxModel forward_layer {}", idx);
hidden_states = hidden_states.contiguous(); hidden_states = hidden_states.contiguous();
...@@ -67,17 +89,19 @@ public: ...@@ -67,17 +89,19 @@ public:
rotary_emb_img = rotary_emb_img.contiguous(); rotary_emb_img = rotary_emb_img.contiguous();
rotary_emb_context = rotary_emb_context.contiguous(); rotary_emb_context = rotary_emb_context.contiguous();
auto &&[result_img, result_txt] = net->transformer_blocks.at(idx)->forward( auto &&[hidden_states_, encoder_hidden_states_] = net->forward_layer(
idx,
from_torch(hidden_states), from_torch(hidden_states),
from_torch(encoder_hidden_states), from_torch(encoder_hidden_states),
from_torch(temb), from_torch(temb),
from_torch(rotary_emb_img), from_torch(rotary_emb_img),
from_torch(rotary_emb_context), from_torch(rotary_emb_context),
0.0f controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{},
controlnet_single_block_samples.has_value() ? from_torch(controlnet_single_block_samples.value().contiguous()) : Tensor{}
); );
hidden_states = to_torch(result_img); hidden_states = to_torch(hidden_states_);
encoder_hidden_states = to_torch(result_txt); encoder_hidden_states = to_torch(encoder_hidden_states_);
Tensor::synchronizeDevice(); Tensor::synchronizeDevice();
return { hidden_states, encoder_hidden_states }; return { hidden_states, encoder_hidden_states };
...@@ -85,10 +109,12 @@ public: ...@@ -85,10 +109,12 @@ public:
torch::Tensor forward_single_layer( torch::Tensor forward_single_layer(
int64_t idx, int64_t idx,
torch::Tensor hidden_states, torch::Tensor hidden_states,
torch::Tensor temb, torch::Tensor temb,
torch::Tensor rotary_emb_single) torch::Tensor rotary_emb_single)
{ {
CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedFluxModel forward_single_layer {}", idx); spdlog::debug("QuantizedFluxModel forward_single_layer {}", idx);
hidden_states = hidden_states.contiguous(); hidden_states = hidden_states.contiguous();
...@@ -115,6 +141,8 @@ public: ...@@ -115,6 +141,8 @@ public:
throw std::invalid_argument("skipRanks must be multiples of 16"); throw std::invalid_argument("skipRanks must be multiples of 16");
} }
CUDADeviceContext ctx(deviceId);
spdlog::info("Set lora scale to {} (skip {} ranks)", scale, skipRanks); spdlog::info("Set lora scale to {} (skip {} ranks)", scale, skipRanks);
net->traverse([&](Module *module) { net->traverse([&](Module *module) {
...@@ -131,8 +159,20 @@ public: ...@@ -131,8 +159,20 @@ public:
}); });
} }
void forceFP16Attention(bool enable) { void setAttentionImpl(std::string name) {
Attention::setForceFP16(net.get(), enable); if (name.empty() || name == "default") {
name = "flashattn2";
}
spdlog::info("Set attention implementation to {}", name);
if (name == "flashattn2") {
net->setAttentionImpl(AttentionImpl::FlashAttention2);
} else if (name == "nunchaku-fp16") {
net->setAttentionImpl(AttentionImpl::NunchakuFP16);
} else {
throw std::invalid_argument(spdlog::fmt_lib::format("Invalid attention implementation {}", name));
}
} }
}; };
\ No newline at end of file
...@@ -10,7 +10,7 @@ class QuantizedGEMM : public ModuleWrapper<GEMM_W4A4> { ...@@ -10,7 +10,7 @@ class QuantizedGEMM : public ModuleWrapper<GEMM_W4A4> {
public: public:
void init(int64_t in_features, int64_t out_features, bool bias, bool use_fp4, bool bf16, int8_t deviceId) { void init(int64_t in_features, int64_t out_features, bool bias, bool use_fp4, bool bf16, int8_t deviceId) {
spdlog::info("Initializing QuantizedGEMM"); spdlog::info("Initializing QuantizedGEMM");
size_t val = 0; size_t val = 0;
checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, 8192)); checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, 8192));
checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize)); checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
...@@ -27,7 +27,7 @@ public: ...@@ -27,7 +27,7 @@ public:
x = x.contiguous(); x = x.contiguous();
Tensor result = net->forward(from_torch(x)); Tensor result = net->forward(from_torch(x));
torch::Tensor output = to_torch(result); torch::Tensor output = to_torch(result);
Tensor::synchronizeDevice(); Tensor::synchronizeDevice();
...@@ -48,7 +48,7 @@ public: ...@@ -48,7 +48,7 @@ public:
const int M = x.shape[0]; const int M = x.shape[0];
const int K = x.shape[1] * 2; const int K = x.shape[1] * 2;
assert(x.dtype() == Tensor::INT8); assert(x.dtype() == Tensor::INT8);
// activation: row major, [M / BLOCK_M, K / WARP_K, NUM_WARPS, WARP_M_TILES, WARP_SIZE] of packed_act_t (uint4) // activation: row major, [M / BLOCK_M, K / WARP_K, NUM_WARPS, WARP_M_TILES, WARP_SIZE] of packed_act_t (uint4)
...@@ -67,7 +67,7 @@ public: ...@@ -67,7 +67,7 @@ public:
const int offset = ((bm * (K / WARP_K) + bn) * NUM_WARPS + warpId) * WARP_M_TILES * WARP_SIZE * 4; const int offset = ((bm * (K / WARP_K) + bn) * NUM_WARPS + warpId) * WARP_M_TILES * WARP_SIZE * 4;
for (int i = 0; i < 16; i++) { for (int i = 0; i < 16; i++) {
assert(offset + i < x.numel() / 4); assert(static_cast<size_t>(offset + i) < x.numel() / 4);
uint32_t val = x.data_ptr<uint32_t>()[offset + i]; uint32_t val = x.data_ptr<uint32_t>()[offset + i];
ss << "{"; ss << "{";
for (int j = 0; j < 8; j++) { for (int j = 0; j < 8; j++) {
...@@ -83,7 +83,7 @@ public: ...@@ -83,7 +83,7 @@ public:
} }
} }
} }
ss << std::endl; ss << std::endl;
return ss.str(); return ss.str();
} }
...@@ -99,7 +99,7 @@ public: ...@@ -99,7 +99,7 @@ public:
from_torch(x), from_torch(x),
fuse_glu fuse_glu
); );
Tensor act = qout.act.copy(Device::cpu()); Tensor act = qout.act.copy(Device::cpu());
Tensor ascales = qout.ascales.copy(Device::cpu()); Tensor ascales = qout.ascales.copy(Device::cpu());
Tensor lora_act = qout.lora_act.copy(Device::cpu()); Tensor lora_act = qout.lora_act.copy(Device::cpu());
...@@ -110,4 +110,4 @@ public: ...@@ -110,4 +110,4 @@ public:
spdlog::debug("ascales = {}", dumpTensorBF16(ascales)); spdlog::debug("ascales = {}", dumpTensorBF16(ascales));
} }
}; };
\ No newline at end of file
...@@ -9,7 +9,12 @@ ...@@ -9,7 +9,12 @@
template<typename M> template<typename M>
class ModuleWrapper { class ModuleWrapper {
public: public:
void init(int deviceId) {
this->deviceId = deviceId;
}
void reset() { void reset() {
CUDADeviceContext ctx(this->deviceId);
debugContext.reset(); debugContext.reset();
net.reset(); net.reset();
Tensor::synchronizeDevice(); Tensor::synchronizeDevice();
...@@ -20,6 +25,7 @@ public: ...@@ -20,6 +25,7 @@ public:
void load(std::string path, bool partial = false) { void load(std::string path, bool partial = false) {
checkModel(); checkModel();
CUDADeviceContext ctx(this->deviceId);
spdlog::info("{} weights from {}", partial ? "Loading partial" : "Loading", path); spdlog::info("{} weights from {}", partial ? "Loading partial" : "Loading", path);
...@@ -30,6 +36,19 @@ public: ...@@ -30,6 +36,19 @@ public:
spdlog::info("Done."); spdlog::info("Done.");
} }
void loadDict(std::map<std::string, torch::Tensor> dict, bool partial = false) {
checkModel();
CUDADeviceContext ctx(this->deviceId);
spdlog::info("{} weights from pytorch", partial ? "Loading partial" : "Loading");
std::shared_ptr<TensorsProviderTorch> provider = std::make_shared<TensorsProviderTorch>(std::move(dict));
net->loadParams(*provider, partial);
Tensor::synchronizeDevice();
spdlog::info("Done.");
}
void startDebug() { void startDebug() {
debugContext = std::make_unique<DebugContext>(); debugContext = std::make_unique<DebugContext>();
} }
...@@ -38,6 +57,8 @@ public: ...@@ -38,6 +57,8 @@ public:
} }
auto getDebugResults() { auto getDebugResults() {
CUDADeviceContext ctx(this->deviceId);
std::map<std::string, torch::Tensor> result; std::map<std::string, torch::Tensor> result;
if (debugContext) { if (debugContext) {
...@@ -59,4 +80,6 @@ protected: ...@@ -59,4 +80,6 @@ protected:
protected: protected:
std::unique_ptr<M> net; std::unique_ptr<M> net;
std::unique_ptr<DebugContext> debugContext; std::unique_ptr<DebugContext> debugContext;
int deviceId = -1;
}; };
\ No newline at end of file
...@@ -32,7 +32,11 @@ namespace nunchaku::ops { ...@@ -32,7 +32,11 @@ namespace nunchaku::ops {
bool fuse_silu, bool fuse_silu,
bool fp4, bool fp4,
float alpha, float alpha,
std::optional<torch::Tensor> wcscales std::optional<torch::Tensor> wcscales,
std::optional<torch::Tensor> out_q, // packed attention [B, H, M, D]
std::optional<torch::Tensor> out_k, // packed attention [B, H, M, D]
std::optional<torch::Tensor> out_v, // packed attention [B, H, M, D]
int attn_tokens
) { ) {
spdlog::trace("running gemm_w4a4: "); spdlog::trace("running gemm_w4a4: ");
...@@ -70,11 +74,31 @@ namespace nunchaku::ops { ...@@ -70,11 +74,31 @@ namespace nunchaku::ops {
fuse_silu, fuse_silu,
fp4, fp4,
alpha, alpha,
getTensor(wcscales) getTensor(wcscales),
getTensor(out_q),
getTensor(out_k),
getTensor(out_v),
attn_tokens
); );
// Tensor::synchronizeDevice(); // Tensor::synchronizeDevice();
} }
void attention_fp16(
torch::Tensor q, // packed [Batch, Head, TokensQ, HEAD_DIM]
torch::Tensor k, // packed [Batch, Head, TokensKV, HEAD_DIM]
torch::Tensor v, // packed [Batch, Head, TokensKV, HEAD_DIM]
torch::Tensor o, // linear [Batch, TokensQ, Head * HEAD_DIM]
float scale
) {
nunchaku::kernels::attention_fp16(
from_torch(q),
from_torch(k),
from_torch(v),
from_torch(o),
scale
);
}
torch::Tensor gemv_awq( torch::Tensor gemv_awq(
torch::Tensor _in_feats, torch::Tensor _in_feats,
torch::Tensor _kernel, torch::Tensor _kernel,
...@@ -122,6 +146,36 @@ namespace nunchaku::ops { ...@@ -122,6 +146,36 @@ namespace nunchaku::ops {
return output; return output;
} }
void test_rmsnorm_rope(
torch::Tensor input,
torch::Tensor output,
torch::Tensor norm_q,
torch::Tensor norm_k,
torch::Tensor rotary_emb)
{
nunchaku::kernels::test_rmsnorm_rope(
from_torch(input),
from_torch(output),
from_torch(norm_q),
from_torch(norm_k),
from_torch(rotary_emb)
);
}
void test_pack_qkv(
torch::Tensor input,
torch::Tensor out_q,
torch::Tensor out_k,
torch::Tensor out_v,
int numTokens)
{
nunchaku::kernels::test_pack_qkv(
from_torch(input),
from_torch(out_q),
from_torch(out_k),
from_torch(out_v),
numTokens
);
}
}; };
\ No newline at end of file
...@@ -18,18 +18,42 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -18,18 +18,42 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("deviceId") py::arg("deviceId")
) )
.def("reset", &QuantizedFluxModel::reset) .def("reset", &QuantizedFluxModel::reset)
.def("load", &QuantizedFluxModel::load, .def("load", &QuantizedFluxModel::load,
py::arg("path"), py::arg("path"),
py::arg("partial") = false py::arg("partial") = false
) )
.def("forward", &QuantizedFluxModel::forward) .def("loadDict", &QuantizedFluxModel::loadDict,
.def("forward_layer", &QuantizedFluxModel::forward_layer) py::arg("dict"),
py::arg("partial") = false
)
.def("forward", &QuantizedFluxModel::forward,
py::arg("hidden_states"),
py::arg("encoder_hidden_states"),
py::arg("temb"),
py::arg("rotary_emb_img"),
py::arg("rotary_emb_context"),
py::arg("rotary_emb_single"),
py::arg("controlnet_block_samples") = py::none(),
py::arg("controlnet_single_block_samples") = py::none(),
py::arg("skip_first_layer") = false
)
.def("forward_layer", &QuantizedFluxModel::forward_layer,
py::arg("idx"),
py::arg("hidden_states"),
py::arg("encoder_hidden_states"),
py::arg("temb"),
py::arg("rotary_emb_img"),
py::arg("rotary_emb_context"),
py::arg("controlnet_block_samples") = py::none(),
py::arg("controlnet_single_block_samples") = py::none()
)
.def("forward_single_layer", &QuantizedFluxModel::forward_single_layer) .def("forward_single_layer", &QuantizedFluxModel::forward_single_layer)
.def("startDebug", &QuantizedFluxModel::startDebug) .def("startDebug", &QuantizedFluxModel::startDebug)
.def("stopDebug", &QuantizedFluxModel::stopDebug) .def("stopDebug", &QuantizedFluxModel::stopDebug)
.def("getDebugResults", &QuantizedFluxModel::getDebugResults) .def("getDebugResults", &QuantizedFluxModel::getDebugResults)
.def("setLoraScale", &QuantizedFluxModel::setLoraScale) .def("setLoraScale", &QuantizedFluxModel::setLoraScale)
.def("forceFP16Attention", &QuantizedFluxModel::forceFP16Attention) .def("setAttentionImpl", &QuantizedFluxModel::setAttentionImpl)
.def("isBF16", &QuantizedFluxModel::isBF16)
; ;
py::class_<QuantizedSanaModel>(m, "QuantizedSanaModel") py::class_<QuantizedSanaModel>(m, "QuantizedSanaModel")
.def(py::init<>()) .def(py::init<>())
...@@ -41,10 +65,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -41,10 +65,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("deviceId") py::arg("deviceId")
) )
.def("reset", &QuantizedSanaModel::reset) .def("reset", &QuantizedSanaModel::reset)
.def("load", &QuantizedSanaModel::load, .def("load", &QuantizedSanaModel::load,
py::arg("path"), py::arg("path"),
py::arg("partial") = false py::arg("partial") = false
) )
.def("loadDict", &QuantizedSanaModel::loadDict,
py::arg("dict"),
py::arg("partial") = false
)
.def("forward", &QuantizedSanaModel::forward) .def("forward", &QuantizedSanaModel::forward)
.def("forward_layer", &QuantizedSanaModel::forward_layer) .def("forward_layer", &QuantizedSanaModel::forward_layer)
.def("startDebug", &QuantizedSanaModel::startDebug) .def("startDebug", &QuantizedSanaModel::startDebug)
...@@ -74,15 +102,22 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -74,15 +102,22 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
; ;
m.def_submodule("ops") m.def_submodule("ops")
.def("gemm_w4a4", nunchaku::ops::gemm_w4a4)
.def("attention_fp16", nunchaku::ops::attention_fp16)
.def("gemm_awq", nunchaku::ops::gemm_awq) .def("gemm_awq", nunchaku::ops::gemm_awq)
.def("gemv_awq", nunchaku::ops::gemv_awq) .def("gemv_awq", nunchaku::ops::gemv_awq)
.def("test_rmsnorm_rope", nunchaku::ops::test_rmsnorm_rope)
.def("test_pack_qkv", nunchaku::ops::test_pack_qkv)
; ;
m.def_submodule("utils") m.def_submodule("utils")
.def("set_log_level", [](const std::string &level) { .def("set_log_level", [](const std::string &level) {
spdlog::set_level(spdlog::level::from_str(level)); spdlog::set_level(spdlog::level::from_str(level));
}) })
.def("set_cuda_stack_limit", nunchaku::utils::set_cuda_stack_limit)
.def("disable_memory_auto_release", nunchaku::utils::disable_memory_auto_release) .def("disable_memory_auto_release", nunchaku::utils::disable_memory_auto_release)
.def("trim_memory", nunchaku::utils::trim_memory) .def("trim_memory", nunchaku::utils::trim_memory)
.def("set_faster_i2f_mode", nunchaku::utils::set_faster_i2f_mode)
; ;
} }
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
class QuantizedSanaModel : public ModuleWrapper<SanaModel> { class QuantizedSanaModel : public ModuleWrapper<SanaModel> {
public: public:
void init(pybind11::dict config, std::vector<int> pag_layers, bool use_fp4, bool bf16, int8_t deviceId) { void init(pybind11::dict config, std::vector<int> pag_layers, bool use_fp4, bool bf16, int8_t deviceId) {
spdlog::info("Initializing QuantizedSanaModel"); spdlog::info("Initializing QuantizedSanaModel on device {}", deviceId);
SanaConfig cfg{ SanaConfig cfg{
.num_layers = config["num_layers"].cast<int>(), .num_layers = config["num_layers"].cast<int>(),
.num_attention_heads = config["num_attention_heads"].cast<int>(), .num_attention_heads = config["num_attention_heads"].cast<int>(),
...@@ -19,21 +19,26 @@ public: ...@@ -19,21 +19,26 @@ public:
.pag_layers = pag_layers, .pag_layers = pag_layers,
.use_fp4 = use_fp4, .use_fp4 = use_fp4,
}; };
ModuleWrapper::init(deviceId);
CUDADeviceContext ctx(this->deviceId);
net = std::make_unique<SanaModel>(cfg, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId)); net = std::make_unique<SanaModel>(cfg, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
} }
torch::Tensor forward( torch::Tensor forward(
torch::Tensor hidden_states, torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states, torch::Tensor encoder_hidden_states,
torch::Tensor timestep, torch::Tensor timestep,
torch::Tensor cu_seqlens_img, torch::Tensor cu_seqlens_img,
torch::Tensor cu_seqlens_txt, torch::Tensor cu_seqlens_txt,
int H, int H,
int W, int W,
bool pag, bool pag,
bool cfg) bool cfg,
bool skip_first_layer = false)
{ {
checkModel(); checkModel();
CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedSanaModel forward"); spdlog::debug("QuantizedSanaModel forward");
...@@ -50,7 +55,8 @@ public: ...@@ -50,7 +55,8 @@ public:
from_torch(cu_seqlens_img), from_torch(cu_seqlens_img),
from_torch(cu_seqlens_txt), from_torch(cu_seqlens_txt),
H, W, H, W,
pag, cfg pag, cfg,
skip_first_layer
); );
torch::Tensor output = to_torch(result); torch::Tensor output = to_torch(result);
...@@ -61,17 +67,18 @@ public: ...@@ -61,17 +67,18 @@ public:
torch::Tensor forward_layer( torch::Tensor forward_layer(
int64_t idx, int64_t idx,
torch::Tensor hidden_states, torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states, torch::Tensor encoder_hidden_states,
torch::Tensor timestep, torch::Tensor timestep,
torch::Tensor cu_seqlens_img, torch::Tensor cu_seqlens_img,
torch::Tensor cu_seqlens_txt, torch::Tensor cu_seqlens_txt,
int H, int H,
int W, int W,
bool pag, bool pag,
bool cfg) bool cfg)
{ {
checkModel(); checkModel();
CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedSanaModel forward_layer {}", idx); spdlog::debug("QuantizedSanaModel forward_layer {}", idx);
......
...@@ -2,9 +2,17 @@ ...@@ -2,9 +2,17 @@
#include "common.h" #include "common.h"
#include "Tensor.h" #include "Tensor.h"
#include "kernels/zgemm/zgemm.h"
namespace nunchaku::utils { namespace nunchaku::utils {
void set_cuda_stack_limit(int64_t newval) {
size_t val = 0;
checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, (size_t)newval));
checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
spdlog::debug("Stack={}", val);
}
void disable_memory_auto_release() { void disable_memory_auto_release() {
int device; int device;
checkCUDA(cudaGetDevice(&device)); checkCUDA(cudaGetDevice(&device));
...@@ -23,4 +31,9 @@ namespace nunchaku::utils { ...@@ -23,4 +31,9 @@ namespace nunchaku::utils {
checkCUDA(cudaMemPoolTrimTo(mempool, bytesToKeep)); checkCUDA(cudaMemPoolTrimTo(mempool, bytesToKeep));
} }
void set_faster_i2f_mode(std::string mode) {
spdlog::info("Set fasteri2f mode to {}", mode);
kernels::set_faster_i2f_mode(mode);
}
}; };
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment