"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "637aa811d4d19708369e8fb6c532b4c420564e05"
Commit b3d834ae authored by Lysandre's avatar Lysandre
Browse files

Reorganize ALBERT conversion script

parent b0ee7c7d
......@@ -68,14 +68,36 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
for name, array in zip(names, arrays):
original_name = name
# If saved from the TF HUB module
name = name.replace("module/", "")
# Renaming and simplifying
name = name.replace("ffn_1", "ffn")
name = name.replace("/bert/", "/albert/")
name = name.replace("ffn/intermediate/output", "ffn_output")
name = name.replace("bert/", "albert/")
name = name.replace("attention_1", "attention")
name = name.replace("cls/predictions", "predictions")
name = name.replace("transform/", "")
name = name.replace("LayerNorm_1", "full_layer_layer_norm")
name = name.replace("LayerNorm", "attention/LayerNorm")
name = name.replace("transformer/", "")
# The feed forward layer had an 'intermediate' step which has been abstracted away
name = name.replace("intermediate/dense/", "")
name = name.replace("ffn/intermediate/output/dense/", "ffn_output/")
# ALBERT attention was split between self and output which have been abstracted away
name = name.replace("/output/", "/")
name = name.replace("/self/", "/")
# The pooler is a linear layer
name = name.replace("pooler/dense", "pooler")
# The classifier was simplified to predictions from cls/predictions
name = name.replace("cls/predictions", "predictions")
name = name.replace("predictions/attention", "predictions")
# Naming was changed to be more explicit
name = name.replace("embeddings/attention", "embeddings")
name = name.replace("inner_group_", "albert_layers/")
name = name.replace("group_", "albert_layer_groups/")
name = name.split('/')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment