Commit 452069e3 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #693 failed with stages
in 0 seconds
### The code will be released in the future.
### The code will be released in the future.
# Copyright (c) Alibaba, Inc. and its affiliates.
from facechain.utils import project_dir
neg_prompt = '(nsfw:2), paintings, sketches, (worst quality:2), (low quality:2), ' \
'lowers, normal quality, ((monochrome)), ((grayscale)), logo, word, character, bad hand, tattoo, (username, watermark, signature, time signature, timestamp, artist name, copyright name, copyright),'\
'low res, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, glans, extra fingers, fewer fingers, strange fingers, bad hand, mole, ((extra legs)), ((extra hands))'
pos_prompt_with_cloth = 'raw photo, masterpiece, chinese, {}, solo, medium shot, high detail face, looking straight into the camera with shoulders parallel to the frame, photorealistic, best quality'
pos_prompt_with_style = '{}, upper_body, raw photo, masterpiece, solo, medium shot, high detail face, photorealistic, best quality'
base_models = [
{'name': 'leosamsMoonfilm_filmGrain20',
'model_id': 'ly261666/cv_portrait_model',
'revision': 'v2.0',
'sub_path': "film/film"},
{'name': 'MajicmixRealistic_v6',
'model_id': 'YorickHe/majicmixRealistic_v6',
'revision': 'v1.0.0',
'sub_path': "realistic"},
{'name': 'sdxl_1.0',
'model_id': 'AI-ModelScope/stable-diffusion-xl-base-1.0',
'revision': 'v1.0.0',
'sub_path': ""},
]
pose_models = [
{'name': '无姿态控制(No pose control)'},
{'name': 'pose-v1.1-with-depth'},
{'name': 'pose-v1.1'}
]
pose_examples = {
'man': [
[f'{project_dir}/poses/man/pose1.png'],
[f'{project_dir}/poses/man/pose2.png'],
[f'{project_dir}/poses/man/pose3.png'],
[f'{project_dir}/poses/man/pose4.png']
],
'woman': [
[f'{project_dir}/poses/woman/pose1.png'],
[f'{project_dir}/poses/woman/pose2.png'],
[f'{project_dir}/poses/woman/pose3.png'],
[f'{project_dir}/poses/woman/pose4.png'],
]
}
tts_speakers_map = {
'普通话(中国大陆)-Xiaoxiao-女': 'zh-CN-XiaoxiaoNeural',
'普通话(中国大陆)-Xiaoyi-女': 'zh-CN-XiaoyiNeural',
'普通话(中国大陆)-Yunjian-男': 'zh-CN-YunjianNeural',
'普通话(中国大陆)-Yunxi-男': 'zh-CN-YunxiNeural',
'普通话(中国大陆)-Yunxia-男': 'zh-CN-YunxiaNeural',
'普通话(中国大陆)-Yunyang-男': 'zh-CN-YunyangNeural',
'普通话(中国辽宁)-Xiaobei-女': 'zh-CN-liaoning-XiaobeiNeural',
'普通话(中国陕西)-Xiaoni-女': 'zh-CN-shaanxi-XiaoniNeural',
'普通话(中国台湾)-HsiaoChen-女': 'zh-TW-HsiaoChenNeural',
'普通话(中国台湾)-HsiaoYu-女': 'zh-TW-HsiaoYuNeural',
'普通话(中国台湾)-YunJhe-男': 'zh-TW-YunJheNeural',
'粤语(中国香港)-HiuMaan-女': 'zh-HK-HiuMaanNeural',
'粤语(中国香港)-HiuGaai-女': 'zh-HK-HiuGaaiNeural',
'粤语(中国香港)-WanLung-男': 'zh-HK-WanLungNeural',
'英语(美国)-Jenny-女': 'en-US-JennyNeural',
'英语(美国)-Guy-男': 'en-US-GuyNeural',
'英语(美国)-Ana-女': 'en-US-AnaNeural',
'英语(美国)-Aria-女': 'en-US-AriaNeural',
'英语(美国)-Christopher-男': 'en-US-ChristopherNeural',
'英语(美国)-Eric-男': 'en-US-EricNeural',
'英语(美国)-Michelle-女': 'en-US-MichelleNeural',
'英语(美国)-Roger-男': 'en-US-RogerNeural',
'英语(澳大利亚)-Natasha-女': 'en-AU-NatashaNeural',
'英语(澳大利亚)-William-男': 'en-AU-WilliamNeural',
'英语(加拿大)-Clara-女': 'en-CA-ClaraNeural',
'英语(加拿大)-Liam-男': 'en-CA-LiamNeural',
'英语(英国)-Libby-女': 'en-GB-LibbyNeural',
'英语(英国)-Maisie-女': 'en-GB-MaisieNeural',
'英语(英国)-Ryan-男': 'en-GB-RyanNeural',
'英语(英国)-Sonia-女': 'en-GB-SoniaNeural',
'英语(英国)-Thomas-男': 'en-GB-ThomasNeural',
'英语(香港)-Sam-男': 'en-HK-SamNeural',
'英语(香港)-Yan-女': 'en-HK-YanNeural',
'英语(爱尔兰)-Connor-男': 'en-IE-ConnorNeural',
'英语(爱尔兰)-Emily-女': 'en-IE-EmilyNeural',
'英语(印度)-Neerja-女': 'en-IN-NeerjaNeural',
'英语(印度)-Prabhat-男': 'en-IN-PrabhatNeural',
'英语(肯尼亚)-Asilia-女': 'en-KE-AsiliaNeural',
'英语(肯尼亚)-Chilemba-男': 'en-KE-ChilembaNeural',
'英语(尼日利亚)-Abeo-男': 'en-NG-AbeoNeural',
'英语(尼日利亚)-Ezinne-女': 'en-NG-EzinneNeural',
'英语(新西兰)-Mitchell-男': 'en-NZ-MitchellNeural',
'英语(菲律宾)-James-男': 'en-PH-JamesNeural',
'英语(菲律宾)-Rosa-女': 'en-PH-RosaNeural',
'英语(新加坡)-Luna-女': 'en-SG-LunaNeural',
'英语(新加坡)-Wayne-男': 'en-SG-WayneNeural',
'英语(坦桑尼亚)-Elimu-男': 'en-TZ-ElimuNeural',
'英语(坦桑尼亚)-Imani-女': 'en-TZ-ImaniNeural',
'英语(南非)-Leah-女': 'en-ZA-LeahNeural',
'英语(南非)-Luke-男': 'en-ZA-LukeNeural',
'韩语(韩国)-SunHi-女': 'ko-KR-SunHiNeural',
'韩语(韩国)-InJoon-男': 'ko-KR-InJoonNeural',
'泰语(泰国)-Premwadee-女': 'th-TH-PremwadeeNeural',
'泰语(泰国)-Niwat-男': 'th-TH-NiwatNeural',
'越南语(越南)-HoaiMy-女': 'vi-VN-HoaiMyNeural',
'越南语(越南)-NamMinh-男': 'vi-VN-NamMinhNeural',
'日语(日本)-Nanami-女': 'ja-JP-NanamiNeural',
'日语(日本)-Keita-男': 'ja-JP-KeitaNeural',
'法语(法国)-Denise-女': 'fr-FR-DeniseNeural',
'法语(法国)-Eloise-女': 'fr-FR-EloiseNeural',
'法语(法国)-Henri-男': 'fr-FR-HenriNeural',
'法语(比利时)-Charline-女': 'fr-BE-CharlineNeural',
'法语(比利时)-Gerard-男': 'fr-BE-GerardNeural',
'法语(加拿大)-Sylvie-女': 'fr-CA-SylvieNeural',
'法语(加拿大)-Antoine-男': 'fr-CA-AntoineNeural',
'法语(加拿大)-Jean-男': 'fr-CA-JeanNeural',
'法语(瑞士)-Ariane-女': 'fr-CH-ArianeNeural',
'法语(瑞士)-Fabrice-男': 'fr-CH-FabriceNeural',
'葡萄牙语(巴西)-Francisca-女': 'pt-BR-FranciscaNeural',
'葡萄牙语(巴西)-Antonio-男': 'pt-BR-AntonioNeural',
'葡萄牙语(葡萄牙)-Duarte-男': 'pt-PT-DuarteNeural',
'葡萄牙语(葡萄牙)-Raquel-女': 'pt-PT-RaquelNeural',
'意大利语(意大利)-Isabella-女': 'it-IT-IsabellaNeural',
'意大利语(意大利)-Diego-男': 'it-IT-DiegoNeural',
'意大利语(意大利)-Elsa-女': 'it-IT-ElsaNeural',
'荷兰语(荷兰)-Colette-女': 'nl-NL-ColetteNeural',
'荷兰语(荷兰)-Fenna-女': 'nl-NL-FennaNeural',
'荷兰语(荷兰)-Maarten-男': 'nl-NL-MaartenNeural',
'荷兰语(比利时)-Arnaud-男': 'nl-BE-ArnaudNeural',
'荷兰语(比利时)-Dena-女': 'nl-BE-DenaNeural',
'挪威语(挪威)-Pernille-女': 'nb-NO-PernilleNeural',
'挪威语(挪威)-Finn-男': 'nb-NO-FinnNeural',
'瑞典语(瑞典)-Sofie-女': 'sv-SE-SofieNeural',
'瑞典语(瑞典)-Mattias-男': 'sv-SE-MattiasNeural',
'希腊语(希腊)-Athina-女': 'el-GR-AthinaNeural',
'希腊语(希腊)-Nestoras-男': 'el-GR-NestorasNeural',
'德语(德国)-Katja-女': 'de-DE-KatjaNeural',
'德语(德国)-Amala-女': 'de-DE-AmalaNeural',
'德语(德国)-Conrad-男': 'de-DE-ConradNeural',
'德语(德国)-Killian-男': 'de-DE-KillianNeural',
'德语(奥地利)-Ingrid-女': 'de-AT-IngridNeural',
'德语(奥地利)-Jonas-男': 'de-AT-JonasNeural',
'德语(瑞士)-Jan-男': 'de-CH-JanNeural',
'德语(瑞士)-Leni-女': 'de-CH-LeniNeural',
'丹麦语(丹麦)-Christel-女': 'da-DK-ChristelNeural',
'丹麦语(丹麦)-Jeppe-男': 'da-DK-JeppeNeural',
'西班牙语(墨西哥)-Dalia-女': 'es-MX-DaliaNeural',
'西班牙语(墨西哥)-Jorge-男': 'es-MX-JorgeNeural',
'西班牙语(阿根廷)-Elena-女': 'es-AR-ElenaNeural',
'西班牙语(阿根廷)-Tomas-男': 'es-AR-TomasNeural',
'西班牙语(玻利维亚)-Marcelo-男': 'es-BO-MarceloNeural',
'西班牙语(玻利维亚)-Sofia-女': 'es-BO-SofiaNeural',
'西班牙语(哥伦比亚)-Gonzalo-男': 'es-CO-GonzaloNeural',
'西班牙语(哥伦比亚)-Salome-女': 'es-CO-SalomeNeural',
'西班牙语(哥斯达黎加)-Juan-男': 'es-CR-JuanNeural',
'西班牙语(哥斯达黎加)-Maria-女': 'es-CR-MariaNeural',
'西班牙语(古巴)-Belkys-女': 'es-CU-BelkysNeural',
'西班牙语(多米尼加共和国)-Emilio-男': 'es-DO-EmilioNeural',
'西班牙语(多米尼加共和国)-Ramona-女': 'es-DO-RamonaNeural',
'西班牙语(厄瓜多尔)-Andrea-女': 'es-EC-AndreaNeural',
'西班牙语(厄瓜多尔)-Luis-男': 'es-EC-LuisNeural',
'西班牙语(西班牙)-Alvaro-男': 'es-ES-AlvaroNeural',
'西班牙语(西班牙)-Elvira-女': 'es-ES-ElviraNeural',
'西班牙语(赤道几内亚)-Teresa-女': 'es-GQ-TeresaNeural',
'西班牙语(危地马拉)-Andres-男': 'es-GT-AndresNeural',
'西班牙语(危地马拉)-Marta-女': 'es-GT-MartaNeural',
'西班牙语(洪都拉斯)-Carlos-男': 'es-HN-CarlosNeural',
'西班牙语(洪都拉斯)-Karla-女': 'es-HN-KarlaNeural',
'西班牙语(尼加拉瓜)-Federico-男': 'es-NI-FedericoNeural',
'西班牙语(尼加拉瓜)-Yolanda-女': 'es-NI-YolandaNeural',
'西班牙语(巴拿马)-Margarita-女': 'es-PA-MargaritaNeural',
'西班牙语(巴拿马)-Roberto-男': 'es-PA-RobertoNeural',
'西班牙语(秘鲁)-Alex-男': 'es-PE-AlexNeural',
'西班牙语(秘鲁)-Camila-女': 'es-PE-CamilaNeural',
'西班牙语(波多黎各)-Karina-女': 'es-PR-KarinaNeural',
'西班牙语(波多黎各)-Victor-男': 'es-PR-VictorNeural',
'西班牙语(巴拉圭)-Mario-男': 'es-PY-MarioNeural',
'西班牙语(巴拉圭)-Tania-女': 'es-PY-TaniaNeural',
'西班牙语(萨尔瓦多)-Lorena-女': 'es-SV-LorenaNeural',
'西班牙语(萨尔瓦多)-Rodrigo-男': 'es-SV-RodrigoNeural',
'西班牙语(美国)-Alonso-男': 'es-US-AlonsoNeural',
'西班牙语(美国)-Paloma-女': 'es-US-PalomaNeural',
'西班牙语(乌拉圭)-Mateo-男': 'es-UY-MateoNeural',
'西班牙语(乌拉圭)-Valentina-女': 'es-UY-ValentinaNeural',
'西班牙语(委内瑞拉)-Paola-女': 'es-VE-PaolaNeural',
'西班牙语(委内瑞拉)-Sebastian-男': 'es-VE-SebastianNeural',
'俄语(俄罗斯)-Svetlana-女': 'ru-RU-SvetlanaNeural',
'俄语(俄罗斯)-Dmitry-男': 'ru-RU-DmitryNeural',
'阿姆哈拉语(埃塞俄比亚)-Ameha-男': 'am-ET-AmehaNeural',
'阿姆哈拉语(埃塞俄比亚)-Mekdes-女': 'am-ET-MekdesNeural',
'阿拉伯语(沙特阿拉伯)-Hamed-男': 'ar-SA-HamedNeural',
'阿拉伯语(沙特阿拉伯)-Zariyah-女': 'ar-SA-ZariyahNeural',
'阿拉伯语(阿拉伯联合酋长国)-Fatima-女': 'ar-AE-FatimaNeural',
'阿拉伯语(阿拉伯联合酋长国)-Hamdan-男': 'ar-AE-HamdanNeural',
'阿拉伯语(巴林)-Ali-男': 'ar-BH-AliNeural',
'阿拉伯语(巴林)-Laila-女': 'ar-BH-LailaNeural',
'阿拉伯语(阿尔及利亚)-Ismael-男': 'ar-DZ-IsmaelNeural',
'阿拉伯语(埃及)-Salma-女': 'ar-EG-SalmaNeural',
'阿拉伯语(埃及)-Shakir-男': 'ar-EG-ShakirNeural',
'阿拉伯语(伊拉克)-Bassel-男': 'ar-IQ-BasselNeural',
'阿拉伯语(伊拉克)-Rana-女': 'ar-IQ-RanaNeural',
'阿拉伯语(约旦)-Sana-女': 'ar-JO-SanaNeural',
'阿拉伯语(约旦)-Taim-男': 'ar-JO-TaimNeural',
'阿拉伯语(科威特)-Fahed-男': 'ar-KW-FahedNeural',
'阿拉伯语(科威特)-Noura-女': 'ar-KW-NouraNeural',
'阿拉伯语(黎巴嫩)-Layla-女': 'ar-LB-LaylaNeural',
'阿拉伯语(黎巴嫩)-Rami-男': 'ar-LB-RamiNeural',
'阿拉伯语(利比亚)-Iman-女': 'ar-LY-ImanNeural',
'阿拉伯语(利比亚)-Omar-男': 'ar-LY-OmarNeural',
'阿拉伯语(摩洛哥)-Jamal-男': 'ar-MA-JamalNeural',
'阿拉伯语(摩洛哥)-Mouna-女': 'ar-MA-MounaNeural',
'阿拉伯语(阿曼)-Abdullah-男': 'ar-OM-AbdullahNeural',
'阿拉伯语(阿曼)-Aysha-女': 'ar-OM-AyshaNeural',
'阿拉伯语(卡塔尔)-Amal-女': 'ar-QA-AmalNeural',
'阿拉伯语(卡塔尔)-Moaz-男': 'ar-QA-MoazNeural',
'阿拉伯语(叙利亚)-Amany-女': 'ar-SY-AmanyNeural',
'阿拉伯语(叙利亚)-Laith-男': 'ar-SY-LaithNeural',
'阿拉伯语(突尼斯)-Hedi-男': 'ar-TN-HediNeural',
'阿拉伯语(突尼斯)-Reem-女': 'ar-TN-ReemNeural',
'阿拉伯语(也门)-Maryam-女': 'ar-YE-MaryamNeural',
'阿拉伯语(也门)-Saleh-男': 'ar-YE-SalehNeural',
'阿拉伯语(南非)-Adri-女': 'af-ZA-AdriNeural',
'阿拉伯语(南非)-Willem-男': 'af-ZA-WillemNeural',
'阿塞拜疆语(阿塞拜疆)-Babek-男': 'az-AZ-BabekNeural',
'阿塞拜疆语(阿塞拜疆)-Banu-女': 'az-AZ-BanuNeural',
'保加利亚语(保加利亚)-Borislav-男': 'bg-BG-BorislavNeural',
'保加利亚语(保加利亚)-Kalina-女': 'bg-BG-KalinaNeural',
'孟加拉语(孟加拉国)-Nabanita-女': 'bn-BD-NabanitaNeural',
'孟加拉语(孟加拉国)-Pradeep-男': 'bn-BD-PradeepNeural',
'孟加拉语(印度)-Bashkar-男': 'bn-IN-BashkarNeural',
'孟加拉语(印度)-Tanishaa-女': 'bn-IN-TanishaaNeural',
'波斯尼亚语(波斯尼亚和黑塞哥维那)-Goran-男': 'bs-BA-GoranNeural',
'波斯尼亚语(波斯尼亚和黑塞哥维那)-Vesna-女': 'bs-BA-VesnaNeural',
'加泰罗尼亚语(西班牙)-Joana-女': 'ca-ES-JoanaNeural',
'加泰罗尼亚语(西班牙)-Enric-男': 'ca-ES-EnricNeural',
'捷克语(捷克共和国)-Antonin-男': 'cs-CZ-AntoninNeural',
'捷克语(捷克共和国)-Vlasta-女': 'cs-CZ-VlastaNeural',
'威尔士语(英国)-Aled-男': 'cy-GB-AledNeural',
'威尔士语(英国)-Nia-女': 'cy-GB-NiaNeural',
'印度尼西亚语(印度尼西亚)-Ardi-男': 'id-ID-ArdiNeural',
'印度尼西亚语(印度尼西亚)-Gadis-女': 'id-ID-GadisNeural',
'希伯来语(以色列)-Avri-男': 'he-IL-AvriNeural',
'希伯来语(以色列)-Hila-女': 'he-IL-HilaNeural',
'爱沙尼亚语(爱沙尼亚)-Anu-女': 'et-EE-AnuNeural',
'爱沙尼亚语(爱沙尼亚)-Kert-男': 'et-EE-KertNeural',
'波斯语(伊朗)-Dilara-女': 'fa-IR-DilaraNeural',
'波斯语(伊朗)-Farid-男': 'fa-IR-FaridNeural',
'芬兰语(芬兰)-Harri-男': 'fi-FI-HarriNeural',
'芬兰语(芬兰)-Noora-女': 'fi-FI-NooraNeural',
'爱尔兰语(爱尔兰)-Colm-男': 'ga-IE-ColmNeural',
'爱尔兰语(爱尔兰)-Orla-女': 'ga-IE-OrlaNeural',
'马来语(马来西亚)-Osman-男': 'ms-MY-OsmanNeural',
'马来语(马来西亚)-Yasmin-女': 'ms-MY-YasminNeural',
'加利西亚语(西班牙)-Roi-男': 'gl-ES-RoiNeural',
'加利西亚语(西班牙)-Sabela-女': 'gl-ES-SabelaNeural',
'古吉拉特语(印度)-Dhwani-女': 'gu-IN-DhwaniNeural',
'古吉拉特语(印度)-Niranjan-男': 'gu-IN-NiranjanNeural',
'印地语(印度)-Madhur-男': 'hi-IN-MadhurNeural',
'印地语(印度)-Swara-女': 'hi-IN-SwaraNeural',
'克罗地亚语(克罗地亚)-Gabrijela-女': 'hr-HR-GabrijelaNeural',
'克罗地亚语(克罗地亚)-Srecko-男': 'hr-HR-SreckoNeural',
'匈牙利语(匈牙利)-Noemi-女': 'hu-HU-NoemiNeural',
'匈牙利语(匈牙利)-Tamas-男': 'hu-HU-TamasNeural',
'冰岛语(冰岛)-Gudrun-女': 'is-IS-GudrunNeural',
'冰岛语(冰岛)-Gunnar-男': 'is-IS-GunnarNeural',
'爪哇语(印度尼西亚)-Dimas-男': 'jv-ID-DimasNeural',
'爪哇语(印度尼西亚)-Siti-女': 'jv-ID-SitiNeural',
'格鲁吉亚语(格鲁吉亚)-Eka-女': 'ka-GE-EkaNeural',
'格鲁吉亚语(格鲁吉亚)-Giorgi-男': 'ka-GE-GiorgiNeural',
'哈萨克语(哈萨克斯坦)-Aigul-女': 'kk-KZ-AigulNeural',
'哈萨克语(哈萨克斯坦)-Daulet-男': 'kk-KZ-DauletNeural',
'高棉语(柬埔寨)-Piseth-男': 'km-KH-PisethNeural',
'高棉语(柬埔寨)-Sreymom-女': 'km-KH-SreymomNeural',
'卡纳达语(印度)-Gagan-男': 'kn-IN-GaganNeural',
'卡纳达语(印度)-Sapna-女': 'kn-IN-SapnaNeural',
'老挝语(老挝)-Chanthavong-男': 'lo-LA-ChanthavongNeural',
'老挝语(老挝)-Keomany-女': 'lo-LA-KeomanyNeural',
'立陶宛语(立陶宛)-Leonas-男': 'lt-LT-LeonasNeural',
'立陶宛语(立陶宛)-Ona-女': 'lt-LT-OnaNeural',
'拉脱维亚语(拉脱维亚)-Everita-女': 'lv-LV-EveritaNeural',
'拉脱维亚语(拉脱维亚)-Nils-男': 'lv-LV-NilsNeural',
'马其顿语(北马其顿共和国)-Aleksandar-男': 'mk-MK-AleksandarNeural',
'马其顿语(北马其顿共和国)-Marija-女': 'mk-MK-MarijaNeural',
'马拉雅拉姆语(印度)-Midhun-男': 'ml-IN-MidhunNeural',
'马拉雅拉姆语(印度)-Sobhana-女': 'ml-IN-SobhanaNeural',
'蒙古语(蒙古)-Bataa-男': 'mn-MN-BataaNeural',
'蒙古语(蒙古)-Yesui-女': 'mn-MN-YesuiNeural',
'马拉地语(印度)-Aarohi-女': 'mr-IN-AarohiNeural',
'马拉地语(印度)-Manohar-男': 'mr-IN-ManoharNeural',
'马耳他语(马耳他)-Grace-女': 'mt-MT-GraceNeural',
'马耳他语(马耳他)-Joseph-男': 'mt-MT-JosephNeural',
'缅甸语(缅甸)-Nilar-女': 'my-MM-NilarNeural',
'缅甸语(缅甸)-Thiha-男': 'my-MM-ThihaNeural',
'尼泊尔语(尼泊尔)-Hemkala-女': 'ne-NP-HemkalaNeural',
'尼泊尔语(尼泊尔)-Sagar-男': 'ne-NP-SagarNeural',
'波兰语(波兰)-Marek-男': 'pl-PL-MarekNeural',
'波兰语(波兰)-Zofia-女': 'pl-PL-ZofiaNeural',
'普什图语(阿富汗)-Gul Nawaz-男': 'ps-AF-GulNawazNeural',
'普什图语(阿富汗)-Latifa-女': 'ps-AF-LatifaNeural',
'罗马尼亚语(罗马尼亚)-Alina-女': 'ro-RO-AlinaNeural',
'罗马尼亚语(罗马尼亚)-Emil-男': 'ro-RO-EmilNeural',
'僧伽罗语(斯里兰卡)-Sameera-男': 'si-LK-SameeraNeural',
'僧伽罗语(斯里兰卡)-Thilini-女': 'si-LK-ThiliniNeural',
'斯洛伐克语(斯洛伐克)-Lukas-男': 'sk-SK-LukasNeural',
'斯洛伐克语(斯洛伐克)-Viktoria-女': 'sk-SK-ViktoriaNeural',
'斯洛文尼亚语(斯洛文尼亚)-Petra-女': 'sl-SI-PetraNeural',
'斯洛文尼亚语(斯洛文尼亚)-Rok-男': 'sl-SI-RokNeural',
'索马里语(索马里)-Muuse-男': 'so-SO-MuuseNeural',
'索马里语(索马里)-Ubax-女': 'so-SO-UbaxNeural',
'阿尔巴尼亚语(阿尔巴尼亚)-Anila-女': 'sq-AL-AnilaNeural',
'阿尔巴尼亚语(阿尔巴尼亚)-Ilir-男': 'sq-AL-IlirNeural',
'塞尔维亚语(塞尔维亚)-Nicholas-男': 'sr-RS-NicholasNeural',
'塞尔维亚语(塞尔维亚)-Sophie-女': 'sr-RS-SophieNeural',
'巽他语(印度尼西亚)-Jajang-男': 'su-ID-JajangNeural',
'巽他语(印度尼西亚)-Tuti-女': 'su-ID-TutiNeural',
'斯瓦希里语(肯尼亚)-Rafiki-男': 'sw-KE-RafikiNeural',
'斯瓦希里语(肯尼亚)-Zuri-女': 'sw-KE-ZuriNeural',
'斯瓦希里语(坦桑尼亚)-Daudi-男': 'sw-TZ-DaudiNeural',
'斯瓦希里语(坦桑尼亚)-Rehema-女': 'sw-TZ-RehemaNeural',
'泰米尔语(印度)-Pallavi-女': 'ta-IN-PallaviNeural',
'泰米尔语(印度)-Valluvar-男': 'ta-IN-ValluvarNeural',
'泰米尔语(斯里兰卡)-Kumar-男': 'ta-LK-KumarNeural',
'泰米尔语(斯里兰卡)-Saranya-女': 'ta-LK-SaranyaNeural',
'泰米尔语(马来西亚)-Kani-女': 'ta-MY-KaniNeural',
'泰米尔语(马来西亚)-Surya-男': 'ta-MY-SuryaNeural',
'泰米尔语(新加坡)-Anbu-男': 'ta-SG-AnbuNeural',
'泰卢固语(印度)-Mohan-男': 'te-IN-MohanNeural',
'泰卢固语(印度)-Shruti-女': 'te-IN-ShrutiNeural',
'土耳其语(土耳其)-Ahmet-男': 'tr-TR-AhmetNeural',
'土耳其语(土耳其)-Emel-女': 'tr-TR-EmelNeural',
'乌克兰语(乌克兰)-Ostap-男': 'uk-UA-OstapNeural',
'乌克兰语(乌克兰)-Polina-女': 'uk-UA-PolinaNeural',
'乌尔都语(印度)-Gul-女': 'ur-IN-GulNeural',
'乌尔都语(印度)-Salman-男': 'ur-IN-SalmanNeural',
'乌尔都语(巴基斯坦)-Asad-男': 'ur-PK-AsadNeural',
'乌尔都语(巴基斯坦)-Uzma-女': 'ur-PK-UzmaNeural',
'乌兹别克语(乌兹别克斯坦)-Madina-女': 'uz-UZ-MadinaNeural',
'乌兹别克语(乌兹别克斯坦)-Sardor-男': 'uz-UZ-SardorNeural',
'祖鲁语(南非)-Thando-女': 'zu-ZA-ThandoNeural',
'祖鲁语(南非)-Themba-男': 'zu-ZA-ThembaNeural'
}
\ No newline at end of file
# Copyright (c) Alibaba, Inc. and its affiliates.
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import re
from PIL import Image
import numpy as np
re_special = re.compile(r'([\\()])')
import torch
import torch.nn as nn
import torch.nn.functional as F
from modelscope.hub.snapshot_download import snapshot_download
# see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
class DeepDanbooruModel(nn.Module):
def __init__(self):
super(DeepDanbooruModel, self).__init__()
self.tags = []
self.n_Conv_0 = nn.Conv2d(kernel_size=(7, 7), in_channels=3, out_channels=64, stride=(2, 2))
self.n_MaxPool_0 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2))
self.n_Conv_1 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
self.n_Conv_2 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=64)
self.n_Conv_3 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
self.n_Conv_4 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
self.n_Conv_5 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
self.n_Conv_6 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
self.n_Conv_7 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
self.n_Conv_8 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
self.n_Conv_9 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
self.n_Conv_10 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
self.n_Conv_11 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=512, stride=(2, 2))
self.n_Conv_12 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=128)
self.n_Conv_13 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128, stride=(2, 2))
self.n_Conv_14 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
self.n_Conv_15 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
self.n_Conv_16 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
self.n_Conv_17 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
self.n_Conv_18 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
self.n_Conv_19 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
self.n_Conv_20 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
self.n_Conv_21 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
self.n_Conv_22 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
self.n_Conv_23 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
self.n_Conv_24 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
self.n_Conv_25 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
self.n_Conv_26 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
self.n_Conv_27 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
self.n_Conv_28 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
self.n_Conv_29 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
self.n_Conv_30 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
self.n_Conv_31 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
self.n_Conv_32 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
self.n_Conv_33 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
self.n_Conv_34 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
self.n_Conv_35 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
self.n_Conv_36 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=1024, stride=(2, 2))
self.n_Conv_37 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=256)
self.n_Conv_38 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
self.n_Conv_39 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_40 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_41 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_42 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_43 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_44 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_45 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_46 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_47 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_48 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_49 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_50 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_51 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_52 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_53 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_54 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_55 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_56 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_57 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_58 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_59 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_60 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_61 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_62 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_63 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_64 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_65 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_66 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_67 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_68 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_69 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_70 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_71 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_72 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_73 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_74 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_75 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_76 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_77 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_78 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_79 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_80 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_81 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_82 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_83 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_84 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_85 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_86 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_87 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_88 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_89 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_90 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_91 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_92 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_93 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_94 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_95 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_96 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_97 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_98 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
self.n_Conv_99 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_100 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=1024, stride=(2, 2))
self.n_Conv_101 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_102 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_103 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_104 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_105 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_106 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_107 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_108 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_109 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_110 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_111 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_112 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_113 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_114 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_115 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_116 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_117 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_118 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_119 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_120 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_121 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_122 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_123 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_124 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_125 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_126 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_127 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_128 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_129 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_130 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_131 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_132 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_133 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_134 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_135 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_136 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_137 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_138 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_139 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_140 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_141 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_142 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_143 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_144 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_145 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_146 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_147 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_148 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_149 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_150 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_151 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_152 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_153 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_154 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_155 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
self.n_Conv_156 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
self.n_Conv_157 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
self.n_Conv_158 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=2048, stride=(2, 2))
self.n_Conv_159 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=512)
self.n_Conv_160 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512, stride=(2, 2))
self.n_Conv_161 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
self.n_Conv_162 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
self.n_Conv_163 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
self.n_Conv_164 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
self.n_Conv_165 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
self.n_Conv_166 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
self.n_Conv_167 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
self.n_Conv_168 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=4096, stride=(2, 2))
self.n_Conv_169 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=1024)
self.n_Conv_170 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024, stride=(2, 2))
self.n_Conv_171 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
self.n_Conv_172 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
self.n_Conv_173 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
self.n_Conv_174 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
self.n_Conv_175 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
self.n_Conv_176 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
self.n_Conv_177 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
self.n_Conv_178 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=9176, bias=False)
def forward(self, *inputs):
t_358, = inputs
t_359 = t_358.permute(*[0, 3, 1, 2])
t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0)
# t_360 = self.n_Conv_0(t_359_padded.to(self.n_Conv_0.bias.dtype) if devices.unet_needs_upcast else t_359_padded)
t_360 = self.n_Conv_0(t_359_padded.to(self.n_Conv_0.bias.dtype))
t_361 = F.relu(t_360)
t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf'))
t_362 = self.n_MaxPool_0(t_361)
t_363 = self.n_Conv_1(t_362)
t_364 = self.n_Conv_2(t_362)
t_365 = F.relu(t_364)
t_365_padded = F.pad(t_365, [1, 1, 1, 1], value=0)
t_366 = self.n_Conv_3(t_365_padded)
t_367 = F.relu(t_366)
t_368 = self.n_Conv_4(t_367)
t_369 = torch.add(t_368, t_363)
t_370 = F.relu(t_369)
t_371 = self.n_Conv_5(t_370)
t_372 = F.relu(t_371)
t_372_padded = F.pad(t_372, [1, 1, 1, 1], value=0)
t_373 = self.n_Conv_6(t_372_padded)
t_374 = F.relu(t_373)
t_375 = self.n_Conv_7(t_374)
t_376 = torch.add(t_375, t_370)
t_377 = F.relu(t_376)
t_378 = self.n_Conv_8(t_377)
t_379 = F.relu(t_378)
t_379_padded = F.pad(t_379, [1, 1, 1, 1], value=0)
t_380 = self.n_Conv_9(t_379_padded)
t_381 = F.relu(t_380)
t_382 = self.n_Conv_10(t_381)
t_383 = torch.add(t_382, t_377)
t_384 = F.relu(t_383)
t_385 = self.n_Conv_11(t_384)
t_386 = self.n_Conv_12(t_384)
t_387 = F.relu(t_386)
t_387_padded = F.pad(t_387, [0, 1, 0, 1], value=0)
t_388 = self.n_Conv_13(t_387_padded)
t_389 = F.relu(t_388)
t_390 = self.n_Conv_14(t_389)
t_391 = torch.add(t_390, t_385)
t_392 = F.relu(t_391)
t_393 = self.n_Conv_15(t_392)
t_394 = F.relu(t_393)
t_394_padded = F.pad(t_394, [1, 1, 1, 1], value=0)
t_395 = self.n_Conv_16(t_394_padded)
t_396 = F.relu(t_395)
t_397 = self.n_Conv_17(t_396)
t_398 = torch.add(t_397, t_392)
t_399 = F.relu(t_398)
t_400 = self.n_Conv_18(t_399)
t_401 = F.relu(t_400)
t_401_padded = F.pad(t_401, [1, 1, 1, 1], value=0)
t_402 = self.n_Conv_19(t_401_padded)
t_403 = F.relu(t_402)
t_404 = self.n_Conv_20(t_403)
t_405 = torch.add(t_404, t_399)
t_406 = F.relu(t_405)
t_407 = self.n_Conv_21(t_406)
t_408 = F.relu(t_407)
t_408_padded = F.pad(t_408, [1, 1, 1, 1], value=0)
t_409 = self.n_Conv_22(t_408_padded)
t_410 = F.relu(t_409)
t_411 = self.n_Conv_23(t_410)
t_412 = torch.add(t_411, t_406)
t_413 = F.relu(t_412)
t_414 = self.n_Conv_24(t_413)
t_415 = F.relu(t_414)
t_415_padded = F.pad(t_415, [1, 1, 1, 1], value=0)
t_416 = self.n_Conv_25(t_415_padded)
t_417 = F.relu(t_416)
t_418 = self.n_Conv_26(t_417)
t_419 = torch.add(t_418, t_413)
t_420 = F.relu(t_419)
t_421 = self.n_Conv_27(t_420)
t_422 = F.relu(t_421)
t_422_padded = F.pad(t_422, [1, 1, 1, 1], value=0)
t_423 = self.n_Conv_28(t_422_padded)
t_424 = F.relu(t_423)
t_425 = self.n_Conv_29(t_424)
t_426 = torch.add(t_425, t_420)
t_427 = F.relu(t_426)
t_428 = self.n_Conv_30(t_427)
t_429 = F.relu(t_428)
t_429_padded = F.pad(t_429, [1, 1, 1, 1], value=0)
t_430 = self.n_Conv_31(t_429_padded)
t_431 = F.relu(t_430)
t_432 = self.n_Conv_32(t_431)
t_433 = torch.add(t_432, t_427)
t_434 = F.relu(t_433)
t_435 = self.n_Conv_33(t_434)
t_436 = F.relu(t_435)
t_436_padded = F.pad(t_436, [1, 1, 1, 1], value=0)
t_437 = self.n_Conv_34(t_436_padded)
t_438 = F.relu(t_437)
t_439 = self.n_Conv_35(t_438)
t_440 = torch.add(t_439, t_434)
t_441 = F.relu(t_440)
t_442 = self.n_Conv_36(t_441)
t_443 = self.n_Conv_37(t_441)
t_444 = F.relu(t_443)
t_444_padded = F.pad(t_444, [0, 1, 0, 1], value=0)
t_445 = self.n_Conv_38(t_444_padded)
t_446 = F.relu(t_445)
t_447 = self.n_Conv_39(t_446)
t_448 = torch.add(t_447, t_442)
t_449 = F.relu(t_448)
t_450 = self.n_Conv_40(t_449)
t_451 = F.relu(t_450)
t_451_padded = F.pad(t_451, [1, 1, 1, 1], value=0)
t_452 = self.n_Conv_41(t_451_padded)
t_453 = F.relu(t_452)
t_454 = self.n_Conv_42(t_453)
t_455 = torch.add(t_454, t_449)
t_456 = F.relu(t_455)
t_457 = self.n_Conv_43(t_456)
t_458 = F.relu(t_457)
t_458_padded = F.pad(t_458, [1, 1, 1, 1], value=0)
t_459 = self.n_Conv_44(t_458_padded)
t_460 = F.relu(t_459)
t_461 = self.n_Conv_45(t_460)
t_462 = torch.add(t_461, t_456)
t_463 = F.relu(t_462)
t_464 = self.n_Conv_46(t_463)
t_465 = F.relu(t_464)
t_465_padded = F.pad(t_465, [1, 1, 1, 1], value=0)
t_466 = self.n_Conv_47(t_465_padded)
t_467 = F.relu(t_466)
t_468 = self.n_Conv_48(t_467)
t_469 = torch.add(t_468, t_463)
t_470 = F.relu(t_469)
t_471 = self.n_Conv_49(t_470)
t_472 = F.relu(t_471)
t_472_padded = F.pad(t_472, [1, 1, 1, 1], value=0)
t_473 = self.n_Conv_50(t_472_padded)
t_474 = F.relu(t_473)
t_475 = self.n_Conv_51(t_474)
t_476 = torch.add(t_475, t_470)
t_477 = F.relu(t_476)
t_478 = self.n_Conv_52(t_477)
t_479 = F.relu(t_478)
t_479_padded = F.pad(t_479, [1, 1, 1, 1], value=0)
t_480 = self.n_Conv_53(t_479_padded)
t_481 = F.relu(t_480)
t_482 = self.n_Conv_54(t_481)
t_483 = torch.add(t_482, t_477)
t_484 = F.relu(t_483)
t_485 = self.n_Conv_55(t_484)
t_486 = F.relu(t_485)
t_486_padded = F.pad(t_486, [1, 1, 1, 1], value=0)
t_487 = self.n_Conv_56(t_486_padded)
t_488 = F.relu(t_487)
t_489 = self.n_Conv_57(t_488)
t_490 = torch.add(t_489, t_484)
t_491 = F.relu(t_490)
t_492 = self.n_Conv_58(t_491)
t_493 = F.relu(t_492)
t_493_padded = F.pad(t_493, [1, 1, 1, 1], value=0)
t_494 = self.n_Conv_59(t_493_padded)
t_495 = F.relu(t_494)
t_496 = self.n_Conv_60(t_495)
t_497 = torch.add(t_496, t_491)
t_498 = F.relu(t_497)
t_499 = self.n_Conv_61(t_498)
t_500 = F.relu(t_499)
t_500_padded = F.pad(t_500, [1, 1, 1, 1], value=0)
t_501 = self.n_Conv_62(t_500_padded)
t_502 = F.relu(t_501)
t_503 = self.n_Conv_63(t_502)
t_504 = torch.add(t_503, t_498)
t_505 = F.relu(t_504)
t_506 = self.n_Conv_64(t_505)
t_507 = F.relu(t_506)
t_507_padded = F.pad(t_507, [1, 1, 1, 1], value=0)
t_508 = self.n_Conv_65(t_507_padded)
t_509 = F.relu(t_508)
t_510 = self.n_Conv_66(t_509)
t_511 = torch.add(t_510, t_505)
t_512 = F.relu(t_511)
t_513 = self.n_Conv_67(t_512)
t_514 = F.relu(t_513)
t_514_padded = F.pad(t_514, [1, 1, 1, 1], value=0)
t_515 = self.n_Conv_68(t_514_padded)
t_516 = F.relu(t_515)
t_517 = self.n_Conv_69(t_516)
t_518 = torch.add(t_517, t_512)
t_519 = F.relu(t_518)
t_520 = self.n_Conv_70(t_519)
t_521 = F.relu(t_520)
t_521_padded = F.pad(t_521, [1, 1, 1, 1], value=0)
t_522 = self.n_Conv_71(t_521_padded)
t_523 = F.relu(t_522)
t_524 = self.n_Conv_72(t_523)
t_525 = torch.add(t_524, t_519)
t_526 = F.relu(t_525)
t_527 = self.n_Conv_73(t_526)
t_528 = F.relu(t_527)
t_528_padded = F.pad(t_528, [1, 1, 1, 1], value=0)
t_529 = self.n_Conv_74(t_528_padded)
t_530 = F.relu(t_529)
t_531 = self.n_Conv_75(t_530)
t_532 = torch.add(t_531, t_526)
t_533 = F.relu(t_532)
t_534 = self.n_Conv_76(t_533)
t_535 = F.relu(t_534)
t_535_padded = F.pad(t_535, [1, 1, 1, 1], value=0)
t_536 = self.n_Conv_77(t_535_padded)
t_537 = F.relu(t_536)
t_538 = self.n_Conv_78(t_537)
t_539 = torch.add(t_538, t_533)
t_540 = F.relu(t_539)
t_541 = self.n_Conv_79(t_540)
t_542 = F.relu(t_541)
t_542_padded = F.pad(t_542, [1, 1, 1, 1], value=0)
t_543 = self.n_Conv_80(t_542_padded)
t_544 = F.relu(t_543)
t_545 = self.n_Conv_81(t_544)
t_546 = torch.add(t_545, t_540)
t_547 = F.relu(t_546)
t_548 = self.n_Conv_82(t_547)
t_549 = F.relu(t_548)
t_549_padded = F.pad(t_549, [1, 1, 1, 1], value=0)
t_550 = self.n_Conv_83(t_549_padded)
t_551 = F.relu(t_550)
t_552 = self.n_Conv_84(t_551)
t_553 = torch.add(t_552, t_547)
t_554 = F.relu(t_553)
t_555 = self.n_Conv_85(t_554)
t_556 = F.relu(t_555)
t_556_padded = F.pad(t_556, [1, 1, 1, 1], value=0)
t_557 = self.n_Conv_86(t_556_padded)
t_558 = F.relu(t_557)
t_559 = self.n_Conv_87(t_558)
t_560 = torch.add(t_559, t_554)
t_561 = F.relu(t_560)
t_562 = self.n_Conv_88(t_561)
t_563 = F.relu(t_562)
t_563_padded = F.pad(t_563, [1, 1, 1, 1], value=0)
t_564 = self.n_Conv_89(t_563_padded)
t_565 = F.relu(t_564)
t_566 = self.n_Conv_90(t_565)
t_567 = torch.add(t_566, t_561)
t_568 = F.relu(t_567)
t_569 = self.n_Conv_91(t_568)
t_570 = F.relu(t_569)
t_570_padded = F.pad(t_570, [1, 1, 1, 1], value=0)
t_571 = self.n_Conv_92(t_570_padded)
t_572 = F.relu(t_571)
t_573 = self.n_Conv_93(t_572)
t_574 = torch.add(t_573, t_568)
t_575 = F.relu(t_574)
t_576 = self.n_Conv_94(t_575)
t_577 = F.relu(t_576)
t_577_padded = F.pad(t_577, [1, 1, 1, 1], value=0)
t_578 = self.n_Conv_95(t_577_padded)
t_579 = F.relu(t_578)
t_580 = self.n_Conv_96(t_579)
t_581 = torch.add(t_580, t_575)
t_582 = F.relu(t_581)
t_583 = self.n_Conv_97(t_582)
t_584 = F.relu(t_583)
t_584_padded = F.pad(t_584, [0, 1, 0, 1], value=0)
t_585 = self.n_Conv_98(t_584_padded)
t_586 = F.relu(t_585)
t_587 = self.n_Conv_99(t_586)
t_588 = self.n_Conv_100(t_582)
t_589 = torch.add(t_587, t_588)
t_590 = F.relu(t_589)
t_591 = self.n_Conv_101(t_590)
t_592 = F.relu(t_591)
t_592_padded = F.pad(t_592, [1, 1, 1, 1], value=0)
t_593 = self.n_Conv_102(t_592_padded)
t_594 = F.relu(t_593)
t_595 = self.n_Conv_103(t_594)
t_596 = torch.add(t_595, t_590)
t_597 = F.relu(t_596)
t_598 = self.n_Conv_104(t_597)
t_599 = F.relu(t_598)
t_599_padded = F.pad(t_599, [1, 1, 1, 1], value=0)
t_600 = self.n_Conv_105(t_599_padded)
t_601 = F.relu(t_600)
t_602 = self.n_Conv_106(t_601)
t_603 = torch.add(t_602, t_597)
t_604 = F.relu(t_603)
t_605 = self.n_Conv_107(t_604)
t_606 = F.relu(t_605)
t_606_padded = F.pad(t_606, [1, 1, 1, 1], value=0)
t_607 = self.n_Conv_108(t_606_padded)
t_608 = F.relu(t_607)
t_609 = self.n_Conv_109(t_608)
t_610 = torch.add(t_609, t_604)
t_611 = F.relu(t_610)
t_612 = self.n_Conv_110(t_611)
t_613 = F.relu(t_612)
t_613_padded = F.pad(t_613, [1, 1, 1, 1], value=0)
t_614 = self.n_Conv_111(t_613_padded)
t_615 = F.relu(t_614)
t_616 = self.n_Conv_112(t_615)
t_617 = torch.add(t_616, t_611)
t_618 = F.relu(t_617)
t_619 = self.n_Conv_113(t_618)
t_620 = F.relu(t_619)
t_620_padded = F.pad(t_620, [1, 1, 1, 1], value=0)
t_621 = self.n_Conv_114(t_620_padded)
t_622 = F.relu(t_621)
t_623 = self.n_Conv_115(t_622)
t_624 = torch.add(t_623, t_618)
t_625 = F.relu(t_624)
t_626 = self.n_Conv_116(t_625)
t_627 = F.relu(t_626)
t_627_padded = F.pad(t_627, [1, 1, 1, 1], value=0)
t_628 = self.n_Conv_117(t_627_padded)
t_629 = F.relu(t_628)
t_630 = self.n_Conv_118(t_629)
t_631 = torch.add(t_630, t_625)
t_632 = F.relu(t_631)
t_633 = self.n_Conv_119(t_632)
t_634 = F.relu(t_633)
t_634_padded = F.pad(t_634, [1, 1, 1, 1], value=0)
t_635 = self.n_Conv_120(t_634_padded)
t_636 = F.relu(t_635)
t_637 = self.n_Conv_121(t_636)
t_638 = torch.add(t_637, t_632)
t_639 = F.relu(t_638)
t_640 = self.n_Conv_122(t_639)
t_641 = F.relu(t_640)
t_641_padded = F.pad(t_641, [1, 1, 1, 1], value=0)
t_642 = self.n_Conv_123(t_641_padded)
t_643 = F.relu(t_642)
t_644 = self.n_Conv_124(t_643)
t_645 = torch.add(t_644, t_639)
t_646 = F.relu(t_645)
t_647 = self.n_Conv_125(t_646)
t_648 = F.relu(t_647)
t_648_padded = F.pad(t_648, [1, 1, 1, 1], value=0)
t_649 = self.n_Conv_126(t_648_padded)
t_650 = F.relu(t_649)
t_651 = self.n_Conv_127(t_650)
t_652 = torch.add(t_651, t_646)
t_653 = F.relu(t_652)
t_654 = self.n_Conv_128(t_653)
t_655 = F.relu(t_654)
t_655_padded = F.pad(t_655, [1, 1, 1, 1], value=0)
t_656 = self.n_Conv_129(t_655_padded)
t_657 = F.relu(t_656)
t_658 = self.n_Conv_130(t_657)
t_659 = torch.add(t_658, t_653)
t_660 = F.relu(t_659)
t_661 = self.n_Conv_131(t_660)
t_662 = F.relu(t_661)
t_662_padded = F.pad(t_662, [1, 1, 1, 1], value=0)
t_663 = self.n_Conv_132(t_662_padded)
t_664 = F.relu(t_663)
t_665 = self.n_Conv_133(t_664)
t_666 = torch.add(t_665, t_660)
t_667 = F.relu(t_666)
t_668 = self.n_Conv_134(t_667)
t_669 = F.relu(t_668)
t_669_padded = F.pad(t_669, [1, 1, 1, 1], value=0)
t_670 = self.n_Conv_135(t_669_padded)
t_671 = F.relu(t_670)
t_672 = self.n_Conv_136(t_671)
t_673 = torch.add(t_672, t_667)
t_674 = F.relu(t_673)
t_675 = self.n_Conv_137(t_674)
t_676 = F.relu(t_675)
t_676_padded = F.pad(t_676, [1, 1, 1, 1], value=0)
t_677 = self.n_Conv_138(t_676_padded)
t_678 = F.relu(t_677)
t_679 = self.n_Conv_139(t_678)
t_680 = torch.add(t_679, t_674)
t_681 = F.relu(t_680)
t_682 = self.n_Conv_140(t_681)
t_683 = F.relu(t_682)
t_683_padded = F.pad(t_683, [1, 1, 1, 1], value=0)
t_684 = self.n_Conv_141(t_683_padded)
t_685 = F.relu(t_684)
t_686 = self.n_Conv_142(t_685)
t_687 = torch.add(t_686, t_681)
t_688 = F.relu(t_687)
t_689 = self.n_Conv_143(t_688)
t_690 = F.relu(t_689)
t_690_padded = F.pad(t_690, [1, 1, 1, 1], value=0)
t_691 = self.n_Conv_144(t_690_padded)
t_692 = F.relu(t_691)
t_693 = self.n_Conv_145(t_692)
t_694 = torch.add(t_693, t_688)
t_695 = F.relu(t_694)
t_696 = self.n_Conv_146(t_695)
t_697 = F.relu(t_696)
t_697_padded = F.pad(t_697, [1, 1, 1, 1], value=0)
t_698 = self.n_Conv_147(t_697_padded)
t_699 = F.relu(t_698)
t_700 = self.n_Conv_148(t_699)
t_701 = torch.add(t_700, t_695)
t_702 = F.relu(t_701)
t_703 = self.n_Conv_149(t_702)
t_704 = F.relu(t_703)
t_704_padded = F.pad(t_704, [1, 1, 1, 1], value=0)
t_705 = self.n_Conv_150(t_704_padded)
t_706 = F.relu(t_705)
t_707 = self.n_Conv_151(t_706)
t_708 = torch.add(t_707, t_702)
t_709 = F.relu(t_708)
t_710 = self.n_Conv_152(t_709)
t_711 = F.relu(t_710)
t_711_padded = F.pad(t_711, [1, 1, 1, 1], value=0)
t_712 = self.n_Conv_153(t_711_padded)
t_713 = F.relu(t_712)
t_714 = self.n_Conv_154(t_713)
t_715 = torch.add(t_714, t_709)
t_716 = F.relu(t_715)
t_717 = self.n_Conv_155(t_716)
t_718 = F.relu(t_717)
t_718_padded = F.pad(t_718, [1, 1, 1, 1], value=0)
t_719 = self.n_Conv_156(t_718_padded)
t_720 = F.relu(t_719)
t_721 = self.n_Conv_157(t_720)
t_722 = torch.add(t_721, t_716)
t_723 = F.relu(t_722)
t_724 = self.n_Conv_158(t_723)
t_725 = self.n_Conv_159(t_723)
t_726 = F.relu(t_725)
t_726_padded = F.pad(t_726, [0, 1, 0, 1], value=0)
t_727 = self.n_Conv_160(t_726_padded)
t_728 = F.relu(t_727)
t_729 = self.n_Conv_161(t_728)
t_730 = torch.add(t_729, t_724)
t_731 = F.relu(t_730)
t_732 = self.n_Conv_162(t_731)
t_733 = F.relu(t_732)
t_733_padded = F.pad(t_733, [1, 1, 1, 1], value=0)
t_734 = self.n_Conv_163(t_733_padded)
t_735 = F.relu(t_734)
t_736 = self.n_Conv_164(t_735)
t_737 = torch.add(t_736, t_731)
t_738 = F.relu(t_737)
t_739 = self.n_Conv_165(t_738)
t_740 = F.relu(t_739)
t_740_padded = F.pad(t_740, [1, 1, 1, 1], value=0)
t_741 = self.n_Conv_166(t_740_padded)
t_742 = F.relu(t_741)
t_743 = self.n_Conv_167(t_742)
t_744 = torch.add(t_743, t_738)
t_745 = F.relu(t_744)
t_746 = self.n_Conv_168(t_745)
t_747 = self.n_Conv_169(t_745)
t_748 = F.relu(t_747)
t_748_padded = F.pad(t_748, [0, 1, 0, 1], value=0)
t_749 = self.n_Conv_170(t_748_padded)
t_750 = F.relu(t_749)
t_751 = self.n_Conv_171(t_750)
t_752 = torch.add(t_751, t_746)
t_753 = F.relu(t_752)
t_754 = self.n_Conv_172(t_753)
t_755 = F.relu(t_754)
t_755_padded = F.pad(t_755, [1, 1, 1, 1], value=0)
t_756 = self.n_Conv_173(t_755_padded)
t_757 = F.relu(t_756)
t_758 = self.n_Conv_174(t_757)
t_759 = torch.add(t_758, t_753)
t_760 = F.relu(t_759)
t_761 = self.n_Conv_175(t_760)
t_762 = F.relu(t_761)
t_762_padded = F.pad(t_762, [1, 1, 1, 1], value=0)
t_763 = self.n_Conv_176(t_762_padded)
t_764 = F.relu(t_763)
t_765 = self.n_Conv_177(t_764)
t_766 = torch.add(t_765, t_760)
t_767 = F.relu(t_766)
t_768 = self.n_Conv_178(t_767)
t_769 = F.avg_pool2d(t_768, kernel_size=t_768.shape[-2:])
t_770 = torch.squeeze(t_769, 3)
t_770 = torch.squeeze(t_770, 2)
t_771 = torch.sigmoid(t_770)
return t_771
def load_state_dict(self, state_dict, **kwargs):
self.tags = state_dict.get('tags', [])
super(DeepDanbooruModel, self).load_state_dict({k: v for k, v in state_dict.items() if k != 'tags'})
def resize_image(im, width, height):
ratio = width / height
src_ratio = im.width / im.height
src_w = width if ratio < src_ratio else im.width * height // im.height
src_h = height if ratio >= src_ratio else im.height * width // im.width
resized = im.resize((src_w, src_h), resample=LANCZOS)
res = Image.new("RGB", (width, height))
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
if ratio < src_ratio:
fill_height = height // 2 - src_h // 2
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)),
box=(0, fill_height + src_h))
elif ratio > src_ratio:
fill_width = width // 2 - src_w // 2
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)),
box=(fill_width + src_w, 0))
return res
class DeepDanbooru:
def __init__(self):
self.model = DeepDanbooruModel()
foundation_model_id = 'ly261666/cv_portrait_model'
snapshot_path = snapshot_download(foundation_model_id, revision='v4.0')
pretrain_model_path = os.path.join(snapshot_path, 'model-resnet_custom_v3.pt')
self.model.load_state_dict(torch.load(pretrain_model_path, map_location="cpu"))
self.model.eval()
self.model.to(torch.float16)
def start(self):
self.model.cuda()
def stop(self):
self.model.cpu()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def tag(self, pil_image):
threshold = 0.5
use_spaces = False
use_escape = True
alpha_sort = True
include_ranks = False
pic = resize_image(pil_image.convert("RGB"), 512, 512)
a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
with torch.no_grad(), torch.autocast("cuda"):
x = torch.from_numpy(a).cuda()
y = self.model(x)[0].detach().cpu().numpy()
probability_dict = {}
for tag, probability in zip(self.model.tags, y):
if probability < threshold:
continue
if tag.startswith("rating:"):
continue
probability_dict[tag] = probability
if alpha_sort:
tags = sorted(probability_dict)
else:
tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])]
res = []
for tag in [x for x in tags]:
probability = probability_dict[tag]
tag_outformat = tag
if use_spaces:
tag_outformat = tag_outformat.replace('_', ' ')
if use_escape:
tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
if include_ranks:
tag_outformat = f"({tag_outformat}:{probability:.3f})"
res.append(tag_outformat)
return ", ".join(res)
'''
model = DeepDanbooru()
impath = 'lyf'
imlist = os.listdir(impath)
result_list = []
for im in imlist:
if im[-4:]=='.png':
print(im)
img = Image.open(os.path.join(impath, im))
result = model.tag(img)
print(result)
result_list.append(result)
model.stop()
'''
# Copyright (c) Alibaba, Inc. and its affiliates.
import json
import math
import os
import shutil
import cv2
import numpy as np
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from PIL import Image
from tqdm import tqdm
from .deepbooru import DeepDanbooru
def crop_and_resize(im, bbox):
h, w, _ = im.shape
thre = 0.35/1.15
maxf = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
cx = (bbox[2] + bbox[0]) / 2
cy = (bbox[3] + bbox[1]) / 2
lenp = int(maxf / thre)
yc = 0.5/1.15
xc = 0.5
xmin = int(cx - xc * lenp)
xmax = xmin + lenp
ymin = int(cy - yc * lenp)
ymax = ymin + lenp
x1 = 0
x2 = lenp
y1 = 0
y2 = lenp
if xmin < 0:
x1 = -xmin
xmin = 0
if xmax > w:
x2 = w - (xmax - lenp)
xmax = w
if ymin < 0:
y1 = -ymin
ymin = 0
if ymax > h:
y2 = h - (ymax - lenp)
ymax = h
imc = (np.ones((lenp, lenp, 3)) * 255).astype(np.uint8)
imc[y1:y2, x1:x2, :] = im[ymin:ymax, xmin:xmax, :]
imr = cv2.resize(imc, (512, 512))
return imr
def pad_to_square(im):
h, w, _ = im.shape
ns = int(max(h, w) * 1.5)
im = cv2.copyMakeBorder(im, int((ns - h) / 2), (ns - h) - int((ns - h) / 2), int((ns - w) / 2),
(ns - w) - int((ns - w) / 2), cv2.BORDER_CONSTANT, 255)
return im
def post_process_naive(result_list, score_gender, score_age):
# determine trigger word
gender = np.argmax(score_gender)
age = np.argmax(score_age)
if age < 2:
if gender == 0:
tag_a_g = ['a boy', 'children']
else:
tag_a_g = ['a girl', 'children']
elif age > 4:
if gender == 0:
tag_a_g = ['a mature man']
else:
tag_a_g = ['a mature woman']
else:
if gender == 0:
tag_a_g = ['a handsome man']
else:
tag_a_g = ['a beautiful woman']
num_images = len(result_list)
cnt_girl = 0
cnt_boy = 0
result_list_new = []
for result in result_list:
result_new = []
result_new.extend(tag_a_g)
## don't include other infos for lora training
#for tag in result:
# if tag == '1girl' or tag == '1boy':
# continue
# if tag[-4:] == '_man':
# continue
# if tag[-6:] == '_woman':
# continue
# if tag[-5:] == '_male':
# continue
# elif tag[-7:] == '_female':
# continue
# elif (
# tag == 'ears' or tag == 'head' or tag == 'face' or tag == 'lips' or tag == 'mouth' or tag == '3d' or tag == 'asian' or tag == 'teeth'):
# continue
# elif ('eye' in tag and not 'eyewear' in tag):
# continue
# elif ('nose' in tag or 'body' in tag):
# continue
# elif tag[-5:] == '_lips':
# continue
# else:
# result_new.append(tag)
# # import pdb;pdb.set_trace()
## result_new.append('slim body')
result_list_new.append(result_new)
return result_list_new
def transformation_from_points(points1, points2):
points1 = points1.astype(np.float64)
points2 = points2.astype(np.float64)
c1 = np.mean(points1, axis=0)
c2 = np.mean(points2, axis=0)
points1 -= c1
points2 -= c2
s1 = np.std(points1)
s2 = np.std(points2)
if s1 < 1.0e-4:
s1 = 1.0e-4
points1 /= s1
points2 /= s2
U, S, Vt = np.linalg.svd(points1.T * points2)
R = (U * Vt).T
return np.vstack([np.hstack(((s2 / s1) * R, c2.T - (s2 / s1) * R * c1.T)), np.matrix([0., 0., 1.])])
def rotate(im, keypoints):
h, w, _ = im.shape
points_array = np.zeros((5, 2))
dst_mean_face_size = 160
dst_mean_face = np.asarray([0.31074522411511746, 0.2798131190011913,
0.6892073313037804, 0.2797830232679366,
0.49997367716346774, 0.5099309118810921,
0.35811903020866753, 0.7233174007629063,
0.6418878095835022, 0.7232890570786875])
dst_mean_face = np.reshape(dst_mean_face, (5, 2)) * dst_mean_face_size
for k in range(5):
points_array[k, 0] = keypoints[2 * k]
points_array[k, 1] = keypoints[2 * k + 1]
pts1 = np.float64(np.matrix([[point[0], point[1]] for point in points_array]))
pts2 = np.float64(np.matrix([[point[0], point[1]] for point in dst_mean_face]))
trans_mat = transformation_from_points(pts1, pts2)
if trans_mat[1, 1] > 1.0e-4:
angle = math.atan(trans_mat[1, 0] / trans_mat[1, 1])
else:
angle = math.atan(trans_mat[0, 1] / trans_mat[0, 2])
im = pad_to_square(im)
ns = int(1.5 * max(h, w))
M = cv2.getRotationMatrix2D((ns / 2, ns / 2), angle=-angle / np.pi * 180, scale=1.0)
im = cv2.warpAffine(im, M=M, dsize=(ns, ns))
return im
def get_mask_head(result):
masks = result['masks']
scores = result['scores']
labels = result['labels']
mask_hair = np.zeros((512, 512))
mask_face = np.zeros((512, 512))
mask_human = np.zeros((512, 512))
for i in range(len(labels)):
if scores[i] > 0.8:
if labels[i] == 'Face':
if np.sum(masks[i]) > np.sum(mask_face):
mask_face = masks[i]
elif labels[i] == 'Human':
if np.sum(masks[i]) > np.sum(mask_human):
mask_human = masks[i]
elif labels[i] == 'Hair':
if np.sum(masks[i]) > np.sum(mask_hair):
mask_hair = masks[i]
mask_head = np.clip(mask_hair + mask_face, 0, 1)
ksize = max(int(np.sqrt(np.sum(mask_face)) / 20), 1)
kernel = np.ones((ksize, ksize))
mask_head = cv2.dilate(mask_head, kernel, iterations=1) * mask_human
_, mask_head = cv2.threshold((mask_head * 255).astype(np.uint8), 127, 255, cv2.THRESH_BINARY)
contours, hierarchy = cv2.findContours(mask_head, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
area = []
for j in range(len(contours)):
area.append(cv2.contourArea(contours[j]))
max_idx = np.argmax(area)
mask_head = np.zeros((512, 512)).astype(np.uint8)
cv2.fillPoly(mask_head, [contours[max_idx]], 255)
mask_head = mask_head.astype(np.float32) / 255
mask_head = np.clip(mask_head + mask_face, 0, 1)
mask_head = np.expand_dims(mask_head, 2)
return mask_head
class Blipv2():
def __init__(self):
self.model = DeepDanbooru()
self.skin_retouching = pipeline('skin-retouching-torch', model='damo/cv_unet_skin_retouching_torch', model_revision='v1.0.1')
# ToDo: face detection
self.face_detection = pipeline(task=Tasks.face_detection, model='damo/cv_ddsar_face-detection_iclr23-damofd', model_revision='v1.1')
# self.mog_face_detection_func = pipeline(Tasks.face_detection, 'damo/cv_resnet101_face-detection_cvpr22papermogface')
self.segmentation_pipeline = pipeline(Tasks.image_segmentation,
'damo/cv_resnet101_image-multiple-human-parsing', model_revision='v1.0.1')
self.fair_face_attribute_func = pipeline(Tasks.face_attribute_recognition,
'damo/cv_resnet34_face-attribute-recognition_fairface', model_revision='v2.0.2')
self.facial_landmark_confidence_func = pipeline(Tasks.face_2d_keypoints,
'damo/cv_manual_facial-landmark-confidence_flcm', model_revision='v2.5')
def __call__(self, imdir):
self.model.start()
savedir = str(imdir) + '_labeled'
shutil.rmtree(savedir, ignore_errors=True)
os.makedirs(savedir, exist_ok=True)
imlist = os.listdir(imdir)
result_list = []
imgs_list = []
cnt = 0
tmp_path = os.path.join(savedir, 'tmp.png')
for imname in imlist:
try:
# if 1:
if imname.startswith('.'):
continue
img_path = os.path.join(imdir, imname)
im = cv2.imread(img_path)
h, w, _ = im.shape
max_size = max(w, h)
ratio = 1024 / max_size
new_w = round(w * ratio)
new_h = round(h * ratio)
imt = cv2.resize(im, (new_w, new_h))
cv2.imwrite(tmp_path, imt)
result_det = self.face_detection(tmp_path)
bboxes = result_det['boxes']
if len(bboxes) > 1:
areas = []
for i in range(len(bboxes)):
bbox = bboxes[i]
areas.append((bbox[2] - bbox[0]) * (bbox[3] - bbox[1]))
areas = np.array(areas)
areas_new = np.sort(areas)[::-1]
idxs = np.argsort(areas)[::-1]
if areas_new[0] < 4 * areas_new[1]:
print('Detecting multiple faces, do not use image {}.'.format(imname))
continue
else:
keypoints = result_det['keypoints'][idxs[0]]
elif len(bboxes) == 0:
print('Detecting no face, do not use image {}.'.format(imname))
continue
else:
keypoints = result_det['keypoints'][0]
im = rotate(im, keypoints)
ns = im.shape[0]
imt = cv2.resize(im, (1024, 1024))
cv2.imwrite(tmp_path, imt)
result_det = self.face_detection(tmp_path)
bboxes = result_det['boxes']
if len(bboxes) > 1:
areas = []
for i in range(len(bboxes)):
bbox = bboxes[i]
areas.append((bbox[2] - bbox[0]) * (bbox[3] - bbox[1]))
areas = np.array(areas)
areas_new = np.sort(areas)[::-1]
idxs = np.argsort(areas)[::-1]
if areas_new[0] < 4 * areas_new[1]:
print('Detecting multiple faces after rotation, do not use image {}.'.format(imname))
continue
else:
bbox = bboxes[idxs[0]]
elif len(bboxes) == 0:
print('Detecting no face after rotation, do not use this image {}'.format(imname))
continue
else:
bbox = bboxes[0]
for idx in range(4):
bbox[idx] = bbox[idx] * ns / 1024
imr = crop_and_resize(im, bbox)
cv2.imwrite(tmp_path, imr)
result = self.skin_retouching(tmp_path)
if (result is None or (result[OutputKeys.OUTPUT_IMG] is None)):
print('Cannot do skin retouching, do not use this image.')
continue
cv2.imwrite(tmp_path, result[OutputKeys.OUTPUT_IMG])
result = self.segmentation_pipeline(tmp_path)
mask_head = get_mask_head(result)
im = cv2.imread(tmp_path)
im = im * mask_head + 255 * (1 - mask_head)
# print(im.shape)
raw_result = self.facial_landmark_confidence_func(im)
if raw_result is None:
print('landmark quality fail...')
continue
print(imname, raw_result['scores'][0])
if float(raw_result['scores'][0]) < (1 - 0.145):
print('landmark quality fail...')
continue
cv2.imwrite(os.path.join(savedir, '{}.png'.format(cnt)), im)
imgs_list.append('{}.png'.format(cnt))
img = Image.open(os.path.join(savedir, '{}.png'.format(cnt)))
result = self.model.tag(img)
print(result)
attribute_result = self.fair_face_attribute_func(tmp_path)
if cnt == 0:
score_gender = np.array(attribute_result['scores'][0])
score_age = np.array(attribute_result['scores'][1])
else:
score_gender += np.array(attribute_result['scores'][0])
score_age += np.array(attribute_result['scores'][1])
result_list.append(result.split(', '))
cnt += 1
except Exception as e:
print('cathed for image process of ' + imname)
print(f'Error: {e}')
print(result_list)
if len(result_list) == 0:
print('Error: result is empty.')
exit()
# return os.path.join(savedir, "metadata.jsonl")
result_list = post_process_naive(result_list, score_gender, score_age)
self.model.stop()
try:
os.remove(tmp_path)
except OSError as e:
print(f"Failed to remove path {tmp_path}: {e}")
out_json_name = os.path.join(savedir, "metadata.jsonl")
fo = open(out_json_name, 'w')
for i in range(len(result_list)):
generated_text = ", ".join(result_list[i])
print(imgs_list[i], generated_text)
info_dict = {"file_name": imgs_list[i], "text": "<fcsks>, " + generated_text}
fo.write(json.dumps(info_dict) + '\n')
fo.close()
return out_json_name
# Copyright (c) Alibaba, Inc. and its affiliates.
import json
import os
import cv2
import numpy as np
import torch
from PIL import Image
from controlnet_aux import OpenposeDetector
from diffusers import StableDiffusionPipeline, StableDiffusionControlNetPipeline, ControlNetModel, \
UniPCMultistepScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, StableDiffusionXLPipeline
from facechain.utils import snapshot_download
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from torch import multiprocessing
from transformers import pipeline as tpipeline
from facechain.data_process.preprocessing import Blipv2
from facechain.merge_lora import merge_lora
from safetensors.torch import load_file, save_file
def _data_process_fn_process(input_img_dir):
Blipv2()(input_img_dir)
def data_process_fn(input_img_dir, use_data_process):
## TODO add face quality filter
if use_data_process:
## TODO
_process = multiprocessing.Process(target=_data_process_fn_process, args=(input_img_dir,))
_process.start()
_process.join()
return os.path.join(str(input_img_dir) + '_labeled', "metadata.jsonl")
def txt2img(pipe, pos_prompt, neg_prompt, num_images=10, height=512, width=512, num_inference_steps=40, guidance_scale=7):
batch_size = 5
images_out = []
for i in range(int(num_images / batch_size)):
images_style = pipe(prompt=pos_prompt, height=height, width=width, guidance_scale=guidance_scale, negative_prompt=neg_prompt,
num_inference_steps=num_inference_steps, num_images_per_prompt=batch_size).images
images_out.extend(images_style)
return images_out
def img_pad(pil_file, fixed_height=512, fixed_width=512):
w, h = pil_file.size
if h / float(fixed_height) >= w / float(fixed_width):
factor = h / float(fixed_height)
new_w = int(w / factor)
pil_file.thumbnail(size=(new_w, fixed_height))
pad_w = int((fixed_width - new_w) / 2)
pad_w1 = (fixed_width - new_w) - pad_w
array_file = np.array(pil_file)
array_file = np.pad(array_file, ((0, 0), (pad_w, pad_w1), (0, 0)), 'constant')
else:
factor = w / float(fixed_width)
new_h = int(h / factor)
pil_file.thumbnail(size=(fixed_width, new_h))
pad_h = fixed_height - new_h
pad_h1 = 0
array_file = np.array(pil_file)
array_file = np.pad(array_file, ((pad_h, pad_h1), (0, 0), (0, 0)), 'constant')
output_file = Image.fromarray(array_file)
return output_file
def preprocess_pose(origin_img) -> Image:
img = Image.open(origin_img)
img = img_pad(img)
model_dir = snapshot_download('damo/face_chain_control_model',revision='v1.0.1')
openpose = OpenposeDetector.from_pretrained(os.path.join(model_dir, 'model_controlnet/ControlNet'))
result = openpose(img, include_hand=True, output_type='np')
# resize to original size
h, w = img.size
result = cv2.resize(result, (w, h))
return result
def txt2img_pose(pipe, pose_im, pos_prompt, neg_prompt, num_images=10, height=512, width=512):
batch_size = 2
images_out = []
for i in range(int(num_images / batch_size)):
images_style = pipe(prompt=pos_prompt, image=pose_im, height=height, width=width, guidance_scale=7, negative_prompt=neg_prompt,
num_inference_steps=40, num_images_per_prompt=batch_size).images
images_out.extend(images_style)
return images_out
def txt2img_multi(pipe, images, pos_prompt, neg_prompt, num_images=10, height=512, width=512):
batch_size = 2
images_out = []
for i in range(int(num_images / batch_size)):
images_style = pipe(pos_prompt, images, height=height, width=width, guidance_scale=7, negative_prompt=neg_prompt, controlnet_conditioning_scale=[1.0, 0.5],
num_inference_steps=40, num_images_per_prompt=batch_size).images
images_out.extend(images_style)
return images_out
def get_mask(result):
masks = result['masks']
scores = result['scores']
labels = result['labels']
h, w = masks[0].shape
mask_hair = np.zeros((h, w))
mask_face = np.zeros((h, w))
mask_human = np.zeros((h, w))
for i in range(len(labels)):
if scores[i] > 0.8:
if labels[i] == 'Face':
if np.sum(masks[i]) > np.sum(mask_face):
mask_face = masks[i]
elif labels[i] == 'Human':
if np.sum(masks[i]) > np.sum(mask_human):
mask_human = masks[i]
elif labels[i] == 'Hair':
if np.sum(masks[i]) > np.sum(mask_hair):
mask_hair = masks[i]
mask_rst = np.clip(mask_human - mask_hair - mask_face, 0, 1)
mask_rst = np.expand_dims(mask_rst, 2)
mask_rst = np.concatenate([mask_rst, mask_rst, mask_rst], axis=2)
return mask_rst
def main_diffusion_inference(pos_prompt, neg_prompt,
input_img_dir, base_model_path, style_model_path, lora_model_path,
use_lcm=False,
multiplier_style=0.25,
multiplier_human=0.85):
if style_model_path is None:
model_dir = snapshot_download('Cherrytest/zjz_mj_jiyi_small_addtxt_fromleo', revision='v1.0.0')
style_model_path = os.path.join(model_dir, 'zjz_mj_jiyi_small_addtxt_fromleo.safetensors')
lora_style_path = style_model_path
lora_human_path = lora_model_path
print ('lora_human_path: ', lora_human_path)
if 'xl-base' in base_model_path:
pipe = StableDiffusionXLPipeline.from_pretrained(base_model_path, safety_checker=None, torch_dtype=torch.float16)
if use_lcm:
try:
from diffusers import LCMScheduler
except:
raise ImportError('diffusers version is not right, please update diffsers to >=0.22')
lcm_model_path = snapshot_download('AI-ModelScope/lcm-lora-sdxl')
pipe.load_lora_weights(lcm_model_path)
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
num_inference_steps = 8
guidance_scale = 2
else:
pipe.scheduler = DPMSolverSinglestepScheduler.from_config(pipe.scheduler.config)
num_inference_steps = 40
guidance_scale = 7
print('base_model_path', base_model_path)
print('lora_human_path', lora_human_path)
print('lora_style_path', lora_style_path)
if not os.path.isfile(lora_human_path):
lora_human_path = os.path.join(lora_human_path, 'pytorch_lora_weights.bin')
lora_human_state_dict = torch.load(lora_human_path, map_location='cpu')
if lora_style_path.endswith('safetensors'):
lora_style_state_dict = load_file(lora_style_path)
else:
lora_style_state_dict = torch.load(lora_style_path, map_location='cpu')
weighted_lora_human_state_dict = {}
for key in lora_human_state_dict:
weighted_lora_human_state_dict[key] = lora_human_state_dict[key] * multiplier_human
weighted_lora_style_state_dict = {}
for key in lora_style_state_dict:
weighted_lora_style_state_dict[key] = lora_style_state_dict[key] * multiplier_style
print('start lora merging')
pipe.load_lora_weights(weighted_lora_style_state_dict)
print('merge style lora done')
pipe.load_lora_weights(weighted_lora_human_state_dict)
print('lora merging done')
else:
pipe = StableDiffusionPipeline.from_pretrained(base_model_path, safety_checker=None, torch_dtype=torch.float32)
if use_lcm:
try:
from diffusers import LCMScheduler
except:
raise ImportError('diffusers version is not right, please update diffsers to >=0.22')
lcm_model_path = snapshot_download('eavesy/lcm-lora-sdv1-5')
pipe.load_lora_weights(lcm_model_path)
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
num_inference_steps = 8
guidance_scale = 2
else:
pipe.scheduler = DPMSolverSinglestepScheduler.from_config(pipe.scheduler.config)
num_inference_steps = 40
guidance_scale = 7
pipe = merge_lora(pipe, lora_style_path, multiplier_style, from_safetensor=True, device='cuda')
pipe = merge_lora(pipe, lora_human_path, multiplier_human, from_safetensor=lora_human_path.endswith('safetensors'), device='cuda')
print(pipe.scheduler)
#pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
print(f'multiplier_style:{multiplier_style}, multiplier_human:{multiplier_human}')
train_dir = str(input_img_dir) + '_labeled'
add_prompt_style = []
f = open(os.path.join(train_dir, 'metadata.jsonl'), 'r')
tags_all = []
cnt = 0
cnts_trigger = np.zeros(6)
for line in f:
cnt += 1
data = json.loads(line)['text'].split(', ')
tags_all.extend(data)
if data[1] == 'a boy':
cnts_trigger[0] += 1
elif data[1] == 'a girl':
cnts_trigger[1] += 1
elif data[1] == 'a handsome man':
cnts_trigger[2] += 1
elif data[1] == 'a beautiful woman':
cnts_trigger[3] += 1
elif data[1] == 'a mature man':
cnts_trigger[4] += 1
elif data[1] == 'a mature woman':
cnts_trigger[5] += 1
else:
print('Error.')
f.close()
attr_idx = np.argmax(cnts_trigger)
trigger_styles = ['a boy, children, ', 'a girl, children, ', 'a handsome man, ', 'a beautiful woman, ',
'a mature man, ', 'a mature woman, ']
trigger_style = '(<fcsks>:10), ' + trigger_styles[attr_idx]
if attr_idx == 2 or attr_idx == 4:
neg_prompt += ', children'
for tag in tags_all:
if tags_all.count(tag) > 0.5 * cnt:
if ('hair' in tag or 'face' in tag or 'mouth' in tag or 'skin' in tag or 'smile' in tag):
if not tag in add_prompt_style:
add_prompt_style.append(tag)
if len(add_prompt_style) > 0:
add_prompt_style = ", ".join(add_prompt_style) + ', '
else:
add_prompt_style = ''
pipe = pipe.to("cuda")
if 'xl-base' in base_model_path:
images_style = txt2img(pipe, trigger_style + add_prompt_style + pos_prompt, neg_prompt, num_images=10, height=768, width=768, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale)
else:
images_style = txt2img(pipe, trigger_style + add_prompt_style + pos_prompt, neg_prompt, num_images=10, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale)
return images_style
def main_diffusion_inference_pose(pose_model_path, pose_image,
pos_prompt, neg_prompt,
input_img_dir, base_model_path, style_model_path, lora_model_path,
use_lcm=False,
multiplier_style=0.25,
multiplier_human=0.85):
if style_model_path is None:
model_dir = snapshot_download('Cherrytest/zjz_mj_jiyi_small_addtxt_fromleo', revision='v1.0.0')
style_model_path = os.path.join(model_dir, 'zjz_mj_jiyi_small_addtxt_fromleo.safetensors')
controlnet = ControlNetModel.from_pretrained(pose_model_path, torch_dtype=torch.float32)
pipe = StableDiffusionControlNetPipeline.from_pretrained(base_model_path, safety_checker=None, controlnet=controlnet, torch_dtype=torch.float32)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pose_im = Image.open(pose_image)
pose_im = img_pad(pose_im)
model_dir = snapshot_download('damo/face_chain_control_model',revision='v1.0.1')
openpose = OpenposeDetector.from_pretrained(os.path.join(model_dir, 'model_controlnet/ControlNet'))
pose_im = openpose(pose_im, include_hand=True)
lora_style_path = style_model_path
lora_human_path = lora_model_path
pipe = merge_lora(pipe, lora_style_path, multiplier_style, from_safetensor=True)
pipe = merge_lora(pipe, lora_human_path, multiplier_human, from_safetensor=False)
print(f'multiplier_style:{multiplier_style}, multiplier_human:{multiplier_human}')
train_dir = str(input_img_dir) + '_labeled'
add_prompt_style = []
f = open(os.path.join(train_dir, 'metadata.jsonl'), 'r')
tags_all = []
cnt = 0
cnts_trigger = np.zeros(6)
for line in f:
cnt += 1
data = json.loads(line)['text'].split(', ')
tags_all.extend(data)
if data[1] == 'a boy':
cnts_trigger[0] += 1
elif data[1] == 'a girl':
cnts_trigger[1] += 1
elif data[1] == 'a handsome man':
cnts_trigger[2] += 1
elif data[1] == 'a beautiful woman':
cnts_trigger[3] += 1
elif data[1] == 'a mature man':
cnts_trigger[4] += 1
elif data[1] == 'a mature woman':
cnts_trigger[5] += 1
else:
print('Error.')
f.close()
attr_idx = np.argmax(cnts_trigger)
trigger_styles = ['a boy, children, ', 'a girl, children, ', 'a handsome man, ', 'a beautiful woman, ',
'a mature man, ', 'a mature woman, ']
trigger_style = '(<fcsks>:10), ' + trigger_styles[attr_idx]
if attr_idx == 2 or attr_idx == 4:
neg_prompt += ', children'
for tag in tags_all:
if tags_all.count(tag) > 0.5 * cnt:
if ('hair' in tag or 'face' in tag or 'mouth' in tag or 'skin' in tag or 'smile' in tag):
if not tag in add_prompt_style:
add_prompt_style.append(tag)
if len(add_prompt_style) > 0:
add_prompt_style = ", ".join(add_prompt_style) + ', '
else:
add_prompt_style = ''
# trigger_style = trigger_style + 'with <input_id> face, '
# pos_prompt = 'Generate a standard ID photo of a chinese {}, solo, wearing high-class business/working suit, beautiful smooth face, with high-class/simple pure color background, looking straight into the camera with shoulders parallel to the frame, smile, high detail face, best quality, photorealistic'.format(gender)
pipe = pipe.to("cuda")
# print(trigger_style + add_prompt_style + pos_prompt)
images_style = txt2img_pose(pipe, pose_im, trigger_style + add_prompt_style + pos_prompt, neg_prompt, num_images=10)
return images_style
def main_diffusion_inference_multi(pose_model_path, pose_image,
pos_prompt, neg_prompt,
input_img_dir, base_model_path, style_model_path, lora_model_path,
use_lcm=False,
multiplier_style=0.25,
multiplier_human=0.85):
if style_model_path is None:
model_dir = snapshot_download('Cherrytest/zjz_mj_jiyi_small_addtxt_fromleo', revision='v1.0.0')
style_model_path = os.path.join(model_dir, 'zjz_mj_jiyi_small_addtxt_fromleo.safetensors')
model_dir = snapshot_download('damo/face_chain_control_model', revision='v1.0.1')
controlnet = [
ControlNetModel.from_pretrained(pose_model_path, torch_dtype=torch.float32),
ControlNetModel.from_pretrained(os.path.join(model_dir, 'model_controlnet/control_v11p_sd15_depth'), torch_dtype=torch.float32)
]
pipe = StableDiffusionControlNetPipeline.from_pretrained(base_model_path, safety_checker=None, controlnet=controlnet, torch_dtype=torch.float32)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pose_image = Image.open(pose_image)
pose_image = img_pad(pose_image)
openpose = OpenposeDetector.from_pretrained(os.path.join(model_dir, 'model_controlnet/ControlNet'))
pose_im = openpose(pose_image, include_hand=True)
segmentation_pipeline = pipeline(Tasks.image_segmentation,
'damo/cv_resnet101_image-multiple-human-parsing')
result = segmentation_pipeline(pose_image)
mask_rst = get_mask(result)
pose_image = np.array(pose_image)
pose_image = (pose_image * mask_rst).astype(np.uint8)
pose_image = Image.fromarray(pose_image)
depth_estimator = tpipeline('depth-estimation', os.path.join(model_dir, 'model_controlnet/dpt-large'))
depth_im = depth_estimator(pose_image)['depth']
depth_im = np.array(depth_im)
depth_im = depth_im[:, :, None]
depth_im = np.concatenate([depth_im, depth_im, depth_im], axis=2)
depth_im = Image.fromarray(depth_im)
control_im = [pose_im, depth_im]
lora_style_path = style_model_path
lora_human_path = lora_model_path
pipe = merge_lora(pipe, lora_style_path, multiplier_style, from_safetensor=True)
pipe = merge_lora(pipe, lora_human_path, multiplier_human, from_safetensor=False)
print(f'multiplier_style:{multiplier_style}, multiplier_human:{multiplier_human}')
train_dir = str(input_img_dir) + '_labeled'
add_prompt_style = []
f = open(os.path.join(train_dir, 'metadata.jsonl'), 'r')
tags_all = []
cnt = 0
cnts_trigger = np.zeros(6)
for line in f:
cnt += 1
data = json.loads(line)['text'].split(', ')
tags_all.extend(data)
if data[1] == 'a boy':
cnts_trigger[0] += 1
elif data[1] == 'a girl':
cnts_trigger[1] += 1
elif data[1] == 'a handsome man':
cnts_trigger[2] += 1
elif data[1] == 'a beautiful woman':
cnts_trigger[3] += 1
elif data[1] == 'a mature man':
cnts_trigger[4] += 1
elif data[1] == 'a mature woman':
cnts_trigger[5] += 1
else:
print('Error.')
f.close()
attr_idx = np.argmax(cnts_trigger)
trigger_styles = ['a boy, children, ', 'a girl, children, ', 'a handsome man, ', 'a beautiful woman, ',
'a mature man, ', 'a mature woman, ']
trigger_style = '(<fcsks>:10), ' + trigger_styles[attr_idx]
if attr_idx == 2 or attr_idx == 4:
neg_prompt += ', children'
for tag in tags_all:
if tags_all.count(tag) > 0.5 * cnt:
if ('hair' in tag or 'face' in tag or 'mouth' in tag or 'skin' in tag or 'smile' in tag):
if not tag in add_prompt_style:
add_prompt_style.append(tag)
if len(add_prompt_style) > 0:
add_prompt_style = ", ".join(add_prompt_style) + ', '
else:
add_prompt_style = ''
# trigger_style = trigger_style + 'with <input_id> face, '
# pos_prompt = 'Generate a standard ID photo of a chinese {}, solo, wearing high-class business/working suit, beautiful smooth face, with high-class/simple pure color background, looking straight into the camera with shoulders parallel to the frame, smile, high detail face, best quality, photorealistic'.format(gender)
pipe = pipe.to("cuda")
# print(trigger_style + add_prompt_style + pos_prompt)
images_style = txt2img_multi(pipe, control_im, trigger_style + add_prompt_style + pos_prompt, neg_prompt, num_images=10)
return images_style
def stylization_fn(use_stylization, rank_results):
if use_stylization:
## TODO
pass
else:
return rank_results
def main_model_inference(pose_model_path, pose_image, use_depth_control, pos_prompt, neg_prompt, style_model_path, multiplier_style, multiplier_human, use_main_model,
input_img_dir=None, base_model_path=None, lora_model_path=None,
use_lcm=False):
if use_main_model:
multiplier_style_kwargs = {'multiplier_style': multiplier_style} if multiplier_style is not None else {}
multiplier_human_kwargs = {'multiplier_human': multiplier_human} if multiplier_human is not None else {}
if pose_image is None:
return main_diffusion_inference(pos_prompt, neg_prompt, input_img_dir, base_model_path,
style_model_path, lora_model_path, use_lcm,
**multiplier_style_kwargs, **multiplier_human_kwargs)
else:
pose_image = compress_image(pose_image, 1024 * 1024)
if use_depth_control:
return main_diffusion_inference_multi(pose_model_path, pose_image, pos_prompt,
neg_prompt, input_img_dir, base_model_path, style_model_path,
lora_model_path, use_lcm,
**multiplier_style_kwargs, **multiplier_human_kwargs)
else:
return main_diffusion_inference_pose(pose_model_path, pose_image, pos_prompt, neg_prompt,
input_img_dir, base_model_path, style_model_path, lora_model_path,
use_lcm,
**multiplier_style_kwargs, **multiplier_human_kwargs)
def select_high_quality_face(input_img_dir):
input_img_dir = str(input_img_dir) + '_labeled'
quality_score_list = []
abs_img_path_list = []
## TODO
face_quality_func = pipeline(Tasks.face_quality_assessment, 'damo/cv_manual_face-quality-assessment_fqa', model_revision='v2.0')
for img_name in os.listdir(input_img_dir):
if img_name.endswith('jsonl') or img_name.startswith('.ipynb') or img_name.startswith('.safetensors'):
continue
if img_name.endswith('jpg') or img_name.endswith('png'):
abs_img_name = os.path.join(input_img_dir, img_name)
face_quality_score = face_quality_func(abs_img_name)[OutputKeys.SCORES]
if face_quality_score is None:
quality_score_list.append(0)
else:
quality_score_list.append(face_quality_score[0])
abs_img_path_list.append(abs_img_name)
sort_idx = np.argsort(quality_score_list)[::-1]
print('Selected face: ' + abs_img_path_list[sort_idx[0]])
return Image.open(abs_img_path_list[sort_idx[0]])
def face_swap_fn(use_face_swap, gen_results, template_face):
if use_face_swap:
## TODO
out_img_list = []
image_face_fusion = pipeline('face_fusion_torch',
model='damo/cv_unet_face_fusion_torch', model_revision='v1.0.3')
for img in gen_results:
result = image_face_fusion(dict(template=img, user=template_face))[OutputKeys.OUTPUT_IMG]
out_img_list.append(result)
return out_img_list
else:
ret_results = []
for img in gen_results:
ret_results.append(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
return ret_results
def post_process_fn(use_post_process, swap_results_ori, selected_face, num_gen_images):
if use_post_process:
sim_list = []
## TODO
face_recognition_func = pipeline(Tasks.face_recognition, 'damo/cv_ir_face-recognition-ood_rts', model_revision='v2.5')
face_det_func = pipeline(task=Tasks.face_detection, model='damo/cv_ddsar_face-detection_iclr23-damofd', model_revision='v1.1')
swap_results = []
for img in swap_results_ori:
result_det = face_det_func(img)
bboxes = result_det['boxes']
if len(bboxes) == 1:
bbox = bboxes[0]
lenface = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
if 120 < lenface < 300:
swap_results.append(img)
select_face_emb = face_recognition_func(selected_face)[OutputKeys.IMG_EMBEDDING][0]
for img in swap_results:
emb = face_recognition_func(img)[OutputKeys.IMG_EMBEDDING]
if emb is None or select_face_emb is None:
sim_list.append(0)
else:
sim = np.dot(emb, select_face_emb)
sim_list.append(sim.item())
sort_idx = np.argsort(sim_list)[::-1]
return np.array(swap_results)[sort_idx[:min(int(num_gen_images), len(swap_results))]]
else:
return np.array(swap_results_ori)
class GenPortrait:
def __init__(self, pose_model_path, pose_image, use_depth_control, pos_prompt, neg_prompt, style_model_path, multiplier_style, multiplier_human,
use_main_model=True, use_face_swap=True,
use_post_process=True, use_stylization=True):
self.use_main_model = use_main_model
self.use_face_swap = use_face_swap
self.use_post_process = use_post_process
self.use_stylization = use_stylization
self.multiplier_style = multiplier_style
self.multiplier_human = multiplier_human
self.style_model_path = style_model_path
self.pos_prompt = pos_prompt
self.neg_prompt = neg_prompt
self.pose_model_path = pose_model_path
self.pose_image = pose_image
self.use_depth_control = use_depth_control
def __call__(self, input_img_dir, num_gen_images=6, base_model_path=None,
lora_model_path=None, sub_path=None, revision=None, sr_img_size=None, portrait_stylization_idx=None, use_lcm_idx=None):
base_model_path = snapshot_download(base_model_path, revision=revision)
if sub_path is not None and len(sub_path) > 0:
base_model_path = os.path.join(base_model_path, sub_path)
use_lcm = False
if (use_lcm_idx is not None) and (int(use_lcm_idx) == 1):
use_lcm = True
# main_model_inference PIL
gen_results = main_model_inference(self.pose_model_path, self.pose_image, self.use_depth_control,
self.pos_prompt, self.neg_prompt,
self.style_model_path, self.multiplier_style, self.multiplier_human,
self.use_main_model, input_img_dir=input_img_dir,
lora_model_path=lora_model_path, base_model_path=base_model_path,
use_lcm=use_lcm)
# select_high_quality_face PIL
selected_face = select_high_quality_face(input_img_dir)
# face_swap cv2
swap_results = face_swap_fn(self.use_face_swap, gen_results, selected_face)
# pose_process
rank_results = post_process_fn(self.use_post_process, swap_results, selected_face,
num_gen_images=num_gen_images)
# stylization
final_gen_results = stylization_fn(self.use_stylization, rank_results)
sr_pipe = pipeline(Tasks.image_super_resolution, model='damo/cv_rrdb_image-super-resolution')
if portrait_stylization_idx is not None:
out_results = []
if int(portrait_stylization_idx) == 0:
img_cartoon = pipeline(Tasks.image_portrait_stylization, model='damo/cv_unet_person-image-cartoon_compound-models')
if int(portrait_stylization_idx) == 1:
img_cartoon = pipeline(Tasks.image_portrait_stylization, model='damo/cv_unet_person-image-cartoon-3d_compound-models')
for i in range(len(final_gen_results)):
img = final_gen_results[i]
img = Image.fromarray(img[:,:,::-1])
result = img_cartoon(img)[OutputKeys.OUTPUT_IMG]
cv2.imwrite('tmp.png', result)
result_img = cv2.imread('tmp.png')
out_results.append(result_img)
os.system('rm tmp.png')
final_gen_results = out_results
if portrait_stylization_idx is None:
if 'xl-base' in base_model_path:
if int(sr_img_size) != 1:
out_results = []
for i in range(len(final_gen_results)):
img = final_gen_results[i]
if int(sr_img_size) == 0:
out_img = cv2.resize(
img, (512, 512),
interpolation=cv2.INTER_AREA)
else:
img = Image.fromarray(img[:,:,::-1])
out_img = sr_pipe(img)['output_img']
if int(sr_img_size) == 2:
new_h = 1024
new_w = 1024
else:
new_h = 2048
new_w = 2048
out_img = cv2.resize(
out_img, (new_w, new_h),
interpolation=cv2.INTER_AREA)
out_results.append(out_img)
final_gen_results = out_results
else:
if int(sr_img_size) != 0:
out_results = []
for i in range(len(final_gen_results)):
img = final_gen_results[i]
img = Image.fromarray(img[:,:,::-1])
out_img = sr_pipe(img)['output_img']
ratio = 1
if int(sr_img_size) == 1:
ratio = 0.375
elif int(sr_img_size) == 2:
ratio = 0.5
if ratio < 1:
out_img = cv2.resize(
out_img, (0, 0),
fx=ratio,
fy=ratio,
interpolation=cv2.INTER_AREA)
out_results.append(out_img)
final_gen_results = out_results
return final_gen_results
def compress_image(input_path, target_size):
output_path = change_extension_to_jpg(input_path)
image = cv2.imread(input_path)
quality = 95
try:
while cv2.imencode('.jpg', image, [cv2.IMWRITE_JPEG_QUALITY, quality])[1].size > target_size:
quality -= 5
except:
import pdb;pdb.set_trace()
compressed_image = cv2.imencode('.jpg', image, [cv2.IMWRITE_JPEG_QUALITY, quality])[1].tostring()
with open(output_path, 'wb') as f:
f.write(compressed_image)
return output_path
def change_extension_to_jpg(image_path):
base_name = os.path.basename(image_path)
new_base_name = os.path.splitext(base_name)[0] + ".jpg"
directory = os.path.dirname(image_path)
new_image_path = os.path.join(directory, new_base_name)
return new_image_path
# Copyright (c) Alibaba, Inc. and its affiliates.
# Modified from the original implementation at https://github.com/modelscope/facechain/pull/104.
import json
import os
import sys
import cv2
import numpy as np
import torch
from PIL import Image
from skimage import transform
from controlnet_aux import OpenposeDetector
from diffusers import StableDiffusionPipeline, StableDiffusionControlNetPipeline, \
StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler
from facechain.utils import snapshot_download
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from torch import multiprocessing
from transformers import pipeline as tpipeline
from facechain.data_process.preprocessing import Blipv2
from facechain.merge_lora import merge_lora
def _data_process_fn_process(input_img_dir):
Blipv2()(input_img_dir)
def concatenate_images(images):
heights = [img.shape[0] for img in images]
max_width = sum([img.shape[1] for img in images])
concatenated_image = np.zeros((max(heights), max_width, 3), dtype=np.uint8)
x_offset = 0
for img in images:
concatenated_image[0:img.shape[0], x_offset:x_offset + img.shape[1], :] = img
x_offset += img.shape[1]
return concatenated_image
def data_process_fn(input_img_dir, use_data_process):
## TODO add face quality filter
if use_data_process:
## TODO
_process = multiprocessing.Process(target=_data_process_fn_process, args=(input_img_dir,))
_process.start()
_process.join()
return os.path.join(str(input_img_dir) + '_labeled', "metadata.jsonl")
def call_face_crop(det_pipeline, image, crop_ratio):
det_result = det_pipeline(image)
bboxes = det_result['boxes']
keypoints = det_result['keypoints']
area = 0
idx = 0
for i in range(len(bboxes)):
bbox = bboxes[i]
area_tmp = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
if area_tmp > area:
area = area_tmp
idx = i
bbox = bboxes[idx]
keypoint = keypoints[idx]
points_array = np.zeros((5, 2))
for k in range(5):
points_array[k, 0] = keypoint[2 * k]
points_array[k, 1] = keypoint[2 * k + 1]
w, h = image.size
face_w = bbox[2] - bbox[0]
face_h = bbox[3] - bbox[1]
bbox[0] = np.clip(np.array(bbox[0], np.int32) - face_w * (crop_ratio - 1) / 2, 0, w - 1)
bbox[1] = np.clip(np.array(bbox[1], np.int32) - face_h * (crop_ratio - 1) / 2, 0, h - 1)
bbox[2] = np.clip(np.array(bbox[2], np.int32) + face_w * (crop_ratio - 1) / 2, 0, w - 1)
bbox[3] = np.clip(np.array(bbox[3], np.int32) + face_h * (crop_ratio - 1) / 2, 0, h - 1)
bbox = np.array(bbox, np.int32)
return bbox, points_array
def crop_and_paste(Source_image, Source_image_mask, Target_image, Source_Five_Point, Target_Five_Point, Source_box, use_warp=True):
if use_warp:
Source_Five_Point = np.reshape(Source_Five_Point, [5, 2]) - np.array(Source_box[:2])
Target_Five_Point = np.reshape(Target_Five_Point, [5, 2])
Crop_Source_image = Source_image.crop(np.int32(Source_box))
Crop_Source_image_mask = Source_image_mask.crop(np.int32(Source_box))
Source_Five_Point, Target_Five_Point = np.array(Source_Five_Point), np.array(Target_Five_Point)
tform = transform.SimilarityTransform()
tform.estimate(Source_Five_Point, Target_Five_Point)
M = tform.params[0:2, :]
warped = cv2.warpAffine(np.array(Crop_Source_image), M, np.shape(Target_image)[:2][::-1], borderValue=0.0)
warped_mask = cv2.warpAffine(np.array(Crop_Source_image_mask), M, np.shape(Target_image)[:2][::-1], borderValue=0.0)
mask = np.float32(warped_mask == 0)
output = mask * np.float32(Target_image) + (1 - mask) * np.float32(warped)
else:
mask = np.float32(np.array(Source_image_mask) == 0)
output = mask * np.float32(Target_image) + (1 - mask) * np.float32(Source_image)
return output, mask
def segment(segmentation_pipeline, img, ksize=0, eyeh=0, ksize1=0, include_neck=False, warp_mask=None, return_human=False):
if True:
result = segmentation_pipeline(img)
masks = result['masks']
scores = result['scores']
labels = result['labels']
if len(masks) == 0:
return
h, w = masks[0].shape
mask_face = np.zeros((h, w))
mask_hair = np.zeros((h, w))
mask_neck = np.zeros((h, w))
mask_cloth = np.zeros((h, w))
mask_human = np.zeros((h, w))
for i in range(len(labels)):
if scores[i] > 0.8:
if labels[i] == 'Torso-skin':
mask_neck += masks[i]
elif labels[i] == 'Face':
mask_face += masks[i]
elif labels[i] == 'Human':
mask_human += masks[i]
elif labels[i] == 'Hair':
mask_hair += masks[i]
elif labels[i] == 'UpperClothes' or labels[i] == 'Coat':
mask_cloth += masks[i]
mask_face = np.clip(mask_face, 0, 1)
mask_hair = np.clip(mask_hair, 0, 1)
mask_neck = np.clip(mask_neck, 0, 1)
mask_cloth = np.clip(mask_cloth, 0, 1)
mask_human = np.clip(mask_human, 0, 1)
if np.sum(mask_face) > 0:
soft_mask = np.clip(mask_face, 0, 1)
if ksize1 > 0:
kernel_size1 = int(np.sqrt(np.sum(soft_mask)) * ksize1)
kernel1 = np.ones((kernel_size1, kernel_size1))
soft_mask = cv2.dilate(soft_mask, kernel1, iterations=1)
if ksize > 0:
kernel_size = int(np.sqrt(np.sum(soft_mask)) * ksize)
kernel = np.ones((kernel_size, kernel_size))
soft_mask_dilate = cv2.dilate(soft_mask, kernel, iterations=1)
if warp_mask is not None:
soft_mask_dilate = soft_mask_dilate * (np.clip(soft_mask + warp_mask[:, :, 0], 0, 1))
if eyeh > 0:
soft_mask = np.concatenate((soft_mask[:eyeh], soft_mask_dilate[eyeh:]), axis=0)
else:
soft_mask = soft_mask_dilate
else:
if ksize1 > 0:
kernel_size1 = int(np.sqrt(np.sum(soft_mask)) * ksize1)
kernel1 = np.ones((kernel_size1, kernel_size1))
soft_mask = cv2.dilate(mask_face, kernel1, iterations=1)
else:
soft_mask = mask_face
if include_neck:
soft_mask = np.clip(soft_mask + mask_neck, 0, 1)
if return_human:
mask_human = cv2.GaussianBlur(mask_human, (21, 21), 0) * mask_human
return soft_mask, mask_human
else:
return soft_mask
def crop_bottom(pil_file, width):
if width == 512:
height = 768
else:
height = 1152
w, h = pil_file.size
factor = w / width
new_h = int(h / factor)
pil_file = pil_file.resize((width, new_h))
crop_h = min(int(new_h / 32) * 32, height)
array_file = np.array(pil_file)
array_file = array_file[:crop_h, :, :]
output_file = Image.fromarray(array_file)
return output_file
def img2img_multicontrol(img, control_image, controlnet_conditioning_scale, pipe, mask, pos_prompt, neg_prompt,
strength, num=1, use_ori=False):
image_mask = Image.fromarray(np.uint8(mask * 255))
image_human = []
for i in range(num):
image_human.append(pipe(image=img, mask_image=image_mask, control_image=control_image, prompt=pos_prompt,
negative_prompt=neg_prompt, guidance_scale=7, strength=strength, num_inference_steps=40,
controlnet_conditioning_scale=controlnet_conditioning_scale,
num_images_per_prompt=1).images[0])
if use_ori:
image_human[i] = Image.fromarray((np.array(image_human[i]) * mask[:,:,None] + np.array(img) * (1 - mask[:,:,None])).astype(np.uint8))
return image_human
def get_mask(result):
masks = result['masks']
scores = result['scores']
labels = result['labels']
h, w = masks[0].shape
mask_hair = np.zeros((h, w))
mask_face = np.zeros((h, w))
mask_human = np.zeros((h, w))
for i in range(len(labels)):
if scores[i] > 0.8:
if labels[i] == 'Face':
if np.sum(masks[i]) > np.sum(mask_face):
mask_face = masks[i]
elif labels[i] == 'Human':
if np.sum(masks[i]) > np.sum(mask_human):
mask_human = masks[i]
elif labels[i] == 'Hair':
if np.sum(masks[i]) > np.sum(mask_hair):
mask_hair = masks[i]
mask_rst = np.clip(mask_human - mask_hair - mask_face, 0, 1)
mask_rst = np.expand_dims(mask_rst, 2)
mask_rst = np.concatenate([mask_rst, mask_rst, mask_rst], axis=2)
return mask_rst
def main_diffusion_inference_inpaint(inpaint_image, strength, output_img_size, pos_prompt, neg_prompt,
input_img_dir, base_model_path, style_model_path, lora_model_path,
multiplier_style=0.05,
multiplier_human=1.0):
if style_model_path is None:
model_dir = snapshot_download('Cherrytest/zjz_mj_jiyi_small_addtxt_fromleo', revision='v1.0.0')
style_model_path = os.path.join(model_dir, 'zjz_mj_jiyi_small_addtxt_fromleo.safetensors')
segmentation_pipeline = pipeline(Tasks.image_segmentation, 'damo/cv_resnet101_image-multiple-human-parsing')
det_pipeline = pipeline(Tasks.face_detection, 'damo/cv_ddsar_face-detection_iclr23-damofd')
model_dir = snapshot_download('damo/face_chain_control_model', revision='v1.0.1')
model_dir1 = snapshot_download('ly261666/cv_wanx_style_model',revision='v1.0.3')
if output_img_size == 512:
dtype = torch.float32
else:
dtype = torch.float16
train_dir = str(input_img_dir) + '_labeled'
add_prompt_style = []
f = open(os.path.join(train_dir, 'metadata.jsonl'), 'r')
tags_all = []
cnt = 0
cnts_trigger = np.zeros(6)
is_old = False
for line in f:
cnt += 1
data = json.loads(line)['text'].split(', ')
tags_all.extend(data)
if data[1] == 'a boy':
cnts_trigger[0] += 1
elif data[1] == 'a girl':
cnts_trigger[1] += 1
elif data[1] == 'a handsome man':
cnts_trigger[2] += 1
elif data[1] == 'a beautiful woman':
cnts_trigger[3] += 1
elif data[1] == 'a mature man':
cnts_trigger[4] += 1
is_old = True
elif data[1] == 'a mature woman':
cnts_trigger[5] += 1
is_old = True
else:
print('Error.')
f.close()
attr_idx = np.argmax(cnts_trigger)
trigger_styles = ['a boy, children, ', 'a girl, children, ', 'a handsome man, ', 'a beautiful woman, ',
'a mature man, ', 'a mature woman, ']
trigger_style = '(<fcsks>:10), ' + trigger_styles[attr_idx]
if attr_idx == 2 or attr_idx == 4:
neg_prompt += ', children'
for tag in tags_all:
if tags_all.count(tag) > 0.5 * cnt:
if ('glasses' in tag or 'smile' in tag):
if not tag in add_prompt_style:
add_prompt_style.append(tag)
if len(add_prompt_style) > 0:
add_prompt_style = ", ".join(add_prompt_style) + ', '
else:
add_prompt_style = ''
if isinstance(inpaint_image, str):
inpaint_im = Image.open(inpaint_image)
else:
inpaint_im = inpaint_image
inpaint_im = crop_bottom(inpaint_im, output_img_size)
# return [inpaint_im, inpaint_im, inpaint_im]
openpose = OpenposeDetector.from_pretrained(os.path.join(model_dir, "model_controlnet/ControlNet"))
controlnet = ControlNetModel.from_pretrained(os.path.join(model_dir, "model_controlnet/control_v11p_sd15_openpose"), torch_dtype=dtype)
openpose_image = openpose(np.array(inpaint_im, np.uint8), include_hand=True, include_face=False)
w, h = inpaint_im.size
pipe = StableDiffusionControlNetPipeline.from_pretrained(base_model_path, controlnet=controlnet, torch_dtype=dtype, safety_checker=None)
lora_style_path = style_model_path
lora_human_path = lora_model_path
pipe = merge_lora(pipe, lora_style_path, multiplier_style, from_safetensor=True, device='cuda')
pipe = merge_lora(pipe, lora_human_path, multiplier_human, from_safetensor=False, device='cuda')
pipe = pipe.to("cuda")
image_faces = []
for i in range(1):
image_face = pipe(prompt=trigger_style + add_prompt_style + pos_prompt, image=openpose_image, height=h, width=w,
guidance_scale=7, negative_prompt=neg_prompt,
num_inference_steps=40, num_images_per_prompt=1).images[0]
image_faces.append(image_face)
selected_face = select_high_quality_face(input_img_dir)
swap_results = face_swap_fn(True, image_faces, selected_face)
controlnet = [
ControlNetModel.from_pretrained(os.path.join(model_dir, "model_controlnet/control_v11p_sd15_openpose"), torch_dtype=dtype),
ControlNetModel.from_pretrained(os.path.join(model_dir1, "contronet-canny"), torch_dtype=dtype)
]
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(base_model_path, controlnet=controlnet,
torch_dtype=dtype, safety_checker=None)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe = merge_lora(pipe, style_model_path, multiplier_style, from_safetensor=True)
pipe = merge_lora(pipe, lora_model_path, multiplier_human, from_safetensor=False)
pipe = pipe.to("cuda")
images_human = []
images_auto = []
inpaint_bbox, inpaint_keypoints = call_face_crop(det_pipeline, inpaint_im, 1.1)
eye_height = int((inpaint_keypoints[0, 1] + inpaint_keypoints[1, 1]) / 2)
canny_image = cv2.Canny(np.array(inpaint_im, np.uint8), 100, 200)[:, :, None]
mask = segment(segmentation_pipeline, inpaint_im, ksize=0.05, eyeh=eye_height)
canny_image = (canny_image * (1.0 - mask[:, :, None])).astype(np.uint8)
canny_image = Image.fromarray(np.concatenate([canny_image, canny_image, canny_image], axis=2))
# canny_image.save('canny.png')
for i in range(1):
image_face = swap_results[i]
image_face = Image.fromarray(image_face[:, :, ::-1])
face_bbox, face_keypoints = call_face_crop(det_pipeline, image_face, 1.5)
face_mask = segment(segmentation_pipeline, image_face)
face_mask = np.expand_dims((face_mask * 255).astype(np.uint8), axis=2)
face_mask = np.concatenate([face_mask, face_mask, face_mask], axis=2)
face_mask = Image.fromarray(face_mask)
replaced_input_image, warp_mask = crop_and_paste(image_face, face_mask, inpaint_im, face_keypoints,
inpaint_keypoints, face_bbox)
warp_mask = 1.0 - warp_mask
# cv2.imwrite('tmp_{}.png'.format(i), replaced_input_image[:, :, ::-1])
openpose_image = openpose(np.array(replaced_input_image * warp_mask, np.uint8), include_hand=True,
include_body=False, include_face=True)
# openpose_image.save('openpose_{}.png'.format(i))
read_control = [openpose_image, canny_image]
inpaint_mask, human_mask = segment(segmentation_pipeline, inpaint_im, ksize=0.1, ksize1=0.06, eyeh=eye_height, include_neck=False,
warp_mask=warp_mask, return_human=True)
inpaint_with_mask = ((1.0 - inpaint_mask[:,:,None]) * np.array(inpaint_im))[:,:,::-1]
# cv2.imwrite('inpaint_with_mask_{}.png'.format(i), inpaint_with_mask)
print('Finishing segmenting images.')
images_human.extend(img2img_multicontrol(inpaint_im, read_control, [1.0, 0.2], pipe, inpaint_mask,
trigger_style + add_prompt_style + pos_prompt, neg_prompt,
strength=strength))
images_auto.extend(img2img_multicontrol(inpaint_im, read_control, [1.0, 0.2], pipe, np.zeros_like(inpaint_mask),
trigger_style + add_prompt_style + pos_prompt, neg_prompt,
strength=0.025))
edge_add = np.array(inpaint_im).astype(np.int16) - np.array(images_auto[i]).astype(np.int16)
edge_add = edge_add * (1 - human_mask[:,:,None])
images_human[i] = Image.fromarray((np.clip(np.array(images_human[i]).astype(np.int16) + edge_add.astype(np.int16), 0, 255)).astype(np.uint8))
images_rst = []
for i in range(len(images_human)):
im = images_human[i]
canny_image = cv2.Canny(np.array(im, np.uint8), 100, 200)[:, :, None]
canny_image = Image.fromarray(np.concatenate([canny_image, canny_image, canny_image], axis=2))
openpose_image = openpose(np.array(im, np.uint8), include_hand=True, include_face=True)
read_control = [openpose_image, canny_image]
inpaint_mask, human_mask = segment(segmentation_pipeline, images_human[i], ksize=0.02, return_human=True)
print('Finishing segmenting images.')
image_rst = img2img_multicontrol(im, read_control, [0.8, 0.8], pipe, inpaint_mask,
trigger_style + add_prompt_style + pos_prompt, neg_prompt, strength=0.1,
num=1)[0]
image_auto = img2img_multicontrol(im, read_control, [0.8, 0.8], pipe, np.zeros_like(inpaint_mask),
trigger_style + add_prompt_style + pos_prompt, neg_prompt, strength=0.025,
num=1)[0]
edge_add = np.array(im).astype(np.int16) - np.array(image_auto).astype(np.int16)
edge_add = edge_add * (1 - human_mask[:,:,None])
image_rst = Image.fromarray((np.clip(np.array(image_rst).astype(np.int16) + edge_add.astype(np.int16), 0, 255)).astype(np.uint8))
images_rst.append(image_rst)
for i in range(1):
images_rst[i].save('inference_{}.png'.format(i))
return images_rst
def main_diffusion_inference_inpaint_multi(inpaint_images, strength, output_img_size, pos_prompt, neg_prompt,
input_img_dir, base_model_path, style_model_path, lora_model_path,
multiplier_style=0.05,
multiplier_human=1.0):
if style_model_path is None:
model_dir = snapshot_download('Cherrytest/zjz_mj_jiyi_small_addtxt_fromleo', revision='v1.0.0')
style_model_path = os.path.join(model_dir, 'zjz_mj_jiyi_small_addtxt_fromleo.safetensors')
segmentation_pipeline = pipeline(Tasks.image_segmentation, 'damo/cv_resnet101_image-multiple-human-parsing')
det_pipeline = pipeline(Tasks.face_detection, 'damo/cv_ddsar_face-detection_iclr23-damofd')
model_dir = snapshot_download('damo/face_chain_control_model', revision='v1.0.1')
model_dir1 = snapshot_download('ly261666/cv_wanx_style_model',revision='v1.0.3')
if output_img_size == 512:
dtype = torch.float32
else:
dtype = torch.float16
train_dir = str(input_img_dir) + '_labeled'
add_prompt_style = []
f = open(os.path.join(train_dir, 'metadata.jsonl'), 'r')
tags_all = []
cnt = 0
cnts_trigger = np.zeros(6)
is_old = False
for line in f:
cnt += 1
data = json.loads(line)['text'].split(', ')
tags_all.extend(data)
if data[1] == 'a boy':
cnts_trigger[0] += 1
elif data[1] == 'a girl':
cnts_trigger[1] += 1
elif data[1] == 'a handsome man':
cnts_trigger[2] += 1
elif data[1] == 'a beautiful woman':
cnts_trigger[3] += 1
elif data[1] == 'a mature man':
cnts_trigger[4] += 1
is_old = True
elif data[1] == 'a mature woman':
cnts_trigger[5] += 1
is_old = True
else:
print('Error.')
f.close()
attr_idx = np.argmax(cnts_trigger)
trigger_styles = ['a boy, children, ', 'a girl, children, ', 'a handsome man, ', 'a beautiful woman, ',
'a mature man, ', 'a mature woman, ']
trigger_style = '(<fcsks>:10), ' + trigger_styles[attr_idx]
if attr_idx == 2 or attr_idx == 4:
neg_prompt += ', children'
for tag in tags_all:
if tags_all.count(tag) > 0.5 * cnt:
if ('glasses' in tag or 'smile' in tag):
if not tag in add_prompt_style:
add_prompt_style.append(tag)
if len(add_prompt_style) > 0:
add_prompt_style = ", ".join(add_prompt_style) + ', '
else:
add_prompt_style = ''
openpose = OpenposeDetector.from_pretrained(os.path.join(model_dir, "model_controlnet/ControlNet"))
controlnet = ControlNetModel.from_pretrained(os.path.join(model_dir, "model_controlnet/control_v11p_sd15_openpose"), torch_dtype=dtype)
pipe = StableDiffusionControlNetPipeline.from_pretrained(base_model_path, controlnet=controlnet, torch_dtype=dtype, safety_checker=None)
lora_style_path = style_model_path
lora_human_path = lora_model_path
pipe = merge_lora(pipe, lora_style_path, multiplier_style, from_safetensor=True)
pipe = merge_lora(pipe, lora_human_path, multiplier_human, from_safetensor=False)
pipe = pipe.to("cuda")
image_faces = []
for i in range(1):
inpaint_im = inpaint_images[i]
inpaint_im = crop_bottom(inpaint_im, output_img_size)
openpose_image = openpose(np.array(inpaint_im, np.uint8), include_hand=True, include_face=False)
w, h = inpaint_im.size
image_face = pipe(prompt=trigger_style + add_prompt_style + pos_prompt, image=openpose_image, height=h, width=w,
guidance_scale=7, negative_prompt=neg_prompt,
num_inference_steps=40, num_images_per_prompt=1).images[0]
image_faces.append(image_face)
selected_face = select_high_quality_face(input_img_dir)
swap_results = face_swap_fn(True, image_faces, selected_face)
controlnet = [
ControlNetModel.from_pretrained(os.path.join(model_dir, "model_controlnet/control_v11p_sd15_openpose"), torch_dtype=dtype),
ControlNetModel.from_pretrained(os.path.join(model_dir1, "contronet-canny"), torch_dtype=dtype)
]
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(base_model_path, controlnet=controlnet,
torch_dtype=dtype, safety_checker=None)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe = merge_lora(pipe, style_model_path, multiplier_style, from_safetensor=True)
pipe = merge_lora(pipe, lora_model_path, multiplier_human, from_safetensor=False)
pipe = pipe.to("cuda")
images_human = []
images_auto = []
for i in range(1):
inpaint_im = inpaint_images[i]
inpaint_bbox, inpaint_keypoints = call_face_crop(det_pipeline, inpaint_im, 1.1)
eye_height = int((inpaint_keypoints[0, 1] + inpaint_keypoints[1, 1]) / 2)
canny_image = cv2.Canny(np.array(inpaint_im, np.uint8), 100, 200)[:, :, None]
mask = segment(segmentation_pipeline, inpaint_im, ksize=0.05, eyeh=eye_height)
canny_image = (canny_image * (1.0 - mask[:, :, None])).astype(np.uint8)
canny_image = Image.fromarray(np.concatenate([canny_image, canny_image, canny_image], axis=2))
image_face = swap_results[i]
image_face = Image.fromarray(image_face[:, :, ::-1])
face_bbox, face_keypoints = call_face_crop(det_pipeline, image_face, 1.5)
face_mask = segment(segmentation_pipeline, image_face)
face_mask = np.expand_dims((face_mask * 255).astype(np.uint8), axis=2)
face_mask = np.concatenate([face_mask, face_mask, face_mask], axis=2)
face_mask = Image.fromarray(face_mask)
replaced_input_image, warp_mask = crop_and_paste(image_face, face_mask, inpaint_im, face_keypoints,
inpaint_keypoints, face_bbox)
warp_mask = 1.0 - warp_mask
# cv2.imwrite('tmp_{}.png'.format(i), replaced_input_image[:, :, ::-1])
openpose_image = openpose(np.array(replaced_input_image * warp_mask, np.uint8), include_hand=True,
include_body=False, include_face=True)
# openpose_image.save('openpose_{}.png'.format(i))
read_control = [openpose_image, canny_image]
inpaint_mask, human_mask = segment(segmentation_pipeline, inpaint_im, ksize=0.1, ksize1=0.06, eyeh=eye_height, include_neck=False,
warp_mask=warp_mask, return_human=True)
inpaint_with_mask = ((1.0 - inpaint_mask[:,:,None]) * np.array(inpaint_im))[:,:,::-1]
# cv2.imwrite('inpaint_with_mask_{}.png'.format(i), inpaint_with_mask)
print('Finishing segmenting images.')
images_human.extend(img2img_multicontrol(inpaint_im, read_control, [1.0, 0.2], pipe, inpaint_mask,
trigger_style + add_prompt_style + pos_prompt, neg_prompt,
strength=strength))
images_auto.extend(img2img_multicontrol(inpaint_im, read_control, [1.0, 0.2], pipe, np.zeros_like(inpaint_mask),
trigger_style + add_prompt_style + pos_prompt, neg_prompt,
strength=0.025))
edge_add = np.array(inpaint_im).astype(np.int16) - np.array(images_auto[i]).astype(np.int16)
edge_add = edge_add * (1 - human_mask[:,:,None])
images_human[i] = Image.fromarray((np.clip(np.array(images_human[i]).astype(np.int16) + edge_add.astype(np.int16), 0, 255)).astype(np.uint8))
images_rst = []
for i in range(len(images_human)):
im = images_human[i]
canny_image = cv2.Canny(np.array(im, np.uint8), 100, 200)[:, :, None]
canny_image = Image.fromarray(np.concatenate([canny_image, canny_image, canny_image], axis=2))
openpose_image = openpose(np.array(im, np.uint8), include_hand=True, include_face=True)
read_control = [openpose_image, canny_image]
inpaint_mask, human_mask = segment(segmentation_pipeline, images_human[i], ksize=0.02, return_human=True)
print('Finishing segmenting images.')
image_rst = img2img_multicontrol(im, read_control, [0.8, 0.8], pipe, np.zeros_like(inpaint_mask),
trigger_style + add_prompt_style + pos_prompt, neg_prompt, strength=0.1,
num=1)[0]
image_auto = img2img_multicontrol(im, read_control, [0.8, 0.8], pipe, np.zeros_like(inpaint_mask),
trigger_style + add_prompt_style + pos_prompt, neg_prompt, strength=0.025,
num=1)[0]
edge_add = np.array(im).astype(np.int16) - np.array(image_auto).astype(np.int16)
edge_add = edge_add * (1 - human_mask[:,:,None])
image_rst = Image.fromarray((np.clip(np.array(image_rst).astype(np.int16) + edge_add.astype(np.int16), 0, 255)).astype(np.uint8))
images_rst.append(image_rst)
for i in range(1):
images_rst[i].save('inference_{}.png'.format(i))
return images_rst
def stylization_fn(use_stylization, rank_results):
if use_stylization:
## TODO
pass
else:
return rank_results
def main_model_inference(inpaint_image, strength, output_img_size,
pos_prompt, neg_prompt, style_model_path, multiplier_style, multiplier_human, use_main_model,
input_img_dir=None, base_model_path=None, lora_model_path=None):
if use_main_model:
multiplier_style_kwargs = {'multiplier_style': multiplier_style} if multiplier_style is not None else {}
multiplier_human_kwargs = {'multiplier_human': multiplier_human} if multiplier_human is not None else {}
return main_diffusion_inference_inpaint(inpaint_image, strength, output_img_size, pos_prompt, neg_prompt,
input_img_dir, base_model_path, style_model_path, lora_model_path,
**multiplier_style_kwargs, **multiplier_human_kwargs)
def main_model_inference_multi(inpaint_image, strength, output_img_size,
pos_prompt, neg_prompt, style_model_path, multiplier_style, multiplier_human, use_main_model,
input_img_dir=None, base_model_path=None, lora_model_path=None):
if use_main_model:
multiplier_style_kwargs = {'multiplier_style': multiplier_style} if multiplier_style is not None else {}
multiplier_human_kwargs = {'multiplier_human': multiplier_human} if multiplier_human is not None else {}
return main_diffusion_inference_inpaint_multi(inpaint_image, strength, output_img_size, pos_prompt, neg_prompt,
input_img_dir, base_model_path, style_model_path, lora_model_path,
**multiplier_style_kwargs, **multiplier_human_kwargs)
def select_high_quality_face(input_img_dir):
input_img_dir = str(input_img_dir) + '_labeled'
quality_score_list = []
abs_img_path_list = []
## TODO
face_quality_func = pipeline(Tasks.face_quality_assessment, 'damo/cv_manual_face-quality-assessment_fqa',
model_revision='v2.0')
for img_name in os.listdir(input_img_dir):
if img_name.endswith('jsonl') or img_name.startswith('.ipynb') or img_name.startswith('.safetensors'):
continue
if img_name.endswith('jpg') or img_name.endswith('png'):
abs_img_name = os.path.join(input_img_dir, img_name)
face_quality_score = face_quality_func(abs_img_name)[OutputKeys.SCORES]
if face_quality_score is None:
quality_score_list.append(0)
else:
quality_score_list.append(face_quality_score[0])
abs_img_path_list.append(abs_img_name)
sort_idx = np.argsort(quality_score_list)[::-1]
print('Selected face: ' + abs_img_path_list[sort_idx[0]])
return Image.open(abs_img_path_list[sort_idx[0]])
def face_swap_fn(use_face_swap, gen_results, template_face):
if use_face_swap:
## TODO
out_img_list = []
image_face_fusion = pipeline('face_fusion_torch',
model='damo/cv_unet_face_fusion_torch', model_revision='v1.0.5')
segmentation_pipeline = pipeline(Tasks.image_segmentation, 'damo/cv_resnet101_image-multiple-human-parsing')
for img in gen_results:
result = image_face_fusion(dict(template=img, user=template_face))[OutputKeys.OUTPUT_IMG]
face_mask = segment(segmentation_pipeline, img, ksize=0.1)
result = (result * face_mask[:,:,None] + np.array(img)[:,:,::-1] * (1 - face_mask[:,:,None])).astype(np.uint8)
out_img_list.append(result)
return out_img_list
else:
ret_results = []
for img in gen_results:
ret_results.append(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
return ret_results
def post_process_fn(use_post_process, swap_results_ori, selected_face, num_gen_images):
if use_post_process:
sim_list = []
## TODO
face_recognition_func = pipeline(Tasks.face_recognition, 'damo/cv_ir_face-recognition-ood_rts',
model_revision='v2.5')
face_det_func = pipeline(task=Tasks.face_detection, model='damo/cv_ddsar_face-detection_iclr23-damofd',
model_revision='v1.1')
swap_results = swap_results_ori
select_face_emb = face_recognition_func(selected_face)[OutputKeys.IMG_EMBEDDING][0]
for img in swap_results:
emb = face_recognition_func(img)[OutputKeys.IMG_EMBEDDING]
if emb is None or select_face_emb is None:
sim_list.append(0)
else:
sim = np.dot(emb, select_face_emb)
sim_list.append(sim.item())
sort_idx = np.argsort(sim_list)[::-1]
return np.array(swap_results)[sort_idx[:min(int(num_gen_images), len(swap_results))]]
else:
return np.array(swap_results_ori)
class GenPortrait_inpaint:
def __init__(self, inpaint_img, strength, num_faces,
pos_prompt, neg_prompt, style_model_path, multiplier_style, multiplier_human,
use_main_model=True, use_face_swap=True,
use_post_process=True, use_stylization=True):
self.use_main_model = use_main_model
self.use_face_swap = use_face_swap
self.use_post_process = use_post_process
self.use_stylization = use_stylization
self.multiplier_style = multiplier_style
self.multiplier_human = multiplier_human
self.style_model_path = style_model_path
self.pos_prompt = pos_prompt
self.neg_prompt = neg_prompt
self.inpaint_img = inpaint_img
self.strength = strength
self.num_faces = num_faces
def __call__(self, input_img_dir1=None, input_img_dir2=None, base_model_path=None,
lora_model_path1=None, lora_model_path2=None, sub_path=None, revision=None):
base_model_path = snapshot_download(base_model_path, revision=revision)
if sub_path is not None and len(sub_path) > 0:
base_model_path = os.path.join(base_model_path, sub_path)
face_detection = pipeline(task=Tasks.face_detection, model='damo/cv_ddsar_face-detection_iclr23-damofd',
model_revision='v1.1')
result_det = face_detection(self.inpaint_img)
bboxes = result_det['boxes']
assert(len(bboxes)) == self.num_faces
bboxes = np.array(bboxes).astype(np.int16)
lefts = []
for bbox in bboxes:
lefts.append(bbox[0])
idxs = np.argsort(lefts)
if lora_model_path1 != None:
face_box = bboxes[idxs[0]]
inpaint_img_large = cv2.imread(self.inpaint_img)
mask_large = np.ones_like(inpaint_img_large)
mask_large1 = np.zeros_like(inpaint_img_large)
h,w,_ = inpaint_img_large.shape
for i in range(len(bboxes)):
if i != idxs[0]:
bbox = bboxes[i]
inpaint_img_large[bbox[1]:bbox[3], bbox[0]:bbox[2]] = 0
mask_large[bbox[1]:bbox[3], bbox[0]:bbox[2]] = 0
face_ratio = 0.45
cropl = int(max(face_box[3] - face_box[1], face_box[2] - face_box[0]) / face_ratio / 2)
cx = int((face_box[2] + face_box[0])/2)
cy = int((face_box[1] + face_box[3])/2)
cropup = min(cy, cropl)
cropbo = min(h-cy, cropl)
crople = min(cx, cropl)
cropri = min(w-cx, cropl)
inpaint_img = np.pad(inpaint_img_large[cy-cropup:cy+cropbo, cx-crople:cx+cropri], ((cropl-cropup, cropl-cropbo), (cropl-crople, cropl-cropri), (0, 0)), 'constant')
inpaint_img = cv2.resize(inpaint_img, (512, 512))
inpaint_img = Image.fromarray(inpaint_img[:,:,::-1])
mask_large1[cy-cropup:cy+cropbo, cx-crople:cx+cropri] = 1
mask_large = mask_large * mask_large1
gen_results = main_model_inference(inpaint_img, self.strength, 512,
self.pos_prompt, self.neg_prompt,
self.style_model_path, self.multiplier_style, self.multiplier_human,
self.use_main_model, input_img_dir=input_img_dir1,
lora_model_path=lora_model_path1, base_model_path=base_model_path)
# select_high_quality_face PIL
selected_face = select_high_quality_face(input_img_dir1)
# face_swap cv2
swap_results = face_swap_fn(self.use_face_swap, gen_results, selected_face)
# stylization
final_gen_results = swap_results
print(len(final_gen_results))
final_gen_results_new = []
inpaint_img_large = cv2.imread(self.inpaint_img)
ksize = int(10 * cropl / 256)
for i in range(len(final_gen_results)):
print('Start cropping.')
rst_gen = cv2.resize(final_gen_results[i], (cropl * 2, cropl * 2))
rst_crop = rst_gen[cropl-cropup:cropl+cropbo, cropl-crople:cropl+cropri]
print(rst_crop.shape)
inpaint_img_rst = np.zeros_like(inpaint_img_large)
print('Start pasting.')
inpaint_img_rst[cy-cropup:cy+cropbo, cx-crople:cx+cropri] = rst_crop
print('Fininsh pasting.')
print(inpaint_img_rst.shape, mask_large.shape, inpaint_img_large.shape)
mask_large = mask_large.astype(np.float32)
kernel = np.ones((ksize * 2, ksize * 2))
mask_large1 = cv2.erode(mask_large, kernel, iterations=1)
mask_large1 = cv2.GaussianBlur(mask_large1, (int(ksize * 1.8) * 2 + 1, int(ksize * 1.8) * 2 + 1), 0)
mask_large1[face_box[1]:face_box[3], face_box[0]:face_box[2]] = 1
mask_large = mask_large * mask_large1
final_inpaint_rst = (inpaint_img_rst.astype(np.float32) * mask_large.astype(np.float32) + inpaint_img_large.astype(np.float32) * (1.0 - mask_large.astype(np.float32))).astype(np.uint8)
print('Finish masking.')
final_gen_results_new.append(final_inpaint_rst)
print('Finish generating.')
else:
inpaint_img_large = cv2.imread(self.inpaint_img)
inpaint_img_le = cv2.imread(self.inpaint_img)
final_gen_results_new = [inpaint_img_le, inpaint_img_le, inpaint_img_le]
for i in range(1):
cv2.imwrite('tmp_inpaint_left_{}.png'.format(i), final_gen_results_new[i])
if lora_model_path2 != None and self.num_faces == 2:
face_box = bboxes[idxs[1]]
mask_large = np.ones_like(inpaint_img_large)
mask_large1 = np.zeros_like(inpaint_img_large)
h,w,_ = inpaint_img_large.shape
for i in range(len(bboxes)):
if i != idxs[1]:
bbox = bboxes[i]
inpaint_img_large[bbox[1]:bbox[3], bbox[0]:bbox[2]] = 0
mask_large[bbox[1]:bbox[3], bbox[0]:bbox[2]] = 0
face_ratio = 0.45
cropl = int(max(face_box[3] - face_box[1], face_box[2] - face_box[0]) / face_ratio / 2)
cx = int((face_box[2] + face_box[0])/2)
cy = int((face_box[1] + face_box[3])/2)
cropup = min(cy, cropl)
cropbo = min(h-cy, cropl)
crople = min(cx, cropl)
cropri = min(w-cx, cropl)
mask_large1[cy-cropup:cy+cropbo, cx-crople:cx+cropri] = 1
mask_large = mask_large * mask_large1
inpaint_imgs = []
for i in range(1):
inpaint_img_large = final_gen_results_new[i] * mask_large
inpaint_img = np.pad(inpaint_img_large[cy-cropup:cy+cropbo, cx-crople:cx+cropri], ((cropl-cropup, cropl-cropbo), (cropl-crople, cropl-cropri), (0, 0)), 'constant')
inpaint_img = cv2.resize(inpaint_img, (512, 512))
inpaint_img = Image.fromarray(inpaint_img[:,:,::-1])
inpaint_imgs.append(inpaint_img)
gen_results = main_model_inference_multi(inpaint_imgs, self.strength, 512,
self.pos_prompt, self.neg_prompt,
self.style_model_path, self.multiplier_style, self.multiplier_human,
self.use_main_model, input_img_dir=input_img_dir2,
lora_model_path=lora_model_path2, base_model_path=base_model_path)
# select_high_quality_face PIL
selected_face = select_high_quality_face(input_img_dir2)
# face_swap cv2
swap_results = face_swap_fn(self.use_face_swap, gen_results, selected_face)
# stylization
final_gen_results = swap_results
print(len(final_gen_results))
final_gen_results_final = []
inpaint_img_large = cv2.imread(self.inpaint_img)
ksize = int(10 * cropl / 256)
for i in range(len(final_gen_results)):
print('Start cropping.')
rst_gen = cv2.resize(final_gen_results[i], (cropl * 2, cropl * 2))
rst_crop = rst_gen[cropl-cropup:cropl+cropbo, cropl-crople:cropl+cropri]
print(rst_crop.shape)
inpaint_img_rst = np.zeros_like(inpaint_img_large)
print('Start pasting.')
inpaint_img_rst[cy-cropup:cy+cropbo, cx-crople:cx+cropri] = rst_crop
print('Fininsh pasting.')
print(inpaint_img_rst.shape, mask_large.shape, inpaint_img_large.shape)
mask_large = mask_large.astype(np.float32)
kernel = np.ones((ksize * 2, ksize * 2))
mask_large1 = cv2.erode(mask_large, kernel, iterations=1)
mask_large1 = cv2.GaussianBlur(mask_large1, (int(ksize * 1.8) * 2 + 1, int(ksize * 1.8) * 2 + 1), 0)
mask_large1[face_box[1]:face_box[3], face_box[0]:face_box[2]] = 1
mask_large = mask_large * mask_large1
final_inpaint_rst = (inpaint_img_rst.astype(np.float32) * mask_large.astype(np.float32) + final_gen_results_new[i].astype(np.float32) * (1.0 - mask_large.astype(np.float32))).astype(np.uint8)
print('Finish masking.')
final_gen_results_final.append(final_inpaint_rst)
print('Finish generating.')
else:
final_gen_results_final = final_gen_results_new
outputs = final_gen_results_final
outputs_RGB = []
for out_tmp in outputs:
outputs_RGB.append(cv2.cvtColor(out_tmp, cv2.COLOR_BGR2RGB))
image_path = './lora_result.png'
if len(outputs) > 0:
result = concatenate_images(outputs)
cv2.imwrite(image_path, result)
return final_gen_results_final
def compress_image(input_path, target_size):
output_path = change_extension_to_jpg(input_path)
image = cv2.imread(input_path)
quality = 95
try:
while cv2.imencode('.jpg', image, [cv2.IMWRITE_JPEG_QUALITY, quality])[1].size > target_size:
quality -= 5
except:
import pdb;pdb.set_trace()
compressed_image = cv2.imencode('.jpg', image, [cv2.IMWRITE_JPEG_QUALITY, quality])[1].tostring()
with open(output_path, 'wb') as f:
f.write(compressed_image)
return output_path
def change_extension_to_jpg(image_path):
base_name = os.path.basename(image_path)
new_base_name = os.path.splitext(base_name)[0] + ".jpg"
directory = os.path.dirname(image_path)
new_image_path = os.path.join(directory, new_base_name)
return new_image_path
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import sys
from typing import Any
import tempfile
from modelscope.pipelines import pipeline
from facechain.constants import tts_speakers_map
from facechain.utils import join_worker_data_dir
try:
import edge_tts
except ImportError:
print("警告:未找到edge_tts模块,语音合成功能将无法使用。您可以通过`pip install edge-tts`安装它。\n Warning: The edge_tts module is not found, so the speech synthesis function will not be available. You can install it by 'pip install edge-tts'.")
class SadTalker():
def __init__(self, uuid):
if not uuid:
if os.getenv("MODELSCOPE_ENVIRONMENT") == 'studio':
return "请登陆后使用! (Please login first)"
else:
uuid = 'qw'
# self.save_dir = os.path.join('/tmp', uuid, 'sythesized_video') # deprecated
# self.save_dir = os.path.join('.', uuid, 'sythesized_video') # deprecated
self.save_dir = join_worker_data_dir(uuid, 'sythesized_video')
def __call__(self, *args, **kwargs) -> Any:
# two required arguments
source_image = kwargs.get("source_image") or args[0]
driven_audio = kwargs.get('driven_audio') or args[1]
# other optional arguments
kwargs = {
'preprocess' : kwargs.get('preprocess') or args[2],
'still_mode' : kwargs.get('still_mode') or args[3],
'use_enhancer' : kwargs.get('use_enhancer') or args[4],
'batch_size' : kwargs.get('batch_size') or args[5],
'size' : kwargs.get('size') or args[6],
'pose_style' : kwargs.get('pose_style') or args[7],
'exp_scale' : kwargs.get('exp_scale') or args[8],
'result_dir': self.save_dir
}
inference = pipeline('talking-head', model='wwd123/sadtalker', model_revision='v1.0.0')
print("initialized sadtalker pipeline")
video_path = inference(source_image, driven_audio=driven_audio, **kwargs)
return video_path
async def text_to_speech_edge(text, speaker):
voice = tts_speakers_map[speaker]
communicate = edge_tts.Communicate(text, voice)
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
tmp_path = tmp_file.name
await communicate.save(tmp_path)
return tmp_path
\ No newline at end of file
# Copyright (c) Alibaba, Inc. and its affiliates.
# Modified from the original implementation at https://github.com/modelscope/facechain/pull/104.
import json
import os
import sys
import cv2
import numpy as np
import torch
from PIL import Image
from skimage import transform
# from controlnet_aux import OpenposeDetector
from dwpose import DWposeDetector
from diffusers import StableDiffusionPipeline, StableDiffusionControlNetPipeline, \
StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler
from facechain.utils import snapshot_download
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from torch import multiprocessing
from transformers import pipeline as tpipeline
from facechain.data_process.preprocessing import Blipv2
from facechain.merge_lora import merge_lora
def _data_process_fn_process(input_img_dir):
Blipv2()(input_img_dir)
def concatenate_images(images):
heights = [img.shape[0] for img in images]
max_width = sum([img.shape[1] for img in images])
concatenated_image = np.zeros((max(heights), max_width, 3), dtype=np.uint8)
x_offset = 0
for img in images:
concatenated_image[0:img.shape[0], x_offset:x_offset + img.shape[1], :] = img
x_offset += img.shape[1]
return concatenated_image
def data_process_fn(input_img_dir, use_data_process):
## TODO add face quality filter
if use_data_process:
## TODO
_process = multiprocessing.Process(target=_data_process_fn_process, args=(input_img_dir,))
_process.start()
_process.join()
return os.path.join(str(input_img_dir) + '_labeled', "metadata.jsonl")
def call_face_crop(det_pipeline, image, crop_ratio):
det_result = det_pipeline(image)
bboxes = det_result['boxes']
keypoints = det_result['keypoints']
area = 0
idx = 0
for i in range(len(bboxes)):
bbox = bboxes[i]
area_tmp = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
if area_tmp > area:
area = area_tmp
idx = i
bbox = bboxes[idx]
keypoint = keypoints[idx]
points_array = np.zeros((5, 2))
for k in range(5):
points_array[k, 0] = keypoint[2 * k]
points_array[k, 1] = keypoint[2 * k + 1]
w, h = image.size
face_w = bbox[2] - bbox[0]
face_h = bbox[3] - bbox[1]
bbox[0] = np.clip(np.array(bbox[0], np.int32) - face_w * (crop_ratio - 1) / 2, 0, w - 1)
bbox[1] = np.clip(np.array(bbox[1], np.int32) - face_h * (crop_ratio - 1) / 2, 0, h - 1)
bbox[2] = np.clip(np.array(bbox[2], np.int32) + face_w * (crop_ratio - 1) / 2, 0, w - 1)
bbox[3] = np.clip(np.array(bbox[3], np.int32) + face_h * (crop_ratio - 1) / 2, 0, h - 1)
bbox = np.array(bbox, np.int32)
return bbox, points_array
def crop_and_paste(Source_image, Source_image_mask, Target_image, Source_Five_Point, Target_Five_Point, Source_box, use_warp=True):
if use_warp:
Source_Five_Point = np.reshape(Source_Five_Point, [5, 2]) - np.array(Source_box[:2])
Target_Five_Point = np.reshape(Target_Five_Point, [5, 2])
Crop_Source_image = Source_image.crop(np.int32(Source_box))
Crop_Source_image_mask = Source_image_mask.crop(np.int32(Source_box))
Source_Five_Point, Target_Five_Point = np.array(Source_Five_Point), np.array(Target_Five_Point)
tform = transform.SimilarityTransform()
tform.estimate(Source_Five_Point, Target_Five_Point)
M = tform.params[0:2, :]
warped = cv2.warpAffine(np.array(Crop_Source_image), M, np.shape(Target_image)[:2][::-1], borderValue=0.0)
warped_mask = cv2.warpAffine(np.array(Crop_Source_image_mask), M, np.shape(Target_image)[:2][::-1], borderValue=0.0)
mask = np.float32(warped_mask == 0)
output = mask * np.float32(Target_image) + (1 - mask) * np.float32(warped)
else:
mask = np.float32(np.array(Source_image_mask) == 0)
output = mask * np.float32(Target_image) + (1 - mask) * np.float32(Source_image)
return output, mask
def segment(segmentation_pipeline, img, ksize=0, return_human=False, return_cloth=False, return_hand=False):
if True:
result = segmentation_pipeline(img)
masks = result['masks']
scores = result['scores']
labels = result['labels']
if len(masks) == 0:
return
h, w = masks[0].shape
mask_face = np.zeros((h, w))
mask_hair = np.zeros((h, w))
mask_neck = np.zeros((h, w))
mask_cloth = np.zeros((h, w))
mask_human = np.zeros((h, w))
mask_hands = np.zeros((h, w))
for i in range(len(labels)):
if scores[i] > 0.8:
if labels[i] == 'Torso-skin':
mask_neck += masks[i]
elif labels[i] == 'Face':
mask_face += masks[i]
elif labels[i] == 'Human':
if np.sum(masks[i]) > np.sum(mask_human):
mask_human = masks[i]
elif labels[i] == 'Hair':
mask_hair += masks[i]
elif labels[i] == 'UpperClothes' or labels[i] == 'Coat' or labels[i] == 'Dress' or labels[i] == 'Pants' or labels[i] == 'Skirt':
mask_cloth += masks[i]
elif labels[i] == 'Left-arm' or labels[i] == 'Right-arm':
mask_hands += masks[i]
mask_face = np.clip(mask_face * mask_human, 0, 1)
mask_hair = np.clip(mask_hair * mask_human, 0, 1)
mask_neck = np.clip(mask_neck * mask_human, 0, 1)
mask_cloth = np.clip(mask_cloth * mask_human, 0, 1)
mask_human = np.clip(mask_human, 0, 1)
mask_hands = np.clip(mask_hands * mask_human, 0, 1)
if return_cloth:
if ksize > 0:
kernel = np.ones((ksize, ksize))
soft_mask = cv2.erode(mask_cloth, kernel, iterations=1)
return soft_mask
else:
return mask_cloth
if return_hand:
return mask_hands
if return_human:
mask_head = np.clip(mask_face + mask_hair + mask_neck, 0, 1)
kernel = np.ones((ksize, ksize))
dilated_head = cv2.dilate(mask_head, kernel, iterations=1)
mask_human = np.clip(mask_human - dilated_head + mask_cloth, 0, 1)
return mask_human
if np.sum(mask_face) > 0:
soft_mask = np.clip(mask_face, 0, 1)
if ksize > 0:
# kernel_size = int(np.sqrt(np.sum(soft_mask)) * ksize)
kernel = np.ones((ksize, ksize))
soft_mask = cv2.dilate(soft_mask, kernel, iterations=1)
else:
soft_mask = mask_face
return soft_mask
def crop_bottom(pil_file, width):
if width == 512:
height = 768
else:
height = 1152
w, h = pil_file.size
factor = w / width
new_h = int(h / factor)
pil_file = pil_file.resize((width, new_h))
crop_h = min(int(new_h / 32) * 32, height)
array_file = np.array(pil_file)
array_file = array_file[:crop_h, :, :]
output_file = Image.fromarray(array_file)
return output_file
def img2img_multicontrol(img, control_image, controlnet_conditioning_scale, pipe, mask, pos_prompt, neg_prompt,
strength, num=1, use_ori=False):
image_mask = Image.fromarray(np.uint8(mask * 255))
image_human = []
for i in range(num):
image_human.append(pipe(image=img, mask_image=image_mask, control_image=control_image, prompt=pos_prompt,
negative_prompt=neg_prompt, guidance_scale=7, strength=strength, num_inference_steps=40,
controlnet_conditioning_scale=controlnet_conditioning_scale,
num_images_per_prompt=1).images[0])
if use_ori:
image_human[i] = Image.fromarray((np.array(image_human[i]) * mask[:,:,None] + np.array(img) * (1 - mask[:,:,None])).astype(np.uint8))
return image_human
def main_diffusion_inference_tryon(inpaint_image, strength, output_img_size, pos_prompt, neg_prompt,
input_img_dir, base_model_path, style_model_path, lora_model_path,
multiplier_style=0.05,
multiplier_human=1.0):
if style_model_path is None:
model_dir = snapshot_download('Cherrytest/zjz_mj_jiyi_small_addtxt_fromleo', revision='v1.0.0')
style_model_path = os.path.join(model_dir, 'zjz_mj_jiyi_small_addtxt_fromleo.safetensors')
segmentation_pipeline = pipeline(Tasks.image_segmentation, 'damo/cv_resnet101_image-multiple-human-parsing')
det_pipeline = pipeline(Tasks.face_detection, 'damo/cv_ddsar_face-detection_iclr23-damofd')
model_dir = snapshot_download('damo/face_chain_control_model', revision='v1.0.1')
model_dir0 = snapshot_download('damo/face_chain_control_model', revision='v1.0.2')
model_dir1 = snapshot_download('ly261666/cv_wanx_style_model',revision='v1.0.3')
if output_img_size == 512:
dtype = torch.float32
else:
dtype = torch.float16
train_dir = str(input_img_dir) + '_labeled'
add_prompt_style = []
f = open(os.path.join(train_dir, 'metadata.jsonl'), 'r')
tags_all = []
cnt = 0
cnts_trigger = np.zeros(6)
for line in f:
cnt += 1
data = json.loads(line)['text'].split(', ')
tags_all.extend(data)
if data[1] == 'a boy':
cnts_trigger[0] += 1
elif data[1] == 'a girl':
cnts_trigger[1] += 1
elif data[1] == 'a handsome man':
cnts_trigger[2] += 1
elif data[1] == 'a beautiful woman':
cnts_trigger[3] += 1
elif data[1] == 'a mature man':
cnts_trigger[4] += 1
elif data[1] == 'a mature woman':
cnts_trigger[5] += 1
else:
print('Error.')
f.close()
attr_idx = np.argmax(cnts_trigger)
trigger_styles = ['a boy, children, ', 'a girl, children, ', 'a handsome man, ', 'a beautiful woman, ',
'a mature man, ', 'a mature woman, ']
trigger_style = '(<fcsks>:10), ' + trigger_styles[attr_idx]
if attr_idx == 2 or attr_idx == 4:
neg_prompt += ', children'
neg_prompt += ', blurry, blurry background'
for tag in tags_all:
if tags_all.count(tag) > 0.5 * cnt:
if ('glasses' in tag or 'smile' in tag or 'hair' in tag):
if not tag in add_prompt_style:
add_prompt_style.append(tag)
if len(add_prompt_style) > 0:
add_prompt_style = ", ".join(add_prompt_style) + ', '
else:
add_prompt_style = ''
print(add_prompt_style)
if isinstance(inpaint_image, str):
inpaint_im = Image.open(inpaint_image)
else:
inpaint_im = inpaint_image
inpaint_im = crop_bottom(inpaint_im, output_img_size)
w, h = inpaint_im.size
dwprocessor = DWposeDetector(os.path.join(model_dir0, 'dwpose_models'))
openpose_image, handbox = dwprocessor(np.array(inpaint_im, np.uint8), include_body=True, include_hand=True, include_face=False, return_handbox=True)
openpose_image = Image.fromarray(openpose_image)
openpose_image.save('openpose.png')
controlnet = [
ControlNetModel.from_pretrained(os.path.join(model_dir, "model_controlnet/control_v11p_sd15_openpose"), torch_dtype=dtype),
ControlNetModel.from_pretrained(os.path.join(model_dir, 'model_controlnet/control_v11p_sd15_depth'),
torch_dtype=dtype),
ControlNetModel.from_pretrained(os.path.join(model_dir1, "contronet-canny"), torch_dtype=dtype)
]
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(base_model_path, controlnet=controlnet,
torch_dtype=dtype, safety_checker=None)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe = merge_lora(pipe, style_model_path, multiplier_style, from_safetensor=True)
pipe = merge_lora(pipe, lora_model_path, multiplier_human, from_safetensor=False)
pipe = pipe.to("cuda")
images_human = []
mask = segment(segmentation_pipeline, inpaint_im, return_hand=True)
mask1 = segment(segmentation_pipeline, inpaint_im, ksize=5, return_human=True)
canny_image = cv2.Canny(np.array(inpaint_im, np.uint8), 80, 200)[:, :, None]
canny_image = (canny_image * mask1[:, :, None]).astype(np.uint8)
canny_image = Image.fromarray(np.concatenate([canny_image, canny_image, canny_image], axis=2))
canny_image.save('canny.png')
depth_estimator = tpipeline('depth-estimation', os.path.join(model_dir, 'model_controlnet/dpt-large'))
depth_im = np.zeros((h, w))
for hbox in handbox:
depth_input = Image.fromarray(np.array(inpaint_im)[hbox[1]:hbox[3], hbox[0]:hbox[2]])
depth_rst = depth_estimator(depth_input)['depth']
depth_rst = np.array(depth_rst)
depth_im[hbox[1]:hbox[3], hbox[0]:hbox[2]] = depth_rst
depth_im = depth_im[:, :, None]
depth_im = np.concatenate([depth_im, depth_im, depth_im], axis=2)
depth_im = (depth_im * mask[:, :, None]).astype(np.uint8)
depth_im = Image.fromarray(depth_im)
depth_im.save('depth.png')
for i in range(1):
read_control = [openpose_image, depth_im, canny_image]
cloth_mask_warp = segment(segmentation_pipeline, inpaint_im, return_cloth=True, ksize=5)
cloth_mask = segment(segmentation_pipeline, inpaint_im, return_cloth=True, ksize=15)
inpaint_with_mask = (cloth_mask_warp[:,:,None] * np.array(inpaint_im))[:,:,::-1]
inpaint_mask = 1.0 - cloth_mask
cv2.imwrite('inpaint_with_mask_{}.png'.format(i), inpaint_with_mask)
print('Finishing segmenting images.')
images_human.extend(img2img_multicontrol(inpaint_im, read_control, [1.0, 0.2, 0.4], pipe, inpaint_mask,
trigger_style + add_prompt_style + pos_prompt, neg_prompt,
strength=strength))
for i in range(1):
soft_cloth_mask_warp = cv2.GaussianBlur(cloth_mask_warp, (5, 5), 0, 0)
image_human = (np.array(images_human[i]) * (1.0 - soft_cloth_mask_warp[:,:,None]) + np.array(inpaint_im) * soft_cloth_mask_warp[:,:,None]).astype(np.uint8)
images_human[i] = Image.fromarray(image_human)
images_human[i].save('inference_{}.png'.format(i))
return images_human
def stylization_fn(use_stylization, rank_results):
if use_stylization:
## TODO
pass
else:
return rank_results
def main_model_inference(inpaint_image, strength, output_img_size,
pos_prompt, neg_prompt, style_model_path, multiplier_style, multiplier_human, use_main_model,
input_img_dir=None, base_model_path=None, lora_model_path=None):
if use_main_model:
multiplier_style_kwargs = {'multiplier_style': multiplier_style} if multiplier_style is not None else {}
multiplier_human_kwargs = {'multiplier_human': multiplier_human} if multiplier_human is not None else {}
return main_diffusion_inference_tryon(inpaint_image, strength, output_img_size, pos_prompt, neg_prompt,
input_img_dir, base_model_path, style_model_path, lora_model_path,
**multiplier_style_kwargs, **multiplier_human_kwargs)
def select_high_quality_face(input_img_dir):
input_img_dir = str(input_img_dir) + '_labeled'
quality_score_list = []
abs_img_path_list = []
## TODO
face_quality_func = pipeline(Tasks.face_quality_assessment, 'damo/cv_manual_face-quality-assessment_fqa',
model_revision='v2.0')
for img_name in os.listdir(input_img_dir):
if img_name.endswith('jsonl') or img_name.startswith('.ipynb') or img_name.startswith('.safetensors'):
continue
if img_name.endswith('jpg') or img_name.endswith('png'):
abs_img_name = os.path.join(input_img_dir, img_name)
face_quality_score = face_quality_func(abs_img_name)[OutputKeys.SCORES]
if face_quality_score is None:
quality_score_list.append(0)
else:
quality_score_list.append(face_quality_score[0])
abs_img_path_list.append(abs_img_name)
sort_idx = np.argsort(quality_score_list)[::-1]
print('Selected face: ' + abs_img_path_list[sort_idx[0]])
return Image.open(abs_img_path_list[sort_idx[0]])
def face_swap_fn(use_face_swap, gen_results, template_face):
if use_face_swap:
## TODO
out_img_list = []
image_face_fusion = pipeline('face_fusion_torch',
model='damo/cv_unet_face_fusion_torch', model_revision='v1.0.5')
segmentation_pipeline = pipeline(Tasks.image_segmentation, 'damo/cv_resnet101_image-multiple-human-parsing')
for img in gen_results:
result = image_face_fusion(dict(template=img, user=template_face))[OutputKeys.OUTPUT_IMG]
# face_mask = segment(segmentation_pipeline, img, ksize=10)
# result = (result * face_mask[:,:,None] + np.array(img)[:,:,::-1] * (1 - face_mask[:,:,None])).astype(np.uint8)
out_img_list.append(result)
return out_img_list
else:
ret_results = []
for img in gen_results:
ret_results.append(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
return ret_results
def post_process_fn(use_post_process, swap_results_ori, selected_face, num_gen_images):
if use_post_process:
sim_list = []
## TODO
face_recognition_func = pipeline(Tasks.face_recognition, 'damo/cv_ir_face-recognition-ood_rts',
model_revision='v2.5')
face_det_func = pipeline(task=Tasks.face_detection, model='damo/cv_ddsar_face-detection_iclr23-damofd',
model_revision='v1.1')
swap_results = swap_results_ori
select_face_emb = face_recognition_func(selected_face)[OutputKeys.IMG_EMBEDDING][0]
for img in swap_results:
emb = face_recognition_func(img)[OutputKeys.IMG_EMBEDDING]
if emb is None or select_face_emb is None:
sim_list.append(0)
else:
sim = np.dot(emb, select_face_emb)
sim_list.append(sim.item())
sort_idx = np.argsort(sim_list)[::-1]
return np.array(swap_results)[sort_idx[:min(int(num_gen_images), len(swap_results))]]
else:
return np.array(swap_results_ori)
class GenPortrait_tryon:
def __init__(self, inpaint_img, strength,
pos_prompt, neg_prompt, style_model_path, multiplier_style, multiplier_human,
use_main_model=True, use_face_swap=True,
use_post_process=True, use_stylization=True):
self.use_main_model = use_main_model
self.use_face_swap = use_face_swap
self.use_post_process = use_post_process
self.use_stylization = use_stylization
self.multiplier_style = multiplier_style
self.multiplier_human = multiplier_human
self.style_model_path = style_model_path
self.pos_prompt = pos_prompt
self.neg_prompt = neg_prompt
self.inpaint_img = inpaint_img
self.strength = strength
def __call__(self, input_img_dir=None, base_model_path=None,
lora_model_path=None, sub_path=None, revision=None):
base_model_path = snapshot_download(base_model_path, revision=revision)
if sub_path is not None and len(sub_path) > 0:
base_model_path = os.path.join(base_model_path, sub_path)
gen_results = main_model_inference(self.inpaint_img, self.strength, 768,
self.pos_prompt, self.neg_prompt,
self.style_model_path, self.multiplier_style, self.multiplier_human,
self.use_main_model, input_img_dir=input_img_dir,
lora_model_path=lora_model_path, base_model_path=base_model_path)
# select_high_quality_face PIL
selected_face = select_high_quality_face(input_img_dir)
# face_swap cv2
swap_results = face_swap_fn(self.use_face_swap, gen_results, selected_face)
# pose_process
final_gen_results_final = swap_results
outputs = final_gen_results_final
outputs_RGB = []
for out_tmp in outputs:
outputs_RGB.append(cv2.cvtColor(out_tmp, cv2.COLOR_BGR2RGB))
image_path = './lora_result.png'
if len(outputs) > 0:
result = concatenate_images(outputs)
cv2.imwrite(image_path, result)
return final_gen_results_final
def compress_image(input_path, target_size):
output_path = change_extension_to_jpg(input_path)
image = cv2.imread(input_path)
quality = 95
try:
while cv2.imencode('.jpg', image, [cv2.IMWRITE_JPEG_QUALITY, quality])[1].size > target_size:
quality -= 5
except:
import pdb;pdb.set_trace()
compressed_image = cv2.imencode('.jpg', image, [cv2.IMWRITE_JPEG_QUALITY, quality])[1].tostring()
with open(output_path, 'wb') as f:
f.write(compressed_image)
return output_path
def change_extension_to_jpg(image_path):
base_name = os.path.basename(image_path)
new_base_name = os.path.splitext(base_name)[0] + ".jpg"
directory = os.path.dirname(image_path)
new_image_path = os.path.join(directory, new_base_name)
return new_image_path
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import os
import re
from collections import defaultdict
from safetensors.torch import load_file
from modelscope.utils.import_utils import is_swift_available
def merge_lora(pipeline, lora_path, multiplier, from_safetensor=False, device='cpu', dtype=torch.float32):
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"
print ('----------')
print ('Lora Path: ', lora_path)
if from_safetensor:
state_dict = load_file(lora_path, device=device)
elif os.path.exists(os.path.join(lora_path, 'swift')):
if not is_swift_available():
raise ValueError(
'Please install swift by `pip install ms-swift` to use efficient_tuners.'
)
from swift import Swift
pipeline.unet = Swift.from_pretrained(pipeline.unet, os.path.join(lora_path, 'swift'))
return pipeline
else:
if os.path.exists(os.path.join(lora_path, 'pytorch_lora_weights.bin')):
checkpoint = torch.load(os.path.join(lora_path, 'pytorch_lora_weights.bin'), map_location=torch.device(device))
elif os.path.exists(os.path.join(lora_path, 'pytorch_lora_weights.safetensors')):
checkpoint= load_file(os.path.join(lora_path,'pytorch_lora_weights.safetensors'), device=device)
new_dict = dict()
for idx, key in enumerate(checkpoint):
new_key = re.sub(r'\.processor\.', '_', key)
new_key = re.sub(r'mid_block\.', 'mid_block_', new_key)
new_key = re.sub('_lora.up.', '.lora_up.', new_key)
new_key = re.sub('_lora.down.', '.lora_down.', new_key)
new_key = re.sub(r'\.(\d+)\.', '_\\1_', new_key)
new_key = re.sub('to_out', 'to_out_0', new_key)
new_key = 'lora_unet_' + new_key
new_dict[new_key] = checkpoint[key]
state_dict = new_dict
updates = defaultdict(dict)
for key, value in state_dict.items():
layer, elem = key.split('.', 1)
updates[layer][elem] = value
for layer, elems in updates.items():
if "text" in layer:
layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
curr_layer = pipeline.text_encoder
else:
layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
curr_layer = pipeline.unet
temp_name = layer_infos.pop(0)
while len(layer_infos) > -1:
try:
curr_layer = curr_layer.__getattr__(temp_name)
if len(layer_infos) > 0:
temp_name = layer_infos.pop(0)
elif len(layer_infos) == 0:
break
except Exception:
if len(layer_infos) == 0:
print('Error loading layer')
if len(temp_name) > 0:
temp_name += "_" + layer_infos.pop(0)
else:
temp_name = layer_infos.pop(0)
weight_up = elems['lora_up.weight'].to(dtype)
weight_down = elems['lora_down.weight'].to(dtype)
if 'alpha' in elems.keys():
alpha = elems['alpha'].item() / weight_up.shape[1]
else:
alpha = 1.0
curr_layer.weight.data = curr_layer.weight.data.to(device)
if len(weight_up.shape) == 4:
curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2),
weight_down.squeeze(3).squeeze(2)).unsqueeze(
2).unsqueeze(3)
else:
curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)
return pipeline
# coding=utf-8
# Copyright (c) Alibaba, Inc. and its affiliates.
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fine-tuning script for Stable Diffusion for text2image with support for LoRA."""
import argparse
import base64
import itertools
import json
import logging
import math
import os
import random
import shutil
from glob import glob
from pathlib import Path
import cv2
import datasets
import diffusers
import numpy as np
import onnxruntime
import PIL.Image
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import Tensor
from typing import List, Optional, Tuple, Union
import torchvision.transforms.functional as Ft
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from diffusers import (AutoencoderKL, DDPMScheduler, DiffusionPipeline,
DPMSolverMultistepScheduler,
StableDiffusionInpaintPipeline, UNet2DConditionModel)
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
from huggingface_hub import create_repo, upload_folder
import sys
parent_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if parent_path not in sys.path:
sys.path.append(parent_path)
from facechain.utils import snapshot_download
from modelscope.utils.import_utils import is_swift_available
from packaging import version
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from torch import multiprocessing
from transformers import CLIPTextModel, CLIPTokenizer
from facechain.inference import data_process_fn
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.14.0.dev0")
logger = get_logger(__name__, log_level="INFO")
class FaceCrop(torch.nn.Module):
@staticmethod
def get_params(img: Tensor) -> Tuple[int, int, int, int]:
_, h, w = Ft.get_dimensions(img)
if h != w:
raise ValueError(f"The input image is not square.")
ratio = torch.rand(size=(1,)).item() * 0.1 + 0.35
yc = torch.rand(size=(1,)).item() * 0.15 + 0.35
th = int(h / 1.15 * 0.35 / ratio)
tw = th
cx = int(0.5 * w)
cy = int(0.5 / 1.15 * h)
i = min(max(int(cy - yc * th), 0), h - th)
j = int(cx - 0.5 * tw)
return i, j, th, tw
def __init__(self):
super().__init__()
def forward(self, img):
i, j, h, w = self.get_params(img)
return Ft.crop(img, i, j, h, w)
def __repr__(self) -> str:
return f"{self.__class__.__name__}"
def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):
img_str = ""
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"![img_{i}](./image_{i}.png)\n"
yaml = f"""
---
license: creativeml-openrail-m
base_model: {base_model}
tags:
- stable-diffusion
- stable-diffusion-diffusers
- text-to-image
- diffusers
- lora
inference: true
---
"""
model_card = f"""
# LoRA text2image fine-tuning - {repo_id}
These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
{img_str}
"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
def softmax(x):
x -= np.max(x, axis=0, keepdims=True)
x = np.exp(x) / np.sum(np.exp(x), axis=0, keepdims=True)
return x
def get_rot(image):
model_dir = snapshot_download('Cherrytest/rot_bgr',
revision='v1.0.0')
model_path = os.path.join(model_dir, 'rot_bgr.onnx')
providers = ['CPUExecutionProvider']
if torch.cuda.is_available():
providers.insert(0, 'CUDAExecutionProvider')
# providers.insert(0, 'ROCMExecutionProvider')
ort_session = onnxruntime.InferenceSession(model_path, providers=providers)
img_cv = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
img_clone = img_cv.copy()
img_np = cv2.resize(img_cv, (224, 224))
img_np = img_np.astype(np.float32)
mean = np.array([103.53, 116.28, 123.675], dtype=np.float32).reshape((1, 1, 3))
norm = np.array([0.01742919, 0.017507, 0.01712475], dtype=np.float32).reshape((1, 1, 3))
img_np = (img_np - mean) * norm
img_tensor = torch.from_numpy(img_np)
img_tensor = img_tensor.unsqueeze(0)
img_nchw = img_tensor.permute(0, 3, 1, 2)
ort_inputs = {ort_session.get_inputs()[0].name: img_nchw.numpy()}
outputs = ort_session.run(None, ort_inputs)
logits = outputs[0].reshape((-1,))
probs = softmax(logits)
rot_idx = np.argmax(probs)
if rot_idx == 1:
print('rot 90')
img_clone = cv2.transpose(img_clone)
img_clone = np.flip(img_clone, 1)
return Image.fromarray(cv2.cvtColor(img_clone, cv2.COLOR_BGR2RGB))
elif rot_idx == 2:
print('rot 180')
img_clone = cv2.flip(img_clone, -1)
return Image.fromarray(cv2.cvtColor(img_clone, cv2.COLOR_BGR2RGB))
elif rot_idx == 3:
print('rot 270')
img_clone = cv2.transpose(img_clone)
img_clone = np.flip(img_clone, 0)
return Image.fromarray(cv2.cvtColor(img_clone, cv2.COLOR_BGR2RGB))
else:
return image
def prepare_dataset(instance_images: list, output_dataset_dir):
if not os.path.exists(output_dataset_dir):
os.makedirs(output_dataset_dir)
for i, temp_path in enumerate(instance_images):
image = PIL.Image.open(temp_path)
# image = PIL.Image.open(temp_path.name)
'''
w, h = image.size
max_size = max(w, h)
ratio = 1024 / max_size
new_w = round(w * ratio)
new_h = round(h * ratio)
'''
image = image.convert('RGB')
image = get_rot(image)
# image = image.resize((new_w, new_h))
# image = image.resize((new_w, new_h), PIL.Image.ANTIALIAS)
out_path = f'{output_dataset_dir}/{i:03d}.jpg'
image.save(out_path, format='JPEG', quality=100)
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier.",
)
parser.add_argument(
"--sub_path",
type=str,
default=None,
required=False,
help="The sub model path of the `pretrained_model_name_or_path`",
)
parser.add_argument(
"--dataset_name",
type=str,
default=None,
help=(
"The data images dir"
),
)
parser.add_argument(
"--dataset_config_name",
type=str,
default=None,
help="The config of the Dataset, leave as None if there's only one config.",
)
parser.add_argument(
"--train_data_dir",
type=str,
default=None,
help=(
"A folder containing the training data. Folder contents must follow the structure described in"
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
),
)
parser.add_argument(
"--output_dataset_name",
type=str,
default=None,
help=(
"The dataset dir after processing"
),
)
parser.add_argument(
"--image_column", type=str, default="image", help="The column of the dataset containing an image."
)
parser.add_argument(
"--caption_column",
type=str,
default="text",
help="The column of the dataset containing a caption or a list of captions.",
)
parser.add_argument(
"--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference."
)
parser.add_argument(
"--num_validation_images",
type=int,
default=1,
help="Number of images that should be generated during validation with `validation_prompt`.",
)
parser.add_argument(
"--validation_epochs",
type=int,
default=1,
help=(
"Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
" `args.validation_prompt` multiple times: `args.num_validation_images`."
),
)
parser.add_argument(
"--max_train_samples",
type=int,
default=None,
help=(
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
),
)
parser.add_argument(
"--output_dir",
type=str,
default="sd-model-finetuned-lora",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--cache_dir",
type=str,
default=None,
help="The directory where the downloaded models and datasets will be stored.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--resolution",
type=int,
default=512,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--center_crop",
default=False,
action="store_true",
help=(
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
" cropped. The images will be resized to the resolution first before cropping."
),
)
parser.add_argument(
"--random_flip",
action="store_true",
help="whether to randomly flip images horizontally",
)
parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
# lora args
parser.add_argument("--use_peft", action="store_true", help="Whether to use peft to support lora")
parser.add_argument("--use_swift", action="store_true", help="Whether to use swift to support lora")
parser.add_argument("--lora_r", type=int, default=4, help="Lora rank, only used if use_lora is True")
parser.add_argument("--lora_alpha", type=int, default=32, help="Lora alpha, only used if lora is True")
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Lora dropout, only used if use_lora is True")
parser.add_argument(
"--lora_bias",
type=str,
default="none",
help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora is True",
)
parser.add_argument(
"--lora_text_encoder_r",
type=int,
default=4,
help="Lora rank for text encoder, only used if `use_lora` and `train_text_encoder` are True",
)
parser.add_argument(
"--lora_text_encoder_alpha",
type=int,
default=32,
help="Lora alpha for text encoder, only used if `use_lora` and `train_text_encoder` are True",
)
parser.add_argument(
"--lora_text_encoder_dropout",
type=float,
default=0.0,
help="Lora dropout for text encoder, only used if `use_lora` and `train_text_encoder` are True",
)
parser.add_argument(
"--lora_text_encoder_bias",
type=str,
default="none",
help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True",
)
parser.add_argument(
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
)
parser.add_argument("--num_train_epochs", type=int, default=100)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
)
parser.add_argument(
"--allow_tf32",
action="store_true",
help=(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--hub_model_id",
type=str,
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument(
"--checkpointing_steps",
type=int,
default=500,
help=(
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
" training using `--resume_from_checkpoint`."
),
)
parser.add_argument(
"--checkpoints_total_limit",
type=int,
default=None,
help=(
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
" for more docs"
),
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help=(
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
),
)
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
# Sanity checks
if args.dataset_name is None and args.train_data_dir is None and args.output_dataset_name is None:
raise ValueError("Need either a dataset name or a training folder.")
return args
DATASET_NAME_MAPPING = {
"lambdalabs/pokemon-blip-captions": ("image", "text"),
}
def main():
args = parse_args()
logging_dir = os.path.join(args.output_dir, args.logging_dir)
shutil.rmtree(args.output_dir, ignore_errors=True)
os.makedirs(args.output_dir)
if args.dataset_name is not None:
# if dataset_name is None, then it's called from the gradio
# the data processing will be executed in the app.py to save the gpu memory.
print('All input images:', args.dataset_name)
args.dataset_name = [os.path.join(args.dataset_name, x) for x in os.listdir(args.dataset_name)]
shutil.rmtree(args.output_dataset_name, ignore_errors=True)
prepare_dataset(args.dataset_name, args.output_dataset_name)
## Our data process fn
data_process_fn(input_img_dir=args.output_dataset_name, use_data_process=True)
args.dataset_name = args.output_dataset_name + '_labeled'
accelerator_project_config = ProjectConfiguration(
total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
)
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
import wandb
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
# Handle the repository creation
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
## Download foundation Model
model_dir = snapshot_download(args.pretrained_model_name_or_path,
revision=args.revision,
user_agent={'invoked_by': 'trainer', 'third_party': 'facechain'})
if args.sub_path is not None and len(args.sub_path) > 0:
model_dir = os.path.join(model_dir, args.sub_path)
# Load scheduler, tokenizer and models.
noise_scheduler = DDPMScheduler.from_pretrained(model_dir, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(
model_dir, subfolder="tokenizer"
)
text_encoder = CLIPTextModel.from_pretrained(
model_dir, subfolder="text_encoder"
)
vae = AutoencoderKL.from_pretrained(model_dir, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(
model_dir, subfolder="unet"
)
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
if args.use_peft:
from peft import LoraConfig, LoraModel, get_peft_model_state_dict, set_peft_model_state_dict
UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"]
TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"]
config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
target_modules=UNET_TARGET_MODULES,
lora_dropout=args.lora_dropout,
bias=args.lora_bias,
)
unet = LoraModel(config, unet)
vae.requires_grad_(False)
if args.train_text_encoder:
config = LoraConfig(
r=args.lora_text_encoder_r,
lora_alpha=args.lora_text_encoder_alpha,
target_modules=TEXT_ENCODER_TARGET_MODULES,
lora_dropout=args.lora_text_encoder_dropout,
bias=args.lora_text_encoder_bias,
)
text_encoder = LoraModel(config, text_encoder)
elif args.use_swift:
if not is_swift_available():
raise ValueError(
'Please install swift by `pip install ms-swift` to use efficient_tuners.'
)
from swift import LoRAConfig, Swift
UNET_TARGET_MODULES = ['to_q', 'to_k', 'to_v', 'query', 'key', 'value', 'to_out.0']
TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"]
# freeze parameters of models to save more memory
unet.requires_grad_(False)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
lora_config = LoRAConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
bias=args.lora_bias,
target_modules=UNET_TARGET_MODULES)
unet = Swift.prepare_model(unet, lora_config)
if args.train_text_encoder:
lora_config = LoRAConfig(
r=args.lora_text_encoder_r,
lora_alpha=args.lora_text_encoder_alpha,
target_modules=TEXT_ENCODER_TARGET_MODULES,
lora_dropout=args.lora_text_encoder_dropout,
bias=args.lora_text_encoder_bias,
)
text_encoder = LoraModel(config, text_encoder)
text_encoder = Swift.prepare_model(text_encoder, lora_config)
else:
# freeze parameters of models to save more memory
unet.requires_grad_(False)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
# now we will add new LoRA weights to the attention layers
# It's important to realize here how many attention weights will be added and of which sizes
# The sizes of the attention layers consist only of two different variables:
# 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
# 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
# Let's first see how many attention processors we will have to set.
# For Stable Diffusion, it should be equal to:
# - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
# - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
# - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
# => 32 layers
# Set correct lora layers
lora_attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=args.lora_r)
unet.set_attn_processor(lora_attn_procs)
# Move unet, vae and text_encoder to device and cast to weight_dtype
vae.to(accelerator.device, dtype=weight_dtype)
if not args.train_text_encoder:
text_encoder.to(accelerator.device, dtype=weight_dtype)
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
import xformers
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
# Initialize the optimizer
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
)
optimizer_cls = bnb.optim.AdamW8bit
else:
optimizer_cls = torch.optim.AdamW
if args.use_peft or args.use_swift:
# Optimizer creation
params_to_optimize = (
itertools.chain(unet.parameters(), text_encoder.parameters())
if args.train_text_encoder
else unet.parameters()
)
optimizer = optimizer_cls(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
else:
lora_layers = AttnProcsLayers(unet.attn_processors)
optimizer = optimizer_cls(
lora_layers.parameters(),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
# Get the datasets: you can either provide your own training and evaluation files (see below)
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
# download the dataset.
dataset = load_dataset("imagefolder", data_dir=args.dataset_name)
# if args.dataset_name is not None:
# # Downloading and loading a dataset from the hub.
# dataset = load_dataset(
# args.dataset_name,
# args.dataset_config_name,
# cache_dir=args.cache_dir,
# num_proc=8,
# )
# else:
# # This branch will not be called
# data_files = {}
# if args.train_data_dir is not None:
# data_files["train"] = os.path.join(args.train_data_dir, "**")
# dataset = load_dataset(
# "imagefolder",
# data_files=data_files,
# cache_dir=args.cache_dir,
# )
# # See more about loading custom images at
# # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
# Preprocessing the datasets.
# We need to tokenize inputs and targets.
column_names = dataset["train"].column_names
# 6. Get the column names for input/target.
dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
if args.image_column is None:
image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
else:
image_column = args.image_column
if image_column not in column_names:
raise ValueError(
f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
)
if args.caption_column is None:
caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
else:
caption_column = args.caption_column
if caption_column not in column_names:
raise ValueError(
f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
)
# Preprocessing the datasets.
# We need to tokenize input captions and transform the images.
def tokenize_captions(examples, is_train=True):
captions = []
for caption in examples[caption_column]:
if isinstance(caption, str):
captions.append(caption)
elif isinstance(caption, (list, np.ndarray)):
# take a random caption if there are multiple
captions.append(random.choice(caption) if is_train else caption[0])
else:
raise ValueError(
f"Caption column `{caption_column}` should contain either strings or lists of strings."
)
inputs = tokenizer(
captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
)
return inputs.input_ids
# Preprocessing the datasets.
train_transforms = transforms.Compose(
[
#transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
#transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
FaceCrop(),
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def preprocess_train(examples):
images = [image.convert("RGB") for image in examples[image_column]]
examples["pixel_values"] = [train_transforms(image) for image in images]
examples["input_ids"] = tokenize_captions(examples)
return examples
with accelerator.main_process_first():
if args.max_train_samples is not None:
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
# Set the training transforms
train_dataset = dataset["train"].with_transform(preprocess_train)
def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = torch.stack([example["input_ids"] for example in examples])
return {"pixel_values": pixel_values, "input_ids": input_ids}
# DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
shuffle=True,
collate_fn=collate_fn,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
)
# Prepare everything with our `accelerator`.
if args.use_peft or args.use_swift:
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
)
else:
lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
lora_layers, optimizer, train_dataloader, lr_scheduler
)
unet = unet.cuda()
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
accelerator.init_trackers("text2image-fine-tune", config=vars(args))
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint == 'fromfacecommon':
weight_model_dir = snapshot_download('damo/face_frombase_c4',
revision='v1.0.0',
user_agent={'invoked_by': 'trainer', 'third_party': 'facechain'})
path = os.path.join(weight_model_dir, 'face_frombase_c4.bin')
elif args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the most recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
accelerator.print(
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
else:
if args.resume_from_checkpoint == 'fromfacecommon':
accelerator.print(f"Resuming from checkpoint {path}")
unet_state_dict = torch.load(path, map_location='cpu')
accelerator._models[-1].load_state_dict(unet_state_dict)
global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
resume_global_step = global_step * args.gradient_accumulation_steps
first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
if args.train_text_encoder:
text_encoder.train()
train_loss = 0.0
for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue
with accelerator.accumulate(unet):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
# Predict the noise residual and compute loss
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
train_loss += avg_loss.item() / args.gradient_accumulation_steps
# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
if args.use_peft or args.use_swift:
params_to_clip = (
itertools.chain(unet.parameters(), text_encoder.parameters())
if args.train_text_encoder
else unet.parameters()
)
else:
params_to_clip = lora_layers.parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
accelerator.log({"train_loss": train_loss}, step=global_step)
train_loss = 0.0
if global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if accelerator.is_main_process:
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = DiffusionPipeline.from_pretrained(
model_dir,
unet=accelerator.unwrap_model(unet),
text_encoder=accelerator.unwrap_model(text_encoder),
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
images = []
for _ in range(args.num_validation_images):
images.append(
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
)
if accelerator.is_main_process:
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
"validation": [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
for i, image in enumerate(images)
]
}
)
del pipeline
torch.cuda.empty_cache()
# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
if args.use_peft:
lora_config = {}
unwarpped_unet = accelerator.unwrap_model(unet)
state_dict = get_peft_model_state_dict(unwarpped_unet, state_dict=accelerator.get_state_dict(unet))
lora_config["peft_config"] = unwarpped_unet.get_peft_config_as_dict(inference=True)
if args.train_text_encoder:
unwarpped_text_encoder = accelerator.unwrap_model(text_encoder)
text_encoder_state_dict = get_peft_model_state_dict(
unwarpped_text_encoder, state_dict=accelerator.get_state_dict(text_encoder)
)
text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()}
state_dict.update(text_encoder_state_dict)
lora_config["text_encoder_peft_config"] = unwarpped_text_encoder.get_peft_config_as_dict(
inference=True
)
accelerator.save(state_dict, os.path.join(args.output_dir, f"{global_step}_lora.pt"))
with open(os.path.join(args.output_dir, f"{global_step}_lora_config.json"), "w") as f:
json.dump(lora_config, f)
elif args.use_swift:
unwarpped_unet = accelerator.unwrap_model(unet)
unwarpped_unet.save_pretrained(os.path.join(args.output_dir, 'swift'))
if args.train_text_encoder:
unwarpped_text_encoder = accelerator.unwrap_model(text_encoder)
unwarpped_text_encoder.save_pretrained(os.path.join(args.output_dir, 'text_encoder'))
else:
unet = unet.to(torch.float32)
unet.save_attn_procs(args.output_dir, safe_serialization=False)
if args.push_to_hub:
save_model_card(
repo_id,
images=images,
base_model=model_dir,
dataset_name=args.dataset_name,
repo_folder=args.output_dir,
)
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
# Final inference
# Load previous pipeline
pipeline = DiffusionPipeline.from_pretrained(
model_dir, torch_dtype=weight_dtype
)
if args.use_peft:
def load_and_set_lora_ckpt(pipe, ckpt_dir, global_step, device, dtype):
with open(os.path.join(args.output_dir, f"{global_step}_lora_config.json"), "r") as f:
lora_config = json.load(f)
print(lora_config)
checkpoint = os.path.join(args.output_dir, f"{global_step}_lora.pt")
lora_checkpoint_sd = torch.load(checkpoint)
unet_lora_ds = {k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k}
text_encoder_lora_ds = {
k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k
}
unet_config = LoraConfig(**lora_config["peft_config"])
# TODO: To be fixed !
pipe.unet = LoraModel(unet_config, pipe.unet)
set_peft_model_state_dict(pipe.unet, unet_lora_ds)
if "text_encoder_peft_config" in lora_config:
text_encoder_config = LoraConfig(**lora_config["text_encoder_peft_config"])
pipe.text_encoder = LoraModel(text_encoder_config, pipe.text_encoder)
set_peft_model_state_dict(pipe.text_encoder, text_encoder_lora_ds)
if dtype in (torch.float16, torch.bfloat16):
pipe.unet.half()
pipe.text_encoder.half()
pipe.to(device)
return pipe
pipeline = load_and_set_lora_ckpt(pipeline, args.output_dir, global_step, accelerator.device, weight_dtype)
elif args.use_swift:
if not is_swift_available():
raise ValueError(
'Please install swift by `pip install ms-swift` to use efficient_tuners.'
)
from swift import Swift
pipeline = pipeline.to(accelerator.device)
pipeline.unet = Swift.from_pretrained(pipeline.unet, os.path.join(args.output_dir, 'swift'))
if args.train_text_encoder:
pipeline.text_encoder = Swift.from_pretrained(pipeline.text_encoder, os.path.join(args.output_dir, 'text_encoder'))
else:
pipeline = pipeline.to(accelerator.device)
# load attention processors
pipeline.unet.load_attn_procs(args.output_dir)
# run inference
if args.seed is not None:
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
else:
generator = None
images = []
accelerator.end_training()
if __name__ == "__main__":
multiprocessing.set_start_method('spawn', force=True)
main()
# coding=utf-8
# Copyright (c) Alibaba, Inc. and its affiliates.
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fine-tuning script for Stable Diffusion for text2image with support for LoRA."""
import argparse
import base64
import itertools
import json
import logging
import math
import os
import random
import shutil
from glob import glob
from pathlib import Path
import cv2
import datasets
import diffusers
import numpy as np
import onnxruntime
import PIL.Image
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import Tensor
from typing import List, Optional, Tuple, Union
import torchvision.transforms.functional as Ft
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from diffusers import (AutoencoderKL, DDPMScheduler, DiffusionPipeline,
DPMSolverMultistepScheduler,
StableDiffusionInpaintPipeline, UNet2DConditionModel, StableDiffusionXLPipeline)
from torchvision.transforms.functional import crop
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
from huggingface_hub import create_repo, upload_folder
import sys
parent_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if parent_path not in sys.path:
sys.path.append(parent_path)
from facechain.utils import snapshot_download
from modelscope.utils.import_utils import is_swift_available
from packaging import version
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from torch import multiprocessing
from transformers import CLIPTextModel, CLIPTokenizer, AutoTokenizer, CLIPTextModelWithProjection
from facechain.inference import data_process_fn
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.14.0.dev0")
logger = get_logger(__name__, log_level="INFO")
class FaceCrop(torch.nn.Module):
@staticmethod
def get_params(img: Tensor) -> Tuple[int, int, int, int]:
_, h, w = Ft.get_dimensions(img)
if h != w:
raise ValueError(f"The input image is not square.")
ratio = torch.rand(size=(1,)).item() * 0.1 + 0.35
yc = torch.rand(size=(1,)).item() * 0.15 + 0.35
th = int(h / 1.15 * 0.35 / ratio)
tw = th
cx = int(0.5 * w)
cy = int(0.5 / 1.15 * h)
i = min(max(int(cy - yc * th), 0), h - th)
j = int(cx - 0.5 * tw)
return i, j, th, tw
def __init__(self):
super().__init__()
def forward(self, img):
i, j, h, w = self.get_params(img)
return Ft.crop(img, i, j, h, w)
def __repr__(self) -> str:
return f"{self.__class__.__name__}"
def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):
img_str = ""
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"![img_{i}](./image_{i}.png)\n"
yaml = f"""
---
license: creativeml-openrail-m
base_model: {base_model}
tags:
- stable-diffusion
- stable-diffusion-diffusers
- text-to-image
- diffusers
- lora
inference: true
---
"""
model_card = f"""
# LoRA text2image fine-tuning - {repo_id}
These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
{img_str}
"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
def softmax(x):
x -= np.max(x, axis=0, keepdims=True)
x = np.exp(x) / np.sum(np.exp(x), axis=0, keepdims=True)
return x
def get_rot(image):
model_dir = snapshot_download('Cherrytest/rot_bgr',
revision='v1.0.0')
model_path = os.path.join(model_dir, 'rot_bgr.onnx')
providers = ['CPUExecutionProvider']
if torch.cuda.is_available():
providers.insert(0, 'CUDAExecutionProvider')
# providers.insert(0, 'ROCMExecutionProvider')
ort_session = onnxruntime.InferenceSession(model_path, providers=providers)
img_cv = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
img_clone = img_cv.copy()
img_np = cv2.resize(img_cv, (224, 224))
img_np = img_np.astype(np.float32)
mean = np.array([103.53, 116.28, 123.675], dtype=np.float32).reshape((1, 1, 3))
norm = np.array([0.01742919, 0.017507, 0.01712475], dtype=np.float32).reshape((1, 1, 3))
img_np = (img_np - mean) * norm
img_tensor = torch.from_numpy(img_np)
img_tensor = img_tensor.unsqueeze(0)
img_nchw = img_tensor.permute(0, 3, 1, 2)
ort_inputs = {ort_session.get_inputs()[0].name: img_nchw.numpy()}
outputs = ort_session.run(None, ort_inputs)
logits = outputs[0].reshape((-1,))
probs = softmax(logits)
rot_idx = np.argmax(probs)
if rot_idx == 1:
print('rot 90')
img_clone = cv2.transpose(img_clone)
img_clone = np.flip(img_clone, 1)
return Image.fromarray(cv2.cvtColor(img_clone, cv2.COLOR_BGR2RGB))
elif rot_idx == 2:
print('rot 180')
img_clone = cv2.flip(img_clone, -1)
return Image.fromarray(cv2.cvtColor(img_clone, cv2.COLOR_BGR2RGB))
elif rot_idx == 3:
print('rot 270')
img_clone = cv2.transpose(img_clone)
img_clone = np.flip(img_clone, 0)
return Image.fromarray(cv2.cvtColor(img_clone, cv2.COLOR_BGR2RGB))
else:
return image
def prepare_dataset(instance_images: list, output_dataset_dir):
if not os.path.exists(output_dataset_dir):
os.makedirs(output_dataset_dir)
for i, temp_path in enumerate(instance_images):
image = PIL.Image.open(temp_path)
# image = PIL.Image.open(temp_path.name)
'''
w, h = image.size
max_size = max(w, h)
ratio = 1024 / max_size
new_w = round(w * ratio)
new_h = round(h * ratio)
'''
image = image.convert('RGB')
image = get_rot(image)
# image = image.resize((new_w, new_h))
# image = image.resize((new_w, new_h), PIL.Image.ANTIALIAS)
out_path = f'{output_dataset_dir}/{i:03d}.jpg'
image.save(out_path, format='JPEG', quality=100)
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier.",
)
parser.add_argument(
"--sub_path",
type=str,
default=None,
required=False,
help="The sub model path of the `pretrained_model_name_or_path`",
)
parser.add_argument(
"--dataset_name",
type=str,
default=None,
help=(
"The data images dir"
),
)
parser.add_argument(
"--dataset_config_name",
type=str,
default=None,
help="The config of the Dataset, leave as None if there's only one config.",
)
parser.add_argument(
"--train_data_dir",
type=str,
default=None,
help=(
"A folder containing the training data. Folder contents must follow the structure described in"
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
),
)
parser.add_argument(
"--output_dataset_name",
type=str,
default=None,
help=(
"The dataset dir after processing"
),
)
parser.add_argument(
"--image_column", type=str, default="image", help="The column of the dataset containing an image."
)
parser.add_argument(
"--caption_column",
type=str,
default="text",
help="The column of the dataset containing a caption or a list of captions.",
)
parser.add_argument(
"--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference."
)
parser.add_argument(
"--num_validation_images",
type=int,
default=1,
help="Number of images that should be generated during validation with `validation_prompt`.",
)
parser.add_argument(
"--validation_epochs",
type=int,
default=1,
help=(
"Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
" `args.validation_prompt` multiple times: `args.num_validation_images`."
),
)
parser.add_argument(
"--max_train_samples",
type=int,
default=None,
help=(
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
),
)
parser.add_argument(
"--output_dir",
type=str,
default="sd-model-finetuned-lora",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--cache_dir",
type=str,
default=None,
help="The directory where the downloaded models and datasets will be stored.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--resolution",
type=int,
default=512,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--center_crop",
default=False,
action="store_true",
help=(
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
" cropped. The images will be resized to the resolution first before cropping."
),
)
parser.add_argument(
"--random_flip",
action="store_true",
help="whether to randomly flip images horizontally",
)
parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
# lora args
parser.add_argument("--use_peft", action="store_true", help="Whether to use peft to support lora")
parser.add_argument("--use_swift", action="store_true", help="Whether to use swift to support lora")
parser.add_argument("--lora_r", type=int, default=4, help="Lora rank, only used if use_lora is True")
parser.add_argument("--lora_alpha", type=int, default=32, help="Lora alpha, only used if lora is True")
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Lora dropout, only used if use_lora is True")
parser.add_argument(
"--lora_bias",
type=str,
default="none",
help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora is True",
)
parser.add_argument(
"--lora_text_encoder_r",
type=int,
default=4,
help="Lora rank for text encoder, only used if `use_lora` and `train_text_encoder` are True",
)
parser.add_argument(
"--lora_text_encoder_alpha",
type=int,
default=32,
help="Lora alpha for text encoder, only used if `use_lora` and `train_text_encoder` are True",
)
parser.add_argument(
"--lora_text_encoder_dropout",
type=float,
default=0.0,
help="Lora dropout for text encoder, only used if `use_lora` and `train_text_encoder` are True",
)
parser.add_argument(
"--lora_text_encoder_bias",
type=str,
default="none",
help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True",
)
parser.add_argument(
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
)
parser.add_argument("--num_train_epochs", type=int, default=100)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
)
parser.add_argument(
"--allow_tf32",
action="store_true",
help=(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--hub_model_id",
type=str,
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument(
"--checkpointing_steps",
type=int,
default=500,
help=(
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
" training using `--resume_from_checkpoint`."
),
)
parser.add_argument(
"--checkpoints_total_limit",
type=int,
default=None,
help=(
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
" for more docs"
),
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help=(
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
),
)
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
# Sanity checks
if args.dataset_name is None and args.train_data_dir is None and args.output_dataset_name is None:
raise ValueError("Need either a dataset name or a training folder.")
return args
DATASET_NAME_MAPPING = {
"lambdalabs/pokemon-blip-captions": ("image", "text"),
}
def main():
args = parse_args()
logging_dir = os.path.join(args.output_dir, args.logging_dir)
shutil.rmtree(args.output_dir, ignore_errors=True)
os.makedirs(args.output_dir)
if args.dataset_name is not None:
# if dataset_name is None, then it's called from the gradio
# the data processing will be executed in the app.py to save the gpu memory.
print('All input images:', args.dataset_name)
args.dataset_name = [os.path.join(args.dataset_name, x) for x in os.listdir(args.dataset_name)]
shutil.rmtree(args.output_dataset_name, ignore_errors=True)
prepare_dataset(args.dataset_name, args.output_dataset_name)
## Our data process fn
data_process_fn(input_img_dir=args.output_dataset_name, use_data_process=True)
args.dataset_name = args.output_dataset_name + '_labeled'
accelerator_project_config = ProjectConfiguration(
total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
)
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
import wandb
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
# Handle the repository creation
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
## Download foundation Model
model_dir = snapshot_download(args.pretrained_model_name_or_path,
revision=args.revision,
user_agent={'invoked_by': 'trainer', 'third_party': 'facechain'})
# if args.sub_path is not None and len(args.sub_path) > 0:
# model_dir = os.path.join(model_dir, args.sub_path)
# Load scheduler, tokenizer and models.
print(f'>>> Loading model from model_dir: {model_dir}')
noise_scheduler = DDPMScheduler.from_pretrained(model_dir, subfolder="scheduler")
tokenizer_one = AutoTokenizer.from_pretrained(
model_dir, subfolder="tokenizer", use_fast=False
)
tokenizer_two = AutoTokenizer.from_pretrained(
model_dir,
subfolder='tokenizer_2',
use_fast=False)
text_encoder_one = CLIPTextModel.from_pretrained(
model_dir, subfolder="text_encoder"
)
text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
model_dir, subfolder='text_encoder_2'
)
vae = AutoencoderKL.from_pretrained(model_dir, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(
model_dir, subfolder="unet"
)
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
if args.use_peft:
from peft import LoraConfig, LoraModel, get_peft_model_state_dict, set_peft_model_state_dict
UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"]
TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"]
config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
target_modules=UNET_TARGET_MODULES,
lora_dropout=args.lora_dropout,
bias=args.lora_bias,
)
unet = LoraModel(model=unet, config=config, adapter_name='default')
vae.requires_grad_(False)
# TODO: to be implemented
# if args.train_text_encoder:
# config = LoraConfig(
# r=args.lora_text_encoder_r,
# lora_alpha=args.lora_text_encoder_alpha,
# target_modules=TEXT_ENCODER_TARGET_MODULES,
# lora_dropout=args.lora_text_encoder_dropout,
# bias=args.lora_text_encoder_bias,
# )
# text_encoder = LoraModel(config, text_encoder)
elif args.use_swift:
if not is_swift_available():
raise ValueError(
'Please install swift by `pip install ms-swift` to use efficient_tuners.'
)
from swift import LoRAConfig, Swift
UNET_TARGET_MODULES = ['to_q', 'to_k', 'to_v', 'query', 'key', 'value', 'to_out.0']
TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"]
# freeze parameters of models to save more memory
unet.requires_grad_(False)
vae.requires_grad_(False)
text_encoder_one.requires_grad_(False)
text_encoder_two.requires_grad_(False)
lora_config = LoRAConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
bias=args.lora_bias,
target_modules=UNET_TARGET_MODULES)
unet = Swift.prepare_model(unet, lora_config)
if args.train_text_encoder:
lora_config = LoRAConfig(
r=args.lora_text_encoder_r,
lora_alpha=args.lora_text_encoder_alpha,
target_modules=TEXT_ENCODER_TARGET_MODULES,
lora_dropout=args.lora_text_encoder_dropout,
bias=args.lora_text_encoder_bias,
)
text_encoder_one = LoraModel(config, text_encoder_one)
text_encoder_one = Swift.prepare_model(text_encoder_one, lora_config)
text_encoder_two = LoraModel(config, text_encoder_two)
text_encoder_two = Swift.prepare_model(text_encoder_two, lora_config)
else:
# freeze parameters of models to save more memory
unet.requires_grad_(False)
vae.requires_grad_(False)
text_encoder_one.requires_grad_(False)
text_encoder_two.requires_grad_(False)
# now we will add new LoRA weights to the attention layers
# It's important to realize here how many attention weights will be added and of which sizes
# The sizes of the attention layers consist only of two different variables:
# 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
# 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
# Let's first see how many attention processors we will have to set.
# For Stable Diffusion, it should be equal to:
# - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
# - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
# - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
# => 32 layers
# Set correct lora layers
lora_attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=args.lora_r)
unet.set_attn_processor(lora_attn_procs)
# Move unet, vae and text_encoder to device and cast to weight_dtype
vae.to(accelerator.device, dtype=weight_dtype)
if not args.train_text_encoder:
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
text_encoder_two.to(accelerator.device, dtype=weight_dtype)
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
import xformers
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
# Initialize the optimizer
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
)
optimizer_cls = bnb.optim.AdamW8bit
else:
optimizer_cls = torch.optim.AdamW
if args.use_peft or args.use_swift:
# Optimizer creation
params_to_optimize = (
itertools.chain(unet.parameters(), text_encoder_one.parameters(), text_encoder_two.parameters())
if args.train_text_encoder
else unet.parameters()
)
optimizer = optimizer_cls(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
else:
lora_layers = AttnProcsLayers(unet.attn_processors)
optimizer = optimizer_cls(
lora_layers.parameters(),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
# Get the datasets: you can either provide your own training and evaluation files (see below)
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
# download the dataset.
dataset = load_dataset("imagefolder", data_dir=args.dataset_name)
# if args.dataset_name is not None:
# # Downloading and loading a dataset from the hub.
# dataset = load_dataset(
# args.dataset_name,
# args.dataset_config_name,
# cache_dir=args.cache_dir,
# num_proc=8,
# )
# else:
# # This branch will not be called
# data_files = {}
# if args.train_data_dir is not None:
# data_files["train"] = os.path.join(args.train_data_dir, "**")
# dataset = load_dataset(
# "imagefolder",
# data_files=data_files,
# cache_dir=args.cache_dir,
# )
# # See more about loading custom images at
# # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
# Preprocessing the datasets.
# We need to tokenize inputs and targets.
column_names = dataset["train"].column_names
# 6. Get the column names for input/target.
dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
if args.image_column is None:
image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
else:
image_column = args.image_column
if image_column not in column_names:
raise ValueError(
f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
)
if args.caption_column is None:
caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
else:
caption_column = args.caption_column
if caption_column not in column_names:
raise ValueError(
f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
)
# Preprocessing the datasets.
# We need to tokenize input captions and transform the images.
def tokenize_prompt(tokenizer, prompt):
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
return text_input_ids
def tokenize_captions(examples, is_train=True):
captions = []
for caption in examples[caption_column]:
if isinstance(caption, str):
captions.append(caption)
elif isinstance(caption, (list, np.ndarray)):
# take a random caption if there are multiple
captions.append(random.choice(caption) if is_train else caption[0])
else:
raise ValueError(
f"Caption column `{caption_column}` should contain either strings or lists of strings."
)
tokens_one = tokenize_prompt(tokenizer_one, captions)
tokens_two = tokenize_prompt(tokenizer_two, captions)
return tokens_one, tokens_two
def compute_time_ids(original_size, crops_coords_top_left):
target_size = (args.resolution, args.resolution)
add_time_ids = list(original_size + crops_coords_top_left
+ target_size)
add_time_ids = torch.tensor([add_time_ids])
add_time_ids = add_time_ids.to(device, dtype=weight_dtype)
return add_time_ids
# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
prompt_embeds_list = []
for i, text_encoder in enumerate(text_encoders):
if tokenizers is not None:
tokenizer = tokenizers[i]
text_input_ids = tokenize_prompt(tokenizer, prompt)
else:
assert text_input_ids_list is not None
text_input_ids = text_input_ids_list[i]
prompt_embeds = text_encoder(
text_input_ids.to(text_encoder.device),
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
prompt_embeds_list.append(prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
return prompt_embeds, pooled_prompt_embeds
# Preprocessing the datasets.
train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
train_flip = transforms.RandomHorizontalFlip(p=1.0)
train_transforms = transforms.Compose(
[
#transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
#transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
FaceCrop(),
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def preprocess_train(examples):
images = [image.convert("RGB") for image in examples[image_column]]
# image aug
original_sizes = []
all_images = []
crop_top_lefts = []
for image in images:
original_sizes.append((image.height, image.width))
image = train_resize(image)
if args.center_crop:
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
image = train_crop(image)
else:
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
image = crop(image, y1, x1, h, w)
if args.random_flip and random.random() < 0.5:
# flip
x1 = image.width - x1
image = train_flip(image)
crop_top_left = (y1, x1)
crop_top_lefts.append(crop_top_left)
image = train_transforms(image)
all_images.append(image)
examples["original_sizes"] = original_sizes
examples["crop_top_lefts"] = crop_top_lefts
examples["pixel_values"] = all_images
tokens_one, tokens_two = tokenize_captions(examples)
examples["input_ids_one"] = tokens_one
examples["input_ids_two"] = tokens_two
return examples
with accelerator.main_process_first():
if args.max_train_samples is not None:
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
# Set the training transforms
train_dataset = dataset["train"].with_transform(preprocess_train)
def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
original_sizes = [example["original_sizes"] for example in examples]
crop_top_lefts = [example["crop_top_lefts"] for example in examples]
input_ids_one = torch.stack([example["input_ids_one"] for example in examples])
input_ids_two = torch.stack([example["input_ids_two"] for example in examples])
return {
"pixel_values": pixel_values,
"input_ids_one": input_ids_one,
"input_ids_two": input_ids_two,
"original_sizes": original_sizes,
"crop_top_lefts": crop_top_lefts,
}
# DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
shuffle=True,
collate_fn=collate_fn,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
)
# Prepare everything with our `accelerator`.
if args.use_peft or args.use_swift:
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
)
else:
lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
lora_layers, optimizer, train_dataloader, lr_scheduler
)
unet = unet.cuda()
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
accelerator.init_trackers("text2image-fine-tune", config=vars(args))
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint == 'fromfacecommon':
weight_model_dir = snapshot_download('damo/face_frombase_c4',
revision='v1.0.0',
user_agent={'invoked_by': 'trainer', 'third_party': 'facechain'})
path = os.path.join(weight_model_dir, 'face_frombase_c4.bin')
elif args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the most recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
accelerator.print(
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
else:
if args.resume_from_checkpoint == 'fromfacecommon':
accelerator.print(f"Resuming from checkpoint {path}")
unet_state_dict = torch.load(path, map_location='cpu')
accelerator._models[-1].load_state_dict(unet_state_dict)
global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
resume_global_step = global_step * args.gradient_accumulation_steps
first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
if args.train_text_encoder:
text_encoder_one.train()
text_encoder_two.train()
train_loss = 0.0
for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue
with accelerator.accumulate(unet):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# time ids
# time ids
def compute_time_ids(original_size, crops_coords_top_left):
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
target_size = (args.resolution, args.resolution)
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids])
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
return add_time_ids
add_time_ids = torch.cat(
[compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
)
# Predict the noise residual
unet_added_conditions = {"time_ids": add_time_ids}
prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two],
tokenizers=None,
prompt=None,
text_input_ids_list=[batch["input_ids_one"], batch["input_ids_two"]],
)
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
model_pred = unet(
noisy_latents, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions
).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
# Predict the noise residual and compute loss
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
train_loss += avg_loss.item() / args.gradient_accumulation_steps
# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
if args.use_peft or args.use_swift:
params_to_clip = (
itertools.chain(unet.parameters(), text_encoder.parameters())
if args.train_text_encoder
else unet.parameters()
)
else:
params_to_clip = lora_layers.parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
accelerator.log({"train_loss": train_loss}, step=global_step)
train_loss = 0.0
if global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if accelerator.is_main_process:
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = DiffusionPipeline.from_pretrained(
model_dir,
unet=accelerator.unwrap_model(unet),
text_encoder=accelerator.unwrap_model(text_encoder_one),
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
images = []
for _ in range(args.num_validation_images):
images.append(
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
)
if accelerator.is_main_process:
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
"validation": [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
for i, image in enumerate(images)
]
}
)
del pipeline
torch.cuda.empty_cache()
# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
if args.use_peft:
lora_config = {}
unwarpped_unet = accelerator.unwrap_model(unet)
state_dict = get_peft_model_state_dict(unwarpped_unet, state_dict=accelerator.get_state_dict(unet))
lora_config["peft_config"] = unwarpped_unet.get_peft_config_as_dict(inference=True)
if args.train_text_encoder:
unwarpped_text_encoder = accelerator.unwrap_model(text_encoder)
text_encoder_state_dict = get_peft_model_state_dict(
unwarpped_text_encoder, state_dict=accelerator.get_state_dict(text_encoder)
)
text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()}
state_dict.update(text_encoder_state_dict)
lora_config["text_encoder_peft_config"] = unwarpped_text_encoder.get_peft_config_as_dict(
inference=True
)
accelerator.save(state_dict, os.path.join(args.output_dir, f"{global_step}_lora.pt"))
with open(os.path.join(args.output_dir, f"{global_step}_lora_config.json"), "w") as f:
json.dump(lora_config, f)
elif args.use_swift:
unwarpped_unet = accelerator.unwrap_model(unet)
unwarpped_unet.save_pretrained(os.path.join(args.output_dir, 'swift'))
if args.train_text_encoder:
unwarpped_text_encoder_one = accelerator.unwrap_model(text_encoder_one)
unwarpped_text_encoder.save_pretrained(os.path.join(args.output_dir, 'text_encoder_one'))
unwarpped_text_encoder_two = accelerator.unwrap_model(text_encoder_two)
unwarpped_text_encoder.save_pretrained(os.path.join(args.output_dir, 'text_encoder_two'))
else:
unet = unet.to(torch.float32)
unet.save_attn_procs(args.output_dir, safe_serialization=False)
if args.push_to_hub:
save_model_card(
repo_id,
images=images,
base_model=model_dir,
dataset_name=args.dataset_name,
repo_folder=args.output_dir,
)
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
# Final inference
# Load previous pipeline
# pipeline = DiffusionPipeline.from_pretrained(
# model_dir, torch_dtype=weight_dtype
# )
# if args.use_peft:
# def load_and_set_lora_ckpt(pipe, ckpt_dir, global_step, device, dtype):
# with open(os.path.join(args.output_dir, f"{global_step}_lora_config.json"), "r") as f:
# lora_config = json.load(f)
# print(lora_config)
# checkpoint = os.path.join(args.output_dir, f"{global_step}_lora.pt")
# lora_checkpoint_sd = torch.load(checkpoint)
# unet_lora_ds = {k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k}
# text_encoder_lora_ds = {
# k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k
# }
# unet_config = LoraConfig(**lora_config["peft_config"])
# pipe.unet = LoraModel(unet_config, pipe.unet)
# set_peft_model_state_dict(pipe.unet, unet_lora_ds)
# if "text_encoder_peft_config" in lora_config:
# text_encoder_config = LoraConfig(**lora_config["text_encoder_peft_config"])
# pipe.text_encoder = LoraModel(text_encoder_config, pipe.text_encoder)
# set_peft_model_state_dict(pipe.text_encoder, text_encoder_lora_ds)
# if dtype in (torch.float16, torch.bfloat16):
# pipe.unet.half()
# pipe.text_encoder.half()
# pipe.to(device)
# return pipe
# pipeline = load_and_set_lora_ckpt(pipeline, args.output_dir, global_step, accelerator.device, weight_dtype)
# elif args.use_swift:
# if not is_swift_available():
# raise ValueError(
# 'Please install swift by `pip install ms-swift` to use efficient_tuners.'
# )
# from swift import Swift
# pipeline = pipeline.to(accelerator.device)
# pipeline.unet = Swift.from_pretrained(pipeline.unet, os.path.join(args.output_dir, 'swift'))
# if args.train_text_encoder:
# pipeline.text_encoder = Swift.from_pretrained(pipeline.text_encoder, os.path.join(args.output_dir, 'text_encoder'))
# else:
# pipeline = pipeline.to(accelerator.device)
# # load attention processors
# pipeline.unet.load_attn_procs(args.output_dir)
# # run inference
# if args.seed is not None:
# generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
# else:
# generator = None
# images = []
accelerator.end_training()
if __name__ == "__main__":
multiprocessing.set_start_method('spawn', force=True)
main()
# Copyright (c) Alibaba, Inc. and its affiliates.
import time
import subprocess
from modelscope import snapshot_download as ms_snapshot_download
import multiprocessing as mp
import os
project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
def max_retries(max_attempts):
def decorator(func):
def wrapper(*args, **kwargs):
attempts = 0
while attempts < max_attempts:
try:
return func(*args, **kwargs)
except Exception as e:
attempts += 1
print(f"Retry {attempts}/{max_attempts}: {e}")
# wait 1 sec
time.sleep(1)
raise Exception(f"Max retries ({max_attempts}) exceeded.")
return wrapper
return decorator
@max_retries(3)
def snapshot_download(*args, **kwargs):
return ms_snapshot_download(*args, **kwargs)
def pre_download_models():
snapshot_download('ly261666/cv_portrait_model', revision='v4.0')
snapshot_download('YorickHe/majicmixRealistic_v6', revision='v1.0.0')
snapshot_download('damo/face_chain_control_model', revision='v1.0.1')
snapshot_download('ly261666/cv_wanx_style_model', revision='v1.0.3')
snapshot_download('damo/face_chain_control_model', revision='v1.0.1')
snapshot_download('Cherrytest/zjz_mj_jiyi_small_addtxt_fromleo', revision='v1.0.0')
snapshot_download('Cherrytest/rot_bgr', revision='v1.0.0')
snapshot_download('damo/face_frombase_c4', revision='v1.0.0')
def set_spawn_method():
try:
mp.set_start_method('spawn', force=True)
except RuntimeError:
print("spawn method already set")
def check_install(*args):
try:
subprocess.check_output(args, stderr=subprocess.STDOUT)
return True
except OSError as e:
return False
def check_ffmpeg():
"""
Check if ffmpeg is installed.
"""
return check_install("ffmpeg", "-version")
def get_worker_data_dir() -> str:
"""
Get the worker data directory.
"""
return os.path.join(project_dir, "worker_data")
def join_worker_data_dir(*kwargs) -> str:
"""
Join the worker data directory with the specified sub directory.
"""
return os.path.join(get_worker_data_dir(), *kwargs)
#### The code and paper will be released soon
# Introduction
如果您熟悉中文,可以阅读[中文版本的README](./README_ZH.md)
This sub-project aims to provide a platform for users to generate the video of their reference image(s) with a special motion sequence. Currently we support [DensePose](https://densepose.org/) model to estimate human pose, and [MagicAnimate](https://showlab.github.io/magicanimate/) to generate video.
# Usage
Currently we only support inference stage and temporarily do not support training stage. We recommand the user check [Installation of MagicAnimate Tab](https://github.com/modelscope/facechain/tree/main/facechain_animate/resources/MagicAnimate/installation_for_magic_animate.md) first. And after installation, you should go to the root path of this project, i.e., `cd /path/to/facechain/`, and then `python -m facechain_animate.app`.
# To-Do List
- Support OpenPose model to videos.
- Add AnimateDiff into this sub-project.
- Add AnimateAnyone into this sub-project.
# Installation
We will support different animation models in future. Each model may have different dependencies. Please refer to the following when using different models:
- MagicAnimate: [Installation of MagicAnimate](https://github.com/modelscope/facechain/tree/main/facechain_animate/resources/MagicAnimate/installation_for_magic_animate.md)
- To be done...
# Acknowledgements
We would like to thank the following projects for their open research and foundational work:
- [MagicAnimate](showlab.github.io/magicanimate/)
- [DensePose](densepose.org)
- [Vid2DensePose](https://github.com/Flode-Labs/vid2densepose/tree/main)
\ No newline at end of file
# 安装
该子项目致力于为用户提供一个可以根据指定图像和特定动作序列生成视频的平台,目前该子项目中已集成用于姿态估计的[DensePose](https://densepose.org/)模型和视频生成的[MagicAnimate](https://showlab.github.io/magicanimate/)模型。
# 使用说明
目前我们只支持算法的推理阶段,暂时不支持训练阶段。我们建议用户先查看[Installation of MagicAnimate Tab](https://github.com/modelscope/facechain/tree/main/facechain_animate/resources/MagicAnimate/installation_for_magic_animate_ZH.md)里描述的内容。安装好依赖项后,进入本项目的根文件夹目录,即`cd /path/to/facechain/`,然后执行`python -m facechain_animate.app`
# 待办事项
- 支持OpenPose模型
- 支持AnimateDiff模型
- 支持AnimateAnyone模型
# 安装
后续我们将支持不同视频生成模型,每个视频生成模型的依赖项可能并不相同,因此使用不同模型时请参考以下对应内容:
- MagicAnimate: [Installation of MagicAnimate Tab](https://github.com/modelscope/facechain/tree/main/facechain_animate/resources/MagicAnimate/installation_for_magic_animate_ZH.md)
- 未完待续。。。
# 致谢
感谢以下项目的开源贡献:
- [MagicAnimate](showlab.github.io/magicanimate/)
- [DensePose](densepose.org)
- [Vid2DensePose](https://github.com/Flode-Labs/vid2densepose/tree/main)
# Copyright (c) Alibaba, Inc. and its affiliates.
import enum
import os
import json
import shutil
import slugify
import time
from concurrent.futures import ProcessPoolExecutor
import cv2
import gradio as gr
import numpy as np
from PIL import Image
import imageio
import torch
from glob import glob
import platform
from facechain.utils import snapshot_download, check_ffmpeg, set_spawn_method, project_dir, join_worker_data_dir
from facechain_animate.inference_animate import MagicAnimate
from facechain_animate.inference_densepose import DensePose
import tempfile
training_done_count = 0
inference_done_count = 0
def get_selected_video(state_video_list, evt: gr.SelectData):
return state_video_list[evt.index]
def get_previous_video_result(uuid):
if not uuid:
if os.getenv("MODELSCOPE_ENVIRONMENT") == 'studio':
return "请登陆后使用! (Please login first)"
else:
uuid = 'qw'
save_dir = join_worker_data_dir(uuid, 'animate', 'densepose')
gen_videos = glob(os.path.join(save_dir, '*.mp4'), recursive=True)
return gen_videos
def update_output_video_result(uuid):
video_list = get_previous_video_result(uuid)
return video_result_list
def launch_pipeline_densepose(uuid, source_video):
# check if source_video is end with .mp4
if not source_video or not source_video.endswith('.mp4'):
raise gr.Error('请提供一段mp4视频(Please provide 1 mp4 video)')
before_queue_size = 0
before_done_count = inference_done_count
user_directory = os.path.expanduser("~")
if not os.path.exists(os.path.join(user_directory, '.cache', 'modelscope', 'hub', 'eavesy', 'vid2densepose')):
gr.Info("第一次初始化会比较耗时,请耐心等待(The first time initialization will take time, please wait)")
gen_video = DensePose(uuid)
with ProcessPoolExecutor(max_workers=1) as executor:
future = executor.submit(gen_video, source_video)
while not future.done():
is_processing = future.running()
if not is_processing:
cur_done_count = inference_done_count
to_wait = before_queue_size - (cur_done_count - before_done_count)
yield ["排队等待资源中,前方还有{}个生成任务(Queueing, there are {} tasks ahead)".format(to_wait, to_wait),
None]
else:
yield ["生成中, 请耐心等待(Generating, please wait)...", None]
time.sleep(1)
output = future.result()
print(f'生成文件位于路径:{output}')
yield ["生成完毕(Generation done)!", output]
def launch_pipeline_animate(uuid, source_image, motion_sequence, random_seed, sampling_steps, guidance_scale=7.5):
before_queue_size = 0
before_done_count = inference_done_count
if not source_image:
raise gr.Error('请选择一张源图片(Please select 1 source image)')
if not motion_sequence or not motion_sequence.endswith('.mp4'):
raise gr.Error('请提供一段mp4视频(Please provide 1 mp4 video)')
def read_image(image, size=512):
return np.array(Image.open(image).resize((size, size)))
def read_video(video):
reader = imageio.get_reader(video)
fps = reader.get_meta_data()['fps']
return video
source_image = read_image(source_image)
motion_sequence = read_video(motion_sequence)
user_directory = os.path.expanduser("~")
if not os.path.exists(os.path.join(user_directory, '.cache', 'modelscope', 'hub', 'AI-ModelScope', 'MagicAnimate')):
gr.Info("第一次初始化会比较耗时,请耐心等待(The first time initialization will take time, please wait)")
gen_video = MagicAnimate(uuid)
with ProcessPoolExecutor(max_workers=1) as executor:
future = executor.submit(gen_video, source_image, motion_sequence, random_seed, sampling_steps, guidance_scale)
while not future.done():
is_processing = future.running()
if not is_processing:
cur_done_count = inference_done_count
to_wait = before_queue_size - (cur_done_count - before_done_count)
yield ["排队等待资源中,前方还有{}个生成任务(Queueing, there are {} tasks ahead)".format(to_wait, to_wait),
None]
else:
yield ["生成中, 请耐心等待(Generating, please wait)...", None]
time.sleep(1)
output = future.result()
print(f'生成文件位于路径:{output}')
yield ["生成完毕(Generation done)!", output]
def inference_animate():
def identity_function(inp):
return inp
with gr.Blocks() as demo:
uuid = gr.Text(label="modelscope_uuid", visible=False)
video_result_list = get_previous_video_result(uuid.value)
print(video_result_list)
state_video_list = gr.State(value=video_result_list)
gr.Markdown("""该标签页的功能基于[MagicAnimate](https://showlab.github.io/magicanimate/)实现,要使用该标签页,请按照[教程](https://github.com/modelscope/facechain/tree/main/facechain_animate/resources/MagicAnimate/installation_for_magic_animate_ZH.md)安装相关依赖。\n
The function of this tab is implemented based on [MagicAnimate](https://showlab.github.io/magicanimate/), to use this tab, you should follow the installation [guide](https://github.com/modelscope/facechain/tree/main/facechain_animate/resources/MagicAnimate/installation_for_magic_animate.md) """)
with gr.Row(equal_height=False):
with gr.Column(variant='panel'):
with gr.Box():
source_image = gr.Image(label="源图片(source image)", source="upload", type="filepath")
with gr.Column():
examples_image=[
["facechain_animate/resources/MagicAnimate/source_image/demo4.png"],
["facechain_animate/resources/MagicAnimate/source_image/0002.png"],
["facechain_animate/resources/MagicAnimate/source_image/dalle8.jpeg"],
]
gr.Examples(examples=examples_image, inputs=[source_image],
outputs=[source_image], fn=identity_function, cache_examples=os.getenv('SYSTEM') == 'spaces', label='Image Example')
motion_sequence = gr.Video(format="mp4", label="动作序列视频(Motion Sequence)", source="upload", height=400)
with gr.Column():
examples_video=[
["facechain_animate/resources/MagicAnimate/driving/densepose/running.mp4"],
["facechain_animate/resources/MagicAnimate/driving/densepose/demo4.mp4"],
["facechain_animate/resources/MagicAnimate/driving/densepose/running2.mp4"],
["facechain_animate/resources/MagicAnimate/driving/densepose/dancing2.mp4"],
]
gr.Examples(examples=examples_video, inputs=[motion_sequence],
outputs=[motion_sequence], fn=identity_function, cache_examples=os.getenv('SYSTEM') == 'spaces', label='Video Example')
with gr.Box():
gr.Markdown("""
注意:
- 如果没有动作序列视频,可以提供原视频文件进行动作序列视频生成(If you don't have motion sequence, you may generate motion sequence from a source video.)
- 动作序列视频生成基于DensePose实现(Motion sequence generation is based on DensePose.)
""")
source_video = gr.Video(label="原始视频(Original Video)", format="mp4", width=256)
gen_motion = gr.Button("生成动作序列视频(Generate motion sequence)", variant='primary')
gen_progress = gr.Textbox(value="当前无生成动作序列视频任务(No motion sequence generation task currently)")
gen_motion.click(fn=launch_pipeline_densepose, inputs=[uuid, source_video],
outputs=[gen_progress, motion_sequence])
with gr.Column(variant='panel'):
with gr.Box():
gr.Markdown("设置(Settings)")
with gr.Column(variant='panel'):
random_seed = gr.Textbox(label="随机种子(Random seed)", value=1, info="default: -1")
sampling_steps = gr.Textbox(label="采样步数(Sampling steps)", value=25, info="default: 25")
submit = gr.Button("生成(Generate)", variant='primary')
with gr.Box():
infer_progress = gr.Textbox(value="当前无任务(No task currently)")
gen_video = gr.Video(label="Generated video", format="mp4")
submit.click(fn=launch_pipeline_animate, inputs=[uuid, source_image, motion_sequence, random_seed, sampling_steps],
outputs=[infer_progress, gen_video])
return demo
with gr.Blocks(css='style.css') as demo:
from importlib.util import find_spec
if find_spec('webui'):
# if running as a webui extension, don't display banner self-advertisement
gr.Markdown("# <center> \N{fire} FaceChain Potrait Generation (\N{whale} [Paper cite it here](https://arxiv.org/abs/2308.14256) \N{whale})</center>")
else:
gr.Markdown("# <center> \N{fire} FaceChain Potrait Generation ([Github star it here](https://github.com/modelscope/facechain/tree/main) \N{whale}, [Paper](https://arxiv.org/abs/2308.14256) \N{whale}, [API](https://help.aliyun.com/zh/dashscope/developer-reference/facechain-quick-start) \N{whale}, [API's Example App](https://tongyi.aliyun.com/wanxiang/app/portrait-gallery) \N{whale})</center>")
gr.Markdown("##### <center> 本项目仅供学习交流,请勿将模型及其制作内容用于非法活动或违反他人隐私的场景。(This project is intended solely for the purpose of technological discussion, and should not be used for illegal activities and violating privacy of individuals.)</center>")
with gr.Tabs():
with gr.TabItem('\N{clapper board}人物动画生成(Human animate)'):
inference_animate()
if __name__ == "__main__":
set_spawn_method()
demo.queue(status_update_rate=1).launch(share=True)
VERSION: 2
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
BACKBONE:
NAME: "build_resnet_fpn_backbone"
RESNETS:
OUT_FEATURES: ["res2", "res3", "res4", "res5"]
FPN:
IN_FEATURES: ["res2", "res3", "res4", "res5"]
ANCHOR_GENERATOR:
SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map
ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps)
RPN:
IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"]
PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level
PRE_NMS_TOPK_TEST: 1000 # Per FPN level
# Detectron1 uses 2000 proposals per-batch,
# (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue)
# which is approximately 1000 proposals per-image since the default batch size for FPN is 2.
POST_NMS_TOPK_TRAIN: 1000
POST_NMS_TOPK_TEST: 1000
DENSEPOSE_ON: True
ROI_HEADS:
NAME: "DensePoseROIHeads"
IN_FEATURES: ["p2", "p3", "p4", "p5"]
NUM_CLASSES: 1
ROI_BOX_HEAD:
NAME: "FastRCNNConvFCHead"
NUM_FC: 2
POOLER_RESOLUTION: 7
POOLER_SAMPLING_RATIO: 2
POOLER_TYPE: "ROIAlign"
ROI_DENSEPOSE_HEAD:
NAME: "DensePoseV1ConvXHead"
POOLER_TYPE: "ROIAlign"
NUM_COARSE_SEGM_CHANNELS: 2
DATASETS:
TRAIN: ("densepose_coco_2014_train", "densepose_coco_2014_valminusminival")
TEST: ("densepose_coco_2014_minival",)
SOLVER:
IMS_PER_BATCH: 16
BASE_LR: 0.01
STEPS: (60000, 80000)
MAX_ITER: 90000
WARMUP_FACTOR: 0.1
INPUT:
MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment