Commit c9cb7f8a authored by Lysandre's avatar Lysandre Committed by Lysandre Debut
Browse files

Torch 1.1.0 compatibility + FP16 O1 + TF checkpoints

Co-authored-by: wassname
parent b18509c2
......@@ -203,8 +203,8 @@ class AlbertAttention(BertSelfAttention):
# Should find a better way to do this
w = self.dense.weight.T.view(self.num_attention_heads, self.attention_head_size, self.hidden_size)
b = self.dense.bias
w = self.dense.weight.t().view(self.num_attention_heads, self.attention_head_size, self.hidden_size).to(context_layer.dtype)
b = self.dense.bias.to(context_layer.dtype)
projected_context_layer = torch.einsum("bfnd,ndh->bfh", context_layer, w) + b
projected_context_layer_dropout = self.dropout(projected_context_layer)
......
......@@ -36,7 +36,14 @@ import logging
logger = logging.getLogger(__name__)
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
# TODO FILL THAT UP
'albert-base-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-tf_model.h5",
'albert-large-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-tf_model.h5",
'albert-xlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-tf_model.h5",
'albert-xxlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-tf_model.h5",
'albert-base-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-tf_model.h5",
'albert-large-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-tf_model.h5",
'albert-xlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-tf_model.h5",
'albert-xxlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-tf_model.h5",
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment