"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "f5c113e4395bc373ab540fc5a1f7490b7120c40f"
Unverified Commit c0964571 authored by Junsong Chen's avatar Junsong Chen Committed by GitHub
Browse files

[Sana 4K] (#10493)

add 4K support for Sana
parent b13cdbb2
...@@ -25,6 +25,7 @@ from diffusers.utils.import_utils import is_accelerate_available ...@@ -25,6 +25,7 @@ from diffusers.utils.import_utils import is_accelerate_available
CTX = init_empty_weights if is_accelerate_available else nullcontext CTX = init_empty_weights if is_accelerate_available else nullcontext
ckpt_ids = [ ckpt_ids = [
"Efficient-Large-Model/Sana_1600M_4Kpx_BF16/checkpoints/Sana_1600M_4Kpx_BF16.pth",
"Efficient-Large-Model/Sana_1600M_2Kpx_BF16/checkpoints/Sana_1600M_2Kpx_BF16.pth", "Efficient-Large-Model/Sana_1600M_2Kpx_BF16/checkpoints/Sana_1600M_2Kpx_BF16.pth",
"Efficient-Large-Model/Sana_1600M_1024px_MultiLing/checkpoints/Sana_1600M_1024px_MultiLing.pth", "Efficient-Large-Model/Sana_1600M_1024px_MultiLing/checkpoints/Sana_1600M_1024px_MultiLing.pth",
"Efficient-Large-Model/Sana_1600M_1024px_BF16/checkpoints/Sana_1600M_1024px_BF16.pth", "Efficient-Large-Model/Sana_1600M_1024px_BF16/checkpoints/Sana_1600M_1024px_BF16.pth",
...@@ -89,7 +90,10 @@ def main(args): ...@@ -89,7 +90,10 @@ def main(args):
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight") converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
# scheduler # scheduler
flow_shift = 3.0 if args.image_size == 4096:
flow_shift = 6.0
else:
flow_shift = 3.0
# model config # model config
if args.model_type == "SanaMS_1600M_P1_D20": if args.model_type == "SanaMS_1600M_P1_D20":
...@@ -99,7 +103,7 @@ def main(args): ...@@ -99,7 +103,7 @@ def main(args):
else: else:
raise ValueError(f"{args.model_type} is not supported.") raise ValueError(f"{args.model_type} is not supported.")
# Positional embedding interpolation scale. # Positional embedding interpolation scale.
interpolation_scale = {512: None, 1024: None, 2048: 1.0} interpolation_scale = {512: None, 1024: None, 2048: 1.0, 4096: 2.0}
for depth in range(layer_num): for depth in range(layer_num):
# Transformer blocks. # Transformer blocks.
...@@ -272,9 +276,9 @@ if __name__ == "__main__": ...@@ -272,9 +276,9 @@ if __name__ == "__main__":
"--image_size", "--image_size",
default=1024, default=1024,
type=int, type=int,
choices=[512, 1024, 2048], choices=[512, 1024, 2048, 4096],
required=False, required=False,
help="Image size of pretrained model, 512, 1024 or 2048.", help="Image size of pretrained model, 512, 1024, 2048 or 4096.",
) )
parser.add_argument( parser.add_argument(
"--model_type", default="SanaMS_1600M_P1_D20", type=str, choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"] "--model_type", default="SanaMS_1600M_P1_D20", type=str, choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"]
......
...@@ -63,6 +63,49 @@ if is_ftfy_available(): ...@@ -63,6 +63,49 @@ if is_ftfy_available():
import ftfy import ftfy
ASPECT_RATIO_4096_BIN = {
"0.25": [2048.0, 8192.0],
"0.26": [2048.0, 7936.0],
"0.27": [2048.0, 7680.0],
"0.28": [2048.0, 7424.0],
"0.32": [2304.0, 7168.0],
"0.33": [2304.0, 6912.0],
"0.35": [2304.0, 6656.0],
"0.4": [2560.0, 6400.0],
"0.42": [2560.0, 6144.0],
"0.48": [2816.0, 5888.0],
"0.5": [2816.0, 5632.0],
"0.52": [2816.0, 5376.0],
"0.57": [3072.0, 5376.0],
"0.6": [3072.0, 5120.0],
"0.68": [3328.0, 4864.0],
"0.72": [3328.0, 4608.0],
"0.78": [3584.0, 4608.0],
"0.82": [3584.0, 4352.0],
"0.88": [3840.0, 4352.0],
"0.94": [3840.0, 4096.0],
"1.0": [4096.0, 4096.0],
"1.07": [4096.0, 3840.0],
"1.13": [4352.0, 3840.0],
"1.21": [4352.0, 3584.0],
"1.29": [4608.0, 3584.0],
"1.38": [4608.0, 3328.0],
"1.46": [4864.0, 3328.0],
"1.67": [5120.0, 3072.0],
"1.75": [5376.0, 3072.0],
"2.0": [5632.0, 2816.0],
"2.09": [5888.0, 2816.0],
"2.4": [6144.0, 2560.0],
"2.5": [6400.0, 2560.0],
"2.89": [6656.0, 2304.0],
"3.0": [6912.0, 2304.0],
"3.11": [7168.0, 2304.0],
"3.62": [7424.0, 2048.0],
"3.75": [7680.0, 2048.0],
"3.88": [7936.0, 2048.0],
"4.0": [8192.0, 2048.0],
}
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -734,7 +777,9 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin): ...@@ -734,7 +777,9 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
if use_resolution_binning: if use_resolution_binning:
if self.transformer.config.sample_size == 64: if self.transformer.config.sample_size == 128:
aspect_ratio_bin = ASPECT_RATIO_4096_BIN
elif self.transformer.config.sample_size == 64:
aspect_ratio_bin = ASPECT_RATIO_2048_BIN aspect_ratio_bin = ASPECT_RATIO_2048_BIN
elif self.transformer.config.sample_size == 32: elif self.transformer.config.sample_size == 32:
aspect_ratio_bin = ASPECT_RATIO_1024_BIN aspect_ratio_bin = ASPECT_RATIO_1024_BIN
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment