"...git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "b168b0161987e95f9571c6fdaaef448c03ac396c"
Commit ce9eade2 authored by Lysandre's avatar Lysandre Committed by Lysandre Debut
Browse files

Initializer range using BertPreTrainedModel

parent 5680a110
...@@ -6,8 +6,7 @@ import torch ...@@ -6,8 +6,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from transformers.configuration_albert import AlbertConfig from transformers.configuration_albert import AlbertConfig
from transformers.modeling_bert import BertEmbeddings, BertModel, BertSelfAttention, prune_linear_layer, ACT2FN from transformers.modeling_bert import BertEmbeddings, BertPreTrainedModel, BertModel, BertSelfAttention, prune_linear_layer, ACT2FN
from transformers.modeling_utils import PreTrainedModel
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -362,7 +361,7 @@ class AlbertModel(BertModel): ...@@ -362,7 +361,7 @@ class AlbertModel(BertModel):
@add_start_docstrings("Bert Model with a `language modeling` head on top.", ALBERT_START_DOCSTRING, ALBERT_INPUTS_DOCSTRING) @add_start_docstrings("Bert Model with a `language modeling` head on top.", ALBERT_START_DOCSTRING, ALBERT_INPUTS_DOCSTRING)
class AlbertForMaskedLM(PreTrainedModel): class AlbertForMaskedLM(BertPreTrainedModel):
r""" r"""
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the masked language modeling loss. Labels for computing the masked language modeling loss.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment