Commit 26b49aab authored by mohammad's avatar mohammad
Browse files

fixed compatiblity with v0 checkpoints

parent 2ff004ac
......@@ -59,7 +59,7 @@ We have provided pretrained [BERT-345M](https://ngc.nvidia.com/catalog/models/nv
Alternatively, you can directly download the checkpoints using:
<pre>
BERT-345M-uncased: wget --content-disposition https://api.ngc.nvidia.com/v2/models/nvidia/megatron_bert_345m/versions/v0.1_cased/zip -O megatron_bert_345m_v0.1_uncased.zip
BERT-345M-uncased: wget --content-disposition https://api.ngc.nvidia.com/v2/models/nvidia/megatron_bert_345m/versions/v0.1_uncased/zip -O megatron_bert_345m_v0.1_uncased.zip
BERT-345M-cased: wget --content-disposition https://api.ngc.nvidia.com/v2/models/nvidia/megatron_bert_345m/versions/v0.1_cased/zip -O megatron_bert_345m_v0.1_cased.zip
GPT-345M: wget --content-disposition https://api.ngc.nvidia.com/v2/models/nvidia/megatron_lm_345m/versions/v0.0/zip -O megatron_lm_345m_v0.0.zip
</pre>
......
......@@ -211,6 +211,9 @@ def fix_query_key_value_ordering(model, checkpoint_version):
version is smaller than 2.0
"""
if checkpoint_version < 2.0:
if isinstance(model, list):
assert len(model)==1
model = model[0]
for name, param in model.named_parameters():
if name.endswith(('.query_key_value.weight', '.query_key_value.bias')):
if checkpoint_version == 0:
......@@ -291,8 +294,9 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
state_dict = torch.load(checkpoint_name, map_location='cpu')
sys.modules.pop('fp16.loss_scaler', None)
sys.modules.pop('megatron.fp16.loss_scaler', None)
except BaseException:
except BaseException as e:
print_rank_0('could not load the checkpoint')
print_rank_0(e)
sys.exit()
# set checkpoint version
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment