"docs/git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "7a7ac6bea15fdffc4b078dd81c485b517534389e"
Unverified Commit 3c1b6f59 authored by Julien Chaumond's avatar Julien Chaumond Committed by GitHub
Browse files

Merge branch 'master' into fix_top_k_top_p_filtering

parents a9f24a16 fa735208
...@@ -14,14 +14,15 @@ ...@@ -14,14 +14,15 @@
# limitations under the License. # limitations under the License.
""" """
Preprocessing script before training DistilBERT. Preprocessing script before training DistilBERT.
Specific to BERT -> DistilBERT.
""" """
from pytorch_transformers import BertForMaskedLM, RobertaForMaskedLM from transformers import BertForMaskedLM, RobertaForMaskedLM
import torch import torch
import argparse import argparse
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation") parser = argparse.ArgumentParser(description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation")
parser.add_argument("--model_type", default="bert", choices=["bert", "roberta"]) parser.add_argument("--model_type", default="bert", choices=["bert"])
parser.add_argument("--model_name", default='bert-base-uncased', type=str) parser.add_argument("--model_name", default='bert-base-uncased', type=str)
parser.add_argument("--dump_checkpoint", default='serialization_dir/tf_bert-base-uncased_0247911.pth', type=str) parser.add_argument("--dump_checkpoint", default='serialization_dir/tf_bert-base-uncased_0247911.pth', type=str)
parser.add_argument("--vocab_transform", action='store_true') parser.add_argument("--vocab_transform", action='store_true')
...@@ -31,9 +32,8 @@ if __name__ == '__main__': ...@@ -31,9 +32,8 @@ if __name__ == '__main__':
if args.model_type == 'bert': if args.model_type == 'bert':
model = BertForMaskedLM.from_pretrained(args.model_name) model = BertForMaskedLM.from_pretrained(args.model_name)
prefix = 'bert' prefix = 'bert'
elif args.model_type == 'roberta': else:
model = RobertaForMaskedLM.from_pretrained(args.model_name) raise ValueError(f'args.model_type should be "bert".')
prefix = 'roberta'
state_dict = model.state_dict() state_dict = model.state_dict()
compressed_sd = {} compressed_sd = {}
...@@ -68,20 +68,12 @@ if __name__ == '__main__': ...@@ -68,20 +68,12 @@ if __name__ == '__main__':
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.output.LayerNorm.{w}'] state_dict[f'{prefix}.encoder.layer.{teacher_idx}.output.LayerNorm.{w}']
std_idx += 1 std_idx += 1
if args.model_type == 'bert':
compressed_sd[f'vocab_projector.weight'] = state_dict[f'cls.predictions.decoder.weight'] compressed_sd[f'vocab_projector.weight'] = state_dict[f'cls.predictions.decoder.weight']
compressed_sd[f'vocab_projector.bias'] = state_dict[f'cls.predictions.bias'] compressed_sd[f'vocab_projector.bias'] = state_dict[f'cls.predictions.bias']
if args.vocab_transform: if args.vocab_transform:
for w in ['weight', 'bias']: for w in ['weight', 'bias']:
compressed_sd[f'vocab_transform.{w}'] = state_dict[f'cls.predictions.transform.dense.{w}'] compressed_sd[f'vocab_transform.{w}'] = state_dict[f'cls.predictions.transform.dense.{w}']
compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'cls.predictions.transform.LayerNorm.{w}'] compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'cls.predictions.transform.LayerNorm.{w}']
elif args.model_type == 'roberta':
compressed_sd[f'vocab_projector.weight'] = state_dict[f'lm_head.decoder.weight']
compressed_sd[f'vocab_projector.bias'] = state_dict[f'lm_head.bias']
if args.vocab_transform:
for w in ['weight', 'bias']:
compressed_sd[f'vocab_transform.{w}'] = state_dict[f'lm_head.dense.{w}']
compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'lm_head.layer_norm.{w}']
print(f'N layers selected for distillation: {std_idx}') print(f'N layers selected for distillation: {std_idx}')
print(f'Number of params transfered for distillation: {len(compressed_sd.keys())}') print(f'Number of params transfered for distillation: {len(compressed_sd.keys())}')
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
Preprocessing script before training DistilBERT. Preprocessing script before training the distilled model.
""" """
from collections import Counter from collections import Counter
import argparse import argparse
......
This diff is collapsed.
{
"activation": "gelu",
"attention_dropout": 0.1,
"dim": 768,
"dropout": 0.1,
"hidden_dim": 3072,
"initializer_range": 0.02,
"max_position_embeddings": 512,
"n_heads": 12,
"n_layers": 6,
"sinusoidal_pos_embds": true,
"tie_weights_": true,
"vocab_size": 30522
}
\ No newline at end of file
{
"initializer_range": 0.02,
"layer_norm_epsilon": 0.00001,
"n_ctx": 1024,
"n_embd": 768,
"n_head": 12,
"n_layer": 6,
"n_positions": 1024,
"vocab_size": 50257
}
\ No newline at end of file
tensorboardX tensorboardX
tensorboard
scikit-learn scikit-learn
seqeval
...@@ -32,7 +32,7 @@ from torch.utils.data import DataLoader, SequentialSampler, TensorDataset, Subse ...@@ -32,7 +32,7 @@ from torch.utils.data import DataLoader, SequentialSampler, TensorDataset, Subse
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from pytorch_transformers import (WEIGHTS_NAME, from transformers import (WEIGHTS_NAME,
BertConfig, BertForSequenceClassification, BertTokenizer, BertConfig, BertForSequenceClassification, BertTokenizer,
XLMConfig, XLMForSequenceClassification, XLMTokenizer, XLMConfig, XLMForSequenceClassification, XLMTokenizer,
XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer) XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer)
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/Transformer-XL/XLNet) """ Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/CTRL/Transformer-XL/XLNet)
""" """
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
...@@ -26,12 +26,14 @@ import torch ...@@ -26,12 +26,14 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
from pytorch_transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig from transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig
from pytorch_transformers import GPT2LMHeadModel, GPT2Tokenizer from transformers import GPT2LMHeadModel, GPT2Tokenizer
from pytorch_transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
from pytorch_transformers import XLNetLMHeadModel, XLNetTokenizer from transformers import XLNetLMHeadModel, XLNetTokenizer
from pytorch_transformers import TransfoXLLMHeadModel, TransfoXLTokenizer from transformers import TransfoXLLMHeadModel, TransfoXLTokenizer
from transformers import CTRLLMHeadModel, CTRLTokenizer
from transformers import XLMWithLMHeadModel, XLMTokenizer
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
...@@ -41,13 +43,15 @@ logger = logging.getLogger(__name__) ...@@ -41,13 +43,15 @@ logger = logging.getLogger(__name__)
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig)), ()) ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig)), ())
MODEL_CLASSES = { MODEL_CLASSES = {
'gpt2': (GPT2LMHeadModel, GPT2Tokenizer), 'gpt2': (GPT2LMHeadModel, GPT2Tokenizer),
'ctrl': (CTRLLMHeadModel, CTRLTokenizer),
'openai-gpt': (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), 'openai-gpt': (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
'xlnet': (XLNetLMHeadModel, XLNetTokenizer), 'xlnet': (XLNetLMHeadModel, XLNetTokenizer),
'transfo-xl': (TransfoXLLMHeadModel, TransfoXLTokenizer), 'transfo-xl': (TransfoXLLMHeadModel, TransfoXLTokenizer),
'xlm': (XLMWithLMHeadModel, XLMTokenizer),
} }
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia # Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
...@@ -103,7 +107,8 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf') ...@@ -103,7 +107,8 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')
return logits return logits
def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, is_xlnet=False, device='cpu'): def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0,
is_xlnet=False, is_xlm_mlm=False, xlm_mask_token=None, xlm_lang=None, device='cpu'):
context = torch.tensor(context, dtype=torch.long, device=device) context = torch.tensor(context, dtype=torch.long, device=device)
context = context.unsqueeze(0).repeat(num_samples, 1) context = context.unsqueeze(0).repeat(num_samples, 1)
generated = context generated = context
...@@ -121,9 +126,26 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k= ...@@ -121,9 +126,26 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
target_mapping[0, 0, -1] = 1.0 # predict last token target_mapping[0, 0, -1] = 1.0 # predict last token
inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping} inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping}
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states) if is_xlm_mlm and xlm_mask_token:
next_token_logits = outputs[0][:, -1, :] / temperature # XLM MLM models are direct models (predict same token, not next token)
# => need one additional dummy token in the input (will be masked and guessed)
input_ids = torch.cat((generated, torch.full((1, 1), xlm_mask_token, dtype=torch.long, device=device)), dim=1)
inputs = {'input_ids': input_ids}
if xlm_lang is not None:
inputs["langs"] = torch.tensor([xlm_lang] * inputs["input_ids"].shape[1], device=device).view(1, -1)
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states)
next_token_logits = outputs[0][:, -1, :] / (temperature if temperature > 0 else 1.)
# reptition penalty from CTRL (https://arxiv.org/abs/1909.05858)
for _ in set(generated.view(-1).tolist()):
next_token_logits[_] /= repetition_penalty
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
if temperature == 0: #greedy sampling:
next_token = torch.argmax(filtered_logits).unsqueeze(0)
else:
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
generated = torch.cat((generated, next_token), dim=1) generated = torch.cat((generated, next_token), dim=1)
return generated return generated
...@@ -137,15 +159,21 @@ def main(): ...@@ -137,15 +159,21 @@ def main():
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
parser.add_argument("--prompt", type=str, default="") parser.add_argument("--prompt", type=str, default="")
parser.add_argument("--padding_text", type=str, default="") parser.add_argument("--padding_text", type=str, default="")
parser.add_argument("--xlm_lang", type=str, default="", help="Optional language when used with the XLM model.")
parser.add_argument("--length", type=int, default=20) parser.add_argument("--length", type=int, default=20)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--num_samples", type=int, default=1) parser.add_argument("--num_samples", type=int, default=1)
parser.add_argument("--temperature", type=float, default=1.0,
help="temperature of 0 implies greedy sampling")
parser.add_argument("--repetition_penalty", type=float, default=1.0,
help="primarily useful for CTRL model; in that case, use 1.2")
parser.add_argument("--top_k", type=int, default=0) parser.add_argument("--top_k", type=int, default=0)
parser.add_argument("--top_p", type=float, default=0.9) parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--no_cuda", action='store_true', parser.add_argument("--no_cuda", action='store_true',
help="Avoid using CUDA when available") help="Avoid using CUDA when available")
parser.add_argument('--seed', type=int, default=42, parser.add_argument('--seed', type=int, default=42,
help="random seed for initialization") help="random seed for initialization")
parser.add_argument('--stop_token', type=str, default=None,
help="Token at which text generation is stopped")
args = parser.parse_args() args = parser.parse_args()
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
...@@ -167,13 +195,39 @@ def main(): ...@@ -167,13 +195,39 @@ def main():
elif args.length < 0: elif args.length < 0:
args.length = MAX_LENGTH # avoid infinite loop args.length = MAX_LENGTH # avoid infinite loop
print(args) logger.info(args)
if args.model_type in ["ctrl"]:
if args.temperature > 0.7:
logger.info('CTRL typically works better with lower temperatures (and lower top_k).')
while True: while True:
xlm_lang = None
# XLM Language usage detailed in the issues #1414
if args.model_type in ["xlm"] and hasattr(tokenizer, 'lang2id') and hasattr(model.config, 'use_lang_emb') \
and model.config.use_lang_emb:
if args.xlm_lang:
language = args.xlm_lang
else:
language = None
while language not in tokenizer.lang2id.keys():
language = input("Using XLM. Select language in " + str(list(tokenizer.lang2id.keys())) + " >>> ")
xlm_lang = tokenizer.lang2id[language]
# XLM masked-language modeling (MLM) models need masked token (see details in sample_sequence)
is_xlm_mlm = args.model_type in ["xlm"] and 'mlm' in args.model_name_or_path
if is_xlm_mlm:
xlm_mask_token = tokenizer.mask_token_id
else:
xlm_mask_token = None
raw_text = args.prompt if args.prompt else input("Model prompt >>> ") raw_text = args.prompt if args.prompt else input("Model prompt >>> ")
if args.model_type in ["transfo-xl", "xlnet"]: if args.model_type in ["transfo-xl", "xlnet"]:
# Models with memory likes to have a long prompt for short inputs. # Models with memory likes to have a long prompt for short inputs.
raw_text = (args.padding_text if args.padding_text else PADDING_TEXT) + raw_text raw_text = (args.padding_text if args.padding_text else PADDING_TEXT) + raw_text
context_tokens = tokenizer.encode(raw_text) context_tokens = tokenizer.encode(raw_text, add_special_tokens=False)
if args.model_type == "ctrl":
if not any(context_tokens[0] == x for x in tokenizer.control_codes.values()):
logger.info("WARNING! You are not starting your generation from a control code so you won't get good results")
out = sample_sequence( out = sample_sequence(
model=model, model=model,
context=context_tokens, context=context_tokens,
...@@ -182,13 +236,20 @@ def main(): ...@@ -182,13 +236,20 @@ def main():
temperature=args.temperature, temperature=args.temperature,
top_k=args.top_k, top_k=args.top_k,
top_p=args.top_p, top_p=args.top_p,
device=args.device, repetition_penalty=args.repetition_penalty,
is_xlnet=bool(args.model_type == "xlnet"), is_xlnet=bool(args.model_type == "xlnet"),
is_xlm_mlm=is_xlm_mlm,
xlm_mask_token=xlm_mask_token,
xlm_lang=xlm_lang,
device=args.device,
) )
out = out[:, len(context_tokens):].tolist() out = out[:, len(context_tokens):].tolist()
for o in out: for o in out:
text = tokenizer.decode(o, clean_up_tokenization_spaces=True) text = tokenizer.decode(o, clean_up_tokenization_spaces=True)
text = text[: text.find(args.stop_token) if args.stop_token else None]
print(text) print(text)
if args.prompt: if args.prompt:
break break
return text return text
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -24,7 +24,7 @@ import math ...@@ -24,7 +24,7 @@ import math
import collections import collections
from io import open from io import open
from pytorch_transformers.tokenization_bert import BasicTokenizer, whitespace_tokenize from transformers.tokenization_bert import BasicTokenizer, whitespace_tokenize
# Required by XLNet evaluation method to compute optimal threshold (see write_predictions_extended() method) # Required by XLNet evaluation method to compute optimal threshold (see write_predictions_extended() method)
from utils_squad_evaluate import find_all_best_thresh_v2, make_qid_to_has_ans, get_raw_scores from utils_squad_evaluate import find_all_best_thresh_v2, make_qid_to_has_ans, get_raw_scores
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment