Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ce50305e
Unverified
Commit
ce50305e
authored
Dec 22, 2019
by
Aymeric Augustin
Committed by
GitHub
Dec 22, 2019
Browse files
Merge pull request #2270 from aaugustin/remove-python-2
Remove support for Python 2
parents
b6ea0f43
1a948d70
Changes
155
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
36 additions
and
138 deletions
+36
-138
src/transformers/convert_roberta_original_pytorch_checkpoint_to_pytorch.py
...convert_roberta_original_pytorch_checkpoint_to_pytorch.py
+0
-1
src/transformers/convert_t5_original_tf_checkpoint_to_pytorch.py
...nsformers/convert_t5_original_tf_checkpoint_to_pytorch.py
+0
-1
src/transformers/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py
...s/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py
+1
-8
src/transformers/convert_xlm_original_pytorch_checkpoint_to_pytorch.py
...ers/convert_xlm_original_pytorch_checkpoint_to_pytorch.py
+0
-2
src/transformers/convert_xlnet_original_tf_checkpoint_to_pytorch.py
...ormers/convert_xlnet_original_tf_checkpoint_to_pytorch.py
+0
-1
src/transformers/data/metrics/squad_metrics.py
src/transformers/data/metrics/squad_metrics.py
+0
-1
src/transformers/data/processors/utils.py
src/transformers/data/processors/utils.py
+1
-8
src/transformers/data/processors/xnli.py
src/transformers/data/processors/xnli.py
+0
-1
src/transformers/file_utils.py
src/transformers/file_utils.py
+18
-46
src/transformers/hf_api.py
src/transformers/hf_api.py
+13
-50
src/transformers/modelcard.py
src/transformers/modelcard.py
+0
-2
src/transformers/modeling_auto.py
src/transformers/modeling_auto.py
+0
-1
src/transformers/modeling_bert.py
src/transformers/modeling_bert.py
+2
-8
src/transformers/modeling_camembert.py
src/transformers/modeling_camembert.py
+0
-1
src/transformers/modeling_ctrl.py
src/transformers/modeling_ctrl.py
+0
-1
src/transformers/modeling_distilbert.py
src/transformers/modeling_distilbert.py
+1
-1
src/transformers/modeling_encoder_decoder.py
src/transformers/modeling_encoder_decoder.py
+0
-1
src/transformers/modeling_gpt2.py
src/transformers/modeling_gpt2.py
+0
-1
src/transformers/modeling_mmbt.py
src/transformers/modeling_mmbt.py
+0
-1
src/transformers/modeling_openai.py
src/transformers/modeling_openai.py
+0
-2
No files found.
src/transformers/convert_roberta_original_pytorch_checkpoint_to_pytorch.py
View file @
ce50305e
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
# limitations under the License.
# limitations under the License.
"""Convert RoBERTa checkpoint."""
"""Convert RoBERTa checkpoint."""
from
__future__
import
absolute_import
,
division
,
print_function
import
argparse
import
argparse
import
logging
import
logging
...
...
src/transformers/convert_t5_original_tf_checkpoint_to_pytorch.py
View file @
ce50305e
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
# limitations under the License.
# limitations under the License.
"""Convert T5 checkpoint."""
"""Convert T5 checkpoint."""
from
__future__
import
absolute_import
,
division
,
print_function
import
argparse
import
argparse
import
logging
import
logging
...
...
src/transformers/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py
View file @
ce50305e
...
@@ -14,13 +14,12 @@
...
@@ -14,13 +14,12 @@
# limitations under the License.
# limitations under the License.
"""Convert Transformer XL checkpoint and datasets."""
"""Convert Transformer XL checkpoint and datasets."""
from
__future__
import
absolute_import
,
division
,
print_function
import
argparse
import
argparse
import
logging
import
logging
import
os
import
os
import
pickle
import
sys
import
sys
from
io
import
open
import
torch
import
torch
...
@@ -35,12 +34,6 @@ from transformers import (
...
@@ -35,12 +34,6 @@ from transformers import (
from
transformers.tokenization_transfo_xl
import
CORPUS_NAME
,
VOCAB_FILES_NAMES
from
transformers.tokenization_transfo_xl
import
CORPUS_NAME
,
VOCAB_FILES_NAMES
if
sys
.
version_info
[
0
]
==
2
:
import
cPickle
as
pickle
else
:
import
pickle
logging
.
basicConfig
(
level
=
logging
.
INFO
)
logging
.
basicConfig
(
level
=
logging
.
INFO
)
# We do this to be able to load python 2 datasets pickles
# We do this to be able to load python 2 datasets pickles
...
...
src/transformers/convert_xlm_original_pytorch_checkpoint_to_pytorch.py
View file @
ce50305e
...
@@ -14,12 +14,10 @@
...
@@ -14,12 +14,10 @@
# limitations under the License.
# limitations under the License.
"""Convert OpenAI GPT checkpoint."""
"""Convert OpenAI GPT checkpoint."""
from
__future__
import
absolute_import
,
division
,
print_function
import
argparse
import
argparse
import
json
import
json
import
logging
import
logging
from
io
import
open
import
numpy
import
numpy
import
torch
import
torch
...
...
src/transformers/convert_xlnet_original_tf_checkpoint_to_pytorch.py
View file @
ce50305e
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
# limitations under the License.
# limitations under the License.
"""Convert BERT checkpoint."""
"""Convert BERT checkpoint."""
from
__future__
import
absolute_import
,
division
,
print_function
import
argparse
import
argparse
import
logging
import
logging
...
...
src/transformers/data/metrics/squad_metrics.py
View file @
ce50305e
...
@@ -14,7 +14,6 @@ import logging
...
@@ -14,7 +14,6 @@ import logging
import
math
import
math
import
re
import
re
import
string
import
string
from
io
import
open
from
transformers.tokenization_bert
import
BasicTokenizer
from
transformers.tokenization_bert
import
BasicTokenizer
...
...
src/transformers/data/processors/utils.py
View file @
ce50305e
...
@@ -18,7 +18,6 @@ import copy
...
@@ -18,7 +18,6 @@ import copy
import
csv
import
csv
import
json
import
json
import
logging
import
logging
import
sys
from
...file_utils
import
is_tf_available
,
is_torch_available
from
...file_utils
import
is_tf_available
,
is_torch_available
...
@@ -98,13 +97,7 @@ class DataProcessor(object):
...
@@ -98,13 +97,7 @@ class DataProcessor(object):
def
_read_tsv
(
cls
,
input_file
,
quotechar
=
None
):
def
_read_tsv
(
cls
,
input_file
,
quotechar
=
None
):
"""Reads a tab separated value file."""
"""Reads a tab separated value file."""
with
open
(
input_file
,
"r"
,
encoding
=
"utf-8-sig"
)
as
f
:
with
open
(
input_file
,
"r"
,
encoding
=
"utf-8-sig"
)
as
f
:
reader
=
csv
.
reader
(
f
,
delimiter
=
"
\t
"
,
quotechar
=
quotechar
)
return
list
(
csv
.
reader
(
f
,
delimiter
=
"
\t
"
,
quotechar
=
quotechar
))
lines
=
[]
for
line
in
reader
:
if
sys
.
version_info
[
0
]
==
2
:
line
=
list
(
unicode
(
cell
,
"utf-8"
)
for
cell
in
line
)
# noqa: F821
lines
.
append
(
line
)
return
lines
class
SingleSentenceClassificationProcessor
(
DataProcessor
):
class
SingleSentenceClassificationProcessor
(
DataProcessor
):
...
...
src/transformers/data/processors/xnli.py
View file @
ce50305e
...
@@ -15,7 +15,6 @@
...
@@ -15,7 +15,6 @@
# limitations under the License.
# limitations under the License.
""" XNLI utils (dataset loading and evaluation) """
""" XNLI utils (dataset loading and evaluation) """
from
__future__
import
absolute_import
,
division
,
print_function
import
logging
import
logging
import
os
import
os
...
...
src/transformers/file_utils.py
View file @
ce50305e
...
@@ -3,7 +3,7 @@ Utilities for working with the local dataset cache.
...
@@ -3,7 +3,7 @@ Utilities for working with the local dataset cache.
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
Copyright by the AllenNLP authors.
Copyright by the AllenNLP authors.
"""
"""
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
fnmatch
import
fnmatch
import
json
import
json
...
@@ -14,11 +14,10 @@ import tempfile
...
@@ -14,11 +14,10 @@ import tempfile
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
functools
import
partial
,
wraps
from
functools
import
partial
,
wraps
from
hashlib
import
sha256
from
hashlib
import
sha256
from
io
import
open
from
urllib.parse
import
urlparse
import
boto3
import
boto3
import
requests
import
requests
import
six
from
botocore.config
import
Config
from
botocore.config
import
Config
from
botocore.exceptions
import
ClientError
from
botocore.exceptions
import
ClientError
from
filelock
import
FileLock
from
filelock
import
FileLock
...
@@ -66,10 +65,6 @@ except ImportError:
...
@@ -66,10 +65,6 @@ except ImportError:
)
)
default_cache_path
=
os
.
path
.
join
(
torch_cache_home
,
"transformers"
)
default_cache_path
=
os
.
path
.
join
(
torch_cache_home
,
"transformers"
)
try
:
from
urllib.parse
import
urlparse
except
ImportError
:
from
urlparse
import
urlparse
try
:
try
:
from
pathlib
import
Path
from
pathlib
import
Path
...
@@ -107,36 +102,20 @@ def is_tf_available():
...
@@ -107,36 +102,20 @@ def is_tf_available():
return
_tf_available
return
_tf_available
if
not
six
.
PY2
:
def
add_start_docstrings
(
*
docstr
):
def
docstring_decorator
(
fn
):
def
add_start_docstrings
(
*
docstr
):
fn
.
__doc__
=
""
.
join
(
docstr
)
+
fn
.
__doc__
def
docstring_decorator
(
fn
):
return
fn
fn
.
__doc__
=
""
.
join
(
docstr
)
+
fn
.
__doc__
return
fn
return
docstring_decorator
def
add_end_docstrings
(
*
docstr
):
return
docstring_decorator
def
docstring_decorator
(
fn
):
fn
.
__doc__
=
fn
.
__doc__
+
""
.
join
(
docstr
)
return
fn
return
docstring_decorator
def
add_end_docstrings
(
*
docstr
):
def
docstring_decorator
(
fn
):
fn
.
__doc__
=
fn
.
__doc__
+
""
.
join
(
docstr
)
return
fn
else
:
return
docstring_decorator
# Not possible to update class docstrings on python2
def
add_start_docstrings
(
*
docstr
):
def
docstring_decorator
(
fn
):
return
fn
return
docstring_decorator
def
add_end_docstrings
(
*
docstr
):
def
docstring_decorator
(
fn
):
return
fn
return
docstring_decorator
def
is_remote_url
(
url_or_filename
):
def
is_remote_url
(
url_or_filename
):
...
@@ -183,7 +162,7 @@ def filename_to_url(filename, cache_dir=None):
...
@@ -183,7 +162,7 @@ def filename_to_url(filename, cache_dir=None):
"""
"""
if
cache_dir
is
None
:
if
cache_dir
is
None
:
cache_dir
=
TRANSFORMERS_CACHE
cache_dir
=
TRANSFORMERS_CACHE
if
sys
.
version_info
[
0
]
==
3
and
isinstance
(
cache_dir
,
Path
):
if
isinstance
(
cache_dir
,
Path
):
cache_dir
=
str
(
cache_dir
)
cache_dir
=
str
(
cache_dir
)
cache_path
=
os
.
path
.
join
(
cache_dir
,
filename
)
cache_path
=
os
.
path
.
join
(
cache_dir
,
filename
)
...
@@ -218,9 +197,9 @@ def cached_path(
...
@@ -218,9 +197,9 @@ def cached_path(
"""
"""
if
cache_dir
is
None
:
if
cache_dir
is
None
:
cache_dir
=
TRANSFORMERS_CACHE
cache_dir
=
TRANSFORMERS_CACHE
if
sys
.
version_info
[
0
]
==
3
and
isinstance
(
url_or_filename
,
Path
):
if
isinstance
(
url_or_filename
,
Path
):
url_or_filename
=
str
(
url_or_filename
)
url_or_filename
=
str
(
url_or_filename
)
if
sys
.
version_info
[
0
]
==
3
and
isinstance
(
cache_dir
,
Path
):
if
isinstance
(
cache_dir
,
Path
):
cache_dir
=
str
(
cache_dir
)
cache_dir
=
str
(
cache_dir
)
if
is_remote_url
(
url_or_filename
):
if
is_remote_url
(
url_or_filename
):
...
@@ -297,7 +276,7 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None):
...
@@ -297,7 +276,7 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None):
ua
=
"transformers/{}; python/{}"
.
format
(
__version__
,
sys
.
version
.
split
()[
0
])
ua
=
"transformers/{}; python/{}"
.
format
(
__version__
,
sys
.
version
.
split
()[
0
])
if
isinstance
(
user_agent
,
dict
):
if
isinstance
(
user_agent
,
dict
):
ua
+=
"; "
+
"; "
.
join
(
"{}/{}"
.
format
(
k
,
v
)
for
k
,
v
in
user_agent
.
items
())
ua
+=
"; "
+
"; "
.
join
(
"{}/{}"
.
format
(
k
,
v
)
for
k
,
v
in
user_agent
.
items
())
elif
isinstance
(
user_agent
,
s
ix
.
string_types
):
elif
isinstance
(
user_agent
,
s
tr
):
ua
+=
"; "
+
user_agent
ua
+=
"; "
+
user_agent
headers
=
{
"user-agent"
:
ua
}
headers
=
{
"user-agent"
:
ua
}
if
resume_size
>
0
:
if
resume_size
>
0
:
...
@@ -331,9 +310,7 @@ def get_from_cache(
...
@@ -331,9 +310,7 @@ def get_from_cache(
"""
"""
if
cache_dir
is
None
:
if
cache_dir
is
None
:
cache_dir
=
TRANSFORMERS_CACHE
cache_dir
=
TRANSFORMERS_CACHE
if
sys
.
version_info
[
0
]
==
3
and
isinstance
(
cache_dir
,
Path
):
if
isinstance
(
cache_dir
,
Path
):
cache_dir
=
str
(
cache_dir
)
if
sys
.
version_info
[
0
]
==
2
and
not
isinstance
(
cache_dir
,
str
):
cache_dir
=
str
(
cache_dir
)
cache_dir
=
str
(
cache_dir
)
if
not
os
.
path
.
exists
(
cache_dir
):
if
not
os
.
path
.
exists
(
cache_dir
):
...
@@ -352,8 +329,6 @@ def get_from_cache(
...
@@ -352,8 +329,6 @@ def get_from_cache(
except
(
EnvironmentError
,
requests
.
exceptions
.
Timeout
):
except
(
EnvironmentError
,
requests
.
exceptions
.
Timeout
):
etag
=
None
etag
=
None
if
sys
.
version_info
[
0
]
==
2
and
etag
is
not
None
:
etag
=
etag
.
decode
(
"utf-8"
)
filename
=
url_to_filename
(
url
,
etag
)
filename
=
url_to_filename
(
url
,
etag
)
# get cache path to put the file
# get cache path to put the file
...
@@ -417,9 +392,6 @@ def get_from_cache(
...
@@ -417,9 +392,6 @@ def get_from_cache(
meta
=
{
"url"
:
url
,
"etag"
:
etag
}
meta
=
{
"url"
:
url
,
"etag"
:
etag
}
meta_path
=
cache_path
+
".json"
meta_path
=
cache_path
+
".json"
with
open
(
meta_path
,
"w"
)
as
meta_file
:
with
open
(
meta_path
,
"w"
)
as
meta_file
:
output_string
=
json
.
dumps
(
meta
)
json
.
dump
(
meta
,
meta_file
)
if
sys
.
version_info
[
0
]
==
2
and
isinstance
(
output_string
,
str
):
output_string
=
unicode
(
output_string
,
"utf-8"
)
# noqa: F821
meta_file
.
write
(
output_string
)
return
cache_path
return
cache_path
src/transformers/hf_api.py
View file @
ce50305e
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
__future__
import
absolute_import
,
division
,
print_function
import
io
import
io
import
os
import
os
...
@@ -20,7 +20,6 @@ from os.path import expanduser
...
@@ -20,7 +20,6 @@ from os.path import expanduser
from
typing
import
List
from
typing
import
List
import
requests
import
requests
import
six
from
tqdm
import
tqdm
from
tqdm
import
tqdm
...
@@ -28,14 +27,7 @@ ENDPOINT = "https://huggingface.co"
...
@@ -28,14 +27,7 @@ ENDPOINT = "https://huggingface.co"
class
S3Obj
:
class
S3Obj
:
def
__init__
(
def
__init__
(
self
,
filename
:
str
,
LastModified
:
str
,
ETag
:
str
,
Size
:
int
,
**
kwargs
):
self
,
filename
,
# type: str
LastModified
,
# type: str
ETag
,
# type: str
Size
,
# type: int
**
kwargs
):
self
.
filename
=
filename
self
.
filename
=
filename
self
.
LastModified
=
LastModified
self
.
LastModified
=
LastModified
self
.
ETag
=
ETag
self
.
ETag
=
ETag
...
@@ -43,13 +35,7 @@ class S3Obj:
...
@@ -43,13 +35,7 @@ class S3Obj:
class
PresignedUrl
:
class
PresignedUrl
:
def
__init__
(
def
__init__
(
self
,
write
:
str
,
access
:
str
,
type
:
str
,
**
kwargs
):
self
,
write
,
# type: str
access
,
# type: str
type
,
# type: str
**
kwargs
):
self
.
write
=
write
self
.
write
=
write
self
.
access
=
access
self
.
access
=
access
self
.
type
=
type
# mime-type to send to S3.
self
.
type
=
type
# mime-type to send to S3.
...
@@ -59,12 +45,7 @@ class HfApi:
...
@@ -59,12 +45,7 @@ class HfApi:
def
__init__
(
self
,
endpoint
=
None
):
def
__init__
(
self
,
endpoint
=
None
):
self
.
endpoint
=
endpoint
if
endpoint
is
not
None
else
ENDPOINT
self
.
endpoint
=
endpoint
if
endpoint
is
not
None
else
ENDPOINT
def
login
(
def
login
(
self
,
username
:
str
,
password
:
str
)
->
str
:
self
,
username
,
# type: str
password
,
# type: str
):
# type: (...) -> str
"""
"""
Call HF API to sign in a user and get a token if credentials are valid.
Call HF API to sign in a user and get a token if credentials are valid.
...
@@ -80,10 +61,7 @@ class HfApi:
...
@@ -80,10 +61,7 @@ class HfApi:
d
=
r
.
json
()
d
=
r
.
json
()
return
d
[
"token"
]
return
d
[
"token"
]
def
whoami
(
def
whoami
(
self
,
token
:
str
)
->
str
:
self
,
token
,
# type: str
):
# type: (...) -> str
"""
"""
Call HF API to know "whoami"
Call HF API to know "whoami"
"""
"""
...
@@ -93,8 +71,7 @@ class HfApi:
...
@@ -93,8 +71,7 @@ class HfApi:
d
=
r
.
json
()
d
=
r
.
json
()
return
d
[
"user"
]
return
d
[
"user"
]
def
logout
(
self
,
token
):
def
logout
(
self
,
token
:
str
)
->
None
:
# type: (...) -> None
"""
"""
Call HF API to log out.
Call HF API to log out.
"""
"""
...
@@ -102,19 +79,17 @@ class HfApi:
...
@@ -102,19 +79,17 @@ class HfApi:
r
=
requests
.
post
(
path
,
headers
=
{
"authorization"
:
"Bearer {}"
.
format
(
token
)})
r
=
requests
.
post
(
path
,
headers
=
{
"authorization"
:
"Bearer {}"
.
format
(
token
)})
r
.
raise_for_status
()
r
.
raise_for_status
()
def
presign
(
self
,
token
,
filename
):
def
presign
(
self
,
token
:
str
,
filename
)
->
PresignedUrl
:
# type: (...) -> PresignedUrl
"""
"""
Call HF API to get a presigned url to upload `filename` to S3.
Call HF API to get a presigned url to upload `filename` to S3.
"""
"""
path
=
"{}/api/presign"
.
format
(
self
.
endpoint
)
path
=
"{}/api/presign"
.
format
(
self
.
endpoint
)
r
=
requests
.
post
(
path
,
headers
=
{
"authorization"
:
"Bearer {}"
.
format
(
token
)},
json
=
{
"filename"
:
filename
}
,
)
r
=
requests
.
post
(
path
,
headers
=
{
"authorization"
:
"Bearer {}"
.
format
(
token
)},
json
=
{
"filename"
:
filename
})
r
.
raise_for_status
()
r
.
raise_for_status
()
d
=
r
.
json
()
d
=
r
.
json
()
return
PresignedUrl
(
**
d
)
return
PresignedUrl
(
**
d
)
def
presign_and_upload
(
self
,
token
,
filename
,
filepath
):
def
presign_and_upload
(
self
,
token
:
str
,
filename
,
filepath
)
->
str
:
# type: (...) -> str
"""
"""
Get a presigned url, then upload file to S3.
Get a presigned url, then upload file to S3.
...
@@ -158,13 +133,10 @@ class TqdmProgressFileReader:
...
@@ -158,13 +133,10 @@ class TqdmProgressFileReader:
def
__init__
(
self
,
f
:
io
.
BufferedReader
):
def
__init__
(
self
,
f
:
io
.
BufferedReader
):
self
.
f
=
f
self
.
f
=
f
self
.
total_size
=
os
.
fstat
(
f
.
fileno
()).
st_size
# type: int
self
.
total_size
=
os
.
fstat
(
f
.
fileno
()).
st_size
self
.
pbar
=
tqdm
(
total
=
self
.
total_size
,
leave
=
False
)
self
.
pbar
=
tqdm
(
total
=
self
.
total_size
,
leave
=
False
)
if
six
.
PY3
:
self
.
read
=
f
.
read
# does not work unless PY3
f
.
read
=
self
.
_read
# no big deal as the CLI does not currently support PY2 anyways.
self
.
read
=
f
.
read
f
.
read
=
self
.
_read
def
_read
(
self
,
n
=-
1
):
def
_read
(
self
,
n
=-
1
):
self
.
pbar
.
update
(
n
)
self
.
pbar
.
update
(
n
)
...
@@ -182,16 +154,7 @@ class HfFolder:
...
@@ -182,16 +154,7 @@ class HfFolder:
"""
"""
Save token, creating folder as needed.
Save token, creating folder as needed.
"""
"""
if
six
.
PY3
:
os
.
makedirs
(
os
.
path
.
dirname
(
cls
.
path_token
),
exist_ok
=
True
)
os
.
makedirs
(
os
.
path
.
dirname
(
cls
.
path_token
),
exist_ok
=
True
)
else
:
# Python 2
try
:
os
.
makedirs
(
os
.
path
.
dirname
(
cls
.
path_token
))
except
OSError
as
e
:
if
e
.
errno
!=
os
.
errno
.
EEXIST
:
raise
e
pass
with
open
(
cls
.
path_token
,
"w+"
)
as
f
:
with
open
(
cls
.
path_token
,
"w+"
)
as
f
:
f
.
write
(
token
)
f
.
write
(
token
)
...
...
src/transformers/modelcard.py
View file @
ce50305e
...
@@ -14,13 +14,11 @@
...
@@ -14,13 +14,11 @@
# limitations under the License.
# limitations under the License.
""" Configuration base class and utilities."""
""" Configuration base class and utilities."""
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
copy
import
copy
import
json
import
json
import
logging
import
logging
import
os
import
os
from
io
import
open
from
.configuration_auto
import
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
from
.configuration_auto
import
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
from
.file_utils
import
(
from
.file_utils
import
(
...
...
src/transformers/modeling_auto.py
View file @
ce50305e
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
# limitations under the License.
# limitations under the License.
""" Auto Model class. """
""" Auto Model class. """
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
logging
import
logging
...
...
src/transformers/modeling_bert.py
View file @
ce50305e
...
@@ -15,12 +15,10 @@
...
@@ -15,12 +15,10 @@
# limitations under the License.
# limitations under the License.
"""PyTorch BERT model. """
"""PyTorch BERT model. """
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
logging
import
logging
import
math
import
math
import
os
import
os
import
sys
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -339,9 +337,7 @@ class BertIntermediate(nn.Module):
...
@@ -339,9 +337,7 @@ class BertIntermediate(nn.Module):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
BertIntermediate
,
self
).
__init__
()
super
(
BertIntermediate
,
self
).
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
intermediate_size
)
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
intermediate_size
)
if
isinstance
(
config
.
hidden_act
,
str
)
or
(
if
isinstance
(
config
.
hidden_act
,
str
):
sys
.
version_info
[
0
]
==
2
and
isinstance
(
config
.
hidden_act
,
unicode
)
# noqa: F821
):
self
.
intermediate_act_fn
=
ACT2FN
[
config
.
hidden_act
]
self
.
intermediate_act_fn
=
ACT2FN
[
config
.
hidden_act
]
else
:
else
:
self
.
intermediate_act_fn
=
config
.
hidden_act
self
.
intermediate_act_fn
=
config
.
hidden_act
...
@@ -461,9 +457,7 @@ class BertPredictionHeadTransform(nn.Module):
...
@@ -461,9 +457,7 @@ class BertPredictionHeadTransform(nn.Module):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
BertPredictionHeadTransform
,
self
).
__init__
()
super
(
BertPredictionHeadTransform
,
self
).
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
if
isinstance
(
config
.
hidden_act
,
str
)
or
(
if
isinstance
(
config
.
hidden_act
,
str
):
sys
.
version_info
[
0
]
==
2
and
isinstance
(
config
.
hidden_act
,
unicode
)
# noqa: F821
):
self
.
transform_act_fn
=
ACT2FN
[
config
.
hidden_act
]
self
.
transform_act_fn
=
ACT2FN
[
config
.
hidden_act
]
else
:
else
:
self
.
transform_act_fn
=
config
.
hidden_act
self
.
transform_act_fn
=
config
.
hidden_act
...
...
src/transformers/modeling_camembert.py
View file @
ce50305e
...
@@ -15,7 +15,6 @@
...
@@ -15,7 +15,6 @@
# limitations under the License.
# limitations under the License.
"""PyTorch CamemBERT model. """
"""PyTorch CamemBERT model. """
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
logging
import
logging
...
...
src/transformers/modeling_ctrl.py
View file @
ce50305e
...
@@ -15,7 +15,6 @@
...
@@ -15,7 +15,6 @@
# limitations under the License.
# limitations under the License.
""" PyTorch CTRL model."""
""" PyTorch CTRL model."""
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
logging
import
logging
...
...
src/transformers/modeling_distilbert.py
View file @
ce50305e
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
and in part from HuggingFace PyTorch version of Google AI Bert model (https://github.com/google-research/bert)
and in part from HuggingFace PyTorch version of Google AI Bert model (https://github.com/google-research/bert)
"""
"""
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
copy
import
copy
import
logging
import
logging
...
...
src/transformers/modeling_encoder_decoder.py
View file @
ce50305e
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
# limitations under the License.
# limitations under the License.
""" Classes to support Encoder-Decoder architectures """
""" Classes to support Encoder-Decoder architectures """
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
logging
import
logging
import
os
import
os
...
...
src/transformers/modeling_gpt2.py
View file @
ce50305e
...
@@ -15,7 +15,6 @@
...
@@ -15,7 +15,6 @@
# limitations under the License.
# limitations under the License.
"""PyTorch OpenAI GPT-2 model."""
"""PyTorch OpenAI GPT-2 model."""
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
logging
import
logging
import
math
import
math
...
...
src/transformers/modeling_mmbt.py
View file @
ce50305e
...
@@ -15,7 +15,6 @@
...
@@ -15,7 +15,6 @@
# limitations under the License.
# limitations under the License.
"""PyTorch MMBT model. """
"""PyTorch MMBT model. """
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
logging
import
logging
...
...
src/transformers/modeling_openai.py
View file @
ce50305e
...
@@ -15,13 +15,11 @@
...
@@ -15,13 +15,11 @@
# limitations under the License.
# limitations under the License.
"""PyTorch OpenAI GPT model."""
"""PyTorch OpenAI GPT model."""
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
json
import
json
import
logging
import
logging
import
math
import
math
import
os
import
os
from
io
import
open
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
...
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment