Commit 1df591b0 authored by Jennifer's avatar Jennifer
Browse files

updates zero_to_fp32.py for new deepspeed version and import_weight bugfix

parent bb3f51e5
...@@ -688,8 +688,9 @@ def convert_deprecated_v1_keys(state_dict): ...@@ -688,8 +688,9 @@ def convert_deprecated_v1_keys(state_dict):
new_key = convert_key_re.sub(lambda m: replacements[m.group()], key) new_key = convert_key_re.sub(lambda m: replacements[m.group()], key)
# Add prefix for template modules # Add prefix for template modules
if new_key.startswith('template'): subheader = re.search('(?<=model.).*$', new_key).group()
new_key = f'template_embedder.{new_key}' if subheader.startswith('template'):
new_key = f'model.template_embedder.{subheader}'
converted_state_dict[new_key] = value converted_state_dict[new_key] = value
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment