Commit a50bcc53 authored by Dongz's avatar Dongz Committed by Yang Yong(雍洋)
Browse files

add lint feature and minor fix (#7)

* [minor]: optimize dockerfile for fewer layer

* [feature]: add pre-commit lint, update readme for contribution guidance

* [minor]: fix run shell privileges

* [auto]: first lint without rule F, fix rule E

* [minor]: fix docker file error
parent 3b460075
......@@ -39,28 +39,15 @@ def init_weights(m):
nn.init.normal_(m.v.weight, std=m.dim**-0.5)
nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5)
elif isinstance(m, T5RelativeEmbedding):
nn.init.normal_(
m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5
)
nn.init.normal_(m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5)
class GELU(nn.Module):
def forward(self, x):
return (
0.5
* x
* (
1.0
+ torch.tanh(
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))
)
)
)
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
class T5LayerNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super(T5LayerNorm, self).__init__()
self.dim = dim
......@@ -75,7 +62,6 @@ class T5LayerNorm(nn.Module):
class T5Attention(nn.Module):
def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
assert dim_attn % num_heads == 0
super(T5Attention, self).__init__()
......@@ -128,7 +114,6 @@ class T5Attention(nn.Module):
class T5FeedForward(nn.Module):
def __init__(self, dim, dim_ffn, dropout=0.1):
super(T5FeedForward, self).__init__()
self.dim = dim
......@@ -149,7 +134,6 @@ class T5FeedForward(nn.Module):
class T5SelfAttention(nn.Module):
def __init__(
self,
dim,
......@@ -173,11 +157,7 @@ class T5SelfAttention(nn.Module):
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.norm2 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
self.pos_embedding = (
None
if shared_pos
else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
)
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
def forward(self, x, mask=None, pos_bias=None):
e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
......@@ -187,7 +167,6 @@ class T5SelfAttention(nn.Module):
class T5CrossAttention(nn.Module):
def __init__(
self,
dim,
......@@ -213,27 +192,17 @@ class T5CrossAttention(nn.Module):
self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.norm3 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
self.pos_embedding = (
None
if shared_pos
else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False)
)
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False)
def forward(
self, x, mask=None, encoder_states=None, encoder_mask=None, pos_bias=None
):
def forward(self, x, mask=None, encoder_states=None, encoder_mask=None, pos_bias=None):
e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
x = fp16_clamp(
x
+ self.cross_attn(self.norm2(x), context=encoder_states, mask=encoder_mask)
)
x = fp16_clamp(x + self.cross_attn(self.norm2(x), context=encoder_states, mask=encoder_mask))
x = fp16_clamp(x + self.ffn(self.norm3(x)))
return x
class T5RelativeEmbedding(nn.Module):
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
super(T5RelativeEmbedding, self).__init__()
self.num_buckets = num_buckets
......@@ -248,9 +217,7 @@ class T5RelativeEmbedding(nn.Module):
device = self.embedding.weight.device
# rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
# torch.arange(lq).unsqueeze(1).to(device)
rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(
lq, device=device
).unsqueeze(1)
rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(lq, device=device).unsqueeze(1)
rel_pos = self._relative_position_bucket(rel_pos)
rel_pos_embeds = self.embedding(rel_pos)
rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0) # [1, N, Lq, Lk]
......@@ -269,23 +236,13 @@ class T5RelativeEmbedding(nn.Module):
# embeddings for small and large positions
max_exact = num_buckets // 2
rel_pos_large = (
max_exact
+ (
torch.log(rel_pos.float() / max_exact)
/ math.log(self.max_dist / max_exact)
* (num_buckets - max_exact)
).long()
)
rel_pos_large = torch.min(
rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)
)
rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) / math.log(self.max_dist / max_exact) * (num_buckets - max_exact)).long()
rel_pos_large = torch.min(rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
return rel_buckets
class T5Encoder(nn.Module):
def __init__(
self,
vocab,
......@@ -308,23 +265,10 @@ class T5Encoder(nn.Module):
self.shared_pos = shared_pos
# layers
self.token_embedding = (
vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
)
self.pos_embedding = (
T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
if shared_pos
else None
)
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) if shared_pos else None
self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList(
[
T5SelfAttention(
dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout
)
for _ in range(num_layers)
]
)
self.blocks = nn.ModuleList([T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) for _ in range(num_layers)])
self.norm = T5LayerNorm(dim)
# initialize weights
......@@ -342,7 +286,6 @@ class T5Encoder(nn.Module):
class T5Decoder(nn.Module):
def __init__(
self,
vocab,
......@@ -365,23 +308,10 @@ class T5Decoder(nn.Module):
self.shared_pos = shared_pos
# layers
self.token_embedding = (
vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
)
self.pos_embedding = (
T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False)
if shared_pos
else None
)
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False) if shared_pos else None
self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList(
[
T5CrossAttention(
dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout
)
for _ in range(num_layers)
]
)
self.blocks = nn.ModuleList([T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) for _ in range(num_layers)])
self.norm = T5LayerNorm(dim)
# initialize weights
......@@ -408,7 +338,6 @@ class T5Decoder(nn.Module):
class T5Model(nn.Module):
def __init__(
self,
vocab_size,
......@@ -530,7 +459,6 @@ def umt5_xxl(**kwargs):
class T5EncoderModel:
def __init__(
self,
text_len,
......@@ -547,13 +475,7 @@ class T5EncoderModel:
self.tokenizer_path = tokenizer_path
# init model
model = (
umt5_xxl(
encoder_only=True, return_tokenizer=False, dtype=dtype, device=device
)
.eval()
.requires_grad_(False)
)
model = umt5_xxl(encoder_only=True, return_tokenizer=False, dtype=dtype, device=device).eval().requires_grad_(False)
logging.info(f"loading {checkpoint_path}")
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu", weights_only=True))
self.model = model
......@@ -562,9 +484,7 @@ class T5EncoderModel:
else:
self.model.to(self.device)
# init tokenizer
self.tokenizer = HuggingfaceTokenizer(
name=tokenizer_path, seq_len=text_len, clean="whitespace"
)
self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=text_len, clean="whitespace")
def to_cpu(self):
self.model = self.model.to("cpu")
......
......@@ -24,10 +24,7 @@ def whitespace_clean(text):
def canonicalize(text, keep_punctuation_exact_string=None):
text = text.replace("_", " ")
if keep_punctuation_exact_string:
text = keep_punctuation_exact_string.join(
part.translate(str.maketrans("", "", string.punctuation))
for part in text.split(keep_punctuation_exact_string)
)
text = keep_punctuation_exact_string.join(part.translate(str.maketrans("", "", string.punctuation)) for part in text.split(keep_punctuation_exact_string))
else:
text = text.translate(str.maketrans("", "", string.punctuation))
text = text.lower()
......@@ -36,7 +33,6 @@ def canonicalize(text, keep_punctuation_exact_string=None):
class HuggingfaceTokenizer:
def __init__(self, name, seq_len=None, clean=None, **kwargs):
assert clean in (None, "whitespace", "lower", "canonicalize")
self.name = name
......
......@@ -123,11 +123,7 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
self.tile_latent_min_tsize = sample_tsize // time_compression_ratio
self.tile_sample_min_size = self.config.sample_size
sample_size = (
self.config.sample_size[0]
if isinstance(self.config.sample_size, (list, tuple))
else self.config.sample_size
)
sample_size = self.config.sample_size[0] if isinstance(self.config.sample_size, (list, tuple)) else self.config.sample_size
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25
......@@ -204,9 +200,7 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False):
r"""
Sets the attention processor to use to compute attention.
......@@ -250,16 +244,12 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor()
else:
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
raise ValueError(f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}")
self.set_attn_processor(processor, _remove_lora=True)
@apply_forward_hook
def encode(
self, x: torch.FloatTensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images/videos into latents.
......@@ -312,9 +302,7 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
return DecoderOutput(sample=dec)
@apply_forward_hook
def decode(
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
) -> Union[DecoderOutput, torch.FloatTensor]:
def decode(self, z: torch.FloatTensor, return_dict: bool = True, generator=None) -> Union[DecoderOutput, torch.FloatTensor]:
"""
Decode a batch of images/videos.
......@@ -386,7 +374,7 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
for i in range(0, x.shape[-2], overlap_size):
row = []
for j in range(0, x.shape[-1], overlap_size):
tile = x[:, :, :, i: i + self.tile_sample_min_size, j: j + self.tile_sample_min_size]
tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
tile = self.encoder(tile)
tile = self.quant_conv(tile)
row.append(tile)
......@@ -438,7 +426,7 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
for i in range(0, z.shape[-2], overlap_size):
row = []
for j in range(0, z.shape[-1], overlap_size):
tile = z[:, :, :, i: i + self.tile_latent_min_size, j: j + self.tile_latent_min_size]
tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
tile = self.post_quant_conv(tile)
decoded = self.decoder(tile)
row.append(decoded)
......@@ -463,7 +451,6 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
return DecoderOutput(sample=dec)
def temporal_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
B, C, T, H, W = x.shape
overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor)
......@@ -472,7 +459,7 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
# Split the video into tiles and encode them separately.
row = []
for i in range(0, T, overlap_size):
tile = x[:, :, i: i + self.tile_sample_min_tsize + 1, :, :]
tile = x[:, :, i : i + self.tile_sample_min_tsize + 1, :, :]
if self.use_spatial_tiling and (tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size):
tile = self.spatial_tiled_encode(tile, return_moments=True)
else:
......@@ -487,7 +474,7 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
tile = self.blend_t(row[i - 1], tile, blend_extent)
result_row.append(tile[:, :, :t_limit, :, :])
else:
result_row.append(tile[:, :, :t_limit + 1, :, :])
result_row.append(tile[:, :, : t_limit + 1, :, :])
moments = torch.cat(result_row, dim=2)
posterior = DiagonalGaussianDistribution(moments)
......@@ -507,7 +494,7 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
row = []
for i in range(0, T, overlap_size):
tile = z[:, :, i: i + self.tile_latent_min_tsize + 1, :, :]
tile = z[:, :, i : i + self.tile_latent_min_tsize + 1, :, :]
if self.use_spatial_tiling and (tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size):
decoded = self.spatial_tiled_decode(tile, return_dict=True).sample
else:
......@@ -522,7 +509,7 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
tile = self.blend_t(row[i - 1], tile, blend_extent)
result_row.append(tile[:, :, :t_limit, :, :])
else:
result_row.append(tile[:, :, :t_limit + 1, :, :])
result_row.append(tile[:, :, : t_limit + 1, :, :])
dec = torch.cat(result_row, dim=2)
if not return_dict:
......
......@@ -3,7 +3,7 @@ import torch
from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D
class VideoEncoderKLCausal3DModel():
class VideoEncoderKLCausal3DModel:
def __init__(self, model_path, dtype, device):
self.model_path = model_path
self.dtype = dtype
......@@ -11,10 +11,10 @@ class VideoEncoderKLCausal3DModel():
self.load()
def load(self):
self.vae_path = os.path.join(self.model_path, 'hunyuan-video-t2v-720p/vae')
self.vae_path = os.path.join(self.model_path, "hunyuan-video-t2v-720p/vae")
config = AutoencoderKLCausal3D.load_config(self.vae_path)
self.model = AutoencoderKLCausal3D.from_config(config)
ckpt = torch.load(os.path.join(self.vae_path, 'pytorch_model.pt'), map_location='cpu', weights_only=True)
ckpt = torch.load(os.path.join(self.vae_path, "pytorch_model.pt"), map_location="cpu", weights_only=True)
self.model.load_state_dict(ckpt)
self.model = self.model.to(dtype=self.dtype, device=self.device)
self.model.requires_grad_(False)
......@@ -32,14 +32,13 @@ class VideoEncoderKLCausal3DModel():
latents = latents / self.model.config.scaling_factor
latents = latents.to(dtype=self.dtype, device=torch.device("cuda"))
self.model.enable_tiling()
image = self.model.decode(
latents, return_dict=False, generator=generator
)[0]
image = self.model.decode(latents, return_dict=False, generator=generator)[0]
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().float()
if args.cpu_offload:
self.to_cpu()
return image
if __name__ == "__main__":
vae_model = VideoEncoderKLCausal3DModel("/mnt/nvme0/yongyang/projects/hy/new/HunyuanVideo/ckpts", dtype=torch.float16, device=torch.device("cuda"))
......@@ -53,16 +53,15 @@ def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_
idx_arr = idx_arr > torch.zeros_like(idx_arr)
for i in range(n_frame):
for j in range(n_frame):
if idx_arr[i,j]:
if idx_arr[i, j]:
mask[i, j] = torch.zeros(n_hw, n_hw, dtype=dtype, device=device)
# mask[idx_arr] = torch.zeros(n_hw, n_hw, dtype=dtype, device=device)
mask = mask.view(n_frame, -1, n_hw).transpose(1, 0).reshape(seq_len, -1).transpose(1,0)
mask = mask.view(n_frame, -1, n_hw).transpose(1, 0).reshape(seq_len, -1).transpose(1, 0)
if batch_size is not None:
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
return mask.to(device)
class CausalConv3d(nn.Module):
"""
Implements a causal 3D convolution layer where each position only depends on previous timesteps and current spatial locations.
......@@ -76,8 +75,8 @@ class CausalConv3d(nn.Module):
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]] = 1,
dilation: Union[int, Tuple[int, int, int]] = 1,
pad_mode='replicate',
**kwargs
pad_mode="replicate",
**kwargs,
):
super().__init__()
......@@ -238,9 +237,7 @@ class DownsampleCausal3D(nn.Module):
raise ValueError(f"unknown norm_type: {norm_type}")
if use_conv:
conv = CausalConv3d(
self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, bias=bias
)
conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, bias=bias)
else:
raise NotImplementedError
......@@ -384,28 +381,18 @@ class ResnetBlockCausal3D(nn.Module):
if hidden_states.shape[0] >= 64:
input_tensor = input_tensor.contiguous()
hidden_states = hidden_states.contiguous()
input_tensor = (
self.upsample(input_tensor, scale=scale)
)
hidden_states = (
self.upsample(hidden_states, scale=scale)
)
input_tensor = self.upsample(input_tensor, scale=scale)
hidden_states = self.upsample(hidden_states, scale=scale)
elif self.downsample is not None:
input_tensor = (
self.downsample(input_tensor, scale=scale)
)
hidden_states = (
self.downsample(hidden_states, scale=scale)
)
input_tensor = self.downsample(input_tensor, scale=scale)
hidden_states = self.downsample(hidden_states, scale=scale)
hidden_states = self.conv1(hidden_states)
if self.time_emb_proj is not None:
if not self.skip_time_act:
temb = self.nonlinearity(temb)
temb = (
self.time_emb_proj(temb, scale)[:, :, None, None]
)
temb = self.time_emb_proj(temb, scale)[:, :, None, None]
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
......@@ -425,9 +412,7 @@ class ResnetBlockCausal3D(nn.Module):
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = (
self.conv_shortcut(input_tensor)
)
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
......@@ -464,9 +449,7 @@ def get_down_block3d(
):
# If attn head dim is not defined, we default it to the number of heads
if attention_head_dim is None:
logger.warn(
f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
)
logger.warn(f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}.")
attention_head_dim = num_attention_heads
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
......@@ -518,9 +501,7 @@ def get_up_block3d(
) -> nn.Module:
# If attn head dim is not defined, we default it to the number of heads
if attention_head_dim is None:
logger.warn(
f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
)
logger.warn(f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}.")
attention_head_dim = num_attention_heads
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
......@@ -588,9 +569,7 @@ class UNetMidBlockCausal3D(nn.Module):
attentions = []
if attention_head_dim is None:
logger.warn(
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
)
logger.warn(f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}.")
attention_head_dim = in_channels
for _ in range(num_layers):
......@@ -637,9 +616,7 @@ class UNetMidBlockCausal3D(nn.Module):
if attn is not None:
B, C, T, H, W = hidden_states.shape
hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c")
attention_mask = prepare_causal_attention_mask(
T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B
)
attention_mask = prepare_causal_attention_mask(T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B)
hidden_states = attn(hidden_states, temb=temb, attention_mask=attention_mask)
hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=T, h=H, w=W)
hidden_states = resnet(hidden_states, temb)
......@@ -770,9 +747,7 @@ class UpDecoderBlockCausal3D(nn.Module):
self.resolution_idx = resolution_idx
def forward(
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
) -> torch.FloatTensor:
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0) -> torch.FloatTensor:
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=temb, scale=scale)
......
......@@ -66,10 +66,7 @@ class EncoderCausal3D(nn.Module):
if time_compression_ratio == 4:
add_spatial_downsample = bool(i < num_spatial_downsample_layers)
add_time_downsample = bool(
i >= (len(block_out_channels) - 1 - num_time_downsample_layers)
and not is_final_block
)
add_time_downsample = bool(i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block)
else:
raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
......@@ -186,10 +183,7 @@ class DecoderCausal3D(nn.Module):
if time_compression_ratio == 4:
add_spatial_upsample = bool(i < num_spatial_upsample_layers)
add_time_upsample = bool(
i >= len(block_out_channels) - 1 - num_time_upsample_layers
and not is_final_block
)
add_time_upsample = bool(i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block)
else:
raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
......@@ -263,9 +257,7 @@ class DecoderCausal3D(nn.Module):
)
else:
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, latent_embeds
)
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample, latent_embeds)
sample = sample.to(upscale_dtype)
# up
......@@ -306,9 +298,7 @@ class DiagonalGaussianDistribution(object):
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
)
self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device, dtype=self.parameters.dtype)
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
# make sure sample is on the same device as the parameters and has same dtype
......@@ -333,11 +323,7 @@ class DiagonalGaussianDistribution(object):
)
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar,
torch.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=reduce_dim,
)
......@@ -346,8 +332,7 @@ class DiagonalGaussianDistribution(object):
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar +
torch.pow(sample - self.mean, 2) / self.var,
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims,
)
......
......@@ -44,7 +44,6 @@ class CausalConv3d(nn.Conv3d):
class RMS_norm(nn.Module):
def __init__(self, dim, channel_first=True, images=True, bias=False):
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
......@@ -56,16 +55,10 @@ class RMS_norm(nn.Module):
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
return (
F.normalize(x, dim=(1 if self.channel_first else -1))
* self.scale
* self.gamma
+ self.bias
)
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
class Upsample(nn.Upsample):
def forward(self, x):
"""
Fix bfloat16 support for nearest neighbor interpolation.
......@@ -74,7 +67,6 @@ class Upsample(nn.Upsample):
class Resample(nn.Module):
def __init__(self, dim, mode):
assert mode in (
"none",
......@@ -101,16 +93,10 @@ class Resample(nn.Module):
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
elif mode == "downsample2d":
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
)
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
elif mode == "downsample3d":
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
)
self.time_conv = CausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
)
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
else:
self.resample = nn.Identity()
......@@ -124,28 +110,17 @@ class Resample(nn.Module):
feat_cache[idx] = "Rep"
feat_idx[0] += 1
else:
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if (
cache_x.shape[2] < 2
and feat_cache[idx] is not None
and feat_cache[idx] != "Rep"
):
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
# cache last frame of last two chunk
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :]
.unsqueeze(2)
.to(cache_x.device),
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
cache_x,
],
dim=2,
)
if (
cache_x.shape[2] < 2
and feat_cache[idx] is not None
and feat_cache[idx] == "Rep"
):
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
cache_x = torch.cat(
[torch.zeros_like(cache_x).to(cache_x.device), cache_x],
dim=2,
......@@ -172,15 +147,12 @@ class Resample(nn.Module):
feat_cache[idx] = x.clone()
feat_idx[0] += 1
else:
cache_x = x[:, :, -1:, :, :].clone()
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
# # cache last frame of last two chunk
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.time_conv(
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)
)
x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
......@@ -210,7 +182,6 @@ class Resample(nn.Module):
class ResidualBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout=0.0):
super().__init__()
self.in_dim = in_dim
......@@ -226,9 +197,7 @@ class ResidualBlock(nn.Module):
nn.Dropout(dropout),
CausalConv3d(out_dim, out_dim, 3, padding=1),
)
self.shortcut = (
CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
)
self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
h = self.shortcut(x)
......@@ -240,9 +209,7 @@ class ResidualBlock(nn.Module):
# cache last frame of last two chunk
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :]
.unsqueeze(2)
.to(cache_x.device),
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
cache_x,
],
dim=2,
......@@ -278,13 +245,7 @@ class AttentionBlock(nn.Module):
x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.norm(x)
# compute query, key, value
q, k, v = (
self.to_qkv(x)
.reshape(b * t, 1, c * 3, -1)
.permute(0, 1, 3, 2)
.contiguous()
.chunk(3, dim=-1)
)
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(0, 1, 3, 2).contiguous().chunk(3, dim=-1)
# apply attention
x = F.scaled_dot_product_attention(
......@@ -301,7 +262,6 @@ class AttentionBlock(nn.Module):
class Encoder3d(nn.Module):
def __init__(
self,
dim=128,
......@@ -400,9 +360,7 @@ class Encoder3d(nn.Module):
# cache last frame of last two chunk
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :]
.unsqueeze(2)
.to(cache_x.device),
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
cache_x,
],
dim=2,
......@@ -416,7 +374,6 @@ class Encoder3d(nn.Module):
class Decoder3d(nn.Module):
def __init__(
self,
dim=128,
......@@ -518,9 +475,7 @@ class Decoder3d(nn.Module):
# cache last frame of last two chunk
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :]
.unsqueeze(2)
.to(cache_x.device),
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
cache_x,
],
dim=2,
......@@ -542,7 +497,6 @@ def count_conv3d(model):
class WanVAE_(nn.Module):
def __init__(
self,
dim=128,
......@@ -613,9 +567,7 @@ class WanVAE_(nn.Module):
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
if isinstance(scale[0], torch.Tensor):
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
1, self.z_dim, 1, 1, 1
)
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
else:
mu = (mu - scale[0]) * scale[1]
self.clear_cache()
......@@ -625,9 +577,7 @@ class WanVAE_(nn.Module):
self.clear_cache()
# z: [b,c,t,h,w]
if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
1, self.z_dim, 1, 1, 1
)
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
else:
z = z / scale[1] + scale[0]
iter_ = z.shape[2]
......@@ -700,7 +650,6 @@ def _video_vae(pretrained_path=None, z_dim=None, device="cpu", **kwargs):
class WanVAE:
def __init__(
self,
z_dim=16,
......@@ -780,10 +729,7 @@ class WanVAE:
"""
videos: A list of videos each with shape [C, T, H, W].
"""
return [
self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)
for u in videos
]
return [self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) for u in videos]
def decode_dist(self, zs, world_size, cur_rank, split_dim):
splited_total_len = zs.shape[split_dim]
......@@ -792,37 +738,37 @@ class WanVAE:
if cur_rank == 0:
if split_dim == 2:
zs = zs[:,:,:splited_chunk_len+2*padding_size,:].contiguous()
zs = zs[:, :, : splited_chunk_len + 2 * padding_size, :].contiguous()
elif split_dim == 3:
zs = zs[:,:,:,:splited_chunk_len+2*padding_size].contiguous()
elif cur_rank == world_size-1:
zs = zs[:, :, :, : splited_chunk_len + 2 * padding_size].contiguous()
elif cur_rank == world_size - 1:
if split_dim == 2:
zs = zs[:,:,-(splited_chunk_len+2*padding_size):,:].contiguous()
zs = zs[:, :, -(splited_chunk_len + 2 * padding_size) :, :].contiguous()
elif split_dim == 3:
zs = zs[:,:,:,-(splited_chunk_len+2*padding_size):].contiguous()
zs = zs[:, :, :, -(splited_chunk_len + 2 * padding_size) :].contiguous()
else:
if split_dim == 2:
zs = zs[:,:,cur_rank*splited_chunk_len-padding_size:(cur_rank+1)*splited_chunk_len+padding_size,:].contiguous()
zs = zs[:, :, cur_rank * splited_chunk_len - padding_size : (cur_rank + 1) * splited_chunk_len + padding_size, :].contiguous()
elif split_dim == 3:
zs = zs[:,:,:,cur_rank*splited_chunk_len-padding_size:(cur_rank+1)*splited_chunk_len+padding_size].contiguous()
zs = zs[:, :, :, cur_rank * splited_chunk_len - padding_size : (cur_rank + 1) * splited_chunk_len + padding_size].contiguous()
images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1)
if cur_rank == 0:
if split_dim == 2:
images = images[:,:,:,:splited_chunk_len*8,:].contiguous()
images = images[:, :, :, : splited_chunk_len * 8, :].contiguous()
elif split_dim == 3:
images = images[:,:,:,:,:splited_chunk_len*8].contiguous()
elif cur_rank == world_size-1:
images = images[:, :, :, :, : splited_chunk_len * 8].contiguous()
elif cur_rank == world_size - 1:
if split_dim == 2:
images = images[:,:,:,-splited_chunk_len*8:,:].contiguous()
images = images[:, :, :, -splited_chunk_len * 8 :, :].contiguous()
elif split_dim == 3:
images = images[:,:,:,:,-splited_chunk_len*8:].contiguous()
images = images[:, :, :, :, -splited_chunk_len * 8 :].contiguous()
else:
if split_dim == 2:
images = images[:,:,:,8*padding_size:-8*padding_size,:].contiguous()
images = images[:, :, :, 8 * padding_size : -8 * padding_size, :].contiguous()
elif split_dim == 3:
images = images[:,:,:,:,8*padding_size:-8*padding_size].contiguous()
images = images[:, :, :, :, 8 * padding_size : -8 * padding_size].contiguous()
full_images = [torch.empty_like(images) for _ in range(world_size)]
dist.all_gather(full_images, images)
......@@ -833,7 +779,6 @@ class WanVAE:
return images
def decode(self, zs, generator, args):
if args.cpu_offload:
self.to_cuda()
......
......@@ -5,7 +5,7 @@ from lightx2v.text2v.models.video_encoders.hf.autoencoder_kl_causal_3d.autoencod
from lightx2v.text2v.models.video_encoders.trt.autoencoder_kl_causal_3d import trt_vae_infer
class VideoEncoderKLCausal3DModel():
class VideoEncoderKLCausal3DModel:
def __init__(self, model_path, dtype, device):
self.model_path = model_path
self.dtype = dtype
......@@ -13,10 +13,10 @@ class VideoEncoderKLCausal3DModel():
self.load()
def load(self):
self.vae_path = os.path.join(self.model_path, 'hunyuan-video-t2v-720p/vae')
self.vae_path = os.path.join(self.model_path, "hunyuan-video-t2v-720p/vae")
config = AutoencoderKLCausal3D.load_config(self.vae_path)
self.model = AutoencoderKLCausal3D.from_config(config)
ckpt = torch.load(os.path.join(self.vae_path, 'pytorch_model.pt'), map_location='cpu', weights_only=True)
ckpt = torch.load(os.path.join(self.vae_path, "pytorch_model.pt"), map_location="cpu", weights_only=True)
self.model.load_state_dict(ckpt)
self.model = self.model.to(dtype=self.dtype, device=self.device)
self.model.requires_grad_(False)
......@@ -28,12 +28,11 @@ class VideoEncoderKLCausal3DModel():
latents = latents / self.model.config.scaling_factor
latents = latents.to(dtype=self.dtype, device=self.device)
self.model.enable_tiling()
image = self.model.decode(
latents, return_dict=False, generator=generator
)[0]
image = self.model.decode(latents, return_dict=False, generator=generator)[0]
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().float()
return image
if __name__ == "__main__":
vae_model = VideoEncoderKLCausal3DModel("/mnt/nvme1/yongyang/models/hy/ckpts", dtype=torch.float16, device=torch.device("cuda"))
......@@ -100,18 +100,18 @@ class HyVaeTrtModelInfer(nn.Module):
device = batch.device
dtype = batch.dtype
batch = batch.cpu().numpy()
def get_output_shape(shp):
b, c, t, h, w = shp
out = (b, 3, 4*(t-1)+1, h*8, w*8)
out = (b, 3, 4 * (t - 1) + 1, h * 8, w * 8)
return out
shp_dict = {"inp": batch.shape, "out": get_output_shape(batch.shape)}
self.alloc(shp_dict)
output = np.zeros(*self.output_spec())
# Process I/O and execute the network
common.memcpy_host_to_device(
self.inputs[0]["allocation"], np.ascontiguousarray(batch)
)
common.memcpy_host_to_device(self.inputs[0]["allocation"], np.ascontiguousarray(batch))
self.context.execute_v2(self.allocations)
common.memcpy_device_to_host(output, self.outputs[0]["allocation"])
output = torch.from_numpy(output).to(device).type(dtype)
......@@ -122,18 +122,15 @@ class HyVaeTrtModelInfer(nn.Module):
logger.info("Start to do VAE onnx exporting.")
device = next(decoder.parameters())[0].device
example_inp = torch.rand(1, 16, 17, 32, 32).to(device).type(next(decoder.parameters())[0].dtype)
out_path = str(Path(str(model_dir))/"vae_decoder.onnx")
out_path = str(Path(str(model_dir)) / "vae_decoder.onnx")
torch.onnx.export(
decoder.eval().half(),
example_inp.half(),
out_path,
input_names=['inp'],
output_names=['out'],
input_names=["inp"],
output_names=["out"],
opset_version=14,
dynamic_axes={
"inp": {1: "c1", 2: "c2", 3: "c3", 4: "c4"},
"out": {1: "c1", 2: "c2", 3: "c3", 4: "c4"}
}
dynamic_axes={"inp": {1: "c1", 2: "c2", 3: "c3", 4: "c4"}, "out": {1: "c1", 2: "c2", 3: "c3", 4: "c4"}},
)
# onnx_ori = onnx.load(out_path)
os.system(f"onnxsim {out_path} {out_path}")
......
......@@ -8,20 +8,20 @@ class BaseQuantizer(object):
self.sym = symmetric
self.granularity = granularity
self.kwargs = kwargs
if self.granularity == 'per_group':
self.group_size = self.kwargs['group_size']
self.calib_algo = self.kwargs.get('calib_algo', 'minmax')
if self.granularity == "per_group":
self.group_size = self.kwargs["group_size"]
self.calib_algo = self.kwargs.get("calib_algo", "minmax")
def get_tensor_range(self, tensor):
if self.calib_algo == 'minmax':
if self.calib_algo == "minmax":
return self.get_minmax_range(tensor)
elif self.calib_algo == 'mse':
elif self.calib_algo == "mse":
return self.get_mse_range(tensor)
else:
raise ValueError(f'Unsupported calibration algorithm: {self.calib_algo}')
raise ValueError(f"Unsupported calibration algorithm: {self.calib_algo}")
def get_minmax_range(self, tensor):
if self.granularity == 'per_tensor':
if self.granularity == "per_tensor":
max_val = torch.max(tensor)
min_val = torch.min(tensor)
else:
......@@ -47,7 +47,7 @@ class BaseQuantizer(object):
return scales, zeros, qmax, qmin
def reshape_tensor(self, tensor, allow_padding=False):
if self.granularity == 'per_group':
if self.granularity == "per_group":
t = tensor.reshape(-1, self.group_size)
else:
t = tensor
......@@ -79,7 +79,7 @@ class BaseQuantizer(object):
tensor, scales, zeros, qmax, qmin = self.get_tensor_qparams(tensor)
tensor = self.quant(tensor, scales, zeros, qmax, qmin)
tensor = self.restore_tensor(tensor, org_shape)
if self.sym == True:
if self.sym:
zeros = None
return tensor, scales, zeros
......@@ -87,9 +87,9 @@ class BaseQuantizer(object):
class IntegerQuantizer(BaseQuantizer):
def __init__(self, bit, symmetric, granularity, **kwargs):
super().__init__(bit, symmetric, granularity, **kwargs)
if 'int_range' in self.kwargs:
self.qmin = self.kwargs['int_range'][0]
self.qmax = self.kwargs['int_range'][1]
if "int_range" in self.kwargs:
self.qmin = self.kwargs["int_range"][0]
self.qmax = self.kwargs["int_range"][1]
else:
if self.sym:
self.qmin = -(2 ** (self.bit - 1))
......@@ -110,7 +110,14 @@ class IntegerQuantizer(BaseQuantizer):
tensor = (tensor - zeros) * scales
return tensor
def quant_dequant(self, tensor, scales, zeros, qmax, qmin,):
def quant_dequant(
self,
tensor,
scales,
zeros,
qmax,
qmin,
):
tensor = self.quant(tensor, scales, zeros, qmax, qmin)
tensor = self.dequant(tensor, scales, zeros)
return tensor
......@@ -119,19 +126,19 @@ class IntegerQuantizer(BaseQuantizer):
class FloatQuantizer(BaseQuantizer):
def __init__(self, bit, symmetric, granularity, **kwargs):
super().__init__(bit, symmetric, granularity, **kwargs)
assert self.bit in ['e4m3', 'e5m2'], f'Unsupported bit configuration: {self.bit}'
assert self.sym == True
assert self.bit in ["e4m3", "e5m2"], f"Unsupported bit configuration: {self.bit}"
assert self.sym
if self.bit == 'e4m3':
if self.bit == "e4m3":
self.e_bits = 4
self.m_bits = 3
self.fp_dtype = torch.float8_e4m3fn
elif self.bit == 'e5m2':
elif self.bit == "e5m2":
self.e_bits = 5
self.m_bits = 2
self.fp_dtype = torch.float8_e5m2
else:
raise ValueError(f'Unsupported bit configuration: {self.bit}')
raise ValueError(f"Unsupported bit configuration: {self.bit}")
finfo = torch.finfo(self.fp_dtype)
self.qmin, self.qmax = finfo.min, finfo.max
......@@ -141,13 +148,9 @@ class FloatQuantizer(BaseQuantizer):
def quant(self, tensor, scales, zeros, qmax, qmin):
scaled_tensor = tensor / scales + zeros
scaled_tensor = torch.clip(
scaled_tensor, self.qmin.cuda(), self.qmax.cuda()
)
scaled_tensor = torch.clip(scaled_tensor, self.qmin.cuda(), self.qmax.cuda())
org_dtype = scaled_tensor.dtype
q_tensor = float_quantize(
scaled_tensor.float(), self.e_bits, self.m_bits, rounding='nearest'
)
q_tensor = float_quantize(scaled_tensor.float(), self.e_bits, self.m_bits, rounding="nearest")
q_tensor.to(org_dtype)
return q_tensor
......@@ -161,9 +164,9 @@ class FloatQuantizer(BaseQuantizer):
return tensor
if __name__ == '__main__':
if __name__ == "__main__":
weight = torch.randn(4096, 4096, dtype=torch.bfloat16).cuda()
quantizer = IntegerQuantizer(4, False, 'per_group', group_size=128)
quantizer = IntegerQuantizer(4, False, "per_group", group_size=128)
q_weight = quantizer.fake_quant_tensor(weight)
print(weight)
print(q_weight)
......@@ -175,7 +178,7 @@ if __name__ == '__main__':
print(f"zeros = {zeros}, {zeros.shape}")
weight = torch.randn(8192, 4096, dtype=torch.bfloat16).cuda()
quantizer = FloatQuantizer('e4m3', True, 'per_channel')
quantizer = FloatQuantizer("e4m3", True, "per_channel")
q_weight = quantizer.fake_quant_tensor(weight)
print(weight)
print(q_weight)
......
......@@ -11,13 +11,13 @@ class Register(dict):
def register(self, target, key=None):
if not callable(target):
raise Exception(f'Error: {target} must be callable!')
raise Exception(f"Error: {target} must be callable!")
if key is None:
key = target.__name__
if key in self._dict:
raise Exception(f'{key} already exists.')
raise Exception(f"{key} already exists.")
self[key] = target
return target
......
......@@ -66,12 +66,7 @@ def cache_video(
# preprocess
tensor = tensor.clamp(min(value_range), max(value_range))
tensor = torch.stack(
[
torchvision.utils.make_grid(
u, nrow=nrow, normalize=normalize, value_range=value_range
)
for u in tensor.unbind(2)
],
[torchvision.utils.make_grid(u, nrow=nrow, normalize=normalize, value_range=value_range) for u in tensor.unbind(2)],
dim=1,
).permute(1, 2, 3, 0)
tensor = (tensor * 255).type(torch.uint8).cpu()
......
[tool.ruff]
exclude = [".git", ".mypy_cache", ".ruff_cache", ".venv", "dist"]
target-version = "py311"
line-length = 200
indent-width = 4
lint.ignore =["F"]
[tool.ruff.format]
line-ending = "lf"
quote-style = "double"
indent-style = "space"
File mode changed from 100644 to 100755
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment