Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
54abc67a
Unverified
Commit
54abc67a
authored
Dec 22, 2019
by
Thomas Wolf
Committed by
GitHub
Dec 22, 2019
Browse files
Merge pull request #2255 from aaugustin/implement-best-practices
Implement some Python best practices
parents
645713e2
c11b3e29
Changes
205
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
586 additions
and
535 deletions
+586
-535
transformers/commands/serving.py
transformers/commands/serving.py
+63
-36
transformers/commands/train.py
transformers/commands/train.py
+78
-65
transformers/commands/user.py
transformers/commands/user.py
+37
-57
transformers/configuration_albert.py
transformers/configuration_albert.py
+29
-24
transformers/configuration_auto.py
transformers/configuration_auto.py
+60
-48
transformers/configuration_bert.py
transformers/configuration_bert.py
+38
-38
transformers/configuration_camembert.py
transformers/configuration_camembert.py
+3
-3
transformers/configuration_ctrl.py
transformers/configuration_ctrl.py
+4
-4
transformers/configuration_distilbert.py
transformers/configuration_distilbert.py
+24
-25
transformers/configuration_gpt2.py
transformers/configuration_gpt2.py
+11
-9
transformers/configuration_mmbt.py
transformers/configuration_mmbt.py
+3
-2
transformers/configuration_openai.py
transformers/configuration_openai.py
+4
-4
transformers/configuration_roberta.py
transformers/configuration_roberta.py
+8
-8
transformers/configuration_t5.py
transformers/configuration_t5.py
+21
-22
transformers/configuration_transfo_xl.py
transformers/configuration_transfo_xl.py
+36
-34
transformers/configuration_utils.py
transformers/configuration_utils.py
+62
-48
transformers/configuration_xlm.py
transformers/configuration_xlm.py
+47
-46
transformers/configuration_xlm_roberta.py
transformers/configuration_xlm_roberta.py
+8
-8
transformers/configuration_xlnet.py
transformers/configuration_xlnet.py
+30
-29
transformers/convert_albert_original_tf_checkpoint_to_pytorch.py
...rmers/convert_albert_original_tf_checkpoint_to_pytorch.py
+20
-25
No files found.
transformers/commands/serving.py
View file @
54abc67a
import
logging
from
argparse
import
ArgumentParser
,
Namespace
from
typing
import
List
,
Optional
,
Union
,
Any
from
typing
import
Any
,
List
,
Optional
,
Union
from
transformers
import
Pipeline
from
transformers.commands
import
BaseTransformersCLICommand
from
transformers.pipelines
import
SUPPORTED_TASKS
,
pipeline
import
logging
try
:
from
uvicorn
import
run
from
fastapi
import
FastAPI
,
HTTPException
,
Body
from
pydantic
import
BaseModel
_serve_dependancies_installed
=
True
except
(
ImportError
,
AttributeError
):
BaseModel
=
object
Body
=
lambda
*
x
,
**
y
:
None
def
Body
(
*
x
,
**
y
):
pass
_serve_dependancies_installed
=
False
from
transformers
import
Pipeline
from
transformers.commands
import
BaseTransformersCLICommand
from
transformers.pipelines
import
SUPPORTED_TASKS
,
pipeline
logger
=
logging
.
getLogger
(
'transformers-cli/serving'
)
logger
=
logging
.
getLogger
(
"transformers-cli/serving"
)
def
serve_command_factory
(
args
:
Namespace
):
"""
Factory function used to instantiate serving server from provided command line arguments.
:return: ServeCommand
"""
nlp
=
pipeline
(
task
=
args
.
task
,
nlp
=
pipeline
(
task
=
args
.
task
,
model
=
args
.
model
if
args
.
model
else
None
,
config
=
args
.
config
,
tokenizer
=
args
.
tokenizer
,
device
=
args
.
device
)
device
=
args
.
device
,
)
return
ServeCommand
(
nlp
,
args
.
host
,
args
.
port
)
...
...
@@ -36,6 +44,7 @@ class ServeModelInfoResult(BaseModel):
"""
Expose model information
"""
infos
:
dict
...
...
@@ -43,6 +52,7 @@ class ServeTokenizeResult(BaseModel):
"""
Tokenize result model
"""
tokens
:
List
[
str
]
tokens_ids
:
Optional
[
List
[
int
]]
...
...
@@ -51,6 +61,7 @@ class ServeDeTokenizeResult(BaseModel):
"""
DeTokenize result model
"""
text
:
str
...
...
@@ -58,11 +69,11 @@ class ServeForwardResult(BaseModel):
"""
Forward result model
"""
output
:
Any
class
ServeCommand
(
BaseTransformersCLICommand
):
@
staticmethod
def
register_subcommand
(
parser
:
ArgumentParser
):
"""
...
...
@@ -70,14 +81,23 @@ class ServeCommand(BaseTransformersCLICommand):
:param parser: Root parser to register command-specific arguments
:return:
"""
serve_parser
=
parser
.
add_parser
(
'serve'
,
help
=
'CLI tool to run inference requests through REST and GraphQL endpoints.'
)
serve_parser
.
add_argument
(
'--task'
,
type
=
str
,
choices
=
SUPPORTED_TASKS
.
keys
(),
help
=
'The task to run the pipeline on'
)
serve_parser
.
add_argument
(
'--host'
,
type
=
str
,
default
=
'localhost'
,
help
=
'Interface the server will listen on.'
)
serve_parser
.
add_argument
(
'--port'
,
type
=
int
,
default
=
8888
,
help
=
'Port the serving will listen to.'
)
serve_parser
.
add_argument
(
'--model'
,
type
=
str
,
help
=
'Model
\'
s name or path to stored model.'
)
serve_parser
.
add_argument
(
'--config'
,
type
=
str
,
help
=
'Model
\'
s config name or path to stored model.'
)
serve_parser
.
add_argument
(
'--tokenizer'
,
type
=
str
,
help
=
'Tokenizer name to use.'
)
serve_parser
.
add_argument
(
'--device'
,
type
=
int
,
default
=-
1
,
help
=
'Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)'
)
serve_parser
=
parser
.
add_parser
(
"serve"
,
help
=
"CLI tool to run inference requests through REST and GraphQL endpoints."
)
serve_parser
.
add_argument
(
"--task"
,
type
=
str
,
choices
=
SUPPORTED_TASKS
.
keys
(),
help
=
"The task to run the pipeline on"
)
serve_parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
,
help
=
"Interface the server will listen on."
)
serve_parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8888
,
help
=
"Port the serving will listen to."
)
serve_parser
.
add_argument
(
"--model"
,
type
=
str
,
help
=
"Model's name or path to stored model."
)
serve_parser
.
add_argument
(
"--config"
,
type
=
str
,
help
=
"Model's config name or path to stored model."
)
serve_parser
.
add_argument
(
"--tokenizer"
,
type
=
str
,
help
=
"Tokenizer name to use."
)
serve_parser
.
add_argument
(
"--device"
,
type
=
int
,
default
=-
1
,
help
=
"Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)"
,
)
serve_parser
.
set_defaults
(
func
=
serve_command_factory
)
def
__init__
(
self
,
pipeline
:
Pipeline
,
host
:
str
,
port
:
int
):
...
...
@@ -87,18 +107,22 @@ class ServeCommand(BaseTransformersCLICommand):
self
.
_host
=
host
self
.
_port
=
port
if
not
_serve_dependancies_installed
:
raise
ImportError
(
"Using serve command requires FastAPI and unicorn. "
raise
ImportError
(
"Using serve command requires FastAPI and unicorn. "
"Please install transformers with [serving]: pip install transformers[serving]."
"Or install FastAPI and unicorn separatly."
)
"Or install FastAPI and unicorn separatly."
)
else
:
logger
.
info
(
'
Serving model over {}:{}
'
.
format
(
host
,
port
))
logger
.
info
(
"
Serving model over {}:{}
"
.
format
(
host
,
port
))
self
.
_app
=
FastAPI
()
# Register routes
self
.
_app
.
add_api_route
(
'/'
,
self
.
model_info
,
response_model
=
ServeModelInfoResult
,
methods
=
[
'GET'
])
self
.
_app
.
add_api_route
(
'/tokenize'
,
self
.
tokenize
,
response_model
=
ServeTokenizeResult
,
methods
=
[
'POST'
])
self
.
_app
.
add_api_route
(
'/detokenize'
,
self
.
detokenize
,
response_model
=
ServeDeTokenizeResult
,
methods
=
[
'POST'
])
self
.
_app
.
add_api_route
(
'/forward'
,
self
.
forward
,
response_model
=
ServeForwardResult
,
methods
=
[
'POST'
])
self
.
_app
.
add_api_route
(
"/"
,
self
.
model_info
,
response_model
=
ServeModelInfoResult
,
methods
=
[
"GET"
])
self
.
_app
.
add_api_route
(
"/tokenize"
,
self
.
tokenize
,
response_model
=
ServeTokenizeResult
,
methods
=
[
"POST"
])
self
.
_app
.
add_api_route
(
"/detokenize"
,
self
.
detokenize
,
response_model
=
ServeDeTokenizeResult
,
methods
=
[
"POST"
]
)
self
.
_app
.
add_api_route
(
"/forward"
,
self
.
forward
,
response_model
=
ServeForwardResult
,
methods
=
[
"POST"
])
def
run
(
self
):
run
(
self
.
_app
,
host
=
self
.
_host
,
port
=
self
.
_port
)
...
...
@@ -122,11 +146,14 @@ class ServeCommand(BaseTransformersCLICommand):
return
ServeTokenizeResult
(
tokens
=
tokens_txt
)
except
Exception
as
e
:
raise
HTTPException
(
status_code
=
500
,
detail
=
{
"model"
:
''
,
"error"
:
str
(
e
)})
raise
HTTPException
(
status_code
=
500
,
detail
=
{
"model"
:
""
,
"error"
:
str
(
e
)})
def
detokenize
(
self
,
tokens_ids
:
List
[
int
]
=
Body
(
None
,
embed
=
True
),
def
detokenize
(
self
,
tokens_ids
:
List
[
int
]
=
Body
(
None
,
embed
=
True
),
skip_special_tokens
:
bool
=
Body
(
False
,
embed
=
True
),
cleanup_tokenization_spaces
:
bool
=
Body
(
True
,
embed
=
True
)):
cleanup_tokenization_spaces
:
bool
=
Body
(
True
,
embed
=
True
),
):
"""
Detokenize the provided tokens ids to readable text:
- **tokens_ids**: List of tokens ids
...
...
@@ -135,9 +162,9 @@ class ServeCommand(BaseTransformersCLICommand):
"""
try
:
decoded_str
=
self
.
_pipeline
.
tokenizer
.
decode
(
tokens_ids
,
skip_special_tokens
,
cleanup_tokenization_spaces
)
return
ServeDeTokenizeResult
(
model
=
''
,
text
=
decoded_str
)
return
ServeDeTokenizeResult
(
model
=
""
,
text
=
decoded_str
)
except
Exception
as
e
:
raise
HTTPException
(
status_code
=
500
,
detail
=
{
"model"
:
''
,
"error"
:
str
(
e
)})
raise
HTTPException
(
status_code
=
500
,
detail
=
{
"model"
:
""
,
"error"
:
str
(
e
)})
def
forward
(
self
,
inputs
:
Union
[
str
,
dict
,
List
[
str
],
List
[
int
],
List
[
dict
]]
=
Body
(
None
,
embed
=
True
)):
"""
...
...
transformers/commands/train.py
View file @
54abc67a
...
...
@@ -2,10 +2,10 @@ import os
from
argparse
import
ArgumentParser
,
Namespace
from
logging
import
getLogger
from
transformers
import
SingleSentenceClassificationProcessor
as
Processor
from
transformers
import
TextClassificationPipeline
,
is_tf_available
,
is_torch_available
from
transformers.commands
import
BaseTransformersCLICommand
from
transformers
import
(
is_tf_available
,
is_torch_available
,
TextClassificationPipeline
,
SingleSentenceClassificationProcessor
as
Processor
)
if
not
is_tf_available
()
and
not
is_torch_available
():
raise
ImportError
(
"At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training"
)
...
...
@@ -14,6 +14,7 @@ if not is_tf_available() and not is_torch_available():
USE_XLA
=
False
USE_AMP
=
False
def
train_command_factory
(
args
:
Namespace
):
"""
Factory function used to instantiate serving server from provided command line arguments.
...
...
@@ -23,7 +24,6 @@ def train_command_factory(args: Namespace):
class
TrainCommand
(
BaseTransformersCLICommand
):
@
staticmethod
def
register_subcommand
(
parser
:
ArgumentParser
):
"""
...
...
@@ -31,47 +31,54 @@ class TrainCommand(BaseTransformersCLICommand):
:param parser: Root parser to register command-specific arguments
:return:
"""
train_parser
=
parser
.
add_parser
(
'
train
'
,
help
=
'
CLI tool to train a model on a task.
'
)
train_parser
=
parser
.
add_parser
(
"
train
"
,
help
=
"
CLI tool to train a model on a task.
"
)
train_parser
.
add_argument
(
'--train_data'
,
type
=
str
,
required
=
True
,
train_parser
.
add_argument
(
"--train_data"
,
type
=
str
,
required
=
True
,
help
=
"path to train (and optionally evaluation) dataset as a csv with "
"tab separated labels and sentences."
)
train_parser
.
add_argument
(
'--column_label'
,
type
=
int
,
default
=
0
,
help
=
'Column of the dataset csv file with example labels.'
)
train_parser
.
add_argument
(
'--column_text'
,
type
=
int
,
default
=
1
,
help
=
'Column of the dataset csv file with example texts.'
)
train_parser
.
add_argument
(
'--column_id'
,
type
=
int
,
default
=
2
,
help
=
'Column of the dataset csv file with example ids.'
)
train_parser
.
add_argument
(
'--skip_first_row'
,
action
=
'store_true'
,
help
=
'Skip the first row of the csv file (headers).'
)
train_parser
.
add_argument
(
'--validation_data'
,
type
=
str
,
default
=
''
,
help
=
'path to validation dataset.'
)
train_parser
.
add_argument
(
'--validation_split'
,
type
=
float
,
default
=
0.1
,
help
=
"if validation dataset is not provided, fraction of train dataset "
"to use as validation dataset."
)
train_parser
.
add_argument
(
'--output'
,
type
=
str
,
default
=
'./'
,
help
=
'path to saved the trained model.'
)
train_parser
.
add_argument
(
'--task'
,
type
=
str
,
default
=
'text_classification'
,
help
=
'Task to train the model on.'
)
train_parser
.
add_argument
(
'--model'
,
type
=
str
,
default
=
'bert-base-uncased'
,
help
=
'Model
\'
s name or path to stored model.'
)
train_parser
.
add_argument
(
'--train_batch_size'
,
type
=
int
,
default
=
32
,
help
=
'Batch size for training.'
)
train_parser
.
add_argument
(
'--valid_batch_size'
,
type
=
int
,
default
=
64
,
help
=
'Batch size for validation.'
)
train_parser
.
add_argument
(
'--learning_rate'
,
type
=
float
,
default
=
3e-5
,
help
=
"Learning rate."
)
train_parser
.
add_argument
(
'--adam_epsilon'
,
type
=
float
,
default
=
1e-08
,
help
=
"Epsilon for Adam optimizer."
)
"tab separated labels and sentences."
,
)
train_parser
.
add_argument
(
"--column_label"
,
type
=
int
,
default
=
0
,
help
=
"Column of the dataset csv file with example labels."
)
train_parser
.
add_argument
(
"--column_text"
,
type
=
int
,
default
=
1
,
help
=
"Column of the dataset csv file with example texts."
)
train_parser
.
add_argument
(
"--column_id"
,
type
=
int
,
default
=
2
,
help
=
"Column of the dataset csv file with example ids."
)
train_parser
.
add_argument
(
"--skip_first_row"
,
action
=
"store_true"
,
help
=
"Skip the first row of the csv file (headers)."
)
train_parser
.
add_argument
(
"--validation_data"
,
type
=
str
,
default
=
""
,
help
=
"path to validation dataset."
)
train_parser
.
add_argument
(
"--validation_split"
,
type
=
float
,
default
=
0.1
,
help
=
"if validation dataset is not provided, fraction of train dataset "
"to use as validation dataset."
,
)
train_parser
.
add_argument
(
"--output"
,
type
=
str
,
default
=
"./"
,
help
=
"path to saved the trained model."
)
train_parser
.
add_argument
(
"--task"
,
type
=
str
,
default
=
"text_classification"
,
help
=
"Task to train the model on."
)
train_parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"bert-base-uncased"
,
help
=
"Model's name or path to stored model."
)
train_parser
.
add_argument
(
"--train_batch_size"
,
type
=
int
,
default
=
32
,
help
=
"Batch size for training."
)
train_parser
.
add_argument
(
"--valid_batch_size"
,
type
=
int
,
default
=
64
,
help
=
"Batch size for validation."
)
train_parser
.
add_argument
(
"--learning_rate"
,
type
=
float
,
default
=
3e-5
,
help
=
"Learning rate."
)
train_parser
.
add_argument
(
"--adam_epsilon"
,
type
=
float
,
default
=
1e-08
,
help
=
"Epsilon for Adam optimizer."
)
train_parser
.
set_defaults
(
func
=
train_command_factory
)
def
__init__
(
self
,
args
:
Namespace
):
self
.
logger
=
getLogger
(
'
transformers-cli/training
'
)
self
.
logger
=
getLogger
(
"
transformers-cli/training
"
)
self
.
framework
=
'
tf
'
if
is_tf_available
()
else
'
torch
'
self
.
framework
=
"
tf
"
if
is_tf_available
()
else
"
torch
"
os
.
makedirs
(
args
.
output
,
exist_ok
=
True
)
assert
os
.
path
.
isdir
(
args
.
output
)
...
...
@@ -81,28 +88,32 @@ class TrainCommand(BaseTransformersCLICommand):
self
.
column_text
=
args
.
column_text
self
.
column_id
=
args
.
column_id
self
.
logger
.
info
(
'
Loading {} pipeline for {}
'
.
format
(
args
.
task
,
args
.
model
))
if
args
.
task
==
'
text_classification
'
:
self
.
logger
.
info
(
"
Loading {} pipeline for {}
"
.
format
(
args
.
task
,
args
.
model
))
if
args
.
task
==
"
text_classification
"
:
self
.
pipeline
=
TextClassificationPipeline
.
from_pretrained
(
args
.
model
)
elif
args
.
task
==
'
token_classification
'
:
elif
args
.
task
==
"
token_classification
"
:
raise
NotImplementedError
elif
args
.
task
==
'
question_answering
'
:
elif
args
.
task
==
"
question_answering
"
:
raise
NotImplementedError
self
.
logger
.
info
(
'Loading dataset from {}'
.
format
(
args
.
train_data
))
self
.
train_dataset
=
Processor
.
create_from_csv
(
args
.
train_data
,
self
.
logger
.
info
(
"Loading dataset from {}"
.
format
(
args
.
train_data
))
self
.
train_dataset
=
Processor
.
create_from_csv
(
args
.
train_data
,
column_label
=
args
.
column_label
,
column_text
=
args
.
column_text
,
column_id
=
args
.
column_id
,
skip_first_row
=
args
.
skip_first_row
)
skip_first_row
=
args
.
skip_first_row
,
)
self
.
valid_dataset
=
None
if
args
.
validation_data
:
self
.
logger
.
info
(
'Loading validation dataset from {}'
.
format
(
args
.
validation_data
))
self
.
valid_dataset
=
Processor
.
create_from_csv
(
args
.
validation_data
,
self
.
logger
.
info
(
"Loading validation dataset from {}"
.
format
(
args
.
validation_data
))
self
.
valid_dataset
=
Processor
.
create_from_csv
(
args
.
validation_data
,
column_label
=
args
.
column_label
,
column_text
=
args
.
column_text
,
column_id
=
args
.
column_id
,
skip_first_row
=
args
.
skip_first_row
)
skip_first_row
=
args
.
skip_first_row
,
)
self
.
validation_split
=
args
.
validation_split
self
.
train_batch_size
=
args
.
train_batch_size
...
...
@@ -111,7 +122,7 @@ class TrainCommand(BaseTransformersCLICommand):
self
.
adam_epsilon
=
args
.
adam_epsilon
def
run
(
self
):
if
self
.
framework
==
'
tf
'
:
if
self
.
framework
==
"
tf
"
:
return
self
.
run_tf
()
return
self
.
run_torch
()
...
...
@@ -119,13 +130,15 @@ class TrainCommand(BaseTransformersCLICommand):
raise
NotImplementedError
def
run_tf
(
self
):
self
.
pipeline
.
fit
(
self
.
train_dataset
,
self
.
pipeline
.
fit
(
self
.
train_dataset
,
validation_data
=
self
.
valid_dataset
,
validation_split
=
self
.
validation_split
,
learning_rate
=
self
.
learning_rate
,
adam_epsilon
=
self
.
adam_epsilon
,
train_batch_size
=
self
.
train_batch_size
,
valid_batch_size
=
self
.
valid_batch_size
)
valid_batch_size
=
self
.
valid_batch_size
,
)
# Save trained pipeline
self
.
pipeline
.
save_pretrained
(
self
.
output
)
transformers/commands/user.py
View file @
54abc67a
import
os
from
argparse
import
ArgumentParser
from
getpass
import
getpass
import
os
from
typing
import
List
,
Union
from
requests.exceptions
import
HTTPError
from
transformers.commands
import
BaseTransformersCLICommand
from
transformers.hf_api
import
HfApi
,
HfFolder
,
HTTPError
from
transformers.hf_api
import
HfApi
,
HfFolder
class
UserCommands
(
BaseTransformersCLICommand
):
@
staticmethod
def
register_subcommand
(
parser
:
ArgumentParser
):
login_parser
=
parser
.
add_parser
(
'
login
'
)
login_parser
=
parser
.
add_parser
(
"
login
"
)
login_parser
.
set_defaults
(
func
=
lambda
args
:
LoginCommand
(
args
))
whoami_parser
=
parser
.
add_parser
(
'
whoami
'
)
whoami_parser
=
parser
.
add_parser
(
"
whoami
"
)
whoami_parser
.
set_defaults
(
func
=
lambda
args
:
WhoamiCommand
(
args
))
logout_parser
=
parser
.
add_parser
(
'
logout
'
)
logout_parser
=
parser
.
add_parser
(
"
logout
"
)
logout_parser
.
set_defaults
(
func
=
lambda
args
:
LogoutCommand
(
args
))
list_parser
=
parser
.
add_parser
(
'
ls
'
)
list_parser
=
parser
.
add_parser
(
"
ls
"
)
list_parser
.
set_defaults
(
func
=
lambda
args
:
ListObjsCommand
(
args
))
# upload
upload_parser
=
parser
.
add_parser
(
'upload'
)
upload_parser
.
add_argument
(
'path'
,
type
=
str
,
help
=
'Local path of the folder or individual file to upload.'
)
upload_parser
.
add_argument
(
'--filename'
,
type
=
str
,
default
=
None
,
help
=
'Optional: override individual object filename on S3.'
)
upload_parser
=
parser
.
add_parser
(
"upload"
)
upload_parser
.
add_argument
(
"path"
,
type
=
str
,
help
=
"Local path of the folder or individual file to upload."
)
upload_parser
.
add_argument
(
"--filename"
,
type
=
str
,
default
=
None
,
help
=
"Optional: override individual object filename on S3."
)
upload_parser
.
set_defaults
(
func
=
lambda
args
:
UploadCommand
(
args
))
class
ANSI
:
"""
Helper for en.wikipedia.org/wiki/ANSI_escape_code
"""
_bold
=
u
"
\u001b
[1m"
_reset
=
u
"
\u001b
[0m"
@
classmethod
def
bold
(
cls
,
s
):
return
"{}{}{}"
.
format
(
cls
.
_bold
,
s
,
cls
.
_reset
)
...
...
@@ -44,14 +50,16 @@ class BaseUserCommand:
class
LoginCommand
(
BaseUserCommand
):
def
run
(
self
):
print
(
"""
print
(
"""
_| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|
_| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
_|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_|
_| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
_| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_|
"""
)
"""
)
username
=
input
(
"Username: "
)
password
=
getpass
()
try
:
...
...
@@ -91,8 +99,7 @@ class LogoutCommand(BaseUserCommand):
class
ListObjsCommand
(
BaseUserCommand
):
def
tabulate
(
self
,
rows
,
headers
):
# type: (List[List[Union[str, int]]], List[str]) -> str
def
tabulate
(
self
,
rows
:
List
[
List
[
Union
[
str
,
int
]]],
headers
:
List
[
str
])
->
str
:
"""
Inspired by:
stackoverflow.com/a/8356620/593036
...
...
@@ -101,16 +108,10 @@ class ListObjsCommand(BaseUserCommand):
col_widths
=
[
max
(
len
(
str
(
x
))
for
x
in
col
)
for
col
in
zip
(
*
rows
,
headers
)]
row_format
=
(
"{{:{}}} "
*
len
(
headers
)).
format
(
*
col_widths
)
lines
=
[]
lines
.
append
(
row_format
.
format
(
*
headers
)
)
lines
.
append
(
row_format
.
format
(
*
[
"-"
*
w
for
w
in
col_widths
])
)
lines
.
append
(
row_format
.
format
(
*
headers
))
lines
.
append
(
row_format
.
format
(
*
[
"-"
*
w
for
w
in
col_widths
]))
for
row
in
rows
:
lines
.
append
(
row_format
.
format
(
*
row
)
)
lines
.
append
(
row_format
.
format
(
*
row
))
return
"
\n
"
.
join
(
lines
)
def
run
(
self
):
...
...
@@ -126,15 +127,8 @@ class ListObjsCommand(BaseUserCommand):
if
len
(
objs
)
==
0
:
print
(
"No shared file yet"
)
exit
()
rows
=
[
[
obj
.
filename
,
obj
.
LastModified
,
obj
.
ETag
,
obj
.
Size
]
for
obj
in
objs
]
print
(
self
.
tabulate
(
rows
,
headers
=
[
"Filename"
,
"LastModified"
,
"ETag"
,
"Size"
])
)
rows
=
[[
obj
.
filename
,
obj
.
LastModified
,
obj
.
ETag
,
obj
.
Size
]
for
obj
in
objs
]
print
(
self
.
tabulate
(
rows
,
headers
=
[
"Filename"
,
"LastModified"
,
"ETag"
,
"Size"
]))
class
UploadCommand
(
BaseUserCommand
):
...
...
@@ -143,13 +137,7 @@ class UploadCommand(BaseUserCommand):
Recursively list all files in a folder.
"""
entries
:
List
[
os
.
DirEntry
]
=
list
(
os
.
scandir
(
rel_path
))
files
=
[
(
os
.
path
.
join
(
os
.
getcwd
(),
f
.
path
),
# filepath
f
.
path
# filename
)
for
f
in
entries
if
f
.
is_file
()
]
files
=
[(
os
.
path
.
join
(
os
.
getcwd
(),
f
.
path
),
f
.
path
)
for
f
in
entries
if
f
.
is_file
()]
# filepath # filename
for
f
in
entries
:
if
f
.
is_dir
():
files
+=
self
.
walk_dir
(
f
.
path
)
...
...
@@ -173,22 +161,14 @@ class UploadCommand(BaseUserCommand):
raise
ValueError
(
"Not a valid file or directory: {}"
.
format
(
local_path
))
for
filepath
,
filename
in
files
:
print
(
"About to upload file {} to S3 under filename {}"
.
format
(
ANSI
.
bold
(
filepath
),
ANSI
.
bold
(
filename
)
)
)
print
(
"About to upload file {} to S3 under filename {}"
.
format
(
ANSI
.
bold
(
filepath
),
ANSI
.
bold
(
filename
)))
choice
=
input
(
"Proceed? [Y/n] "
).
lower
()
if
not
(
choice
==
""
or
choice
==
"y"
or
choice
==
"yes"
):
if
not
(
choice
==
""
or
choice
==
"y"
or
choice
==
"yes"
):
print
(
"Abort"
)
exit
()
print
(
ANSI
.
bold
(
"Uploading... This might take a while if files are large"
)
)
print
(
ANSI
.
bold
(
"Uploading... This might take a while if files are large"
))
for
filepath
,
filename
in
files
:
access_url
=
self
.
_api
.
presign_and_upload
(
token
=
token
,
filename
=
filename
,
filepath
=
filepath
)
access_url
=
self
.
_api
.
presign_and_upload
(
token
=
token
,
filename
=
filename
,
filepath
=
filepath
)
print
(
"Your file now lives at:"
)
print
(
access_url
)
transformers/configuration_albert.py
View file @
54abc67a
...
...
@@ -17,17 +17,19 @@
from
.configuration_utils
import
PretrainedConfig
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
'
albert-base-v1
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-config.json"
,
'
albert-large-v1
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-config.json"
,
'
albert-xlarge-v1
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-config.json"
,
'
albert-xxlarge-v1
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-config.json"
,
'
albert-base-v2
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-config.json"
,
'
albert-large-v2
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json"
,
'
albert-xlarge-v2
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-config.json"
,
'
albert-xxlarge-v2
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-config.json"
,
"
albert-base-v1
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-config.json"
,
"
albert-large-v1
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-config.json"
,
"
albert-xlarge-v1
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-config.json"
,
"
albert-xxlarge-v1
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-config.json"
,
"
albert-base-v2
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-config.json"
,
"
albert-large-v2
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json"
,
"
albert-xlarge-v2
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-config.json"
,
"
albert-xxlarge-v2
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-config.json"
,
}
class
AlbertConfig
(
PretrainedConfig
):
"""Configuration for `AlbertModel`.
...
...
@@ -36,7 +38,8 @@ class AlbertConfig(PretrainedConfig):
pretrained_config_archive_map
=
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
def
__init__
(
self
,
def
__init__
(
self
,
vocab_size
=
30000
,
embedding_size
=
128
,
hidden_size
=
4096
,
...
...
@@ -51,7 +54,9 @@ class AlbertConfig(PretrainedConfig):
max_position_embeddings
=
512
,
type_vocab_size
=
2
,
initializer_range
=
0.02
,
layer_norm_eps
=
1e-12
,
**
kwargs
):
layer_norm_eps
=
1e-12
,
**
kwargs
):
"""Constructs AlbertConfig.
Args:
...
...
transformers/configuration_auto.py
View file @
54abc67a
...
...
@@ -18,24 +18,26 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import
logging
from
.configuration_bert
import
BertConfig
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
from
.configuration_openai
import
OpenAIGPTConfig
,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
from
.configuration_transfo_xl
import
TransfoXLConfig
,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
from
.configuration_gpt2
import
GPT2Config
,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
from
.configuration_ctrl
import
CTRLConfig
,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
from
.configuration_xlnet
import
XLNetConfig
,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
from
.configuration_xlm
import
XLMConfig
,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
from
.configuration_roberta
import
RobertaConfig
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
from
.configuration_distilbert
import
DistilBertConfig
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
from
.configuration_albert
import
AlbertConfig
,
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
from
.configuration_camembert
import
CamembertConfig
,
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
from
.configuration_t5
import
T5Config
,
T5_PRETRAINED_CONFIG_ARCHIVE_MAP
from
.configuration_xlm_roberta
import
XLMRobertaConfig
,
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
from
.configuration_albert
import
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
AlbertConfig
from
.configuration_bert
import
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
BertConfig
from
.configuration_camembert
import
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
CamembertConfig
from
.configuration_ctrl
import
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
,
CTRLConfig
from
.configuration_distilbert
import
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
DistilBertConfig
from
.configuration_gpt2
import
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
,
GPT2Config
from
.configuration_openai
import
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
OpenAIGPTConfig
from
.configuration_roberta
import
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
,
RobertaConfig
from
.configuration_t5
import
T5_PRETRAINED_CONFIG_ARCHIVE_MAP
,
T5Config
from
.configuration_transfo_xl
import
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
,
TransfoXLConfig
from
.configuration_xlm
import
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLMConfig
from
.configuration_xlm_roberta
import
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLMRobertaConfig
from
.configuration_xlnet
import
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLNetConfig
logger
=
logging
.
getLogger
(
__name__
)
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
=
dict
((
key
,
value
)
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
=
dict
(
(
key
,
value
)
for
pretrained_map
in
[
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
...
...
@@ -51,7 +53,8 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict((key, value)
T5_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
,
]
for
key
,
value
,
in
pretrained_map
.
items
())
for
key
,
value
,
in
pretrained_map
.
items
()
)
class
AutoConfig
(
object
):
...
...
@@ -79,37 +82,42 @@ class AutoConfig(object):
- contains `ctrl` : CTRLConfig (CTRL model)
This class cannot be instantiated using `__init__()` (throw an error).
"""
def
__init__
(
self
):
raise
EnvironmentError
(
"AutoConfig is designed to be instantiated "
"using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method."
)
raise
EnvironmentError
(
"AutoConfig is designed to be instantiated "
"using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method."
)
@
classmethod
def
for_model
(
cls
,
model_type
,
*
args
,
**
kwargs
):
if
'
distilbert
'
in
model_type
:
if
"
distilbert
"
in
model_type
:
return
DistilBertConfig
(
*
args
,
**
kwargs
)
elif
'
roberta
'
in
model_type
:
elif
"
roberta
"
in
model_type
:
return
RobertaConfig
(
*
args
,
**
kwargs
)
elif
'
bert
'
in
model_type
:
elif
"
bert
"
in
model_type
:
return
BertConfig
(
*
args
,
**
kwargs
)
elif
'
openai-gpt
'
in
model_type
:
elif
"
openai-gpt
"
in
model_type
:
return
OpenAIGPTConfig
(
*
args
,
**
kwargs
)
elif
'
gpt2
'
in
model_type
:
elif
"
gpt2
"
in
model_type
:
return
GPT2Config
(
*
args
,
**
kwargs
)
elif
'
transfo-xl
'
in
model_type
:
elif
"
transfo-xl
"
in
model_type
:
return
TransfoXLConfig
(
*
args
,
**
kwargs
)
elif
'
xlnet
'
in
model_type
:
elif
"
xlnet
"
in
model_type
:
return
XLNetConfig
(
*
args
,
**
kwargs
)
elif
'
xlm
'
in
model_type
:
elif
"
xlm
"
in
model_type
:
return
XLMConfig
(
*
args
,
**
kwargs
)
elif
'
ctrl
'
in
model_type
:
elif
"
ctrl
"
in
model_type
:
return
CTRLConfig
(
*
args
,
**
kwargs
)
elif
'
albert
'
in
model_type
:
elif
"
albert
"
in
model_type
:
return
AlbertConfig
(
*
args
,
**
kwargs
)
elif
'
camembert
'
in
model_type
:
elif
"
camembert
"
in
model_type
:
return
CamembertConfig
(
*
args
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'distilbert', 'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta', 'ctrl', 'camembert', 'albert'"
.
format
(
model_type
))
"'xlm', 'roberta', 'ctrl', 'camembert', 'albert'"
.
format
(
model_type
)
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
**
kwargs
):
...
...
@@ -176,32 +184,36 @@ class AutoConfig(object):
assert unused_kwargs == {'foo': False}
"""
if
'
t5
'
in
pretrained_model_name_or_path
:
if
"
t5
"
in
pretrained_model_name_or_path
:
return
T5Config
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
elif
'
distilbert
'
in
pretrained_model_name_or_path
:
elif
"
distilbert
"
in
pretrained_model_name_or_path
:
return
DistilBertConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
elif
'
albert
'
in
pretrained_model_name_or_path
:
elif
"
albert
"
in
pretrained_model_name_or_path
:
return
AlbertConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
elif
'
camembert
'
in
pretrained_model_name_or_path
:
elif
"
camembert
"
in
pretrained_model_name_or_path
:
return
CamembertConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
elif
'
xlm-roberta
'
in
pretrained_model_name_or_path
:
elif
"
xlm-roberta
"
in
pretrained_model_name_or_path
:
return
XLMRobertaConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
elif
'
roberta
'
in
pretrained_model_name_or_path
:
elif
"
roberta
"
in
pretrained_model_name_or_path
:
return
RobertaConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
elif
'
bert
'
in
pretrained_model_name_or_path
:
elif
"
bert
"
in
pretrained_model_name_or_path
:
return
BertConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
elif
'
openai-gpt
'
in
pretrained_model_name_or_path
:
elif
"
openai-gpt
"
in
pretrained_model_name_or_path
:
return
OpenAIGPTConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
elif
'
gpt2
'
in
pretrained_model_name_or_path
:
elif
"
gpt2
"
in
pretrained_model_name_or_path
:
return
GPT2Config
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
elif
'
transfo-xl
'
in
pretrained_model_name_or_path
:
elif
"
transfo-xl
"
in
pretrained_model_name_or_path
:
return
TransfoXLConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
elif
'
xlnet
'
in
pretrained_model_name_or_path
:
elif
"
xlnet
"
in
pretrained_model_name_or_path
:
return
XLNetConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
elif
'
xlm
'
in
pretrained_model_name_or_path
:
elif
"
xlm
"
in
pretrained_model_name_or_path
:
return
XLMConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
elif
'
ctrl
'
in
pretrained_model_name_or_path
:
elif
"
ctrl
"
in
pretrained_model_name_or_path
:
return
CTRLConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm-roberta', 'xlm', 'roberta', 'distilbert', 'camembert', 'ctrl', 'albert'"
.
format
(
pretrained_model_name_or_path
))
"'xlm-roberta', 'xlm', 'roberta', 'distilbert', 'camembert', 'ctrl', 'albert'"
.
format
(
pretrained_model_name_or_path
)
)
transformers/configuration_bert.py
View file @
54abc67a
...
...
@@ -17,37 +17,35 @@
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
json
import
logging
import
sys
from
io
import
open
from
.configuration_utils
import
PretrainedConfig
logger
=
logging
.
getLogger
(
__name__
)
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
'
bert-base-uncased
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json"
,
'
bert-large-uncased
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json"
,
'
bert-base-cased
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json"
,
'
bert-large-cased
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json"
,
'
bert-base-multilingual-uncased
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json"
,
'
bert-base-multilingual-cased
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json"
,
'
bert-base-chinese
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json"
,
'
bert-base-german-cased
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json"
,
'
bert-large-uncased-whole-word-masking
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json"
,
'
bert-large-cased-whole-word-masking
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json"
,
'
bert-large-uncased-whole-word-masking-finetuned-squad
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json"
,
'
bert-large-cased-whole-word-masking-finetuned-squad
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json"
,
'
bert-base-cased-finetuned-mrpc
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json"
,
'
bert-base-german-dbmdz-cased
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json"
,
'
bert-base-german-dbmdz-uncased
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json"
,
'
bert-base-japanese
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-config.json"
,
'
bert-base-japanese-whole-word-masking
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-config.json"
,
'
bert-base-japanese-char
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-config.json"
,
'
bert-base-japanese-char-whole-word-masking
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-config.json"
,
'
bert-base-finnish-cased-v1
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json"
,
'
bert-base-finnish-uncased-v1
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json"
,
"
bert-base-uncased
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json"
,
"
bert-large-uncased
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json"
,
"
bert-base-cased
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json"
,
"
bert-large-cased
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json"
,
"
bert-base-multilingual-uncased
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json"
,
"
bert-base-multilingual-cased
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json"
,
"
bert-base-chinese
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json"
,
"
bert-base-german-cased
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json"
,
"
bert-large-uncased-whole-word-masking
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json"
,
"
bert-large-cased-whole-word-masking
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json"
,
"
bert-large-uncased-whole-word-masking-finetuned-squad
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json"
,
"
bert-large-cased-whole-word-masking-finetuned-squad
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json"
,
"
bert-base-cased-finetuned-mrpc
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json"
,
"
bert-base-german-dbmdz-cased
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json"
,
"
bert-base-german-dbmdz-uncased
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json"
,
"
bert-base-japanese
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-config.json"
,
"
bert-base-japanese-whole-word-masking
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-config.json"
,
"
bert-base-japanese-char
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-config.json"
,
"
bert-base-japanese-char-whole-word-masking
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-config.json"
,
"
bert-base-finnish-cased-v1
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json"
,
"
bert-base-finnish-uncased-v1
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json"
,
}
...
...
@@ -82,7 +80,8 @@ class BertConfig(PretrainedConfig):
"""
pretrained_config_archive_map
=
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
def
__init__
(
self
,
def
__init__
(
self
,
vocab_size
=
30522
,
hidden_size
=
768
,
num_hidden_layers
=
12
,
...
...
@@ -95,7 +94,8 @@ class BertConfig(PretrainedConfig):
type_vocab_size
=
2
,
initializer_range
=
0.02
,
layer_norm_eps
=
1e-12
,
**
kwargs
):
**
kwargs
):
super
(
BertConfig
,
self
).
__init__
(
**
kwargs
)
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
...
...
transformers/configuration_camembert.py
View file @
54abc67a
...
...
@@ -15,17 +15,17 @@
# limitations under the License.
""" CamemBERT configuration """
from
__future__
import
(
absolute_import
,
division
,
print_function
,
unicode_literals
)
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
logging
from
.configuration_roberta
import
RobertaConfig
logger
=
logging
.
getLogger
(
__name__
)
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
'
camembert-base
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-config.json"
,
"
camembert-base
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-config.json"
,
}
...
...
transformers/configuration_ctrl.py
View file @
54abc67a
...
...
@@ -16,17 +16,16 @@
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
json
import
logging
import
sys
from
io
import
open
from
.configuration_utils
import
PretrainedConfig
logger
=
logging
.
getLogger
(
__name__
)
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
"ctrl"
:
"https://storage.googleapis.com/sf-ctrl/pytorch/ctrl-config.json"
}
class
CTRLConfig
(
PretrainedConfig
):
"""Configuration class to store the configuration of a `CTRLModel`.
...
...
@@ -48,6 +47,7 @@ class CTRLConfig(PretrainedConfig):
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
"""
pretrained_config_archive_map
=
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
def
__init__
(
...
...
@@ -64,7 +64,7 @@ class CTRLConfig(PretrainedConfig):
attn_pdrop
=
0.1
,
layer_norm_epsilon
=
1e-6
,
initializer_range
=
0.02
,
summary_type
=
'
cls_index
'
,
summary_type
=
"
cls_index
"
,
summary_use_proj
=
True
,
summary_activation
=
None
,
summary_proj_to_labels
=
True
,
...
...
transformers/configuration_distilbert.py
View file @
54abc67a
...
...
@@ -13,45 +13,44 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" DistilBERT model configuration """
from
__future__
import
(
absolute_import
,
division
,
print_function
,
unicode_literals
)
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
sys
import
json
import
logging
from
io
import
open
from
.configuration_utils
import
PretrainedConfig
logger
=
logging
.
getLogger
(
__name__
)
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
'
distilbert-base-uncased
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json"
,
'
distilbert-base-uncased-distilled-squad
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-config.json"
,
'
distilbert-base-german-cased
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-german-cased-config.json"
,
'
distilbert-base-multilingual-cased
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-multilingual-cased-config.json"
,
"
distilbert-base-uncased
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json"
,
"
distilbert-base-uncased-distilled-squad
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-config.json"
,
"
distilbert-base-german-cased
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-german-cased-config.json"
,
"
distilbert-base-multilingual-cased
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-multilingual-cased-config.json"
,
}
class
DistilBertConfig
(
PretrainedConfig
):
pretrained_config_archive_map
=
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
def
__init__
(
self
,
def
__init__
(
self
,
vocab_size
=
30522
,
max_position_embeddings
=
512
,
sinusoidal_pos_embds
=
False
,
n_layers
=
6
,
n_heads
=
12
,
dim
=
768
,
hidden_dim
=
4
*
768
,
hidden_dim
=
4
*
768
,
dropout
=
0.1
,
attention_dropout
=
0.1
,
activation
=
'
gelu
'
,
activation
=
"
gelu
"
,
initializer_range
=
0.02
,
tie_weights_
=
True
,
qa_dropout
=
0.1
,
seq_classif_dropout
=
0.2
,
**
kwargs
):
**
kwargs
):
super
(
DistilBertConfig
,
self
).
__init__
(
**
kwargs
)
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
...
...
transformers/configuration_gpt2.py
View file @
54abc67a
...
...
@@ -17,20 +17,21 @@
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
json
import
logging
import
sys
from
io
import
open
from
.configuration_utils
import
PretrainedConfig
logger
=
logging
.
getLogger
(
__name__
)
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
"gpt2"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json"
,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
"gpt2"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json"
,
"gpt2-medium"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"
,
"gpt2-large"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json"
,
"gpt2-xl"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-config.json"
,
"distilgpt2"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-config.json"
,}
"distilgpt2"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-config.json"
,
}
class
GPT2Config
(
PretrainedConfig
):
"""Configuration class to store the configuration of a `GPT2Model`.
...
...
@@ -52,6 +53,7 @@ class GPT2Config(PretrainedConfig):
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
"""
pretrained_config_archive_map
=
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
def
__init__
(
...
...
@@ -67,7 +69,7 @@ class GPT2Config(PretrainedConfig):
attn_pdrop
=
0.1
,
layer_norm_epsilon
=
1e-5
,
initializer_range
=
0.02
,
summary_type
=
'
cls_index
'
,
summary_type
=
"
cls_index
"
,
summary_use_proj
=
True
,
summary_activation
=
None
,
summary_proj_to_labels
=
True
,
...
...
transformers/configuration_mmbt.py
View file @
54abc67a
...
...
@@ -15,11 +15,11 @@
# limitations under the License.
""" MMBT configuration """
from
__future__
import
(
absolute_import
,
division
,
print_function
,
unicode_literals
)
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
logging
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -31,6 +31,7 @@ class MMBTConfig(object):
num_labels: Size of final Linear layer for classification.
modal_hidden_size: Embedding dimension of the non-text modality encoder.
"""
def
__init__
(
self
,
config
,
num_labels
=
None
,
modal_hidden_size
=
2048
):
self
.
__dict__
=
config
.
__dict__
self
.
modal_hidden_size
=
modal_hidden_size
...
...
transformers/configuration_openai.py
View file @
54abc67a
...
...
@@ -17,19 +17,18 @@
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
json
import
logging
import
sys
from
io
import
open
from
.configuration_utils
import
PretrainedConfig
logger
=
logging
.
getLogger
(
__name__
)
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
"openai-gpt"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"
}
class
OpenAIGPTConfig
(
PretrainedConfig
):
"""
Configuration class to store the configuration of a `OpenAIGPTModel`.
...
...
@@ -54,6 +53,7 @@ class OpenAIGPTConfig(PretrainedConfig):
initializing all weight matrices.
predict_special_tokens: should we predict special tokens (when the model has a LM head)
"""
pretrained_config_archive_map
=
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
def
__init__
(
...
...
@@ -71,7 +71,7 @@ class OpenAIGPTConfig(PretrainedConfig):
layer_norm_epsilon
=
1e-5
,
initializer_range
=
0.02
,
predict_special_tokens
=
True
,
summary_type
=
'
cls_index
'
,
summary_type
=
"
cls_index
"
,
summary_use_proj
=
True
,
summary_activation
=
None
,
summary_proj_to_labels
=
True
,
...
...
transformers/configuration_roberta.py
View file @
54abc67a
...
...
@@ -15,22 +15,22 @@
# limitations under the License.
""" RoBERTa configuration """
from
__future__
import
(
absolute_import
,
division
,
print_function
,
unicode_literals
)
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
logging
from
.configuration_bert
import
BertConfig
logger
=
logging
.
getLogger
(
__name__
)
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
'
roberta-base
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json"
,
'
roberta-large
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-config.json"
,
'
roberta-large-mnli
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-config.json"
,
'
distilroberta-base
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-config.json"
,
'
roberta-base-openai-detector
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-openai-detector-config.json"
,
'
roberta-large-openai-detector
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-openai-detector-config.json"
,
"
roberta-base
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json"
,
"
roberta-large
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-config.json"
,
"
roberta-large-mnli
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-config.json"
,
"
distilroberta-base
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-config.json"
,
"
roberta-base-openai-detector
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-openai-detector-config.json"
,
"
roberta-large-openai-detector
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-openai-detector-config.json"
,
}
...
...
transformers/configuration_t5.py
View file @
54abc67a
...
...
@@ -16,22 +16,19 @@
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
json
import
logging
import
sys
import
six
from
io
import
open
from
.configuration_utils
import
PretrainedConfig
logger
=
logging
.
getLogger
(
__name__
)
T5_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
'
t5-small
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/t5-small-config.json"
,
'
t5-base
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/t5-base-config.json"
,
'
t5-large
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/t5-large-config.json"
,
'
t5-3b
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/t5-3b-config.json"
,
'
t5-11b
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/t5-11b-config.json"
,
"
t5-small
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/t5-small-config.json"
,
"
t5-base
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/t5-base-config.json"
,
"
t5-large
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/t5-large-config.json"
,
"
t5-3b
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/t5-3b-config.json"
,
"
t5-11b
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/t5-11b-config.json"
,
}
...
...
@@ -65,7 +62,8 @@ class T5Config(PretrainedConfig):
"""
pretrained_config_archive_map
=
T5_PRETRAINED_CONFIG_ARCHIVE_MAP
def
__init__
(
self
,
def
__init__
(
self
,
vocab_size
=
32128
,
n_positions
=
512
,
d_model
=
512
,
...
...
@@ -77,7 +75,8 @@ class T5Config(PretrainedConfig):
dropout_rate
=
0.1
,
layer_norm_epsilon
=
1e-6
,
initializer_factor
=
1.0
,
**
kwargs
):
**
kwargs
):
super
(
T5Config
,
self
).
__init__
(
**
kwargs
)
self
.
vocab_size
=
vocab_size
self
.
n_positions
=
n_positions
...
...
transformers/configuration_transfo_xl.py
View file @
54abc67a
...
...
@@ -17,19 +17,18 @@
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
json
import
logging
import
sys
from
io
import
open
from
.configuration_utils
import
PretrainedConfig
logger
=
logging
.
getLogger
(
__name__
)
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
'
transfo-xl-wt103
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json"
,
"
transfo-xl-wt103
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json"
,
}
class
TransfoXLConfig
(
PretrainedConfig
):
"""Configuration class to store the configuration of a `TransfoXLModel`.
...
...
@@ -65,9 +64,11 @@ class TransfoXLConfig(PretrainedConfig):
proj_init_std: parameters initialized by N(0, init_std)
init_std: parameters initialized by N(0, init_std)
"""
pretrained_config_archive_map
=
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
def
__init__
(
self
,
def
__init__
(
self
,
vocab_size
=
267735
,
cutoffs
=
[
20000
,
40000
,
200000
],
d_model
=
1024
,
...
...
@@ -96,7 +97,8 @@ class TransfoXLConfig(PretrainedConfig):
proj_init_std
=
0.01
,
init_std
=
0.02
,
layer_norm_epsilon
=
1e-5
,
**
kwargs
):
**
kwargs
):
"""Constructs TransfoXLConfig.
"""
super
(
TransfoXLConfig
,
self
).
__init__
(
**
kwargs
)
...
...
transformers/configuration_utils.py
View file @
54abc67a
...
...
@@ -15,8 +15,7 @@
# limitations under the License.
""" Configuration base class and utilities."""
from
__future__
import
(
absolute_import
,
division
,
print_function
,
unicode_literals
)
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
copy
import
json
...
...
@@ -24,10 +23,12 @@ import logging
import
os
from
io
import
open
from
.file_utils
import
CONFIG_NAME
,
cached_path
,
is_remote_url
,
hf_bucket_url
from
.file_utils
import
CONFIG_NAME
,
cached_path
,
hf_bucket_url
,
is_remote_url
logger
=
logging
.
getLogger
(
__name__
)
class
PretrainedConfig
(
object
):
r
""" Base class for all configuration classes.
Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations.
...
...
@@ -50,36 +51,36 @@ class PretrainedConfig(object):
def
__init__
(
self
,
**
kwargs
):
# Attributes with defaults
self
.
output_attentions
=
kwargs
.
pop
(
'
output_attentions
'
,
False
)
self
.
output_hidden_states
=
kwargs
.
pop
(
'
output_hidden_states
'
,
False
)
self
.
output_past
=
kwargs
.
pop
(
'
output_past
'
,
True
)
# Not used by all models
self
.
torchscript
=
kwargs
.
pop
(
'
torchscript
'
,
False
)
# Only used by PyTorch models
self
.
use_bfloat16
=
kwargs
.
pop
(
'
use_bfloat16
'
,
False
)
self
.
pruned_heads
=
kwargs
.
pop
(
'
pruned_heads
'
,
{})
self
.
output_attentions
=
kwargs
.
pop
(
"
output_attentions
"
,
False
)
self
.
output_hidden_states
=
kwargs
.
pop
(
"
output_hidden_states
"
,
False
)
self
.
output_past
=
kwargs
.
pop
(
"
output_past
"
,
True
)
# Not used by all models
self
.
torchscript
=
kwargs
.
pop
(
"
torchscript
"
,
False
)
# Only used by PyTorch models
self
.
use_bfloat16
=
kwargs
.
pop
(
"
use_bfloat16
"
,
False
)
self
.
pruned_heads
=
kwargs
.
pop
(
"
pruned_heads
"
,
{})
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
self
.
is_decoder
=
kwargs
.
pop
(
'
is_decoder
'
,
False
)
self
.
is_decoder
=
kwargs
.
pop
(
"
is_decoder
"
,
False
)
# Parameters for sequence generation
self
.
max_length
=
kwargs
.
pop
(
'
max_length
'
,
20
)
self
.
do_sample
=
kwargs
.
pop
(
'
do_sample
'
,
False
)
self
.
num_beams
=
kwargs
.
pop
(
'
num_beams
'
,
1
)
self
.
temperature
=
kwargs
.
pop
(
'
temperature
'
,
1.0
)
self
.
top_k
=
kwargs
.
pop
(
'
top_k
'
,
50
)
self
.
top_p
=
kwargs
.
pop
(
'
top_p
'
,
1.0
)
self
.
repetition_penalty
=
kwargs
.
pop
(
'
repetition_penalty
'
,
1.0
)
self
.
bos_token_id
=
kwargs
.
pop
(
'
bos_token_id
'
,
0
)
self
.
pad_token_id
=
kwargs
.
pop
(
'
pad_token_id
'
,
0
)
self
.
eos_token_ids
=
kwargs
.
pop
(
'
eos_token_ids
'
,
0
)
self
.
length_penalty
=
kwargs
.
pop
(
'
length_penalty
'
,
1.
)
self
.
num_return_sequences
=
kwargs
.
pop
(
'
num_return_sequences
'
,
1
)
self
.
max_length
=
kwargs
.
pop
(
"
max_length
"
,
20
)
self
.
do_sample
=
kwargs
.
pop
(
"
do_sample
"
,
False
)
self
.
num_beams
=
kwargs
.
pop
(
"
num_beams
"
,
1
)
self
.
temperature
=
kwargs
.
pop
(
"
temperature
"
,
1.0
)
self
.
top_k
=
kwargs
.
pop
(
"
top_k
"
,
50
)
self
.
top_p
=
kwargs
.
pop
(
"
top_p
"
,
1.0
)
self
.
repetition_penalty
=
kwargs
.
pop
(
"
repetition_penalty
"
,
1.0
)
self
.
bos_token_id
=
kwargs
.
pop
(
"
bos_token_id
"
,
0
)
self
.
pad_token_id
=
kwargs
.
pop
(
"
pad_token_id
"
,
0
)
self
.
eos_token_ids
=
kwargs
.
pop
(
"
eos_token_ids
"
,
0
)
self
.
length_penalty
=
kwargs
.
pop
(
"
length_penalty
"
,
1.
0
)
self
.
num_return_sequences
=
kwargs
.
pop
(
"
num_return_sequences
"
,
1
)
# Fine-tuning task arguments
self
.
finetuning_task
=
kwargs
.
pop
(
'
finetuning_task
'
,
None
)
self
.
num_labels
=
kwargs
.
pop
(
'
num_labels
'
,
2
)
self
.
id2label
=
kwargs
.
pop
(
'
id2label
'
,
{
i
:
'
LABEL_{}
'
.
format
(
i
)
for
i
in
range
(
self
.
num_labels
)})
self
.
finetuning_task
=
kwargs
.
pop
(
"
finetuning_task
"
,
None
)
self
.
num_labels
=
kwargs
.
pop
(
"
num_labels
"
,
2
)
self
.
id2label
=
kwargs
.
pop
(
"
id2label
"
,
{
i
:
"
LABEL_{}
"
.
format
(
i
)
for
i
in
range
(
self
.
num_labels
)})
self
.
id2label
=
dict
((
int
(
key
),
value
)
for
key
,
value
in
self
.
id2label
.
items
())
self
.
label2id
=
kwargs
.
pop
(
'
label2id
'
,
dict
(
zip
(
self
.
id2label
.
values
(),
self
.
id2label
.
keys
())))
self
.
label2id
=
kwargs
.
pop
(
"
label2id
"
,
dict
(
zip
(
self
.
id2label
.
values
(),
self
.
id2label
.
keys
())))
self
.
label2id
=
dict
((
key
,
int
(
value
))
for
key
,
value
in
self
.
label2id
.
items
())
# Additional attributes without default values
...
...
@@ -94,7 +95,9 @@ class PretrainedConfig(object):
""" Save a configuration object to the directory `save_directory`, so that it
can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method.
"""
assert
os
.
path
.
isdir
(
save_directory
),
"Saving path should be a directory where the model and configuration can be saved"
assert
os
.
path
.
isdir
(
save_directory
),
"Saving path should be a directory where the model and configuration can be saved"
# If we save using the predefined names, we can load using `from_pretrained`
output_config_file
=
os
.
path
.
join
(
save_directory
,
CONFIG_NAME
)
...
...
@@ -153,11 +156,11 @@ class PretrainedConfig(object):
assert unused_kwargs == {'foo': False}
"""
cache_dir
=
kwargs
.
pop
(
'
cache_dir
'
,
None
)
force_download
=
kwargs
.
pop
(
'
force_download
'
,
False
)
resume_download
=
kwargs
.
pop
(
'
resume_download
'
,
False
)
proxies
=
kwargs
.
pop
(
'
proxies
'
,
None
)
return_unused_kwargs
=
kwargs
.
pop
(
'
return_unused_kwargs
'
,
False
)
cache_dir
=
kwargs
.
pop
(
"
cache_dir
"
,
None
)
force_download
=
kwargs
.
pop
(
"
force_download
"
,
False
)
resume_download
=
kwargs
.
pop
(
"
resume_download
"
,
False
)
proxies
=
kwargs
.
pop
(
"
proxies
"
,
None
)
return_unused_kwargs
=
kwargs
.
pop
(
"
return_unused_kwargs
"
,
False
)
if
pretrained_model_name_or_path
in
cls
.
pretrained_config_archive_map
:
config_file
=
cls
.
pretrained_config_archive_map
[
pretrained_model_name_or_path
]
...
...
@@ -170,37 +173,48 @@ class PretrainedConfig(object):
try
:
# Load from URL or cache if already cached
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
,
resume_download
=
resume_download
)
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
,
resume_download
=
resume_download
,
)
# Load config
config
=
cls
.
from_json_file
(
resolved_config_file
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_config_archive_map
:
msg
=
"Couldn't reach server at '{}' to download pretrained model configuration file."
.
format
(
config_file
)
config_file
)
else
:
msg
=
"Model name '{}' was not found in model name list ({}). "
\
"We assumed '{}' was a path or url to a configuration file named {} or "
\
msg
=
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url to a configuration file named {} or "
"a directory containing such a file but couldn't find any such file at this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
cls
.
pretrained_config_archive_map
.
keys
()),
config_file
,
CONFIG_NAME
)
", "
.
join
(
cls
.
pretrained_config_archive_map
.
keys
()),
config_file
,
CONFIG_NAME
,
)
)
raise
EnvironmentError
(
msg
)
except
json
.
JSONDecodeError
:
msg
=
"Couldn't reach server at '{}' to download configuration file or "
\
"configuration file is not a valid JSON file. "
\
msg
=
(
"Couldn't reach server at '{}' to download configuration file or "
"configuration file is not a valid JSON file. "
"Please check network or file content here: {}."
.
format
(
config_file
,
resolved_config_file
)
)
raise
EnvironmentError
(
msg
)
if
resolved_config_file
==
config_file
:
logger
.
info
(
"loading configuration file {}"
.
format
(
config_file
))
else
:
logger
.
info
(
"loading configuration file {} from cache at {}"
.
format
(
config_file
,
resolved_config_file
))
logger
.
info
(
"loading configuration file {} from cache at {}"
.
format
(
config_file
,
resolved_config_file
))
if
hasattr
(
config
,
'
pruned_heads
'
):
if
hasattr
(
config
,
"
pruned_heads
"
):
config
.
pruned_heads
=
dict
((
int
(
key
),
value
)
for
key
,
value
in
config
.
pruned_heads
.
items
())
# Update config with kwargs if needed
...
...
@@ -226,7 +240,7 @@ class PretrainedConfig(object):
@
classmethod
def
from_json_file
(
cls
,
json_file
):
"""Constructs a `Config` from a json file of parameters."""
with
open
(
json_file
,
"r"
,
encoding
=
'
utf-8
'
)
as
reader
:
with
open
(
json_file
,
"r"
,
encoding
=
"
utf-8
"
)
as
reader
:
text
=
reader
.
read
()
dict_obj
=
json
.
loads
(
text
)
return
cls
(
**
dict_obj
)
...
...
@@ -248,5 +262,5 @@ class PretrainedConfig(object):
def
to_json_file
(
self
,
json_file_path
):
""" Save this instance to a json file."""
with
open
(
json_file_path
,
"w"
,
encoding
=
'
utf-8
'
)
as
writer
:
with
open
(
json_file_path
,
"w"
,
encoding
=
"
utf-8
"
)
as
writer
:
writer
.
write
(
self
.
to_json_string
())
transformers/configuration_xlm.py
View file @
54abc67a
...
...
@@ -15,26 +15,24 @@
""" XLM configuration """
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
json
import
logging
import
sys
from
io
import
open
from
.configuration_utils
import
PretrainedConfig
logger
=
logging
.
getLogger
(
__name__
)
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
'
xlm-mlm-en-2048
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json"
,
'
xlm-mlm-ende-1024
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-config.json"
,
'
xlm-mlm-enfr-1024
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-config.json"
,
'
xlm-mlm-enro-1024
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-config.json"
,
'
xlm-mlm-tlm-xnli15-1024
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-config.json"
,
'
xlm-mlm-xnli15-1024
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-config.json"
,
'
xlm-clm-enfr-1024
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-config.json"
,
'
xlm-clm-ende-1024
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-config.json"
,
'
xlm-mlm-17-1280
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-config.json"
,
'
xlm-mlm-100-1280
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-config.json"
,
"
xlm-mlm-en-2048
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json"
,
"
xlm-mlm-ende-1024
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-config.json"
,
"
xlm-mlm-enfr-1024
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-config.json"
,
"
xlm-mlm-enro-1024
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-config.json"
,
"
xlm-mlm-tlm-xnli15-1024
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-config.json"
,
"
xlm-mlm-xnli15-1024
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-config.json"
,
"
xlm-clm-enfr-1024
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-config.json"
,
"
xlm-clm-ende-1024
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-config.json"
,
"
xlm-mlm-17-1280
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-config.json"
,
"
xlm-mlm-100-1280
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-config.json"
,
}
...
...
@@ -78,9 +76,11 @@ class XLMConfig(PretrainedConfig):
-1 means no clamping.
same_length: bool, whether to use the same attention length for each token.
"""
pretrained_config_archive_map
=
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
def
__init__
(
self
,
def
__init__
(
self
,
vocab_size
=
30145
,
emb_dim
=
2048
,
n_layers
=
12
,
...
...
@@ -103,7 +103,7 @@ class XLMConfig(PretrainedConfig):
unk_index
=
3
,
mask_index
=
5
,
is_encoder
=
True
,
summary_type
=
'
first
'
,
summary_type
=
"
first
"
,
summary_use_proj
=
True
,
summary_activation
=
None
,
summary_proj_to_labels
=
True
,
...
...
@@ -112,7 +112,8 @@ class XLMConfig(PretrainedConfig):
end_n_top
=
5
,
mask_token_id
=
0
,
lang_id
=
0
,
**
kwargs
):
**
kwargs
):
"""Constructs XLMConfig.
"""
super
(
XLMConfig
,
self
).
__init__
(
**
kwargs
)
...
...
transformers/configuration_xlm_roberta.py
View file @
54abc67a
...
...
@@ -15,22 +15,22 @@
# limitations under the License.
""" XLM-RoBERTa configuration """
from
__future__
import
(
absolute_import
,
division
,
print_function
,
unicode_literals
)
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
logging
from
.configuration_roberta
import
RobertaConfig
logger
=
logging
.
getLogger
(
__name__
)
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
'
xlm-roberta-base
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-base-config.json"
,
'
xlm-roberta-large
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-config.json"
,
'
xlm-roberta-large-finetuned-conll02-dutch
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-dutch-config.json"
,
'
xlm-roberta-large-finetuned-conll02-spanish
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-spanish-config.json"
,
'
xlm-roberta-large-finetuned-conll03-english
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-english-config.json"
,
'
xlm-roberta-large-finetuned-conll03-german
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-german-config.json"
,
"
xlm-roberta-base
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-base-config.json"
,
"
xlm-roberta-large
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-config.json"
,
"
xlm-roberta-large-finetuned-conll02-dutch
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-dutch-config.json"
,
"
xlm-roberta-large-finetuned-conll02-spanish
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-spanish-config.json"
,
"
xlm-roberta-large-finetuned-conll03-english
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-english-config.json"
,
"
xlm-roberta-large-finetuned-conll03-german
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-german-config.json"
,
}
...
...
transformers/configuration_xlnet.py
View file @
54abc67a
...
...
@@ -16,18 +16,16 @@
""" XLNet configuration """
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
json
import
logging
import
sys
from
io
import
open
from
.configuration_utils
import
PretrainedConfig
logger
=
logging
.
getLogger
(
__name__
)
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
'
xlnet-base-cased
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-config.json"
,
'
xlnet-large-cased
'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json"
,
"
xlnet-base-cased
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-config.json"
,
"
xlnet-large-cased
"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json"
,
}
...
...
@@ -69,9 +67,11 @@ class XLNetConfig(PretrainedConfig):
same_length: bool, whether to use the same attention length for each token.
finetuning_task: name of the glue task on which the model was fine-tuned if any
"""
pretrained_config_archive_map
=
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
def
__init__
(
self
,
def
__init__
(
self
,
vocab_size
=
32000
,
d_model
=
1024
,
n_layer
=
24
,
...
...
@@ -88,13 +88,14 @@ class XLNetConfig(PretrainedConfig):
bi_data
=
False
,
clamp_len
=-
1
,
same_length
=
False
,
summary_type
=
'
last
'
,
summary_type
=
"
last
"
,
summary_use_proj
=
True
,
summary_activation
=
'
tanh
'
,
summary_activation
=
"
tanh
"
,
summary_last_dropout
=
0.1
,
start_n_top
=
5
,
end_n_top
=
5
,
**
kwargs
):
**
kwargs
):
"""Constructs XLNetConfig.
"""
super
(
XLNetConfig
,
self
).
__init__
(
**
kwargs
)
...
...
transformers/convert_albert_original_tf_checkpoint_to_pytorch.py
View file @
54abc67a
...
...
@@ -14,16 +14,16 @@
# limitations under the License.
"""Convert ALBERT checkpoint."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
absolute_import
,
division
,
print_function
import
argparse
import
logging
import
torch
from
transformers
import
AlbertConfig
,
AlbertForMaskedLM
,
load_tf_weights_in_albert
import
logging
logging
.
basicConfig
(
level
=
logging
.
INFO
)
...
...
@@ -43,25 +43,20 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pyt
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
## Required parameters
parser
.
add_argument
(
"--tf_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the TensorFlow checkpoint path."
)
parser
.
add_argument
(
"--albert_config_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained ALBERT model.
\n
"
"This specifies the model architecture."
)
parser
.
add_argument
(
"--pytorch_dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output PyTorch model."
)
# Required parameters
parser
.
add_argument
(
"--tf_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the TensorFlow checkpoint path."
)
parser
.
add_argument
(
"--albert_config_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained ALBERT model.
\n
"
"This specifies the model architecture."
,
)
parser
.
add_argument
(
"--pytorch_dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output PyTorch model."
)
args
=
parser
.
parse_args
()
convert_tf_checkpoint_to_pytorch
(
args
.
tf_checkpoint_path
,
args
.
albert_config_file
,
args
.
pytorch_dump_path
)
\ No newline at end of file
convert_tf_checkpoint_to_pytorch
(
args
.
tf_checkpoint_path
,
args
.
albert_config_file
,
args
.
pytorch_dump_path
)
Prev
1
2
3
4
5
6
7
8
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment