"...git@developer.sourcefind.cn:modelzoo/solov2-pytorch.git" did not exist on "cfc3cf0eeccee2978a03e3e002e8e174962b1b51"
Commit f2d1d16f authored by comfyanonymous's avatar comfyanonymous
Browse files

Support Stable Cascade Stage B lite.

parent 0b3c5048
......@@ -46,6 +46,18 @@ def detect_unet_config(state_dict, key_prefix):
unet_config['c_cond'] = 2048
elif '{}clip_mapper.weight'.format(key_prefix) in state_dict_keys:
unet_config['stable_cascade_stage'] = 'b'
w = state_dict['{}down_blocks.1.0.channelwise.0.weight'.format(key_prefix)]
if w.shape[-1] == 640:
unet_config['c_hidden'] = [320, 640, 1280, 1280]
unet_config['nhead'] = [-1, -1, 20, 20]
unet_config['blocks'] = [[2, 6, 28, 6], [6, 28, 6, 2]]
unet_config['block_repeat'] = [[1, 1, 1, 1], [3, 3, 2, 2]]
elif w.shape[-1] == 576: #stage b lite
unet_config['c_hidden'] = [320, 576, 1152, 1152]
unet_config['nhead'] = [-1, 9, 18, 18]
unet_config['blocks'] = [[2, 4, 14, 4], [4, 14, 4, 2]]
unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]]
return unet_config
unet_config = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment