Unverified Commit 85855252 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

fix: fix the typo in FluxModel.cpp as in #297 (#317)

* fix: fix a typo

* style: format the imports
parent ccd93d1e
from typing import Any, Dict, Optional, Union
import logging import logging
import os import os
from typing import Any, Dict, Optional, Union
import diffusers import diffusers
import torch import torch
...@@ -10,7 +9,7 @@ from diffusers.configuration_utils import register_to_config ...@@ -10,7 +9,7 @@ from diffusers.configuration_utils import register_to_config
from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.modeling_outputs import Transformer2DModelOutput
from huggingface_hub import utils from huggingface_hub import utils
from packaging.version import Version from packaging.version import Version
from safetensors.torch import load_file, save_file from safetensors.torch import load_file
from torch import nn from torch import nn
from .utils import NunchakuModelLoaderMixin, pad_tensor from .utils import NunchakuModelLoaderMixin, pad_tensor
...@@ -180,9 +179,11 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -180,9 +179,11 @@ class NunchakuFluxTransformerBlocks(nn.Module):
encoder_hidden_states = encoder_hidden_states.to(original_dtype).to(original_device) encoder_hidden_states = encoder_hidden_states.to(original_dtype).to(original_device)
return encoder_hidden_states, hidden_states return encoder_hidden_states, hidden_states
def __del__(self): def __del__(self):
self.m.reset() self.m.reset()
## copied from diffusers 0.30.3 ## copied from diffusers 0.30.3
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0, "The dimension must be even." assert dim % 2 == 0, "The dimension must be even."
......
...@@ -526,7 +526,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, ...@@ -526,7 +526,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
? pool.slice(0, i, i + 1).slice(1, 0, num_tokens_img / POOL_SIZE) ? pool.slice(0, i, i + 1).slice(1, 0, num_tokens_img / POOL_SIZE)
: Tensor{}; : Tensor{};
Tensor pool_qkv_context = pool.valid() Tensor pool_qkv_context = pool.valid()
? concat.slice(0, i, i + 1).slice(1, num_tokens_img / POOL_SIZE, num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE) ? pool.slice(0, i, i + 1).slice(1, num_tokens_img / POOL_SIZE, num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE)
: Tensor{}; : Tensor{};
// qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv); // qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment