Unverified Commit 54abc67a authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #2255 from aaugustin/implement-best-practices

Implement some Python best practices
parents 645713e2 c11b3e29
import logging
from argparse import ArgumentParser, Namespace
from typing import List, Optional, Union, Any
from typing import Any, List, Optional, Union
from transformers import Pipeline
from transformers.commands import BaseTransformersCLICommand
from transformers.pipelines import SUPPORTED_TASKS, pipeline
import logging
try:
from uvicorn import run
from fastapi import FastAPI, HTTPException, Body
from pydantic import BaseModel
_serve_dependancies_installed = True
except (ImportError, AttributeError):
BaseModel = object
Body = lambda *x, **y: None
def Body(*x, **y):
pass
_serve_dependancies_installed = False
from transformers import Pipeline
from transformers.commands import BaseTransformersCLICommand
from transformers.pipelines import SUPPORTED_TASKS, pipeline
logger = logging.getLogger('transformers-cli/serving')
logger = logging.getLogger("transformers-cli/serving")
def serve_command_factory(args: Namespace):
"""
Factory function used to instantiate serving server from provided command line arguments.
:return: ServeCommand
"""
nlp = pipeline(task=args.task,
nlp = pipeline(
task=args.task,
model=args.model if args.model else None,
config=args.config,
tokenizer=args.tokenizer,
device=args.device)
device=args.device,
)
return ServeCommand(nlp, args.host, args.port)
......@@ -36,6 +44,7 @@ class ServeModelInfoResult(BaseModel):
"""
Expose model information
"""
infos: dict
......@@ -43,6 +52,7 @@ class ServeTokenizeResult(BaseModel):
"""
Tokenize result model
"""
tokens: List[str]
tokens_ids: Optional[List[int]]
......@@ -51,6 +61,7 @@ class ServeDeTokenizeResult(BaseModel):
"""
DeTokenize result model
"""
text: str
......@@ -58,11 +69,11 @@ class ServeForwardResult(BaseModel):
"""
Forward result model
"""
output: Any
class ServeCommand(BaseTransformersCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
"""
......@@ -70,14 +81,23 @@ class ServeCommand(BaseTransformersCLICommand):
:param parser: Root parser to register command-specific arguments
:return:
"""
serve_parser = parser.add_parser('serve', help='CLI tool to run inference requests through REST and GraphQL endpoints.')
serve_parser.add_argument('--task', type=str, choices=SUPPORTED_TASKS.keys(), help='The task to run the pipeline on')
serve_parser.add_argument('--host', type=str, default='localhost', help='Interface the server will listen on.')
serve_parser.add_argument('--port', type=int, default=8888, help='Port the serving will listen to.')
serve_parser.add_argument('--model', type=str, help='Model\'s name or path to stored model.')
serve_parser.add_argument('--config', type=str, help='Model\'s config name or path to stored model.')
serve_parser.add_argument('--tokenizer', type=str, help='Tokenizer name to use.')
serve_parser.add_argument('--device', type=int, default=-1, help='Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)')
serve_parser = parser.add_parser(
"serve", help="CLI tool to run inference requests through REST and GraphQL endpoints."
)
serve_parser.add_argument(
"--task", type=str, choices=SUPPORTED_TASKS.keys(), help="The task to run the pipeline on"
)
serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.")
serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.")
serve_parser.add_argument("--model", type=str, help="Model's name or path to stored model.")
serve_parser.add_argument("--config", type=str, help="Model's config name or path to stored model.")
serve_parser.add_argument("--tokenizer", type=str, help="Tokenizer name to use.")
serve_parser.add_argument(
"--device",
type=int,
default=-1,
help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)",
)
serve_parser.set_defaults(func=serve_command_factory)
def __init__(self, pipeline: Pipeline, host: str, port: int):
......@@ -87,18 +107,22 @@ class ServeCommand(BaseTransformersCLICommand):
self._host = host
self._port = port
if not _serve_dependancies_installed:
raise ImportError("Using serve command requires FastAPI and unicorn. "
raise ImportError(
"Using serve command requires FastAPI and unicorn. "
"Please install transformers with [serving]: pip install transformers[serving]."
"Or install FastAPI and unicorn separatly.")
"Or install FastAPI and unicorn separatly."
)
else:
logger.info('Serving model over {}:{}'.format(host, port))
logger.info("Serving model over {}:{}".format(host, port))
self._app = FastAPI()
# Register routes
self._app.add_api_route('/', self.model_info, response_model=ServeModelInfoResult, methods=['GET'])
self._app.add_api_route('/tokenize', self.tokenize, response_model=ServeTokenizeResult, methods=['POST'])
self._app.add_api_route('/detokenize', self.detokenize, response_model=ServeDeTokenizeResult, methods=['POST'])
self._app.add_api_route('/forward', self.forward, response_model=ServeForwardResult, methods=['POST'])
self._app.add_api_route("/", self.model_info, response_model=ServeModelInfoResult, methods=["GET"])
self._app.add_api_route("/tokenize", self.tokenize, response_model=ServeTokenizeResult, methods=["POST"])
self._app.add_api_route(
"/detokenize", self.detokenize, response_model=ServeDeTokenizeResult, methods=["POST"]
)
self._app.add_api_route("/forward", self.forward, response_model=ServeForwardResult, methods=["POST"])
def run(self):
run(self._app, host=self._host, port=self._port)
......@@ -122,11 +146,14 @@ class ServeCommand(BaseTransformersCLICommand):
return ServeTokenizeResult(tokens=tokens_txt)
except Exception as e:
raise HTTPException(status_code=500, detail={"model": '', "error": str(e)})
raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
def detokenize(self, tokens_ids: List[int] = Body(None, embed=True),
def detokenize(
self,
tokens_ids: List[int] = Body(None, embed=True),
skip_special_tokens: bool = Body(False, embed=True),
cleanup_tokenization_spaces: bool = Body(True, embed=True)):
cleanup_tokenization_spaces: bool = Body(True, embed=True),
):
"""
Detokenize the provided tokens ids to readable text:
- **tokens_ids**: List of tokens ids
......@@ -135,9 +162,9 @@ class ServeCommand(BaseTransformersCLICommand):
"""
try:
decoded_str = self._pipeline.tokenizer.decode(tokens_ids, skip_special_tokens, cleanup_tokenization_spaces)
return ServeDeTokenizeResult(model='', text=decoded_str)
return ServeDeTokenizeResult(model="", text=decoded_str)
except Exception as e:
raise HTTPException(status_code=500, detail={"model": '', "error": str(e)})
raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
def forward(self, inputs: Union[str, dict, List[str], List[int], List[dict]] = Body(None, embed=True)):
"""
......
......@@ -2,10 +2,10 @@ import os
from argparse import ArgumentParser, Namespace
from logging import getLogger
from transformers import SingleSentenceClassificationProcessor as Processor
from transformers import TextClassificationPipeline, is_tf_available, is_torch_available
from transformers.commands import BaseTransformersCLICommand
from transformers import (is_tf_available, is_torch_available,
TextClassificationPipeline,
SingleSentenceClassificationProcessor as Processor)
if not is_tf_available() and not is_torch_available():
raise ImportError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training")
......@@ -14,6 +14,7 @@ if not is_tf_available() and not is_torch_available():
USE_XLA = False
USE_AMP = False
def train_command_factory(args: Namespace):
"""
Factory function used to instantiate serving server from provided command line arguments.
......@@ -23,7 +24,6 @@ def train_command_factory(args: Namespace):
class TrainCommand(BaseTransformersCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
"""
......@@ -31,47 +31,54 @@ class TrainCommand(BaseTransformersCLICommand):
:param parser: Root parser to register command-specific arguments
:return:
"""
train_parser = parser.add_parser('train', help='CLI tool to train a model on a task.')
train_parser = parser.add_parser("train", help="CLI tool to train a model on a task.")
train_parser.add_argument('--train_data', type=str, required=True,
train_parser.add_argument(
"--train_data",
type=str,
required=True,
help="path to train (and optionally evaluation) dataset as a csv with "
"tab separated labels and sentences.")
train_parser.add_argument('--column_label', type=int, default=0,
help='Column of the dataset csv file with example labels.')
train_parser.add_argument('--column_text', type=int, default=1,
help='Column of the dataset csv file with example texts.')
train_parser.add_argument('--column_id', type=int, default=2,
help='Column of the dataset csv file with example ids.')
train_parser.add_argument('--skip_first_row', action='store_true',
help='Skip the first row of the csv file (headers).')
train_parser.add_argument('--validation_data', type=str, default='',
help='path to validation dataset.')
train_parser.add_argument('--validation_split', type=float, default=0.1,
help="if validation dataset is not provided, fraction of train dataset "
"to use as validation dataset.")
train_parser.add_argument('--output', type=str, default='./',
help='path to saved the trained model.')
train_parser.add_argument('--task', type=str, default='text_classification',
help='Task to train the model on.')
train_parser.add_argument('--model', type=str, default='bert-base-uncased',
help='Model\'s name or path to stored model.')
train_parser.add_argument('--train_batch_size', type=int, default=32,
help='Batch size for training.')
train_parser.add_argument('--valid_batch_size', type=int, default=64,
help='Batch size for validation.')
train_parser.add_argument('--learning_rate', type=float, default=3e-5,
help="Learning rate.")
train_parser.add_argument('--adam_epsilon', type=float, default=1e-08,
help="Epsilon for Adam optimizer.")
"tab separated labels and sentences.",
)
train_parser.add_argument(
"--column_label", type=int, default=0, help="Column of the dataset csv file with example labels."
)
train_parser.add_argument(
"--column_text", type=int, default=1, help="Column of the dataset csv file with example texts."
)
train_parser.add_argument(
"--column_id", type=int, default=2, help="Column of the dataset csv file with example ids."
)
train_parser.add_argument(
"--skip_first_row", action="store_true", help="Skip the first row of the csv file (headers)."
)
train_parser.add_argument("--validation_data", type=str, default="", help="path to validation dataset.")
train_parser.add_argument(
"--validation_split",
type=float,
default=0.1,
help="if validation dataset is not provided, fraction of train dataset " "to use as validation dataset.",
)
train_parser.add_argument("--output", type=str, default="./", help="path to saved the trained model.")
train_parser.add_argument(
"--task", type=str, default="text_classification", help="Task to train the model on."
)
train_parser.add_argument(
"--model", type=str, default="bert-base-uncased", help="Model's name or path to stored model."
)
train_parser.add_argument("--train_batch_size", type=int, default=32, help="Batch size for training.")
train_parser.add_argument("--valid_batch_size", type=int, default=64, help="Batch size for validation.")
train_parser.add_argument("--learning_rate", type=float, default=3e-5, help="Learning rate.")
train_parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon for Adam optimizer.")
train_parser.set_defaults(func=train_command_factory)
def __init__(self, args: Namespace):
self.logger = getLogger('transformers-cli/training')
self.logger = getLogger("transformers-cli/training")
self.framework = 'tf' if is_tf_available() else 'torch'
self.framework = "tf" if is_tf_available() else "torch"
os.makedirs(args.output, exist_ok=True)
assert os.path.isdir(args.output)
......@@ -81,28 +88,32 @@ class TrainCommand(BaseTransformersCLICommand):
self.column_text = args.column_text
self.column_id = args.column_id
self.logger.info('Loading {} pipeline for {}'.format(args.task, args.model))
if args.task == 'text_classification':
self.logger.info("Loading {} pipeline for {}".format(args.task, args.model))
if args.task == "text_classification":
self.pipeline = TextClassificationPipeline.from_pretrained(args.model)
elif args.task == 'token_classification':
elif args.task == "token_classification":
raise NotImplementedError
elif args.task == 'question_answering':
elif args.task == "question_answering":
raise NotImplementedError
self.logger.info('Loading dataset from {}'.format(args.train_data))
self.train_dataset = Processor.create_from_csv(args.train_data,
self.logger.info("Loading dataset from {}".format(args.train_data))
self.train_dataset = Processor.create_from_csv(
args.train_data,
column_label=args.column_label,
column_text=args.column_text,
column_id=args.column_id,
skip_first_row=args.skip_first_row)
skip_first_row=args.skip_first_row,
)
self.valid_dataset = None
if args.validation_data:
self.logger.info('Loading validation dataset from {}'.format(args.validation_data))
self.valid_dataset = Processor.create_from_csv(args.validation_data,
self.logger.info("Loading validation dataset from {}".format(args.validation_data))
self.valid_dataset = Processor.create_from_csv(
args.validation_data,
column_label=args.column_label,
column_text=args.column_text,
column_id=args.column_id,
skip_first_row=args.skip_first_row)
skip_first_row=args.skip_first_row,
)
self.validation_split = args.validation_split
self.train_batch_size = args.train_batch_size
......@@ -111,7 +122,7 @@ class TrainCommand(BaseTransformersCLICommand):
self.adam_epsilon = args.adam_epsilon
def run(self):
if self.framework == 'tf':
if self.framework == "tf":
return self.run_tf()
return self.run_torch()
......@@ -119,13 +130,15 @@ class TrainCommand(BaseTransformersCLICommand):
raise NotImplementedError
def run_tf(self):
self.pipeline.fit(self.train_dataset,
self.pipeline.fit(
self.train_dataset,
validation_data=self.valid_dataset,
validation_split=self.validation_split,
learning_rate=self.learning_rate,
adam_epsilon=self.adam_epsilon,
train_batch_size=self.train_batch_size,
valid_batch_size=self.valid_batch_size)
valid_batch_size=self.valid_batch_size,
)
# Save trained pipeline
self.pipeline.save_pretrained(self.output)
import os
from argparse import ArgumentParser
from getpass import getpass
import os
from typing import List, Union
from requests.exceptions import HTTPError
from transformers.commands import BaseTransformersCLICommand
from transformers.hf_api import HfApi, HfFolder, HTTPError
from transformers.hf_api import HfApi, HfFolder
class UserCommands(BaseTransformersCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
login_parser = parser.add_parser('login')
login_parser = parser.add_parser("login")
login_parser.set_defaults(func=lambda args: LoginCommand(args))
whoami_parser = parser.add_parser('whoami')
whoami_parser = parser.add_parser("whoami")
whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args))
logout_parser = parser.add_parser('logout')
logout_parser = parser.add_parser("logout")
logout_parser.set_defaults(func=lambda args: LogoutCommand(args))
list_parser = parser.add_parser('ls')
list_parser = parser.add_parser("ls")
list_parser.set_defaults(func=lambda args: ListObjsCommand(args))
# upload
upload_parser = parser.add_parser('upload')
upload_parser.add_argument('path', type=str, help='Local path of the folder or individual file to upload.')
upload_parser.add_argument('--filename', type=str, default=None, help='Optional: override individual object filename on S3.')
upload_parser = parser.add_parser("upload")
upload_parser.add_argument("path", type=str, help="Local path of the folder or individual file to upload.")
upload_parser.add_argument(
"--filename", type=str, default=None, help="Optional: override individual object filename on S3."
)
upload_parser.set_defaults(func=lambda args: UploadCommand(args))
class ANSI:
"""
Helper for en.wikipedia.org/wiki/ANSI_escape_code
"""
_bold = u"\u001b[1m"
_reset = u"\u001b[0m"
@classmethod
def bold(cls, s):
return "{}{}{}".format(cls._bold, s, cls._reset)
......@@ -44,14 +50,16 @@ class BaseUserCommand:
class LoginCommand(BaseUserCommand):
def run(self):
print("""
print(
"""
_| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|
_| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
_|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_|
_| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
_| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_|
""")
"""
)
username = input("Username: ")
password = getpass()
try:
......@@ -91,8 +99,7 @@ class LogoutCommand(BaseUserCommand):
class ListObjsCommand(BaseUserCommand):
def tabulate(self, rows, headers):
# type: (List[List[Union[str, int]]], List[str]) -> str
def tabulate(self, rows: List[List[Union[str, int]]], headers: List[str]) -> str:
"""
Inspired by:
stackoverflow.com/a/8356620/593036
......@@ -101,16 +108,10 @@ class ListObjsCommand(BaseUserCommand):
col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
lines = []
lines.append(
row_format.format(*headers)
)
lines.append(
row_format.format(*["-" * w for w in col_widths])
)
lines.append(row_format.format(*headers))
lines.append(row_format.format(*["-" * w for w in col_widths]))
for row in rows:
lines.append(
row_format.format(*row)
)
lines.append(row_format.format(*row))
return "\n".join(lines)
def run(self):
......@@ -126,15 +127,8 @@ class ListObjsCommand(BaseUserCommand):
if len(objs) == 0:
print("No shared file yet")
exit()
rows = [ [
obj.filename,
obj.LastModified,
obj.ETag,
obj.Size
] for obj in objs ]
print(
self.tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"])
)
rows = [[obj.filename, obj.LastModified, obj.ETag, obj.Size] for obj in objs]
print(self.tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"]))
class UploadCommand(BaseUserCommand):
......@@ -143,13 +137,7 @@ class UploadCommand(BaseUserCommand):
Recursively list all files in a folder.
"""
entries: List[os.DirEntry] = list(os.scandir(rel_path))
files = [
(
os.path.join(os.getcwd(), f.path), # filepath
f.path # filename
)
for f in entries if f.is_file()
]
files = [(os.path.join(os.getcwd(), f.path), f.path) for f in entries if f.is_file()] # filepath # filename
for f in entries:
if f.is_dir():
files += self.walk_dir(f.path)
......@@ -173,22 +161,14 @@ class UploadCommand(BaseUserCommand):
raise ValueError("Not a valid file or directory: {}".format(local_path))
for filepath, filename in files:
print(
"About to upload file {} to S3 under filename {}".format(
ANSI.bold(filepath), ANSI.bold(filename)
)
)
print("About to upload file {} to S3 under filename {}".format(ANSI.bold(filepath), ANSI.bold(filename)))
choice = input("Proceed? [Y/n] ").lower()
if not(choice == "" or choice == "y" or choice == "yes"):
if not (choice == "" or choice == "y" or choice == "yes"):
print("Abort")
exit()
print(
ANSI.bold("Uploading... This might take a while if files are large")
)
print(ANSI.bold("Uploading... This might take a while if files are large"))
for filepath, filename in files:
access_url = self._api.presign_and_upload(
token=token, filename=filename, filepath=filepath
)
access_url = self._api.presign_and_upload(token=token, filename=filename, filepath=filepath)
print("Your file now lives at:")
print(access_url)
......@@ -17,17 +17,19 @@
from .configuration_utils import PretrainedConfig
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'albert-base-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-config.json",
'albert-large-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-config.json",
'albert-xlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-config.json",
'albert-xxlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-config.json",
'albert-base-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-config.json",
'albert-large-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json",
'albert-xlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-config.json",
'albert-xxlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-config.json",
"albert-base-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-config.json",
"albert-large-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-config.json",
"albert-xlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-config.json",
"albert-xxlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-config.json",
"albert-base-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-config.json",
"albert-large-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json",
"albert-xlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-config.json",
"albert-xxlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-config.json",
}
class AlbertConfig(PretrainedConfig):
"""Configuration for `AlbertModel`.
......@@ -36,7 +38,8 @@ class AlbertConfig(PretrainedConfig):
pretrained_config_archive_map = ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self,
def __init__(
self,
vocab_size=30000,
embedding_size=128,
hidden_size=4096,
......@@ -51,7 +54,9 @@ class AlbertConfig(PretrainedConfig):
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12, **kwargs):
layer_norm_eps=1e-12,
**kwargs
):
"""Constructs AlbertConfig.
Args:
......
......@@ -18,24 +18,26 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import logging
from .configuration_bert import BertConfig, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_openai import OpenAIGPTConfig, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_transfo_xl import TransfoXLConfig, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_gpt2 import GPT2Config, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_ctrl import CTRLConfig, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_xlnet import XLNetConfig, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_xlm import XLMConfig, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_roberta import RobertaConfig, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_distilbert import DistilBertConfig, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_albert import AlbertConfig, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_camembert import CamembertConfig, CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_t5 import T5Config, T5_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_xlm_roberta import XLMRobertaConfig, XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
from .configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig
from .configuration_xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
from .configuration_xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig
logger = logging.getLogger(__name__)
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict((key, value)
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
(key, value)
for pretrained_map in [
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
......@@ -51,7 +53,8 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict((key, value)
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
]
for key, value, in pretrained_map.items())
for key, value, in pretrained_map.items()
)
class AutoConfig(object):
......@@ -79,37 +82,42 @@ class AutoConfig(object):
- contains `ctrl` : CTRLConfig (CTRL model)
This class cannot be instantiated using `__init__()` (throw an error).
"""
def __init__(self):
raise EnvironmentError("AutoConfig is designed to be instantiated "
"using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method.")
raise EnvironmentError(
"AutoConfig is designed to be instantiated "
"using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method."
)
@classmethod
def for_model(cls, model_type, *args, **kwargs):
if 'distilbert' in model_type:
if "distilbert" in model_type:
return DistilBertConfig(*args, **kwargs)
elif 'roberta' in model_type:
elif "roberta" in model_type:
return RobertaConfig(*args, **kwargs)
elif 'bert' in model_type:
elif "bert" in model_type:
return BertConfig(*args, **kwargs)
elif 'openai-gpt' in model_type:
elif "openai-gpt" in model_type:
return OpenAIGPTConfig(*args, **kwargs)
elif 'gpt2' in model_type:
elif "gpt2" in model_type:
return GPT2Config(*args, **kwargs)
elif 'transfo-xl' in model_type:
elif "transfo-xl" in model_type:
return TransfoXLConfig(*args, **kwargs)
elif 'xlnet' in model_type:
elif "xlnet" in model_type:
return XLNetConfig(*args, **kwargs)
elif 'xlm' in model_type:
elif "xlm" in model_type:
return XLMConfig(*args, **kwargs)
elif 'ctrl' in model_type:
elif "ctrl" in model_type:
return CTRLConfig(*args, **kwargs)
elif 'albert' in model_type:
elif "albert" in model_type:
return AlbertConfig(*args, **kwargs)
elif 'camembert' in model_type:
elif "camembert" in model_type:
return CamembertConfig(*args, **kwargs)
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
raise ValueError(
"Unrecognized model identifier in {}. Should contains one of "
"'distilbert', 'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta', 'ctrl', 'camembert', 'albert'".format(model_type))
"'xlm', 'roberta', 'ctrl', 'camembert', 'albert'".format(model_type)
)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
......@@ -176,32 +184,36 @@ class AutoConfig(object):
assert unused_kwargs == {'foo': False}
"""
if 't5' in pretrained_model_name_or_path:
if "t5" in pretrained_model_name_or_path:
return T5Config.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'distilbert' in pretrained_model_name_or_path:
elif "distilbert" in pretrained_model_name_or_path:
return DistilBertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'albert' in pretrained_model_name_or_path:
elif "albert" in pretrained_model_name_or_path:
return AlbertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'camembert' in pretrained_model_name_or_path:
elif "camembert" in pretrained_model_name_or_path:
return CamembertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'xlm-roberta' in pretrained_model_name_or_path:
elif "xlm-roberta" in pretrained_model_name_or_path:
return XLMRobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'roberta' in pretrained_model_name_or_path:
elif "roberta" in pretrained_model_name_or_path:
return RobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'bert' in pretrained_model_name_or_path:
elif "bert" in pretrained_model_name_or_path:
return BertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'openai-gpt' in pretrained_model_name_or_path:
elif "openai-gpt" in pretrained_model_name_or_path:
return OpenAIGPTConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'gpt2' in pretrained_model_name_or_path:
elif "gpt2" in pretrained_model_name_or_path:
return GPT2Config.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'transfo-xl' in pretrained_model_name_or_path:
elif "transfo-xl" in pretrained_model_name_or_path:
return TransfoXLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'xlnet' in pretrained_model_name_or_path:
elif "xlnet" in pretrained_model_name_or_path:
return XLNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'xlm' in pretrained_model_name_or_path:
elif "xlm" in pretrained_model_name_or_path:
return XLMConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'ctrl' in pretrained_model_name_or_path:
elif "ctrl" in pretrained_model_name_or_path:
return CTRLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
raise ValueError(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm-roberta', 'xlm', 'roberta', 'distilbert', 'camembert', 'ctrl', 'albert'".format(pretrained_model_name_or_path))
"'xlm-roberta', 'xlm', 'roberta', 'distilbert', 'camembert', 'ctrl', 'albert'".format(
pretrained_model_name_or_path
)
)
......@@ -17,37 +17,35 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import json
import logging
import sys
from io import open
from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__)
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json",
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
'bert-base-german-dbmdz-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json",
'bert-base-german-dbmdz-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json",
'bert-base-japanese': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-config.json",
'bert-base-japanese-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-config.json",
'bert-base-japanese-char': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-config.json",
'bert-base-japanese-char-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-config.json",
'bert-base-finnish-cased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json",
'bert-base-finnish-uncased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json",
"bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
"bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
"bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
"bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
"bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
"bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
"bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
"bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
"bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
"bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
"bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json",
"bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
"bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
"bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json",
"bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json",
"bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-config.json",
"bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-config.json",
"bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-config.json",
"bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-config.json",
"bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json",
"bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json",
}
......@@ -82,7 +80,8 @@ class BertConfig(PretrainedConfig):
"""
pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self,
def __init__(
self,
vocab_size=30522,
hidden_size=768,
num_hidden_layers=12,
......@@ -95,7 +94,8 @@ class BertConfig(PretrainedConfig):
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
**kwargs):
**kwargs
):
super(BertConfig, self).__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
......
......@@ -15,17 +15,17 @@
# limitations under the License.
""" CamemBERT configuration """
from __future__ import (absolute_import, division, print_function,
unicode_literals)
from __future__ import absolute_import, division, print_function, unicode_literals
import logging
from .configuration_roberta import RobertaConfig
logger = logging.getLogger(__name__)
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'camembert-base': "https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-config.json",
"camembert-base": "https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-config.json",
}
......
......@@ -16,17 +16,16 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import json
import logging
import sys
from io import open
from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__)
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP = {"ctrl": "https://storage.googleapis.com/sf-ctrl/pytorch/ctrl-config.json"}
class CTRLConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `CTRLModel`.
......@@ -48,6 +47,7 @@ class CTRLConfig(PretrainedConfig):
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
"""
pretrained_config_archive_map = CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(
......@@ -64,7 +64,7 @@ class CTRLConfig(PretrainedConfig):
attn_pdrop=0.1,
layer_norm_epsilon=1e-6,
initializer_range=0.02,
summary_type='cls_index',
summary_type="cls_index",
summary_use_proj=True,
summary_activation=None,
summary_proj_to_labels=True,
......
......@@ -13,45 +13,44 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" DistilBERT model configuration """
from __future__ import (absolute_import, division, print_function,
unicode_literals)
from __future__ import absolute_import, division, print_function, unicode_literals
import sys
import json
import logging
from io import open
from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__)
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'distilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json",
'distilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-config.json",
'distilbert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-german-cased-config.json",
'distilbert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-multilingual-cased-config.json",
"distilbert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json",
"distilbert-base-uncased-distilled-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-config.json",
"distilbert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-german-cased-config.json",
"distilbert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-multilingual-cased-config.json",
}
class DistilBertConfig(PretrainedConfig):
pretrained_config_archive_map = DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self,
def __init__(
self,
vocab_size=30522,
max_position_embeddings=512,
sinusoidal_pos_embds=False,
n_layers=6,
n_heads=12,
dim=768,
hidden_dim=4*768,
hidden_dim=4 * 768,
dropout=0.1,
attention_dropout=0.1,
activation='gelu',
activation="gelu",
initializer_range=0.02,
tie_weights_=True,
qa_dropout=0.1,
seq_classif_dropout=0.2,
**kwargs):
**kwargs
):
super(DistilBertConfig, self).__init__(**kwargs)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
......
......@@ -17,20 +17,21 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import json
import logging
import sys
from io import open
from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__)
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json",
"gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json",
"gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-config.json",
"distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-config.json",}
"distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-config.json",
}
class GPT2Config(PretrainedConfig):
"""Configuration class to store the configuration of a `GPT2Model`.
......@@ -52,6 +53,7 @@ class GPT2Config(PretrainedConfig):
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
"""
pretrained_config_archive_map = GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(
......@@ -67,7 +69,7 @@ class GPT2Config(PretrainedConfig):
attn_pdrop=0.1,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
summary_type='cls_index',
summary_type="cls_index",
summary_use_proj=True,
summary_activation=None,
summary_proj_to_labels=True,
......
......@@ -15,11 +15,11 @@
# limitations under the License.
""" MMBT configuration """
from __future__ import (absolute_import, division, print_function,
unicode_literals)
from __future__ import absolute_import, division, print_function, unicode_literals
import logging
logger = logging.getLogger(__name__)
......@@ -31,6 +31,7 @@ class MMBTConfig(object):
num_labels: Size of final Linear layer for classification.
modal_hidden_size: Embedding dimension of the non-text modality encoder.
"""
def __init__(self, config, num_labels=None, modal_hidden_size=2048):
self.__dict__ = config.__dict__
self.modal_hidden_size = modal_hidden_size
......
......@@ -17,19 +17,18 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import json
import logging
import sys
from io import open
from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__)
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"
}
class OpenAIGPTConfig(PretrainedConfig):
"""
Configuration class to store the configuration of a `OpenAIGPTModel`.
......@@ -54,6 +53,7 @@ class OpenAIGPTConfig(PretrainedConfig):
initializing all weight matrices.
predict_special_tokens: should we predict special tokens (when the model has a LM head)
"""
pretrained_config_archive_map = OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(
......@@ -71,7 +71,7 @@ class OpenAIGPTConfig(PretrainedConfig):
layer_norm_epsilon=1e-5,
initializer_range=0.02,
predict_special_tokens=True,
summary_type='cls_index',
summary_type="cls_index",
summary_use_proj=True,
summary_activation=None,
summary_proj_to_labels=True,
......
......@@ -15,22 +15,22 @@
# limitations under the License.
""" RoBERTa configuration """
from __future__ import (absolute_import, division, print_function,
unicode_literals)
from __future__ import absolute_import, division, print_function, unicode_literals
import logging
from .configuration_bert import BertConfig
logger = logging.getLogger(__name__)
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json",
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-config.json",
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-config.json",
'distilroberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-config.json",
'roberta-base-openai-detector': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-openai-detector-config.json",
'roberta-large-openai-detector': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-openai-detector-config.json",
"roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json",
"roberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-config.json",
"roberta-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-config.json",
"distilroberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-config.json",
"roberta-base-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-openai-detector-config.json",
"roberta-large-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-openai-detector-config.json",
}
......
......@@ -16,22 +16,19 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import json
import logging
import sys
import six
from io import open
from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__)
T5_PRETRAINED_CONFIG_ARCHIVE_MAP = {
't5-small': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-small-config.json",
't5-base': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-base-config.json",
't5-large': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-large-config.json",
't5-3b': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-3b-config.json",
't5-11b': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-11b-config.json",
"t5-small": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-small-config.json",
"t5-base": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-base-config.json",
"t5-large": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-large-config.json",
"t5-3b": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-3b-config.json",
"t5-11b": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-11b-config.json",
}
......@@ -65,7 +62,8 @@ class T5Config(PretrainedConfig):
"""
pretrained_config_archive_map = T5_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self,
def __init__(
self,
vocab_size=32128,
n_positions=512,
d_model=512,
......@@ -77,7 +75,8 @@ class T5Config(PretrainedConfig):
dropout_rate=0.1,
layer_norm_epsilon=1e-6,
initializer_factor=1.0,
**kwargs):
**kwargs
):
super(T5Config, self).__init__(**kwargs)
self.vocab_size = vocab_size
self.n_positions = n_positions
......
......@@ -17,19 +17,18 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import json
import logging
import sys
from io import open
from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__)
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json",
"transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json",
}
class TransfoXLConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `TransfoXLModel`.
......@@ -65,9 +64,11 @@ class TransfoXLConfig(PretrainedConfig):
proj_init_std: parameters initialized by N(0, init_std)
init_std: parameters initialized by N(0, init_std)
"""
pretrained_config_archive_map = TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self,
def __init__(
self,
vocab_size=267735,
cutoffs=[20000, 40000, 200000],
d_model=1024,
......@@ -96,7 +97,8 @@ class TransfoXLConfig(PretrainedConfig):
proj_init_std=0.01,
init_std=0.02,
layer_norm_epsilon=1e-5,
**kwargs):
**kwargs
):
"""Constructs TransfoXLConfig.
"""
super(TransfoXLConfig, self).__init__(**kwargs)
......
......@@ -15,8 +15,7 @@
# limitations under the License.
""" Configuration base class and utilities."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
from __future__ import absolute_import, division, print_function, unicode_literals
import copy
import json
......@@ -24,10 +23,12 @@ import logging
import os
from io import open
from .file_utils import CONFIG_NAME, cached_path, is_remote_url, hf_bucket_url
from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url
logger = logging.getLogger(__name__)
class PretrainedConfig(object):
r""" Base class for all configuration classes.
Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations.
......@@ -50,36 +51,36 @@ class PretrainedConfig(object):
def __init__(self, **kwargs):
# Attributes with defaults
self.output_attentions = kwargs.pop('output_attentions', False)
self.output_hidden_states = kwargs.pop('output_hidden_states', False)
self.output_past = kwargs.pop('output_past', True) # Not used by all models
self.torchscript = kwargs.pop('torchscript', False) # Only used by PyTorch models
self.use_bfloat16 = kwargs.pop('use_bfloat16', False)
self.pruned_heads = kwargs.pop('pruned_heads', {})
self.output_attentions = kwargs.pop("output_attentions", False)
self.output_hidden_states = kwargs.pop("output_hidden_states", False)
self.output_past = kwargs.pop("output_past", True) # Not used by all models
self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
self.pruned_heads = kwargs.pop("pruned_heads", {})
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
self.is_decoder = kwargs.pop('is_decoder', False)
self.is_decoder = kwargs.pop("is_decoder", False)
# Parameters for sequence generation
self.max_length = kwargs.pop('max_length', 20)
self.do_sample = kwargs.pop('do_sample', False)
self.num_beams = kwargs.pop('num_beams', 1)
self.temperature = kwargs.pop('temperature', 1.0)
self.top_k = kwargs.pop('top_k', 50)
self.top_p = kwargs.pop('top_p', 1.0)
self.repetition_penalty = kwargs.pop('repetition_penalty', 1.0)
self.bos_token_id = kwargs.pop('bos_token_id', 0)
self.pad_token_id = kwargs.pop('pad_token_id', 0)
self.eos_token_ids = kwargs.pop('eos_token_ids', 0)
self.length_penalty = kwargs.pop('length_penalty', 1.)
self.num_return_sequences = kwargs.pop('num_return_sequences', 1)
self.max_length = kwargs.pop("max_length", 20)
self.do_sample = kwargs.pop("do_sample", False)
self.num_beams = kwargs.pop("num_beams", 1)
self.temperature = kwargs.pop("temperature", 1.0)
self.top_k = kwargs.pop("top_k", 50)
self.top_p = kwargs.pop("top_p", 1.0)
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
self.bos_token_id = kwargs.pop("bos_token_id", 0)
self.pad_token_id = kwargs.pop("pad_token_id", 0)
self.eos_token_ids = kwargs.pop("eos_token_ids", 0)
self.length_penalty = kwargs.pop("length_penalty", 1.0)
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
# Fine-tuning task arguments
self.finetuning_task = kwargs.pop('finetuning_task', None)
self.num_labels = kwargs.pop('num_labels', 2)
self.id2label = kwargs.pop('id2label', {i: 'LABEL_{}'.format(i) for i in range(self.num_labels)})
self.finetuning_task = kwargs.pop("finetuning_task", None)
self.num_labels = kwargs.pop("num_labels", 2)
self.id2label = kwargs.pop("id2label", {i: "LABEL_{}".format(i) for i in range(self.num_labels)})
self.id2label = dict((int(key), value) for key, value in self.id2label.items())
self.label2id = kwargs.pop('label2id', dict(zip(self.id2label.values(), self.id2label.keys())))
self.label2id = kwargs.pop("label2id", dict(zip(self.id2label.values(), self.id2label.keys())))
self.label2id = dict((key, int(value)) for key, value in self.label2id.items())
# Additional attributes without default values
......@@ -94,7 +95,9 @@ class PretrainedConfig(object):
""" Save a configuration object to the directory `save_directory`, so that it
can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method.
"""
assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved"
assert os.path.isdir(
save_directory
), "Saving path should be a directory where the model and configuration can be saved"
# If we save using the predefined names, we can load using `from_pretrained`
output_config_file = os.path.join(save_directory, CONFIG_NAME)
......@@ -153,11 +156,11 @@ class PretrainedConfig(object):
assert unused_kwargs == {'foo': False}
"""
cache_dir = kwargs.pop('cache_dir', None)
force_download = kwargs.pop('force_download', False)
resume_download = kwargs.pop('resume_download', False)
proxies = kwargs.pop('proxies', None)
return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
......@@ -170,37 +173,48 @@ class PretrainedConfig(object):
try:
# Load from URL or cache if already cached
resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download,
proxies=proxies, resume_download=resume_download)
resolved_config_file = cached_path(
config_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
)
# Load config
config = cls.from_json_file(resolved_config_file)
except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
config_file)
config_file
)
else:
msg = "Model name '{}' was not found in model name list ({}). " \
"We assumed '{}' was a path or url to a configuration file named {} or " \
msg = (
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url to a configuration file named {} or "
"a directory containing such a file but couldn't find any such file at this path or url.".format(
pretrained_model_name_or_path,
', '.join(cls.pretrained_config_archive_map.keys()),
config_file, CONFIG_NAME)
", ".join(cls.pretrained_config_archive_map.keys()),
config_file,
CONFIG_NAME,
)
)
raise EnvironmentError(msg)
except json.JSONDecodeError:
msg = "Couldn't reach server at '{}' to download configuration file or " \
"configuration file is not a valid JSON file. " \
msg = (
"Couldn't reach server at '{}' to download configuration file or "
"configuration file is not a valid JSON file. "
"Please check network or file content here: {}.".format(config_file, resolved_config_file)
)
raise EnvironmentError(msg)
if resolved_config_file == config_file:
logger.info("loading configuration file {}".format(config_file))
else:
logger.info("loading configuration file {} from cache at {}".format(
config_file, resolved_config_file))
logger.info("loading configuration file {} from cache at {}".format(config_file, resolved_config_file))
if hasattr(config, 'pruned_heads'):
if hasattr(config, "pruned_heads"):
config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
# Update config with kwargs if needed
......@@ -226,7 +240,7 @@ class PretrainedConfig(object):
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `Config` from a json file of parameters."""
with open(json_file, "r", encoding='utf-8') as reader:
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
dict_obj = json.loads(text)
return cls(**dict_obj)
......@@ -248,5 +262,5 @@ class PretrainedConfig(object):
def to_json_file(self, json_file_path):
""" Save this instance to a json file."""
with open(json_file_path, "w", encoding='utf-8') as writer:
with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string())
......@@ -15,26 +15,24 @@
""" XLM configuration """
from __future__ import absolute_import, division, print_function, unicode_literals
import json
import logging
import sys
from io import open
from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__)
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json",
'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-config.json",
'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-config.json",
'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-config.json",
'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-config.json",
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-config.json",
'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-config.json",
'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-config.json",
'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-config.json",
'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-config.json",
"xlm-mlm-en-2048": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json",
"xlm-mlm-ende-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-config.json",
"xlm-mlm-enfr-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-config.json",
"xlm-mlm-enro-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-config.json",
"xlm-mlm-tlm-xnli15-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-config.json",
"xlm-mlm-xnli15-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-config.json",
"xlm-clm-enfr-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-config.json",
"xlm-clm-ende-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-config.json",
"xlm-mlm-17-1280": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-config.json",
"xlm-mlm-100-1280": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-config.json",
}
......@@ -78,9 +76,11 @@ class XLMConfig(PretrainedConfig):
-1 means no clamping.
same_length: bool, whether to use the same attention length for each token.
"""
pretrained_config_archive_map = XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self,
def __init__(
self,
vocab_size=30145,
emb_dim=2048,
n_layers=12,
......@@ -103,7 +103,7 @@ class XLMConfig(PretrainedConfig):
unk_index=3,
mask_index=5,
is_encoder=True,
summary_type='first',
summary_type="first",
summary_use_proj=True,
summary_activation=None,
summary_proj_to_labels=True,
......@@ -112,7 +112,8 @@ class XLMConfig(PretrainedConfig):
end_n_top=5,
mask_token_id=0,
lang_id=0,
**kwargs):
**kwargs
):
"""Constructs XLMConfig.
"""
super(XLMConfig, self).__init__(**kwargs)
......
......@@ -15,22 +15,22 @@
# limitations under the License.
""" XLM-RoBERTa configuration """
from __future__ import (absolute_import, division, print_function,
unicode_literals)
from __future__ import absolute_import, division, print_function, unicode_literals
import logging
from .configuration_roberta import RobertaConfig
logger = logging.getLogger(__name__)
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'xlm-roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-base-config.json",
'xlm-roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-config.json",
'xlm-roberta-large-finetuned-conll02-dutch': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-dutch-config.json",
'xlm-roberta-large-finetuned-conll02-spanish': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-spanish-config.json",
'xlm-roberta-large-finetuned-conll03-english': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-english-config.json",
'xlm-roberta-large-finetuned-conll03-german': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-german-config.json",
"xlm-roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-base-config.json",
"xlm-roberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-config.json",
"xlm-roberta-large-finetuned-conll02-dutch": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-dutch-config.json",
"xlm-roberta-large-finetuned-conll02-spanish": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-spanish-config.json",
"xlm-roberta-large-finetuned-conll03-english": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-english-config.json",
"xlm-roberta-large-finetuned-conll03-german": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-german-config.json",
}
......
......@@ -16,18 +16,16 @@
""" XLNet configuration """
from __future__ import absolute_import, division, print_function, unicode_literals
import json
import logging
import sys
from io import open
from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__)
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'xlnet-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-config.json",
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json",
"xlnet-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-config.json",
"xlnet-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json",
}
......@@ -69,9 +67,11 @@ class XLNetConfig(PretrainedConfig):
same_length: bool, whether to use the same attention length for each token.
finetuning_task: name of the glue task on which the model was fine-tuned if any
"""
pretrained_config_archive_map = XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self,
def __init__(
self,
vocab_size=32000,
d_model=1024,
n_layer=24,
......@@ -88,13 +88,14 @@ class XLNetConfig(PretrainedConfig):
bi_data=False,
clamp_len=-1,
same_length=False,
summary_type='last',
summary_type="last",
summary_use_proj=True,
summary_activation='tanh',
summary_activation="tanh",
summary_last_dropout=0.1,
start_n_top=5,
end_n_top=5,
**kwargs):
**kwargs
):
"""Constructs XLNetConfig.
"""
super(XLNetConfig, self).__init__(**kwargs)
......
......@@ -14,16 +14,16 @@
# limitations under the License.
"""Convert ALBERT checkpoint."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import, division, print_function
import argparse
import logging
import torch
from transformers import AlbertConfig, AlbertForMaskedLM, load_tf_weights_in_albert
import logging
logging.basicConfig(level=logging.INFO)
......@@ -43,25 +43,20 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pyt
if __name__ == "__main__":
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--tf_checkpoint_path",
default = None,
type = str,
required = True,
help = "Path to the TensorFlow checkpoint path.")
parser.add_argument("--albert_config_file",
default = None,
type = str,
required = True,
help = "The config json file corresponding to the pre-trained ALBERT model. \n"
"This specifies the model architecture.")
parser.add_argument("--pytorch_dump_path",
default = None,
type = str,
required = True,
help = "Path to the output PyTorch model.")
# Required parameters
parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
)
parser.add_argument(
"--albert_config_file",
default=None,
type=str,
required=True,
help="The config json file corresponding to the pre-trained ALBERT model. \n"
"This specifies the model architecture.",
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
args.albert_config_file,
args.pytorch_dump_path)
\ No newline at end of file
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.albert_config_file, args.pytorch_dump_path)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment