Commit ce9eade2 authored by Lysandre's avatar Lysandre Committed by Lysandre Debut
Browse files

Initializer range using BertPreTrainedModel

parent 5680a110
......@@ -6,8 +6,7 @@ import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from transformers.configuration_albert import AlbertConfig
from transformers.modeling_bert import BertEmbeddings, BertModel, BertSelfAttention, prune_linear_layer, ACT2FN
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_bert import BertEmbeddings, BertPreTrainedModel, BertModel, BertSelfAttention, prune_linear_layer, ACT2FN
from .file_utils import add_start_docstrings
logger = logging.getLogger(__name__)
......@@ -362,7 +361,7 @@ class AlbertModel(BertModel):
@add_start_docstrings("Bert Model with a `language modeling` head on top.", ALBERT_START_DOCSTRING, ALBERT_INPUTS_DOCSTRING)
class AlbertForMaskedLM(PreTrainedModel):
class AlbertForMaskedLM(BertPreTrainedModel):
r"""
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the masked language modeling loss.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment