Unverified Commit 57e50f8d authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

style: upgrade the linter (#339)

* style: reformated codes

* style: reformated codes
parent b737368d
......@@ -11,13 +11,13 @@ public:
void init(pybind11::dict config, std::vector<int> pag_layers, bool use_fp4, bool bf16, int8_t deviceId) {
spdlog::info("Initializing QuantizedSanaModel on device {}", deviceId);
SanaConfig cfg{
.num_layers = config["num_layers"].cast<int>(),
.num_attention_heads = config["num_attention_heads"].cast<int>(),
.attention_head_dim = config["attention_head_dim"].cast<int>(),
.num_layers = config["num_layers"].cast<int>(),
.num_attention_heads = config["num_attention_heads"].cast<int>(),
.attention_head_dim = config["attention_head_dim"].cast<int>(),
.num_cross_attention_heads = config["num_cross_attention_heads"].cast<int>(),
.expand_ratio = config["mlp_ratio"].cast<double>(),
.pag_layers = pag_layers,
.use_fp4 = use_fp4,
.expand_ratio = config["mlp_ratio"].cast<double>(),
.pag_layers = pag_layers,
.use_fp4 = use_fp4,
};
ModuleWrapper::init(deviceId);
......@@ -25,39 +25,37 @@ public:
net = std::make_unique<SanaModel>(cfg, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
}
torch::Tensor forward(
torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states,
torch::Tensor timestep,
torch::Tensor cu_seqlens_img,
torch::Tensor cu_seqlens_txt,
int H,
int W,
bool pag,
bool cfg,
bool skip_first_layer = false)
{
torch::Tensor forward(torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states,
torch::Tensor timestep,
torch::Tensor cu_seqlens_img,
torch::Tensor cu_seqlens_txt,
int H,
int W,
bool pag,
bool cfg,
bool skip_first_layer = false) {
checkModel();
CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedSanaModel forward");
hidden_states = hidden_states.contiguous();
hidden_states = hidden_states.contiguous();
encoder_hidden_states = encoder_hidden_states.contiguous();
timestep = timestep.contiguous();
cu_seqlens_img = cu_seqlens_img.contiguous();
cu_seqlens_txt = cu_seqlens_txt.contiguous();
timestep = timestep.contiguous();
cu_seqlens_img = cu_seqlens_img.contiguous();
cu_seqlens_txt = cu_seqlens_txt.contiguous();
Tensor result = net->forward(
from_torch(hidden_states),
from_torch(encoder_hidden_states),
from_torch(timestep),
from_torch(cu_seqlens_img),
from_torch(cu_seqlens_txt),
H, W,
pag, cfg,
skip_first_layer
);
Tensor result = net->forward(from_torch(hidden_states),
from_torch(encoder_hidden_states),
from_torch(timestep),
from_torch(cu_seqlens_img),
from_torch(cu_seqlens_txt),
H,
W,
pag,
cfg,
skip_first_layer);
torch::Tensor output = to_torch(result);
// Tensor::synchronizeDevice();
......@@ -65,42 +63,40 @@ public:
return output;
}
torch::Tensor forward_layer(
int64_t idx,
torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states,
torch::Tensor timestep,
torch::Tensor cu_seqlens_img,
torch::Tensor cu_seqlens_txt,
int H,
int W,
bool pag,
bool cfg)
{
torch::Tensor forward_layer(int64_t idx,
torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states,
torch::Tensor timestep,
torch::Tensor cu_seqlens_img,
torch::Tensor cu_seqlens_txt,
int H,
int W,
bool pag,
bool cfg) {
checkModel();
CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedSanaModel forward_layer {}", idx);
hidden_states = hidden_states.contiguous();
hidden_states = hidden_states.contiguous();
encoder_hidden_states = encoder_hidden_states.contiguous();
timestep = timestep.contiguous();
cu_seqlens_img = cu_seqlens_img.contiguous();
cu_seqlens_txt = cu_seqlens_txt.contiguous();
timestep = timestep.contiguous();
cu_seqlens_img = cu_seqlens_img.contiguous();
cu_seqlens_txt = cu_seqlens_txt.contiguous();
Tensor result = net->transformer_blocks.at(idx)->forward(
from_torch(hidden_states),
from_torch(encoder_hidden_states),
from_torch(timestep),
from_torch(cu_seqlens_img),
from_torch(cu_seqlens_txt),
H, W,
pag, cfg
);
Tensor result = net->transformer_blocks.at(idx)->forward(from_torch(hidden_states),
from_torch(encoder_hidden_states),
from_torch(timestep),
from_torch(cu_seqlens_img),
from_torch(cu_seqlens_txt),
H,
W,
pag,
cfg);
torch::Tensor output = to_torch(result);
// Tensor::synchronizeDevice();
return output;
}
};
\ No newline at end of file
};
......@@ -6,34 +6,34 @@
namespace nunchaku::utils {
void set_cuda_stack_limit(int64_t newval) {
size_t val = 0;
checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, (size_t)newval));
checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
spdlog::debug("Stack={}", val);
}
void set_cuda_stack_limit(int64_t newval) {
size_t val = 0;
checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, (size_t)newval));
checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
spdlog::debug("Stack={}", val);
}
void disable_memory_auto_release() {
int device;
checkCUDA(cudaGetDevice(&device));
cudaMemPool_t mempool;
checkCUDA(cudaDeviceGetDefaultMemPool(&mempool, device));
uint64_t threshold = UINT64_MAX;
checkCUDA(cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrReleaseThreshold, &threshold));
}
void disable_memory_auto_release() {
int device;
checkCUDA(cudaGetDevice(&device));
cudaMemPool_t mempool;
checkCUDA(cudaDeviceGetDefaultMemPool(&mempool, device));
uint64_t threshold = UINT64_MAX;
checkCUDA(cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrReleaseThreshold, &threshold));
}
void trim_memory() {
int device;
checkCUDA(cudaGetDevice(&device));
cudaMemPool_t mempool;
checkCUDA(cudaDeviceGetDefaultMemPool(&mempool, device));
size_t bytesToKeep = 0;
checkCUDA(cudaMemPoolTrimTo(mempool, bytesToKeep));
}
void trim_memory() {
int device;
checkCUDA(cudaGetDevice(&device));
cudaMemPool_t mempool;
checkCUDA(cudaDeviceGetDefaultMemPool(&mempool, device));
size_t bytesToKeep = 0;
checkCUDA(cudaMemPoolTrimTo(mempool, bytesToKeep));
}
void set_faster_i2f_mode(std::string mode) {
spdlog::info("Set fasteri2f mode to {}", mode);
kernels::set_faster_i2f_mode(mode);
}
void set_faster_i2f_mode(std::string mode) {
spdlog::info("Set fasteri2f mode to {}", mode);
kernels::set_faster_i2f_mode(mode);
}
};
\ No newline at end of file
}; // namespace nunchaku::utils
from .diffusers_converter import to_diffusers
from .nunchaku_converter import convert_to_nunchaku_flux_lowrank_dict, to_nunchaku
from .utils import is_nunchaku_format
__all__ = ["to_diffusers", "to_nunchaku", "convert_to_nunchaku_flux_lowrank_dict", "is_nunchaku_format"]
......@@ -7,10 +7,10 @@ import torch
from safetensors.torch import save_file
from tqdm import tqdm
from ...utils import filter_state_dict, load_state_dict_in_safetensors
from .diffusers_converter import to_diffusers
from .packer import NunchakuWeightPacker
from .utils import is_nunchaku_format, pad
from ...utils import filter_state_dict, load_state_dict_in_safetensors
logger = logging.getLogger(__name__)
......
# Copy the packer from https://github.com/mit-han-lab/deepcompressor/
import torch
from .utils import pad
from ...utils import ceil_divide
from .utils import pad
class MmaWeightPackerBase:
......
from .text_encoders.t5_encoder import NunchakuT5EncoderModel
from .transformers import NunchakuFluxTransformer2dModel, NunchakuSanaTransformer2DModel
__all__ = ["NunchakuFluxTransformer2dModel", "NunchakuSanaTransformer2DModel", "NunchakuT5EncoderModel"]
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .factory import create_model_and_transforms
__all__ = ["create_model_and_transforms", "OPENAI_DATASET_MEAN", "OPENAI_DATASET_STD"]
......@@ -14,8 +14,8 @@ try:
except ImportError:
from timm.layers import drop_path, to_2tuple, trunc_normal_
from .transformer import PatchDropout
from .rope import VisionRotaryEmbeddingFast
from .transformer import PatchDropout
if os.getenv("ENV_TYPE") == "deepspeed":
try:
......@@ -26,7 +26,7 @@ else:
from torch.utils.checkpoint import checkpoint
try:
import xformers
import xformers # noqa: F401
import xformers.ops as xops
XFORMERS_IS_AVAILBLE = True
......
......@@ -9,7 +9,7 @@ from typing import Optional, Tuple, Union
import torch
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .model import CLIP, convert_to_custom_text_state_dict, CustomCLIP, get_cast_dtype
from .model import CLIP, CustomCLIP, convert_to_custom_text_state_dict, get_cast_dtype
from .pretrained import download_pretrained, get_pretrained_cfg, list_pretrained_tags_by_model
from .transform import image_transform
from .utils import resize_clip_pos_embed, resize_eva_pos_embed, resize_evaclip_pos_embed, resize_visual_pos_embed
......
......@@ -11,7 +11,7 @@ from torch import TensorType
try:
import transformers
from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, AutoConfig, PretrainedConfig
from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM, AutoTokenizer, PretrainedConfig
except ImportError:
transformers = None
......
......@@ -16,9 +16,9 @@ try:
from .hf_model import HFTextEncoder
except ImportError:
HFTextEncoder = None
from .modified_resnet import ModifiedResNet
from .eva_vit_model import EVAVisionTransformer
from .transformer import LayerNorm, QuickGELU, VisionTransformer, TextTransformer
from .modified_resnet import ModifiedResNet
from .transformer import LayerNorm, QuickGELU, TextTransformer, VisionTransformer
try:
from apex.normalization import FusedLayerNorm
......
......@@ -26,4 +26,4 @@
"xattn": false,
"fusedLN": true
}
}
\ No newline at end of file
}
......@@ -4,8 +4,6 @@ import torch
from torch import nn
from torch.nn import functional as F
from .utils import freeze_batch_norm_2d
class Bottleneck(nn.Module):
expansion = 4
......
......@@ -2,7 +2,7 @@ import logging
import math
import os
from collections import OrderedDict
from typing import Callable, Optional, Sequence
from typing import Callable, Optional
import torch
from torch import nn
......@@ -11,7 +11,7 @@ from torch.nn import functional as F
try:
from timm.models.layers import trunc_normal_
except ImportError:
from timm.layers import trunc_normal_
from timm.layers import trunc_normal_ # noqa: F401
from .utils import to_2tuple
......
......@@ -3,7 +3,6 @@ import logging
import math
from itertools import repeat
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn as nn
......
# Adapted from https://github.com/ToTheBeginning/PuLID
import torch
import logging
from typing import Any, Dict, Optional, Union
import torch
from diffusers.models.modeling_outputs import Transformer2DModelOutput
import logging
logger = logging.getLogger(__name__)
......
......@@ -4,8 +4,8 @@
import torch
import torch.nn as nn
from .tinychat_utils import ceil_num_groups, convert_to_tinychat_w4x16y16_linear_weight
from ..._C.ops import gemm_awq, gemv_awq
from .tinychat_utils import ceil_num_groups, convert_to_tinychat_w4x16y16_linear_weight
__all__ = ["W4Linear"]
......
from .transformer_flux import NunchakuFluxTransformer2dModel
from .transformer_sana import NunchakuSanaTransformer2DModel
__all__ = ["NunchakuFluxTransformer2dModel", "NunchakuSanaTransformer2DModel"]
......@@ -12,11 +12,12 @@ from packaging.version import Version
from safetensors.torch import load_file
from torch import nn
from .utils import NunchakuModelLoaderMixin, pad_tensor
from ..._C import QuantizedFluxModel, utils as cutils
from ..._C import QuantizedFluxModel
from ..._C import utils as cutils
from ...lora.flux.nunchaku_converter import fuse_vectors, to_nunchaku
from ...lora.flux.utils import is_nunchaku_format
from ...utils import get_precision, load_state_dict_in_safetensors
from .utils import NunchakuModelLoaderMixin, pad_tensor
SVD_RANK = 32
......@@ -77,7 +78,7 @@ class NunchakuFluxTransformerBlocks(nn.Module):
self.id_embeddings = id_embeddings
self.id_weight = id_weight
self.pulid_ca_idx = 0
if self.id_embeddings is not None :
if self.id_embeddings is not None:
self.set_residual_callback()
original_dtype = hidden_states.dtype
......@@ -122,13 +123,12 @@ class NunchakuFluxTransformerBlocks(nn.Module):
rotary_emb_single,
controlnet_block_samples,
controlnet_single_block_samples,
skip_first_layer
skip_first_layer,
)
if self.id_embeddings is not None :
if self.id_embeddings is not None:
self.reset_residual_callback()
hidden_states = hidden_states.to(original_dtype).to(original_device)
encoder_hidden_states = hidden_states[:, :txt_tokens, ...]
......@@ -191,20 +191,25 @@ class NunchakuFluxTransformerBlocks(nn.Module):
encoder_hidden_states = encoder_hidden_states.to(original_dtype).to(original_device)
return encoder_hidden_states, hidden_states
def set_residual_callback(self):
id_embeddings = self.id_embeddings
pulid_ca = self.pulid_ca
pulid_ca_idx = [self.pulid_ca_idx]
id_weight = self.id_weight
def callback(hidden_states):
ip = id_weight * pulid_ca[pulid_ca_idx[0]](id_embeddings, hidden_states.to("cuda"))
pulid_ca_idx[0] += 1
return ip
self.callback_holder = callback
self.m.set_residual_callback(callback)
def reset_residual_callback(self):
self.callback_holder = None
self.m.set_residual_callback(None)
def __del__(self):
self.m.reset()
......@@ -477,10 +482,7 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
if len(self._unquantized_part_loras) > 0 or len(unquantized_part_loras) > 0:
self._unquantized_part_loras = unquantized_part_loras
self._unquantized_part_sd = {
k: v for k, v in self._unquantized_part_sd.items()
if "pulid_ca" not in k
}
self._unquantized_part_sd = {k: v for k, v in self._unquantized_part_sd.items() if "pulid_ca" not in k}
self._update_unquantized_part_lora_params(1)
quantized_part_vectors = {}
......
......@@ -8,9 +8,10 @@ from safetensors.torch import load_file
from torch import nn
from torch.nn import functional as F
from .utils import NunchakuModelLoaderMixin
from ..._C import QuantizedSanaModel
from ..._C import utils as cutils
from ...utils import get_precision
from ..._C import QuantizedSanaModel, utils as cutils
from .utils import NunchakuModelLoaderMixin
SVD_RANK = 32
......@@ -130,9 +131,11 @@ class NunchakuSanaTransformerBlocks(nn.Module):
.to(original_dtype)
.to(original_device)
)
def __del__(self):
self.m.reset()
class NunchakuSanaTransformer2DModel(SanaTransformer2DModel, NunchakuModelLoaderMixin):
@classmethod
@utils.validate_hf_hub_args
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment