"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "a4c75f149269099a98613f51b76cd0b579a109ee"
Commit 629b22ad authored by Julien Chaumond's avatar Julien Chaumond
Browse files

[run_lm_finetuning] mask_tokens: document types

parent 594ca6de
...@@ -28,6 +28,7 @@ import pickle ...@@ -28,6 +28,7 @@ import pickle
import random import random
import re import re
import shutil import shutil
from typing import Tuple
import numpy as np import numpy as np
import torch import torch
...@@ -53,6 +54,7 @@ from transformers import ( ...@@ -53,6 +54,7 @@ from transformers import (
OpenAIGPTConfig, OpenAIGPTConfig,
OpenAIGPTLMHeadModel, OpenAIGPTLMHeadModel,
OpenAIGPTTokenizer, OpenAIGPTTokenizer,
PreTrainedTokenizer,
RobertaConfig, RobertaConfig,
RobertaForMaskedLM, RobertaForMaskedLM,
RobertaTokenizer, RobertaTokenizer,
...@@ -164,7 +166,7 @@ def _rotate_checkpoints(args, checkpoint_prefix, use_mtime=False): ...@@ -164,7 +166,7 @@ def _rotate_checkpoints(args, checkpoint_prefix, use_mtime=False):
shutil.rmtree(checkpoint) shutil.rmtree(checkpoint)
def mask_tokens(inputs, tokenizer, args): def mask_tokens(inputs: torch.Tensor, tokenizer: PreTrainedTokenizer, args) -> Tuple[torch.Tensor, torch.Tensor]:
""" Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """ """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
labels = inputs.clone() labels = inputs.clone()
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment