Unverified Commit 57e50f8d authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

style: upgrade the linter (#339)

* style: reformated codes

* style: reformated codes
parent b737368d
...@@ -11,13 +11,13 @@ public: ...@@ -11,13 +11,13 @@ public:
void init(pybind11::dict config, std::vector<int> pag_layers, bool use_fp4, bool bf16, int8_t deviceId) { void init(pybind11::dict config, std::vector<int> pag_layers, bool use_fp4, bool bf16, int8_t deviceId) {
spdlog::info("Initializing QuantizedSanaModel on device {}", deviceId); spdlog::info("Initializing QuantizedSanaModel on device {}", deviceId);
SanaConfig cfg{ SanaConfig cfg{
.num_layers = config["num_layers"].cast<int>(), .num_layers = config["num_layers"].cast<int>(),
.num_attention_heads = config["num_attention_heads"].cast<int>(), .num_attention_heads = config["num_attention_heads"].cast<int>(),
.attention_head_dim = config["attention_head_dim"].cast<int>(), .attention_head_dim = config["attention_head_dim"].cast<int>(),
.num_cross_attention_heads = config["num_cross_attention_heads"].cast<int>(), .num_cross_attention_heads = config["num_cross_attention_heads"].cast<int>(),
.expand_ratio = config["mlp_ratio"].cast<double>(), .expand_ratio = config["mlp_ratio"].cast<double>(),
.pag_layers = pag_layers, .pag_layers = pag_layers,
.use_fp4 = use_fp4, .use_fp4 = use_fp4,
}; };
ModuleWrapper::init(deviceId); ModuleWrapper::init(deviceId);
...@@ -25,39 +25,37 @@ public: ...@@ -25,39 +25,37 @@ public:
net = std::make_unique<SanaModel>(cfg, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId)); net = std::make_unique<SanaModel>(cfg, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
} }
torch::Tensor forward( torch::Tensor forward(torch::Tensor hidden_states,
torch::Tensor hidden_states, torch::Tensor encoder_hidden_states,
torch::Tensor encoder_hidden_states, torch::Tensor timestep,
torch::Tensor timestep, torch::Tensor cu_seqlens_img,
torch::Tensor cu_seqlens_img, torch::Tensor cu_seqlens_txt,
torch::Tensor cu_seqlens_txt, int H,
int H, int W,
int W, bool pag,
bool pag, bool cfg,
bool cfg, bool skip_first_layer = false) {
bool skip_first_layer = false)
{
checkModel(); checkModel();
CUDADeviceContext ctx(deviceId); CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedSanaModel forward"); spdlog::debug("QuantizedSanaModel forward");
hidden_states = hidden_states.contiguous(); hidden_states = hidden_states.contiguous();
encoder_hidden_states = encoder_hidden_states.contiguous(); encoder_hidden_states = encoder_hidden_states.contiguous();
timestep = timestep.contiguous(); timestep = timestep.contiguous();
cu_seqlens_img = cu_seqlens_img.contiguous(); cu_seqlens_img = cu_seqlens_img.contiguous();
cu_seqlens_txt = cu_seqlens_txt.contiguous(); cu_seqlens_txt = cu_seqlens_txt.contiguous();
Tensor result = net->forward( Tensor result = net->forward(from_torch(hidden_states),
from_torch(hidden_states), from_torch(encoder_hidden_states),
from_torch(encoder_hidden_states), from_torch(timestep),
from_torch(timestep), from_torch(cu_seqlens_img),
from_torch(cu_seqlens_img), from_torch(cu_seqlens_txt),
from_torch(cu_seqlens_txt), H,
H, W, W,
pag, cfg, pag,
skip_first_layer cfg,
); skip_first_layer);
torch::Tensor output = to_torch(result); torch::Tensor output = to_torch(result);
// Tensor::synchronizeDevice(); // Tensor::synchronizeDevice();
...@@ -65,42 +63,40 @@ public: ...@@ -65,42 +63,40 @@ public:
return output; return output;
} }
torch::Tensor forward_layer( torch::Tensor forward_layer(int64_t idx,
int64_t idx, torch::Tensor hidden_states,
torch::Tensor hidden_states, torch::Tensor encoder_hidden_states,
torch::Tensor encoder_hidden_states, torch::Tensor timestep,
torch::Tensor timestep, torch::Tensor cu_seqlens_img,
torch::Tensor cu_seqlens_img, torch::Tensor cu_seqlens_txt,
torch::Tensor cu_seqlens_txt, int H,
int H, int W,
int W, bool pag,
bool pag, bool cfg) {
bool cfg)
{
checkModel(); checkModel();
CUDADeviceContext ctx(deviceId); CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedSanaModel forward_layer {}", idx); spdlog::debug("QuantizedSanaModel forward_layer {}", idx);
hidden_states = hidden_states.contiguous(); hidden_states = hidden_states.contiguous();
encoder_hidden_states = encoder_hidden_states.contiguous(); encoder_hidden_states = encoder_hidden_states.contiguous();
timestep = timestep.contiguous(); timestep = timestep.contiguous();
cu_seqlens_img = cu_seqlens_img.contiguous(); cu_seqlens_img = cu_seqlens_img.contiguous();
cu_seqlens_txt = cu_seqlens_txt.contiguous(); cu_seqlens_txt = cu_seqlens_txt.contiguous();
Tensor result = net->transformer_blocks.at(idx)->forward( Tensor result = net->transformer_blocks.at(idx)->forward(from_torch(hidden_states),
from_torch(hidden_states), from_torch(encoder_hidden_states),
from_torch(encoder_hidden_states), from_torch(timestep),
from_torch(timestep), from_torch(cu_seqlens_img),
from_torch(cu_seqlens_img), from_torch(cu_seqlens_txt),
from_torch(cu_seqlens_txt), H,
H, W, W,
pag, cfg pag,
); cfg);
torch::Tensor output = to_torch(result); torch::Tensor output = to_torch(result);
// Tensor::synchronizeDevice(); // Tensor::synchronizeDevice();
return output; return output;
} }
}; };
\ No newline at end of file
...@@ -6,34 +6,34 @@ ...@@ -6,34 +6,34 @@
namespace nunchaku::utils { namespace nunchaku::utils {
void set_cuda_stack_limit(int64_t newval) { void set_cuda_stack_limit(int64_t newval) {
size_t val = 0; size_t val = 0;
checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, (size_t)newval)); checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, (size_t)newval));
checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize)); checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
spdlog::debug("Stack={}", val); spdlog::debug("Stack={}", val);
} }
void disable_memory_auto_release() { void disable_memory_auto_release() {
int device; int device;
checkCUDA(cudaGetDevice(&device)); checkCUDA(cudaGetDevice(&device));
cudaMemPool_t mempool; cudaMemPool_t mempool;
checkCUDA(cudaDeviceGetDefaultMemPool(&mempool, device)); checkCUDA(cudaDeviceGetDefaultMemPool(&mempool, device));
uint64_t threshold = UINT64_MAX; uint64_t threshold = UINT64_MAX;
checkCUDA(cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrReleaseThreshold, &threshold)); checkCUDA(cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrReleaseThreshold, &threshold));
} }
void trim_memory() { void trim_memory() {
int device; int device;
checkCUDA(cudaGetDevice(&device)); checkCUDA(cudaGetDevice(&device));
cudaMemPool_t mempool; cudaMemPool_t mempool;
checkCUDA(cudaDeviceGetDefaultMemPool(&mempool, device)); checkCUDA(cudaDeviceGetDefaultMemPool(&mempool, device));
size_t bytesToKeep = 0; size_t bytesToKeep = 0;
checkCUDA(cudaMemPoolTrimTo(mempool, bytesToKeep)); checkCUDA(cudaMemPoolTrimTo(mempool, bytesToKeep));
} }
void set_faster_i2f_mode(std::string mode) { void set_faster_i2f_mode(std::string mode) {
spdlog::info("Set fasteri2f mode to {}", mode); spdlog::info("Set fasteri2f mode to {}", mode);
kernels::set_faster_i2f_mode(mode); kernels::set_faster_i2f_mode(mode);
} }
}; }; // namespace nunchaku::utils
\ No newline at end of file
from .diffusers_converter import to_diffusers from .diffusers_converter import to_diffusers
from .nunchaku_converter import convert_to_nunchaku_flux_lowrank_dict, to_nunchaku from .nunchaku_converter import convert_to_nunchaku_flux_lowrank_dict, to_nunchaku
from .utils import is_nunchaku_format from .utils import is_nunchaku_format
__all__ = ["to_diffusers", "to_nunchaku", "convert_to_nunchaku_flux_lowrank_dict", "is_nunchaku_format"]
...@@ -7,10 +7,10 @@ import torch ...@@ -7,10 +7,10 @@ import torch
from safetensors.torch import save_file from safetensors.torch import save_file
from tqdm import tqdm from tqdm import tqdm
from ...utils import filter_state_dict, load_state_dict_in_safetensors
from .diffusers_converter import to_diffusers from .diffusers_converter import to_diffusers
from .packer import NunchakuWeightPacker from .packer import NunchakuWeightPacker
from .utils import is_nunchaku_format, pad from .utils import is_nunchaku_format, pad
from ...utils import filter_state_dict, load_state_dict_in_safetensors
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
# Copy the packer from https://github.com/mit-han-lab/deepcompressor/ # Copy the packer from https://github.com/mit-han-lab/deepcompressor/
import torch import torch
from .utils import pad
from ...utils import ceil_divide from ...utils import ceil_divide
from .utils import pad
class MmaWeightPackerBase: class MmaWeightPackerBase:
......
from .text_encoders.t5_encoder import NunchakuT5EncoderModel from .text_encoders.t5_encoder import NunchakuT5EncoderModel
from .transformers import NunchakuFluxTransformer2dModel, NunchakuSanaTransformer2DModel from .transformers import NunchakuFluxTransformer2dModel, NunchakuSanaTransformer2DModel
__all__ = ["NunchakuFluxTransformer2dModel", "NunchakuSanaTransformer2DModel", "NunchakuT5EncoderModel"]
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .factory import create_model_and_transforms from .factory import create_model_and_transforms
__all__ = ["create_model_and_transforms", "OPENAI_DATASET_MEAN", "OPENAI_DATASET_STD"]
...@@ -14,8 +14,8 @@ try: ...@@ -14,8 +14,8 @@ try:
except ImportError: except ImportError:
from timm.layers import drop_path, to_2tuple, trunc_normal_ from timm.layers import drop_path, to_2tuple, trunc_normal_
from .transformer import PatchDropout
from .rope import VisionRotaryEmbeddingFast from .rope import VisionRotaryEmbeddingFast
from .transformer import PatchDropout
if os.getenv("ENV_TYPE") == "deepspeed": if os.getenv("ENV_TYPE") == "deepspeed":
try: try:
...@@ -26,7 +26,7 @@ else: ...@@ -26,7 +26,7 @@ else:
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
try: try:
import xformers import xformers # noqa: F401
import xformers.ops as xops import xformers.ops as xops
XFORMERS_IS_AVAILBLE = True XFORMERS_IS_AVAILBLE = True
......
...@@ -9,7 +9,7 @@ from typing import Optional, Tuple, Union ...@@ -9,7 +9,7 @@ from typing import Optional, Tuple, Union
import torch import torch
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .model import CLIP, convert_to_custom_text_state_dict, CustomCLIP, get_cast_dtype from .model import CLIP, CustomCLIP, convert_to_custom_text_state_dict, get_cast_dtype
from .pretrained import download_pretrained, get_pretrained_cfg, list_pretrained_tags_by_model from .pretrained import download_pretrained, get_pretrained_cfg, list_pretrained_tags_by_model
from .transform import image_transform from .transform import image_transform
from .utils import resize_clip_pos_embed, resize_eva_pos_embed, resize_evaclip_pos_embed, resize_visual_pos_embed from .utils import resize_clip_pos_embed, resize_eva_pos_embed, resize_evaclip_pos_embed, resize_visual_pos_embed
......
...@@ -11,7 +11,7 @@ from torch import TensorType ...@@ -11,7 +11,7 @@ from torch import TensorType
try: try:
import transformers import transformers
from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, AutoConfig, PretrainedConfig from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM, AutoTokenizer, PretrainedConfig
except ImportError: except ImportError:
transformers = None transformers = None
......
...@@ -16,9 +16,9 @@ try: ...@@ -16,9 +16,9 @@ try:
from .hf_model import HFTextEncoder from .hf_model import HFTextEncoder
except ImportError: except ImportError:
HFTextEncoder = None HFTextEncoder = None
from .modified_resnet import ModifiedResNet
from .eva_vit_model import EVAVisionTransformer from .eva_vit_model import EVAVisionTransformer
from .transformer import LayerNorm, QuickGELU, VisionTransformer, TextTransformer from .modified_resnet import ModifiedResNet
from .transformer import LayerNorm, QuickGELU, TextTransformer, VisionTransformer
try: try:
from apex.normalization import FusedLayerNorm from apex.normalization import FusedLayerNorm
......
...@@ -26,4 +26,4 @@ ...@@ -26,4 +26,4 @@
"xattn": false, "xattn": false,
"fusedLN": true "fusedLN": true
} }
} }
\ No newline at end of file
...@@ -4,8 +4,6 @@ import torch ...@@ -4,8 +4,6 @@ import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from .utils import freeze_batch_norm_2d
class Bottleneck(nn.Module): class Bottleneck(nn.Module):
expansion = 4 expansion = 4
......
...@@ -2,7 +2,7 @@ import logging ...@@ -2,7 +2,7 @@ import logging
import math import math
import os import os
from collections import OrderedDict from collections import OrderedDict
from typing import Callable, Optional, Sequence from typing import Callable, Optional
import torch import torch
from torch import nn from torch import nn
...@@ -11,7 +11,7 @@ from torch.nn import functional as F ...@@ -11,7 +11,7 @@ from torch.nn import functional as F
try: try:
from timm.models.layers import trunc_normal_ from timm.models.layers import trunc_normal_
except ImportError: except ImportError:
from timm.layers import trunc_normal_ from timm.layers import trunc_normal_ # noqa: F401
from .utils import to_2tuple from .utils import to_2tuple
......
...@@ -3,7 +3,6 @@ import logging ...@@ -3,7 +3,6 @@ import logging
import math import math
from itertools import repeat from itertools import repeat
import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn as nn from torch import nn as nn
......
# Adapted from https://github.com/ToTheBeginning/PuLID # Adapted from https://github.com/ToTheBeginning/PuLID
import torch import logging
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union
import torch
from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.modeling_outputs import Transformer2DModelOutput
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from .tinychat_utils import ceil_num_groups, convert_to_tinychat_w4x16y16_linear_weight
from ..._C.ops import gemm_awq, gemv_awq from ..._C.ops import gemm_awq, gemv_awq
from .tinychat_utils import ceil_num_groups, convert_to_tinychat_w4x16y16_linear_weight
__all__ = ["W4Linear"] __all__ = ["W4Linear"]
......
from .transformer_flux import NunchakuFluxTransformer2dModel from .transformer_flux import NunchakuFluxTransformer2dModel
from .transformer_sana import NunchakuSanaTransformer2DModel from .transformer_sana import NunchakuSanaTransformer2DModel
__all__ = ["NunchakuFluxTransformer2dModel", "NunchakuSanaTransformer2DModel"]
...@@ -12,11 +12,12 @@ from packaging.version import Version ...@@ -12,11 +12,12 @@ from packaging.version import Version
from safetensors.torch import load_file from safetensors.torch import load_file
from torch import nn from torch import nn
from .utils import NunchakuModelLoaderMixin, pad_tensor from ..._C import QuantizedFluxModel
from ..._C import QuantizedFluxModel, utils as cutils from ..._C import utils as cutils
from ...lora.flux.nunchaku_converter import fuse_vectors, to_nunchaku from ...lora.flux.nunchaku_converter import fuse_vectors, to_nunchaku
from ...lora.flux.utils import is_nunchaku_format from ...lora.flux.utils import is_nunchaku_format
from ...utils import get_precision, load_state_dict_in_safetensors from ...utils import get_precision, load_state_dict_in_safetensors
from .utils import NunchakuModelLoaderMixin, pad_tensor
SVD_RANK = 32 SVD_RANK = 32
...@@ -77,7 +78,7 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -77,7 +78,7 @@ class NunchakuFluxTransformerBlocks(nn.Module):
self.id_embeddings = id_embeddings self.id_embeddings = id_embeddings
self.id_weight = id_weight self.id_weight = id_weight
self.pulid_ca_idx = 0 self.pulid_ca_idx = 0
if self.id_embeddings is not None : if self.id_embeddings is not None:
self.set_residual_callback() self.set_residual_callback()
original_dtype = hidden_states.dtype original_dtype = hidden_states.dtype
...@@ -122,13 +123,12 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -122,13 +123,12 @@ class NunchakuFluxTransformerBlocks(nn.Module):
rotary_emb_single, rotary_emb_single,
controlnet_block_samples, controlnet_block_samples,
controlnet_single_block_samples, controlnet_single_block_samples,
skip_first_layer skip_first_layer,
) )
if self.id_embeddings is not None : if self.id_embeddings is not None:
self.reset_residual_callback() self.reset_residual_callback()
hidden_states = hidden_states.to(original_dtype).to(original_device) hidden_states = hidden_states.to(original_dtype).to(original_device)
encoder_hidden_states = hidden_states[:, :txt_tokens, ...] encoder_hidden_states = hidden_states[:, :txt_tokens, ...]
...@@ -191,20 +191,25 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -191,20 +191,25 @@ class NunchakuFluxTransformerBlocks(nn.Module):
encoder_hidden_states = encoder_hidden_states.to(original_dtype).to(original_device) encoder_hidden_states = encoder_hidden_states.to(original_dtype).to(original_device)
return encoder_hidden_states, hidden_states return encoder_hidden_states, hidden_states
def set_residual_callback(self): def set_residual_callback(self):
id_embeddings = self.id_embeddings id_embeddings = self.id_embeddings
pulid_ca = self.pulid_ca pulid_ca = self.pulid_ca
pulid_ca_idx = [self.pulid_ca_idx] pulid_ca_idx = [self.pulid_ca_idx]
id_weight = self.id_weight id_weight = self.id_weight
def callback(hidden_states): def callback(hidden_states):
ip = id_weight * pulid_ca[pulid_ca_idx[0]](id_embeddings, hidden_states.to("cuda")) ip = id_weight * pulid_ca[pulid_ca_idx[0]](id_embeddings, hidden_states.to("cuda"))
pulid_ca_idx[0] += 1 pulid_ca_idx[0] += 1
return ip return ip
self.callback_holder = callback self.callback_holder = callback
self.m.set_residual_callback(callback) self.m.set_residual_callback(callback)
def reset_residual_callback(self): def reset_residual_callback(self):
self.callback_holder = None self.callback_holder = None
self.m.set_residual_callback(None) self.m.set_residual_callback(None)
def __del__(self): def __del__(self):
self.m.reset() self.m.reset()
...@@ -477,10 +482,7 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader ...@@ -477,10 +482,7 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
if len(self._unquantized_part_loras) > 0 or len(unquantized_part_loras) > 0: if len(self._unquantized_part_loras) > 0 or len(unquantized_part_loras) > 0:
self._unquantized_part_loras = unquantized_part_loras self._unquantized_part_loras = unquantized_part_loras
self._unquantized_part_sd = { self._unquantized_part_sd = {k: v for k, v in self._unquantized_part_sd.items() if "pulid_ca" not in k}
k: v for k, v in self._unquantized_part_sd.items()
if "pulid_ca" not in k
}
self._update_unquantized_part_lora_params(1) self._update_unquantized_part_lora_params(1)
quantized_part_vectors = {} quantized_part_vectors = {}
......
...@@ -8,9 +8,10 @@ from safetensors.torch import load_file ...@@ -8,9 +8,10 @@ from safetensors.torch import load_file
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from .utils import NunchakuModelLoaderMixin from ..._C import QuantizedSanaModel
from ..._C import utils as cutils
from ...utils import get_precision from ...utils import get_precision
from ..._C import QuantizedSanaModel, utils as cutils from .utils import NunchakuModelLoaderMixin
SVD_RANK = 32 SVD_RANK = 32
...@@ -130,9 +131,11 @@ class NunchakuSanaTransformerBlocks(nn.Module): ...@@ -130,9 +131,11 @@ class NunchakuSanaTransformerBlocks(nn.Module):
.to(original_dtype) .to(original_dtype)
.to(original_device) .to(original_device)
) )
def __del__(self): def __del__(self):
self.m.reset() self.m.reset()
class NunchakuSanaTransformer2DModel(SanaTransformer2DModel, NunchakuModelLoaderMixin): class NunchakuSanaTransformer2DModel(SanaTransformer2DModel, NunchakuModelLoaderMixin):
@classmethod @classmethod
@utils.validate_hf_hub_args @utils.validate_hf_hub_args
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment