"...resnet50_tensorflow.git" did not exist on "4b8f80c3d31dffce153e4954b309cf52afc2c51a"
Unverified Commit 54abc67a authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #2255 from aaugustin/implement-best-practices

Implement some Python best practices
parents 645713e2 c11b3e29
import logging
from argparse import ArgumentParser, Namespace from argparse import ArgumentParser, Namespace
from typing import List, Optional, Union, Any from typing import Any, List, Optional, Union
from transformers import Pipeline
from transformers.commands import BaseTransformersCLICommand
from transformers.pipelines import SUPPORTED_TASKS, pipeline
import logging
try: try:
from uvicorn import run from uvicorn import run
from fastapi import FastAPI, HTTPException, Body from fastapi import FastAPI, HTTPException, Body
from pydantic import BaseModel from pydantic import BaseModel
_serve_dependancies_installed = True _serve_dependancies_installed = True
except (ImportError, AttributeError): except (ImportError, AttributeError):
BaseModel = object BaseModel = object
Body = lambda *x, **y: None
def Body(*x, **y):
pass
_serve_dependancies_installed = False _serve_dependancies_installed = False
from transformers import Pipeline
from transformers.commands import BaseTransformersCLICommand
from transformers.pipelines import SUPPORTED_TASKS, pipeline
logger = logging.getLogger('transformers-cli/serving') logger = logging.getLogger("transformers-cli/serving")
def serve_command_factory(args: Namespace): def serve_command_factory(args: Namespace):
""" """
Factory function used to instantiate serving server from provided command line arguments. Factory function used to instantiate serving server from provided command line arguments.
:return: ServeCommand :return: ServeCommand
""" """
nlp = pipeline(task=args.task, nlp = pipeline(
model=args.model if args.model else None, task=args.task,
config=args.config, model=args.model if args.model else None,
tokenizer=args.tokenizer, config=args.config,
device=args.device) tokenizer=args.tokenizer,
device=args.device,
)
return ServeCommand(nlp, args.host, args.port) return ServeCommand(nlp, args.host, args.port)
...@@ -36,6 +44,7 @@ class ServeModelInfoResult(BaseModel): ...@@ -36,6 +44,7 @@ class ServeModelInfoResult(BaseModel):
""" """
Expose model information Expose model information
""" """
infos: dict infos: dict
...@@ -43,6 +52,7 @@ class ServeTokenizeResult(BaseModel): ...@@ -43,6 +52,7 @@ class ServeTokenizeResult(BaseModel):
""" """
Tokenize result model Tokenize result model
""" """
tokens: List[str] tokens: List[str]
tokens_ids: Optional[List[int]] tokens_ids: Optional[List[int]]
...@@ -51,6 +61,7 @@ class ServeDeTokenizeResult(BaseModel): ...@@ -51,6 +61,7 @@ class ServeDeTokenizeResult(BaseModel):
""" """
DeTokenize result model DeTokenize result model
""" """
text: str text: str
...@@ -58,11 +69,11 @@ class ServeForwardResult(BaseModel): ...@@ -58,11 +69,11 @@ class ServeForwardResult(BaseModel):
""" """
Forward result model Forward result model
""" """
output: Any output: Any
class ServeCommand(BaseTransformersCLICommand): class ServeCommand(BaseTransformersCLICommand):
@staticmethod @staticmethod
def register_subcommand(parser: ArgumentParser): def register_subcommand(parser: ArgumentParser):
""" """
...@@ -70,14 +81,23 @@ class ServeCommand(BaseTransformersCLICommand): ...@@ -70,14 +81,23 @@ class ServeCommand(BaseTransformersCLICommand):
:param parser: Root parser to register command-specific arguments :param parser: Root parser to register command-specific arguments
:return: :return:
""" """
serve_parser = parser.add_parser('serve', help='CLI tool to run inference requests through REST and GraphQL endpoints.') serve_parser = parser.add_parser(
serve_parser.add_argument('--task', type=str, choices=SUPPORTED_TASKS.keys(), help='The task to run the pipeline on') "serve", help="CLI tool to run inference requests through REST and GraphQL endpoints."
serve_parser.add_argument('--host', type=str, default='localhost', help='Interface the server will listen on.') )
serve_parser.add_argument('--port', type=int, default=8888, help='Port the serving will listen to.') serve_parser.add_argument(
serve_parser.add_argument('--model', type=str, help='Model\'s name or path to stored model.') "--task", type=str, choices=SUPPORTED_TASKS.keys(), help="The task to run the pipeline on"
serve_parser.add_argument('--config', type=str, help='Model\'s config name or path to stored model.') )
serve_parser.add_argument('--tokenizer', type=str, help='Tokenizer name to use.') serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.")
serve_parser.add_argument('--device', type=int, default=-1, help='Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)') serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.")
serve_parser.add_argument("--model", type=str, help="Model's name or path to stored model.")
serve_parser.add_argument("--config", type=str, help="Model's config name or path to stored model.")
serve_parser.add_argument("--tokenizer", type=str, help="Tokenizer name to use.")
serve_parser.add_argument(
"--device",
type=int,
default=-1,
help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)",
)
serve_parser.set_defaults(func=serve_command_factory) serve_parser.set_defaults(func=serve_command_factory)
def __init__(self, pipeline: Pipeline, host: str, port: int): def __init__(self, pipeline: Pipeline, host: str, port: int):
...@@ -87,18 +107,22 @@ class ServeCommand(BaseTransformersCLICommand): ...@@ -87,18 +107,22 @@ class ServeCommand(BaseTransformersCLICommand):
self._host = host self._host = host
self._port = port self._port = port
if not _serve_dependancies_installed: if not _serve_dependancies_installed:
raise ImportError("Using serve command requires FastAPI and unicorn. " raise ImportError(
"Please install transformers with [serving]: pip install transformers[serving]." "Using serve command requires FastAPI and unicorn. "
"Or install FastAPI and unicorn separatly.") "Please install transformers with [serving]: pip install transformers[serving]."
"Or install FastAPI and unicorn separatly."
)
else: else:
logger.info('Serving model over {}:{}'.format(host, port)) logger.info("Serving model over {}:{}".format(host, port))
self._app = FastAPI() self._app = FastAPI()
# Register routes # Register routes
self._app.add_api_route('/', self.model_info, response_model=ServeModelInfoResult, methods=['GET']) self._app.add_api_route("/", self.model_info, response_model=ServeModelInfoResult, methods=["GET"])
self._app.add_api_route('/tokenize', self.tokenize, response_model=ServeTokenizeResult, methods=['POST']) self._app.add_api_route("/tokenize", self.tokenize, response_model=ServeTokenizeResult, methods=["POST"])
self._app.add_api_route('/detokenize', self.detokenize, response_model=ServeDeTokenizeResult, methods=['POST']) self._app.add_api_route(
self._app.add_api_route('/forward', self.forward, response_model=ServeForwardResult, methods=['POST']) "/detokenize", self.detokenize, response_model=ServeDeTokenizeResult, methods=["POST"]
)
self._app.add_api_route("/forward", self.forward, response_model=ServeForwardResult, methods=["POST"])
def run(self): def run(self):
run(self._app, host=self._host, port=self._port) run(self._app, host=self._host, port=self._port)
...@@ -122,11 +146,14 @@ class ServeCommand(BaseTransformersCLICommand): ...@@ -122,11 +146,14 @@ class ServeCommand(BaseTransformersCLICommand):
return ServeTokenizeResult(tokens=tokens_txt) return ServeTokenizeResult(tokens=tokens_txt)
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail={"model": '', "error": str(e)}) raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
def detokenize(self, tokens_ids: List[int] = Body(None, embed=True), def detokenize(
skip_special_tokens: bool = Body(False, embed=True), self,
cleanup_tokenization_spaces: bool = Body(True, embed=True)): tokens_ids: List[int] = Body(None, embed=True),
skip_special_tokens: bool = Body(False, embed=True),
cleanup_tokenization_spaces: bool = Body(True, embed=True),
):
""" """
Detokenize the provided tokens ids to readable text: Detokenize the provided tokens ids to readable text:
- **tokens_ids**: List of tokens ids - **tokens_ids**: List of tokens ids
...@@ -135,9 +162,9 @@ class ServeCommand(BaseTransformersCLICommand): ...@@ -135,9 +162,9 @@ class ServeCommand(BaseTransformersCLICommand):
""" """
try: try:
decoded_str = self._pipeline.tokenizer.decode(tokens_ids, skip_special_tokens, cleanup_tokenization_spaces) decoded_str = self._pipeline.tokenizer.decode(tokens_ids, skip_special_tokens, cleanup_tokenization_spaces)
return ServeDeTokenizeResult(model='', text=decoded_str) return ServeDeTokenizeResult(model="", text=decoded_str)
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail={"model": '', "error": str(e)}) raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
def forward(self, inputs: Union[str, dict, List[str], List[int], List[dict]] = Body(None, embed=True)): def forward(self, inputs: Union[str, dict, List[str], List[int], List[dict]] = Body(None, embed=True)):
""" """
......
...@@ -2,10 +2,10 @@ import os ...@@ -2,10 +2,10 @@ import os
from argparse import ArgumentParser, Namespace from argparse import ArgumentParser, Namespace
from logging import getLogger from logging import getLogger
from transformers import SingleSentenceClassificationProcessor as Processor
from transformers import TextClassificationPipeline, is_tf_available, is_torch_available
from transformers.commands import BaseTransformersCLICommand from transformers.commands import BaseTransformersCLICommand
from transformers import (is_tf_available, is_torch_available,
TextClassificationPipeline,
SingleSentenceClassificationProcessor as Processor)
if not is_tf_available() and not is_torch_available(): if not is_tf_available() and not is_torch_available():
raise ImportError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training") raise ImportError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training")
...@@ -14,6 +14,7 @@ if not is_tf_available() and not is_torch_available(): ...@@ -14,6 +14,7 @@ if not is_tf_available() and not is_torch_available():
USE_XLA = False USE_XLA = False
USE_AMP = False USE_AMP = False
def train_command_factory(args: Namespace): def train_command_factory(args: Namespace):
""" """
Factory function used to instantiate serving server from provided command line arguments. Factory function used to instantiate serving server from provided command line arguments.
...@@ -23,7 +24,6 @@ def train_command_factory(args: Namespace): ...@@ -23,7 +24,6 @@ def train_command_factory(args: Namespace):
class TrainCommand(BaseTransformersCLICommand): class TrainCommand(BaseTransformersCLICommand):
@staticmethod @staticmethod
def register_subcommand(parser: ArgumentParser): def register_subcommand(parser: ArgumentParser):
""" """
...@@ -31,47 +31,54 @@ class TrainCommand(BaseTransformersCLICommand): ...@@ -31,47 +31,54 @@ class TrainCommand(BaseTransformersCLICommand):
:param parser: Root parser to register command-specific arguments :param parser: Root parser to register command-specific arguments
:return: :return:
""" """
train_parser = parser.add_parser('train', help='CLI tool to train a model on a task.') train_parser = parser.add_parser("train", help="CLI tool to train a model on a task.")
train_parser.add_argument('--train_data', type=str, required=True, train_parser.add_argument(
help="path to train (and optionally evaluation) dataset as a csv with " "--train_data",
"tab separated labels and sentences.") type=str,
train_parser.add_argument('--column_label', type=int, default=0, required=True,
help='Column of the dataset csv file with example labels.') help="path to train (and optionally evaluation) dataset as a csv with "
train_parser.add_argument('--column_text', type=int, default=1, "tab separated labels and sentences.",
help='Column of the dataset csv file with example texts.') )
train_parser.add_argument('--column_id', type=int, default=2, train_parser.add_argument(
help='Column of the dataset csv file with example ids.') "--column_label", type=int, default=0, help="Column of the dataset csv file with example labels."
train_parser.add_argument('--skip_first_row', action='store_true', )
help='Skip the first row of the csv file (headers).') train_parser.add_argument(
"--column_text", type=int, default=1, help="Column of the dataset csv file with example texts."
train_parser.add_argument('--validation_data', type=str, default='', )
help='path to validation dataset.') train_parser.add_argument(
train_parser.add_argument('--validation_split', type=float, default=0.1, "--column_id", type=int, default=2, help="Column of the dataset csv file with example ids."
help="if validation dataset is not provided, fraction of train dataset " )
"to use as validation dataset.") train_parser.add_argument(
"--skip_first_row", action="store_true", help="Skip the first row of the csv file (headers)."
train_parser.add_argument('--output', type=str, default='./', )
help='path to saved the trained model.')
train_parser.add_argument("--validation_data", type=str, default="", help="path to validation dataset.")
train_parser.add_argument('--task', type=str, default='text_classification', train_parser.add_argument(
help='Task to train the model on.') "--validation_split",
train_parser.add_argument('--model', type=str, default='bert-base-uncased', type=float,
help='Model\'s name or path to stored model.') default=0.1,
train_parser.add_argument('--train_batch_size', type=int, default=32, help="if validation dataset is not provided, fraction of train dataset " "to use as validation dataset.",
help='Batch size for training.') )
train_parser.add_argument('--valid_batch_size', type=int, default=64,
help='Batch size for validation.') train_parser.add_argument("--output", type=str, default="./", help="path to saved the trained model.")
train_parser.add_argument('--learning_rate', type=float, default=3e-5,
help="Learning rate.") train_parser.add_argument(
train_parser.add_argument('--adam_epsilon', type=float, default=1e-08, "--task", type=str, default="text_classification", help="Task to train the model on."
help="Epsilon for Adam optimizer.") )
train_parser.add_argument(
"--model", type=str, default="bert-base-uncased", help="Model's name or path to stored model."
)
train_parser.add_argument("--train_batch_size", type=int, default=32, help="Batch size for training.")
train_parser.add_argument("--valid_batch_size", type=int, default=64, help="Batch size for validation.")
train_parser.add_argument("--learning_rate", type=float, default=3e-5, help="Learning rate.")
train_parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon for Adam optimizer.")
train_parser.set_defaults(func=train_command_factory) train_parser.set_defaults(func=train_command_factory)
def __init__(self, args: Namespace): def __init__(self, args: Namespace):
self.logger = getLogger('transformers-cli/training') self.logger = getLogger("transformers-cli/training")
self.framework = 'tf' if is_tf_available() else 'torch' self.framework = "tf" if is_tf_available() else "torch"
os.makedirs(args.output, exist_ok=True) os.makedirs(args.output, exist_ok=True)
assert os.path.isdir(args.output) assert os.path.isdir(args.output)
...@@ -81,28 +88,32 @@ class TrainCommand(BaseTransformersCLICommand): ...@@ -81,28 +88,32 @@ class TrainCommand(BaseTransformersCLICommand):
self.column_text = args.column_text self.column_text = args.column_text
self.column_id = args.column_id self.column_id = args.column_id
self.logger.info('Loading {} pipeline for {}'.format(args.task, args.model)) self.logger.info("Loading {} pipeline for {}".format(args.task, args.model))
if args.task == 'text_classification': if args.task == "text_classification":
self.pipeline = TextClassificationPipeline.from_pretrained(args.model) self.pipeline = TextClassificationPipeline.from_pretrained(args.model)
elif args.task == 'token_classification': elif args.task == "token_classification":
raise NotImplementedError raise NotImplementedError
elif args.task == 'question_answering': elif args.task == "question_answering":
raise NotImplementedError raise NotImplementedError
self.logger.info('Loading dataset from {}'.format(args.train_data)) self.logger.info("Loading dataset from {}".format(args.train_data))
self.train_dataset = Processor.create_from_csv(args.train_data, self.train_dataset = Processor.create_from_csv(
column_label=args.column_label, args.train_data,
column_text=args.column_text, column_label=args.column_label,
column_id=args.column_id, column_text=args.column_text,
skip_first_row=args.skip_first_row) column_id=args.column_id,
skip_first_row=args.skip_first_row,
)
self.valid_dataset = None self.valid_dataset = None
if args.validation_data: if args.validation_data:
self.logger.info('Loading validation dataset from {}'.format(args.validation_data)) self.logger.info("Loading validation dataset from {}".format(args.validation_data))
self.valid_dataset = Processor.create_from_csv(args.validation_data, self.valid_dataset = Processor.create_from_csv(
column_label=args.column_label, args.validation_data,
column_text=args.column_text, column_label=args.column_label,
column_id=args.column_id, column_text=args.column_text,
skip_first_row=args.skip_first_row) column_id=args.column_id,
skip_first_row=args.skip_first_row,
)
self.validation_split = args.validation_split self.validation_split = args.validation_split
self.train_batch_size = args.train_batch_size self.train_batch_size = args.train_batch_size
...@@ -111,7 +122,7 @@ class TrainCommand(BaseTransformersCLICommand): ...@@ -111,7 +122,7 @@ class TrainCommand(BaseTransformersCLICommand):
self.adam_epsilon = args.adam_epsilon self.adam_epsilon = args.adam_epsilon
def run(self): def run(self):
if self.framework == 'tf': if self.framework == "tf":
return self.run_tf() return self.run_tf()
return self.run_torch() return self.run_torch()
...@@ -119,13 +130,15 @@ class TrainCommand(BaseTransformersCLICommand): ...@@ -119,13 +130,15 @@ class TrainCommand(BaseTransformersCLICommand):
raise NotImplementedError raise NotImplementedError
def run_tf(self): def run_tf(self):
self.pipeline.fit(self.train_dataset, self.pipeline.fit(
validation_data=self.valid_dataset, self.train_dataset,
validation_split=self.validation_split, validation_data=self.valid_dataset,
learning_rate=self.learning_rate, validation_split=self.validation_split,
adam_epsilon=self.adam_epsilon, learning_rate=self.learning_rate,
train_batch_size=self.train_batch_size, adam_epsilon=self.adam_epsilon,
valid_batch_size=self.valid_batch_size) train_batch_size=self.train_batch_size,
valid_batch_size=self.valid_batch_size,
)
# Save trained pipeline # Save trained pipeline
self.pipeline.save_pretrained(self.output) self.pipeline.save_pretrained(self.output)
import os
from argparse import ArgumentParser from argparse import ArgumentParser
from getpass import getpass from getpass import getpass
import os from typing import List, Union
from requests.exceptions import HTTPError
from transformers.commands import BaseTransformersCLICommand from transformers.commands import BaseTransformersCLICommand
from transformers.hf_api import HfApi, HfFolder, HTTPError from transformers.hf_api import HfApi, HfFolder
class UserCommands(BaseTransformersCLICommand): class UserCommands(BaseTransformersCLICommand):
@staticmethod @staticmethod
def register_subcommand(parser: ArgumentParser): def register_subcommand(parser: ArgumentParser):
login_parser = parser.add_parser('login') login_parser = parser.add_parser("login")
login_parser.set_defaults(func=lambda args: LoginCommand(args)) login_parser.set_defaults(func=lambda args: LoginCommand(args))
whoami_parser = parser.add_parser('whoami') whoami_parser = parser.add_parser("whoami")
whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args)) whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args))
logout_parser = parser.add_parser('logout') logout_parser = parser.add_parser("logout")
logout_parser.set_defaults(func=lambda args: LogoutCommand(args)) logout_parser.set_defaults(func=lambda args: LogoutCommand(args))
list_parser = parser.add_parser('ls') list_parser = parser.add_parser("ls")
list_parser.set_defaults(func=lambda args: ListObjsCommand(args)) list_parser.set_defaults(func=lambda args: ListObjsCommand(args))
# upload # upload
upload_parser = parser.add_parser('upload') upload_parser = parser.add_parser("upload")
upload_parser.add_argument('path', type=str, help='Local path of the folder or individual file to upload.') upload_parser.add_argument("path", type=str, help="Local path of the folder or individual file to upload.")
upload_parser.add_argument('--filename', type=str, default=None, help='Optional: override individual object filename on S3.') upload_parser.add_argument(
"--filename", type=str, default=None, help="Optional: override individual object filename on S3."
)
upload_parser.set_defaults(func=lambda args: UploadCommand(args)) upload_parser.set_defaults(func=lambda args: UploadCommand(args))
class ANSI: class ANSI:
""" """
Helper for en.wikipedia.org/wiki/ANSI_escape_code Helper for en.wikipedia.org/wiki/ANSI_escape_code
""" """
_bold = u"\u001b[1m" _bold = u"\u001b[1m"
_reset = u"\u001b[0m" _reset = u"\u001b[0m"
@classmethod @classmethod
def bold(cls, s): def bold(cls, s):
return "{}{}{}".format(cls._bold, s, cls._reset) return "{}{}{}".format(cls._bold, s, cls._reset)
...@@ -44,14 +50,16 @@ class BaseUserCommand: ...@@ -44,14 +50,16 @@ class BaseUserCommand:
class LoginCommand(BaseUserCommand): class LoginCommand(BaseUserCommand):
def run(self): def run(self):
print(""" print(
_| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_| """
_| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _| _| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|
_|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
_| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _| _|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_|
_| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_| _| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
_| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_|
""")
"""
)
username = input("Username: ") username = input("Username: ")
password = getpass() password = getpass()
try: try:
...@@ -91,8 +99,7 @@ class LogoutCommand(BaseUserCommand): ...@@ -91,8 +99,7 @@ class LogoutCommand(BaseUserCommand):
class ListObjsCommand(BaseUserCommand): class ListObjsCommand(BaseUserCommand):
def tabulate(self, rows, headers): def tabulate(self, rows: List[List[Union[str, int]]], headers: List[str]) -> str:
# type: (List[List[Union[str, int]]], List[str]) -> str
""" """
Inspired by: Inspired by:
stackoverflow.com/a/8356620/593036 stackoverflow.com/a/8356620/593036
...@@ -101,16 +108,10 @@ class ListObjsCommand(BaseUserCommand): ...@@ -101,16 +108,10 @@ class ListObjsCommand(BaseUserCommand):
col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)] col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
row_format = ("{{:{}}} " * len(headers)).format(*col_widths) row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
lines = [] lines = []
lines.append( lines.append(row_format.format(*headers))
row_format.format(*headers) lines.append(row_format.format(*["-" * w for w in col_widths]))
)
lines.append(
row_format.format(*["-" * w for w in col_widths])
)
for row in rows: for row in rows:
lines.append( lines.append(row_format.format(*row))
row_format.format(*row)
)
return "\n".join(lines) return "\n".join(lines)
def run(self): def run(self):
...@@ -126,15 +127,8 @@ class ListObjsCommand(BaseUserCommand): ...@@ -126,15 +127,8 @@ class ListObjsCommand(BaseUserCommand):
if len(objs) == 0: if len(objs) == 0:
print("No shared file yet") print("No shared file yet")
exit() exit()
rows = [ [ rows = [[obj.filename, obj.LastModified, obj.ETag, obj.Size] for obj in objs]
obj.filename, print(self.tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"]))
obj.LastModified,
obj.ETag,
obj.Size
] for obj in objs ]
print(
self.tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"])
)
class UploadCommand(BaseUserCommand): class UploadCommand(BaseUserCommand):
...@@ -143,13 +137,7 @@ class UploadCommand(BaseUserCommand): ...@@ -143,13 +137,7 @@ class UploadCommand(BaseUserCommand):
Recursively list all files in a folder. Recursively list all files in a folder.
""" """
entries: List[os.DirEntry] = list(os.scandir(rel_path)) entries: List[os.DirEntry] = list(os.scandir(rel_path))
files = [ files = [(os.path.join(os.getcwd(), f.path), f.path) for f in entries if f.is_file()] # filepath # filename
(
os.path.join(os.getcwd(), f.path), # filepath
f.path # filename
)
for f in entries if f.is_file()
]
for f in entries: for f in entries:
if f.is_dir(): if f.is_dir():
files += self.walk_dir(f.path) files += self.walk_dir(f.path)
...@@ -173,22 +161,14 @@ class UploadCommand(BaseUserCommand): ...@@ -173,22 +161,14 @@ class UploadCommand(BaseUserCommand):
raise ValueError("Not a valid file or directory: {}".format(local_path)) raise ValueError("Not a valid file or directory: {}".format(local_path))
for filepath, filename in files: for filepath, filename in files:
print( print("About to upload file {} to S3 under filename {}".format(ANSI.bold(filepath), ANSI.bold(filename)))
"About to upload file {} to S3 under filename {}".format(
ANSI.bold(filepath), ANSI.bold(filename)
)
)
choice = input("Proceed? [Y/n] ").lower() choice = input("Proceed? [Y/n] ").lower()
if not(choice == "" or choice == "y" or choice == "yes"): if not (choice == "" or choice == "y" or choice == "yes"):
print("Abort") print("Abort")
exit() exit()
print( print(ANSI.bold("Uploading... This might take a while if files are large"))
ANSI.bold("Uploading... This might take a while if files are large")
)
for filepath, filename in files: for filepath, filename in files:
access_url = self._api.presign_and_upload( access_url = self._api.presign_and_upload(token=token, filename=filename, filepath=filepath)
token=token, filename=filename, filepath=filepath
)
print("Your file now lives at:") print("Your file now lives at:")
print(access_url) print(access_url)
...@@ -17,17 +17,19 @@ ...@@ -17,17 +17,19 @@
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'albert-base-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-config.json", "albert-base-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-config.json",
'albert-large-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-config.json", "albert-large-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-config.json",
'albert-xlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-config.json", "albert-xlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-config.json",
'albert-xxlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-config.json", "albert-xxlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-config.json",
'albert-base-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-config.json", "albert-base-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-config.json",
'albert-large-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json", "albert-large-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json",
'albert-xlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-config.json", "albert-xlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-config.json",
'albert-xxlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-config.json", "albert-xxlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-config.json",
} }
class AlbertConfig(PretrainedConfig): class AlbertConfig(PretrainedConfig):
"""Configuration for `AlbertModel`. """Configuration for `AlbertModel`.
...@@ -36,22 +38,25 @@ class AlbertConfig(PretrainedConfig): ...@@ -36,22 +38,25 @@ class AlbertConfig(PretrainedConfig):
pretrained_config_archive_map = ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self, def __init__(
vocab_size=30000, self,
embedding_size=128, vocab_size=30000,
hidden_size=4096, embedding_size=128,
num_hidden_layers=12, hidden_size=4096,
num_hidden_groups=1, num_hidden_layers=12,
num_attention_heads=64, num_hidden_groups=1,
intermediate_size=16384, num_attention_heads=64,
inner_group_num=1, intermediate_size=16384,
hidden_act="gelu_new", inner_group_num=1,
hidden_dropout_prob=0, hidden_act="gelu_new",
attention_probs_dropout_prob=0, hidden_dropout_prob=0,
max_position_embeddings=512, attention_probs_dropout_prob=0,
type_vocab_size=2, max_position_embeddings=512,
initializer_range=0.02, type_vocab_size=2,
layer_norm_eps=1e-12, **kwargs): initializer_range=0.02,
layer_norm_eps=1e-12,
**kwargs
):
"""Constructs AlbertConfig. """Constructs AlbertConfig.
Args: Args:
......
...@@ -18,24 +18,26 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -18,24 +18,26 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import logging import logging
from .configuration_bert import BertConfig, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
from .configuration_openai import OpenAIGPTConfig, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
from .configuration_transfo_xl import TransfoXLConfig, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
from .configuration_gpt2 import GPT2Config, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
from .configuration_ctrl import CTRLConfig, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig
from .configuration_xlnet import XLNetConfig, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
from .configuration_xlm import XLMConfig, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
from .configuration_roberta import RobertaConfig, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
from .configuration_distilbert import DistilBertConfig, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
from .configuration_albert import AlbertConfig, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
from .configuration_camembert import CamembertConfig, CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP from .configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig
from .configuration_t5 import T5Config, T5_PRETRAINED_CONFIG_ARCHIVE_MAP from .configuration_xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
from .configuration_xlm_roberta import XLMRobertaConfig, XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP from .configuration_xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict((key, value) ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
(key, value)
for pretrained_map in [ for pretrained_map in [
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
...@@ -50,8 +52,9 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict((key, value) ...@@ -50,8 +52,9 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict((key, value)
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
] ]
for key, value, in pretrained_map.items()) for key, value, in pretrained_map.items()
)
class AutoConfig(object): class AutoConfig(object):
...@@ -79,37 +82,42 @@ class AutoConfig(object): ...@@ -79,37 +82,42 @@ class AutoConfig(object):
- contains `ctrl` : CTRLConfig (CTRL model) - contains `ctrl` : CTRLConfig (CTRL model)
This class cannot be instantiated using `__init__()` (throw an error). This class cannot be instantiated using `__init__()` (throw an error).
""" """
def __init__(self): def __init__(self):
raise EnvironmentError("AutoConfig is designed to be instantiated " raise EnvironmentError(
"using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method.") "AutoConfig is designed to be instantiated "
"using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method."
)
@classmethod @classmethod
def for_model(cls, model_type, *args, **kwargs): def for_model(cls, model_type, *args, **kwargs):
if 'distilbert' in model_type: if "distilbert" in model_type:
return DistilBertConfig(*args, **kwargs) return DistilBertConfig(*args, **kwargs)
elif 'roberta' in model_type: elif "roberta" in model_type:
return RobertaConfig(*args, **kwargs) return RobertaConfig(*args, **kwargs)
elif 'bert' in model_type: elif "bert" in model_type:
return BertConfig(*args, **kwargs) return BertConfig(*args, **kwargs)
elif 'openai-gpt' in model_type: elif "openai-gpt" in model_type:
return OpenAIGPTConfig(*args, **kwargs) return OpenAIGPTConfig(*args, **kwargs)
elif 'gpt2' in model_type: elif "gpt2" in model_type:
return GPT2Config(*args, **kwargs) return GPT2Config(*args, **kwargs)
elif 'transfo-xl' in model_type: elif "transfo-xl" in model_type:
return TransfoXLConfig(*args, **kwargs) return TransfoXLConfig(*args, **kwargs)
elif 'xlnet' in model_type: elif "xlnet" in model_type:
return XLNetConfig(*args, **kwargs) return XLNetConfig(*args, **kwargs)
elif 'xlm' in model_type: elif "xlm" in model_type:
return XLMConfig(*args, **kwargs) return XLMConfig(*args, **kwargs)
elif 'ctrl' in model_type: elif "ctrl" in model_type:
return CTRLConfig(*args, **kwargs) return CTRLConfig(*args, **kwargs)
elif 'albert' in model_type: elif "albert" in model_type:
return AlbertConfig(*args, **kwargs) return AlbertConfig(*args, **kwargs)
elif 'camembert' in model_type: elif "camembert" in model_type:
return CamembertConfig(*args, **kwargs) return CamembertConfig(*args, **kwargs)
raise ValueError("Unrecognized model identifier in {}. Should contains one of " raise ValueError(
"'distilbert', 'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " "Unrecognized model identifier in {}. Should contains one of "
"'xlm', 'roberta', 'ctrl', 'camembert', 'albert'".format(model_type)) "'distilbert', 'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta', 'ctrl', 'camembert', 'albert'".format(model_type)
)
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
...@@ -176,32 +184,36 @@ class AutoConfig(object): ...@@ -176,32 +184,36 @@ class AutoConfig(object):
assert unused_kwargs == {'foo': False} assert unused_kwargs == {'foo': False}
""" """
if 't5' in pretrained_model_name_or_path: if "t5" in pretrained_model_name_or_path:
return T5Config.from_pretrained(pretrained_model_name_or_path, **kwargs) return T5Config.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'distilbert' in pretrained_model_name_or_path: elif "distilbert" in pretrained_model_name_or_path:
return DistilBertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) return DistilBertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'albert' in pretrained_model_name_or_path: elif "albert" in pretrained_model_name_or_path:
return AlbertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) return AlbertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'camembert' in pretrained_model_name_or_path: elif "camembert" in pretrained_model_name_or_path:
return CamembertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) return CamembertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'xlm-roberta' in pretrained_model_name_or_path: elif "xlm-roberta" in pretrained_model_name_or_path:
return XLMRobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) return XLMRobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'roberta' in pretrained_model_name_or_path: elif "roberta" in pretrained_model_name_or_path:
return RobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) return RobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'bert' in pretrained_model_name_or_path: elif "bert" in pretrained_model_name_or_path:
return BertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) return BertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'openai-gpt' in pretrained_model_name_or_path: elif "openai-gpt" in pretrained_model_name_or_path:
return OpenAIGPTConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) return OpenAIGPTConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'gpt2' in pretrained_model_name_or_path: elif "gpt2" in pretrained_model_name_or_path:
return GPT2Config.from_pretrained(pretrained_model_name_or_path, **kwargs) return GPT2Config.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'transfo-xl' in pretrained_model_name_or_path: elif "transfo-xl" in pretrained_model_name_or_path:
return TransfoXLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) return TransfoXLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'xlnet' in pretrained_model_name_or_path: elif "xlnet" in pretrained_model_name_or_path:
return XLNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) return XLNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'xlm' in pretrained_model_name_or_path: elif "xlm" in pretrained_model_name_or_path:
return XLMConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) return XLMConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'ctrl' in pretrained_model_name_or_path: elif "ctrl" in pretrained_model_name_or_path:
return CTRLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) return CTRLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
raise ValueError("Unrecognized model identifier in {}. Should contains one of " raise ValueError(
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " "Unrecognized model identifier in {}. Should contains one of "
"'xlm-roberta', 'xlm', 'roberta', 'distilbert', 'camembert', 'ctrl', 'albert'".format(pretrained_model_name_or_path)) "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm-roberta', 'xlm', 'roberta', 'distilbert', 'camembert', 'ctrl', 'albert'".format(
pretrained_model_name_or_path
)
)
...@@ -17,37 +17,35 @@ ...@@ -17,37 +17,35 @@
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import json
import logging import logging
import sys
from io import open
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json", "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json", "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json", "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json", "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json", "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json", "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json", "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json", "bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json", "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json", "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json", "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json",
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json", "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json", "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
'bert-base-german-dbmdz-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json", "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json",
'bert-base-german-dbmdz-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json", "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json",
'bert-base-japanese': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-config.json", "bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-config.json",
'bert-base-japanese-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-config.json", "bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-config.json",
'bert-base-japanese-char': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-config.json", "bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-config.json",
'bert-base-japanese-char-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-config.json", "bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-config.json",
'bert-base-finnish-cased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json", "bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json",
'bert-base-finnish-uncased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json", "bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json",
} }
...@@ -82,20 +80,22 @@ class BertConfig(PretrainedConfig): ...@@ -82,20 +80,22 @@ class BertConfig(PretrainedConfig):
""" """
pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self, def __init__(
vocab_size=30522, self,
hidden_size=768, vocab_size=30522,
num_hidden_layers=12, hidden_size=768,
num_attention_heads=12, num_hidden_layers=12,
intermediate_size=3072, num_attention_heads=12,
hidden_act="gelu", intermediate_size=3072,
hidden_dropout_prob=0.1, hidden_act="gelu",
attention_probs_dropout_prob=0.1, hidden_dropout_prob=0.1,
max_position_embeddings=512, attention_probs_dropout_prob=0.1,
type_vocab_size=2, max_position_embeddings=512,
initializer_range=0.02, type_vocab_size=2,
layer_norm_eps=1e-12, initializer_range=0.02,
**kwargs): layer_norm_eps=1e-12,
**kwargs
):
super(BertConfig, self).__init__(**kwargs) super(BertConfig, self).__init__(**kwargs)
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
......
...@@ -15,17 +15,17 @@ ...@@ -15,17 +15,17 @@
# limitations under the License. # limitations under the License.
""" CamemBERT configuration """ """ CamemBERT configuration """
from __future__ import (absolute_import, division, print_function, from __future__ import absolute_import, division, print_function, unicode_literals
unicode_literals)
import logging import logging
from .configuration_roberta import RobertaConfig from .configuration_roberta import RobertaConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'camembert-base': "https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-config.json", "camembert-base": "https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-config.json",
} }
......
...@@ -16,17 +16,16 @@ ...@@ -16,17 +16,16 @@
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import json
import logging import logging
import sys
from io import open
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP = {"ctrl": "https://storage.googleapis.com/sf-ctrl/pytorch/ctrl-config.json"} CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP = {"ctrl": "https://storage.googleapis.com/sf-ctrl/pytorch/ctrl-config.json"}
class CTRLConfig(PretrainedConfig): class CTRLConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `CTRLModel`. """Configuration class to store the configuration of a `CTRLModel`.
...@@ -48,6 +47,7 @@ class CTRLConfig(PretrainedConfig): ...@@ -48,6 +47,7 @@ class CTRLConfig(PretrainedConfig):
initializer_range: The sttdev of the truncated_normal_initializer for initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices. initializing all weight matrices.
""" """
pretrained_config_archive_map = CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__( def __init__(
...@@ -64,7 +64,7 @@ class CTRLConfig(PretrainedConfig): ...@@ -64,7 +64,7 @@ class CTRLConfig(PretrainedConfig):
attn_pdrop=0.1, attn_pdrop=0.1,
layer_norm_epsilon=1e-6, layer_norm_epsilon=1e-6,
initializer_range=0.02, initializer_range=0.02,
summary_type='cls_index', summary_type="cls_index",
summary_use_proj=True, summary_use_proj=True,
summary_activation=None, summary_activation=None,
summary_proj_to_labels=True, summary_proj_to_labels=True,
......
...@@ -13,45 +13,44 @@ ...@@ -13,45 +13,44 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" DistilBERT model configuration """ """ DistilBERT model configuration """
from __future__ import (absolute_import, division, print_function, from __future__ import absolute_import, division, print_function, unicode_literals
unicode_literals)
import sys
import json
import logging import logging
from io import open
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'distilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json", "distilbert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json",
'distilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-config.json", "distilbert-base-uncased-distilled-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-config.json",
'distilbert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-german-cased-config.json", "distilbert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-german-cased-config.json",
'distilbert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-multilingual-cased-config.json", "distilbert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-multilingual-cased-config.json",
} }
class DistilBertConfig(PretrainedConfig): class DistilBertConfig(PretrainedConfig):
pretrained_config_archive_map = DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self, def __init__(
vocab_size=30522, self,
max_position_embeddings=512, vocab_size=30522,
sinusoidal_pos_embds=False, max_position_embeddings=512,
n_layers=6, sinusoidal_pos_embds=False,
n_heads=12, n_layers=6,
dim=768, n_heads=12,
hidden_dim=4*768, dim=768,
dropout=0.1, hidden_dim=4 * 768,
attention_dropout=0.1, dropout=0.1,
activation='gelu', attention_dropout=0.1,
initializer_range=0.02, activation="gelu",
tie_weights_=True, initializer_range=0.02,
qa_dropout=0.1, tie_weights_=True,
seq_classif_dropout=0.2, qa_dropout=0.1,
**kwargs): seq_classif_dropout=0.2,
**kwargs
):
super(DistilBertConfig, self).__init__(**kwargs) super(DistilBertConfig, self).__init__(**kwargs)
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
......
...@@ -17,20 +17,21 @@ ...@@ -17,20 +17,21 @@
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import json
import logging import logging
import sys
from io import open
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json", GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json", "gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
"gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json", "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json",
"gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-config.json", "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json",
"distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-config.json",} "gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-config.json",
"distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-config.json",
}
class GPT2Config(PretrainedConfig): class GPT2Config(PretrainedConfig):
"""Configuration class to store the configuration of a `GPT2Model`. """Configuration class to store the configuration of a `GPT2Model`.
...@@ -52,6 +53,7 @@ class GPT2Config(PretrainedConfig): ...@@ -52,6 +53,7 @@ class GPT2Config(PretrainedConfig):
initializer_range: The sttdev of the truncated_normal_initializer for initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices. initializing all weight matrices.
""" """
pretrained_config_archive_map = GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__( def __init__(
...@@ -67,7 +69,7 @@ class GPT2Config(PretrainedConfig): ...@@ -67,7 +69,7 @@ class GPT2Config(PretrainedConfig):
attn_pdrop=0.1, attn_pdrop=0.1,
layer_norm_epsilon=1e-5, layer_norm_epsilon=1e-5,
initializer_range=0.02, initializer_range=0.02,
summary_type='cls_index', summary_type="cls_index",
summary_use_proj=True, summary_use_proj=True,
summary_activation=None, summary_activation=None,
summary_proj_to_labels=True, summary_proj_to_labels=True,
......
...@@ -15,11 +15,11 @@ ...@@ -15,11 +15,11 @@
# limitations under the License. # limitations under the License.
""" MMBT configuration """ """ MMBT configuration """
from __future__ import (absolute_import, division, print_function, from __future__ import absolute_import, division, print_function, unicode_literals
unicode_literals)
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -31,6 +31,7 @@ class MMBTConfig(object): ...@@ -31,6 +31,7 @@ class MMBTConfig(object):
num_labels: Size of final Linear layer for classification. num_labels: Size of final Linear layer for classification.
modal_hidden_size: Embedding dimension of the non-text modality encoder. modal_hidden_size: Embedding dimension of the non-text modality encoder.
""" """
def __init__(self, config, num_labels=None, modal_hidden_size=2048): def __init__(self, config, num_labels=None, modal_hidden_size=2048):
self.__dict__ = config.__dict__ self.__dict__ = config.__dict__
self.modal_hidden_size = modal_hidden_size self.modal_hidden_size = modal_hidden_size
......
...@@ -17,19 +17,18 @@ ...@@ -17,19 +17,18 @@
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import json
import logging import logging
import sys
from io import open
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = { OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json" "openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"
} }
class OpenAIGPTConfig(PretrainedConfig): class OpenAIGPTConfig(PretrainedConfig):
""" """
Configuration class to store the configuration of a `OpenAIGPTModel`. Configuration class to store the configuration of a `OpenAIGPTModel`.
...@@ -54,6 +53,7 @@ class OpenAIGPTConfig(PretrainedConfig): ...@@ -54,6 +53,7 @@ class OpenAIGPTConfig(PretrainedConfig):
initializing all weight matrices. initializing all weight matrices.
predict_special_tokens: should we predict special tokens (when the model has a LM head) predict_special_tokens: should we predict special tokens (when the model has a LM head)
""" """
pretrained_config_archive_map = OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__( def __init__(
...@@ -71,7 +71,7 @@ class OpenAIGPTConfig(PretrainedConfig): ...@@ -71,7 +71,7 @@ class OpenAIGPTConfig(PretrainedConfig):
layer_norm_epsilon=1e-5, layer_norm_epsilon=1e-5,
initializer_range=0.02, initializer_range=0.02,
predict_special_tokens=True, predict_special_tokens=True,
summary_type='cls_index', summary_type="cls_index",
summary_use_proj=True, summary_use_proj=True,
summary_activation=None, summary_activation=None,
summary_proj_to_labels=True, summary_proj_to_labels=True,
......
...@@ -15,22 +15,22 @@ ...@@ -15,22 +15,22 @@
# limitations under the License. # limitations under the License.
""" RoBERTa configuration """ """ RoBERTa configuration """
from __future__ import (absolute_import, division, print_function, from __future__ import absolute_import, division, print_function, unicode_literals
unicode_literals)
import logging import logging
from .configuration_bert import BertConfig from .configuration_bert import BertConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = { ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json", "roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json",
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-config.json", "roberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-config.json",
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-config.json", "roberta-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-config.json",
'distilroberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-config.json", "distilroberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-config.json",
'roberta-base-openai-detector': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-openai-detector-config.json", "roberta-base-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-openai-detector-config.json",
'roberta-large-openai-detector': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-openai-detector-config.json", "roberta-large-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-openai-detector-config.json",
} }
......
...@@ -16,22 +16,19 @@ ...@@ -16,22 +16,19 @@
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import json
import logging import logging
import sys
import six
from io import open
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
T5_PRETRAINED_CONFIG_ARCHIVE_MAP = { T5_PRETRAINED_CONFIG_ARCHIVE_MAP = {
't5-small': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-small-config.json", "t5-small": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-small-config.json",
't5-base': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-base-config.json", "t5-base": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-base-config.json",
't5-large': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-large-config.json", "t5-large": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-large-config.json",
't5-3b': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-3b-config.json", "t5-3b": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-3b-config.json",
't5-11b': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-11b-config.json", "t5-11b": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-11b-config.json",
} }
...@@ -65,19 +62,21 @@ class T5Config(PretrainedConfig): ...@@ -65,19 +62,21 @@ class T5Config(PretrainedConfig):
""" """
pretrained_config_archive_map = T5_PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = T5_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self, def __init__(
vocab_size=32128, self,
n_positions=512, vocab_size=32128,
d_model=512, n_positions=512,
d_kv=64, d_model=512,
d_ff=2048, d_kv=64,
num_layers=6, d_ff=2048,
num_heads=8, num_layers=6,
relative_attention_num_buckets=32, num_heads=8,
dropout_rate=0.1, relative_attention_num_buckets=32,
layer_norm_epsilon=1e-6, dropout_rate=0.1,
initializer_factor=1.0, layer_norm_epsilon=1e-6,
**kwargs): initializer_factor=1.0,
**kwargs
):
super(T5Config, self).__init__(**kwargs) super(T5Config, self).__init__(**kwargs)
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.n_positions = n_positions self.n_positions = n_positions
......
...@@ -17,19 +17,18 @@ ...@@ -17,19 +17,18 @@
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import json
import logging import logging
import sys
from io import open
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP = { TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json", "transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json",
} }
class TransfoXLConfig(PretrainedConfig): class TransfoXLConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `TransfoXLModel`. """Configuration class to store the configuration of a `TransfoXLModel`.
...@@ -65,38 +64,41 @@ class TransfoXLConfig(PretrainedConfig): ...@@ -65,38 +64,41 @@ class TransfoXLConfig(PretrainedConfig):
proj_init_std: parameters initialized by N(0, init_std) proj_init_std: parameters initialized by N(0, init_std)
init_std: parameters initialized by N(0, init_std) init_std: parameters initialized by N(0, init_std)
""" """
pretrained_config_archive_map = TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self, def __init__(
vocab_size=267735, self,
cutoffs=[20000, 40000, 200000], vocab_size=267735,
d_model=1024, cutoffs=[20000, 40000, 200000],
d_embed=1024, d_model=1024,
n_head=16, d_embed=1024,
d_head=64, n_head=16,
d_inner=4096, d_head=64,
div_val=4, d_inner=4096,
pre_lnorm=False, div_val=4,
n_layer=18, pre_lnorm=False,
tgt_len=128, n_layer=18,
ext_len=0, tgt_len=128,
mem_len=1600, ext_len=0,
clamp_len=1000, mem_len=1600,
same_length=True, clamp_len=1000,
proj_share_all_but_first=True, same_length=True,
attn_type=0, proj_share_all_but_first=True,
sample_softmax=-1, attn_type=0,
adaptive=True, sample_softmax=-1,
tie_weight=True, adaptive=True,
dropout=0.1, tie_weight=True,
dropatt=0.0, dropout=0.1,
untie_r=True, dropatt=0.0,
init="normal", untie_r=True,
init_range=0.01, init="normal",
proj_init_std=0.01, init_range=0.01,
init_std=0.02, proj_init_std=0.01,
layer_norm_epsilon=1e-5, init_std=0.02,
**kwargs): layer_norm_epsilon=1e-5,
**kwargs
):
"""Constructs TransfoXLConfig. """Constructs TransfoXLConfig.
""" """
super(TransfoXLConfig, self).__init__(**kwargs) super(TransfoXLConfig, self).__init__(**kwargs)
......
...@@ -15,8 +15,7 @@ ...@@ -15,8 +15,7 @@
# limitations under the License. # limitations under the License.
""" Configuration base class and utilities.""" """ Configuration base class and utilities."""
from __future__ import (absolute_import, division, print_function, from __future__ import absolute_import, division, print_function, unicode_literals
unicode_literals)
import copy import copy
import json import json
...@@ -24,10 +23,12 @@ import logging ...@@ -24,10 +23,12 @@ import logging
import os import os
from io import open from io import open
from .file_utils import CONFIG_NAME, cached_path, is_remote_url, hf_bucket_url from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PretrainedConfig(object): class PretrainedConfig(object):
r""" Base class for all configuration classes. r""" Base class for all configuration classes.
Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations. Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations.
...@@ -50,36 +51,36 @@ class PretrainedConfig(object): ...@@ -50,36 +51,36 @@ class PretrainedConfig(object):
def __init__(self, **kwargs): def __init__(self, **kwargs):
# Attributes with defaults # Attributes with defaults
self.output_attentions = kwargs.pop('output_attentions', False) self.output_attentions = kwargs.pop("output_attentions", False)
self.output_hidden_states = kwargs.pop('output_hidden_states', False) self.output_hidden_states = kwargs.pop("output_hidden_states", False)
self.output_past = kwargs.pop('output_past', True) # Not used by all models self.output_past = kwargs.pop("output_past", True) # Not used by all models
self.torchscript = kwargs.pop('torchscript', False) # Only used by PyTorch models self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
self.use_bfloat16 = kwargs.pop('use_bfloat16', False) self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
self.pruned_heads = kwargs.pop('pruned_heads', {}) self.pruned_heads = kwargs.pop("pruned_heads", {})
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder # Is decoder is used in encoder-decoder models to differentiate encoder from decoder
self.is_decoder = kwargs.pop('is_decoder', False) self.is_decoder = kwargs.pop("is_decoder", False)
# Parameters for sequence generation # Parameters for sequence generation
self.max_length = kwargs.pop('max_length', 20) self.max_length = kwargs.pop("max_length", 20)
self.do_sample = kwargs.pop('do_sample', False) self.do_sample = kwargs.pop("do_sample", False)
self.num_beams = kwargs.pop('num_beams', 1) self.num_beams = kwargs.pop("num_beams", 1)
self.temperature = kwargs.pop('temperature', 1.0) self.temperature = kwargs.pop("temperature", 1.0)
self.top_k = kwargs.pop('top_k', 50) self.top_k = kwargs.pop("top_k", 50)
self.top_p = kwargs.pop('top_p', 1.0) self.top_p = kwargs.pop("top_p", 1.0)
self.repetition_penalty = kwargs.pop('repetition_penalty', 1.0) self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
self.bos_token_id = kwargs.pop('bos_token_id', 0) self.bos_token_id = kwargs.pop("bos_token_id", 0)
self.pad_token_id = kwargs.pop('pad_token_id', 0) self.pad_token_id = kwargs.pop("pad_token_id", 0)
self.eos_token_ids = kwargs.pop('eos_token_ids', 0) self.eos_token_ids = kwargs.pop("eos_token_ids", 0)
self.length_penalty = kwargs.pop('length_penalty', 1.) self.length_penalty = kwargs.pop("length_penalty", 1.0)
self.num_return_sequences = kwargs.pop('num_return_sequences', 1) self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
# Fine-tuning task arguments # Fine-tuning task arguments
self.finetuning_task = kwargs.pop('finetuning_task', None) self.finetuning_task = kwargs.pop("finetuning_task", None)
self.num_labels = kwargs.pop('num_labels', 2) self.num_labels = kwargs.pop("num_labels", 2)
self.id2label = kwargs.pop('id2label', {i: 'LABEL_{}'.format(i) for i in range(self.num_labels)}) self.id2label = kwargs.pop("id2label", {i: "LABEL_{}".format(i) for i in range(self.num_labels)})
self.id2label = dict((int(key), value) for key, value in self.id2label.items()) self.id2label = dict((int(key), value) for key, value in self.id2label.items())
self.label2id = kwargs.pop('label2id', dict(zip(self.id2label.values(), self.id2label.keys()))) self.label2id = kwargs.pop("label2id", dict(zip(self.id2label.values(), self.id2label.keys())))
self.label2id = dict((key, int(value)) for key, value in self.label2id.items()) self.label2id = dict((key, int(value)) for key, value in self.label2id.items())
# Additional attributes without default values # Additional attributes without default values
...@@ -94,7 +95,9 @@ class PretrainedConfig(object): ...@@ -94,7 +95,9 @@ class PretrainedConfig(object):
""" Save a configuration object to the directory `save_directory`, so that it """ Save a configuration object to the directory `save_directory`, so that it
can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method. can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method.
""" """
assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved" assert os.path.isdir(
save_directory
), "Saving path should be a directory where the model and configuration can be saved"
# If we save using the predefined names, we can load using `from_pretrained` # If we save using the predefined names, we can load using `from_pretrained`
output_config_file = os.path.join(save_directory, CONFIG_NAME) output_config_file = os.path.join(save_directory, CONFIG_NAME)
...@@ -153,11 +156,11 @@ class PretrainedConfig(object): ...@@ -153,11 +156,11 @@ class PretrainedConfig(object):
assert unused_kwargs == {'foo': False} assert unused_kwargs == {'foo': False}
""" """
cache_dir = kwargs.pop('cache_dir', None) cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop('force_download', False) force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop('resume_download', False) resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop('proxies', None) proxies = kwargs.pop("proxies", None)
return_unused_kwargs = kwargs.pop('return_unused_kwargs', False) return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
if pretrained_model_name_or_path in cls.pretrained_config_archive_map: if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path] config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
...@@ -170,37 +173,48 @@ class PretrainedConfig(object): ...@@ -170,37 +173,48 @@ class PretrainedConfig(object):
try: try:
# Load from URL or cache if already cached # Load from URL or cache if already cached
resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, resolved_config_file = cached_path(
proxies=proxies, resume_download=resume_download) config_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
)
# Load config # Load config
config = cls.from_json_file(resolved_config_file) config = cls.from_json_file(resolved_config_file)
except EnvironmentError: except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_config_archive_map: if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format( msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
config_file) config_file
)
else: else:
msg = "Model name '{}' was not found in model name list ({}). " \ msg = (
"We assumed '{}' was a path or url to a configuration file named {} or " \ "Model name '{}' was not found in model name list ({}). "
"a directory containing such a file but couldn't find any such file at this path or url.".format( "We assumed '{}' was a path or url to a configuration file named {} or "
"a directory containing such a file but couldn't find any such file at this path or url.".format(
pretrained_model_name_or_path, pretrained_model_name_or_path,
', '.join(cls.pretrained_config_archive_map.keys()), ", ".join(cls.pretrained_config_archive_map.keys()),
config_file, CONFIG_NAME) config_file,
CONFIG_NAME,
)
)
raise EnvironmentError(msg) raise EnvironmentError(msg)
except json.JSONDecodeError: except json.JSONDecodeError:
msg = "Couldn't reach server at '{}' to download configuration file or " \ msg = (
"configuration file is not a valid JSON file. " \ "Couldn't reach server at '{}' to download configuration file or "
"Please check network or file content here: {}.".format(config_file, resolved_config_file) "configuration file is not a valid JSON file. "
"Please check network or file content here: {}.".format(config_file, resolved_config_file)
)
raise EnvironmentError(msg) raise EnvironmentError(msg)
if resolved_config_file == config_file: if resolved_config_file == config_file:
logger.info("loading configuration file {}".format(config_file)) logger.info("loading configuration file {}".format(config_file))
else: else:
logger.info("loading configuration file {} from cache at {}".format( logger.info("loading configuration file {} from cache at {}".format(config_file, resolved_config_file))
config_file, resolved_config_file))
if hasattr(config, 'pruned_heads'): if hasattr(config, "pruned_heads"):
config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items()) config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
# Update config with kwargs if needed # Update config with kwargs if needed
...@@ -226,7 +240,7 @@ class PretrainedConfig(object): ...@@ -226,7 +240,7 @@ class PretrainedConfig(object):
@classmethod @classmethod
def from_json_file(cls, json_file): def from_json_file(cls, json_file):
"""Constructs a `Config` from a json file of parameters.""" """Constructs a `Config` from a json file of parameters."""
with open(json_file, "r", encoding='utf-8') as reader: with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read() text = reader.read()
dict_obj = json.loads(text) dict_obj = json.loads(text)
return cls(**dict_obj) return cls(**dict_obj)
...@@ -248,5 +262,5 @@ class PretrainedConfig(object): ...@@ -248,5 +262,5 @@ class PretrainedConfig(object):
def to_json_file(self, json_file_path): def to_json_file(self, json_file_path):
""" Save this instance to a json file.""" """ Save this instance to a json file."""
with open(json_file_path, "w", encoding='utf-8') as writer: with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string()) writer.write(self.to_json_string())
...@@ -15,26 +15,24 @@ ...@@ -15,26 +15,24 @@
""" XLM configuration """ """ XLM configuration """
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import json
import logging import logging
import sys
from io import open
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = { XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json", "xlm-mlm-en-2048": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json",
'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-config.json", "xlm-mlm-ende-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-config.json",
'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-config.json", "xlm-mlm-enfr-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-config.json",
'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-config.json", "xlm-mlm-enro-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-config.json",
'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-config.json", "xlm-mlm-tlm-xnli15-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-config.json",
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-config.json", "xlm-mlm-xnli15-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-config.json",
'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-config.json", "xlm-clm-enfr-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-config.json",
'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-config.json", "xlm-clm-ende-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-config.json",
'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-config.json", "xlm-mlm-17-1280": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-config.json",
'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-config.json", "xlm-mlm-100-1280": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-config.json",
} }
...@@ -78,41 +76,44 @@ class XLMConfig(PretrainedConfig): ...@@ -78,41 +76,44 @@ class XLMConfig(PretrainedConfig):
-1 means no clamping. -1 means no clamping.
same_length: bool, whether to use the same attention length for each token. same_length: bool, whether to use the same attention length for each token.
""" """
pretrained_config_archive_map = XLM_PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self, def __init__(
vocab_size=30145, self,
emb_dim=2048, vocab_size=30145,
n_layers=12, emb_dim=2048,
n_heads=16, n_layers=12,
dropout=0.1, n_heads=16,
attention_dropout=0.1, dropout=0.1,
gelu_activation=True, attention_dropout=0.1,
sinusoidal_embeddings=False, gelu_activation=True,
causal=False, sinusoidal_embeddings=False,
asm=False, causal=False,
n_langs=1, asm=False,
use_lang_emb=True, n_langs=1,
max_position_embeddings=512, use_lang_emb=True,
embed_init_std=2048 ** -0.5, max_position_embeddings=512,
layer_norm_eps=1e-12, embed_init_std=2048 ** -0.5,
init_std=0.02, layer_norm_eps=1e-12,
bos_index=0, init_std=0.02,
eos_index=1, bos_index=0,
pad_index=2, eos_index=1,
unk_index=3, pad_index=2,
mask_index=5, unk_index=3,
is_encoder=True, mask_index=5,
summary_type='first', is_encoder=True,
summary_use_proj=True, summary_type="first",
summary_activation=None, summary_use_proj=True,
summary_proj_to_labels=True, summary_activation=None,
summary_first_dropout=0.1, summary_proj_to_labels=True,
start_n_top=5, summary_first_dropout=0.1,
end_n_top=5, start_n_top=5,
mask_token_id=0, end_n_top=5,
lang_id=0, mask_token_id=0,
**kwargs): lang_id=0,
**kwargs
):
"""Constructs XLMConfig. """Constructs XLMConfig.
""" """
super(XLMConfig, self).__init__(**kwargs) super(XLMConfig, self).__init__(**kwargs)
......
...@@ -15,22 +15,22 @@ ...@@ -15,22 +15,22 @@
# limitations under the License. # limitations under the License.
""" XLM-RoBERTa configuration """ """ XLM-RoBERTa configuration """
from __future__ import (absolute_import, division, print_function, from __future__ import absolute_import, division, print_function, unicode_literals
unicode_literals)
import logging import logging
from .configuration_roberta import RobertaConfig from .configuration_roberta import RobertaConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = { XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'xlm-roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-base-config.json", "xlm-roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-base-config.json",
'xlm-roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-config.json", "xlm-roberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-config.json",
'xlm-roberta-large-finetuned-conll02-dutch': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-dutch-config.json", "xlm-roberta-large-finetuned-conll02-dutch": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-dutch-config.json",
'xlm-roberta-large-finetuned-conll02-spanish': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-spanish-config.json", "xlm-roberta-large-finetuned-conll02-spanish": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-spanish-config.json",
'xlm-roberta-large-finetuned-conll03-english': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-english-config.json", "xlm-roberta-large-finetuned-conll03-english": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-english-config.json",
'xlm-roberta-large-finetuned-conll03-german': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-german-config.json", "xlm-roberta-large-finetuned-conll03-german": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-german-config.json",
} }
......
...@@ -16,18 +16,16 @@ ...@@ -16,18 +16,16 @@
""" XLNet configuration """ """ XLNet configuration """
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import json
import logging import logging
import sys
from io import open
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP = { XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'xlnet-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-config.json", "xlnet-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-config.json",
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json", "xlnet-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json",
} }
...@@ -69,32 +67,35 @@ class XLNetConfig(PretrainedConfig): ...@@ -69,32 +67,35 @@ class XLNetConfig(PretrainedConfig):
same_length: bool, whether to use the same attention length for each token. same_length: bool, whether to use the same attention length for each token.
finetuning_task: name of the glue task on which the model was fine-tuned if any finetuning_task: name of the glue task on which the model was fine-tuned if any
""" """
pretrained_config_archive_map = XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self, def __init__(
vocab_size=32000, self,
d_model=1024, vocab_size=32000,
n_layer=24, d_model=1024,
n_head=16, n_layer=24,
d_inner=4096, n_head=16,
ff_activation="gelu", d_inner=4096,
untie_r=True, ff_activation="gelu",
attn_type="bi", untie_r=True,
initializer_range=0.02, attn_type="bi",
layer_norm_eps=1e-12, initializer_range=0.02,
dropout=0.1, layer_norm_eps=1e-12,
mem_len=None, dropout=0.1,
reuse_len=None, mem_len=None,
bi_data=False, reuse_len=None,
clamp_len=-1, bi_data=False,
same_length=False, clamp_len=-1,
summary_type='last', same_length=False,
summary_use_proj=True, summary_type="last",
summary_activation='tanh', summary_use_proj=True,
summary_last_dropout=0.1, summary_activation="tanh",
start_n_top=5, summary_last_dropout=0.1,
end_n_top=5, start_n_top=5,
**kwargs): end_n_top=5,
**kwargs
):
"""Constructs XLNetConfig. """Constructs XLNetConfig.
""" """
super(XLNetConfig, self).__init__(**kwargs) super(XLNetConfig, self).__init__(**kwargs)
......
...@@ -14,16 +14,16 @@ ...@@ -14,16 +14,16 @@
# limitations under the License. # limitations under the License.
"""Convert ALBERT checkpoint.""" """Convert ALBERT checkpoint."""
from __future__ import absolute_import from __future__ import absolute_import, division, print_function
from __future__ import division
from __future__ import print_function
import argparse import argparse
import logging
import torch import torch
from transformers import AlbertConfig, AlbertForMaskedLM, load_tf_weights_in_albert from transformers import AlbertConfig, AlbertForMaskedLM, load_tf_weights_in_albert
import logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
...@@ -43,25 +43,20 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pyt ...@@ -43,25 +43,20 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pyt
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument("--tf_checkpoint_path", parser.add_argument(
default = None, "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
type = str, )
required = True, parser.add_argument(
help = "Path to the TensorFlow checkpoint path.") "--albert_config_file",
parser.add_argument("--albert_config_file", default=None,
default = None, type=str,
type = str, required=True,
required = True, help="The config json file corresponding to the pre-trained ALBERT model. \n"
help = "The config json file corresponding to the pre-trained ALBERT model. \n" "This specifies the model architecture.",
"This specifies the model architecture.") )
parser.add_argument("--pytorch_dump_path", parser.add_argument(
default = None, "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
type = str, )
required = True,
help = "Path to the output PyTorch model.")
args = parser.parse_args() args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.albert_config_file, args.pytorch_dump_path)
args.albert_config_file,
args.pytorch_dump_path)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment