Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
269c9638
Commit
269c9638
authored
Apr 08, 2021
by
Sylvain Gugger
Browse files
Merge branch 'master' of github.com:huggingface/transformers
parents
d31c7b10
c2e0fd52
Changes
46
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
192 additions
and
55 deletions
+192
-55
.circleci/config.yml
.circleci/config.yml
+2
-2
.github/workflows/self-scheduled.yml
.github/workflows/self-scheduled.yml
+2
-2
docs/source/main_classes/trainer.rst
docs/source/main_classes/trainer.rst
+18
-2
docs/source/testing.rst
docs/source/testing.rst
+20
-3
examples/language-modeling/run_mlm.py
examples/language-modeling/run_mlm.py
+6
-1
setup.py
setup.py
+26
-15
src/transformers/data/data_collator.py
src/transformers/data/data_collator.py
+9
-4
src/transformers/dependency_versions_check.py
src/transformers/dependency_versions_check.py
+5
-1
src/transformers/dependency_versions_table.py
src/transformers/dependency_versions_table.py
+8
-3
src/transformers/integrations.py
src/transformers/integrations.py
+7
-5
src/transformers/modeling_flax_utils.py
src/transformers/modeling_flax_utils.py
+1
-0
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+1
-0
src/transformers/models/auto/__init__.py
src/transformers/models/auto/__init__.py
+2
-0
src/transformers/models/auto/auto_factory.py
src/transformers/models/auto/auto_factory.py
+35
-4
src/transformers/models/auto/configuration_auto.py
src/transformers/models/auto/configuration_auto.py
+14
-5
src/transformers/models/auto/modeling_auto.py
src/transformers/models/auto/modeling_auto.py
+2
-1
src/transformers/models/auto/modeling_tf_auto.py
src/transformers/models/auto/modeling_tf_auto.py
+2
-1
src/transformers/testing_utils.py
src/transformers/testing_utils.py
+22
-0
src/transformers/trainer.py
src/transformers/trainer.py
+4
-6
src/transformers/training_args.py
src/transformers/training_args.py
+6
-0
No files found.
.circleci/config.yml
View file @
269c9638
...
@@ -348,7 +348,7 @@ jobs:
...
@@ -348,7 +348,7 @@ jobs:
-
v0.4-{{ checksum "setup.py" }}
-
v0.4-{{ checksum "setup.py" }}
-
run
:
sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
-
run
:
sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
-
run
:
pip install --upgrade pip
-
run
:
pip install --upgrade pip
-
run
:
pip install ."[
all,
docs]"
-
run
:
pip install ."[docs]"
-
save_cache
:
-
save_cache
:
key
:
v0.4-build_doc-{{ checksum "setup.py" }}
key
:
v0.4-build_doc-{{ checksum "setup.py" }}
paths
:
paths
:
...
@@ -370,7 +370,7 @@ jobs:
...
@@ -370,7 +370,7 @@ jobs:
keys
:
keys
:
-
v0.4-deploy_doc-{{ checksum "setup.py" }}
-
v0.4-deploy_doc-{{ checksum "setup.py" }}
-
v0.4-{{ checksum "setup.py" }}
-
v0.4-{{ checksum "setup.py" }}
-
run
:
pip install ."[
all,
docs]"
-
run
:
pip install ."[docs]"
-
save_cache
:
-
save_cache
:
key
:
v0.4-deploy_doc-{{ checksum "setup.py" }}
key
:
v0.4-deploy_doc-{{ checksum "setup.py" }}
paths
:
paths
:
...
...
.github/workflows/self-scheduled.yml
View file @
269c9638
...
@@ -33,7 +33,7 @@ jobs:
...
@@ -33,7 +33,7 @@ jobs:
run
:
|
run
:
|
apt -y update && apt install -y libsndfile1-dev
apt -y update && apt install -y libsndfile1-dev
pip install --upgrade pip
pip install --upgrade pip
pip install .[sklearn,testing,onnxruntime,sentencepiece,speech]
pip install .[sklearn,testing,onnxruntime,sentencepiece,speech
,deepspeed
]
-
name
:
Are GPUs recognized by our DL frameworks
-
name
:
Are GPUs recognized by our DL frameworks
run
:
|
run
:
|
...
@@ -155,7 +155,7 @@ jobs:
...
@@ -155,7 +155,7 @@ jobs:
run
:
|
run
:
|
apt -y update && apt install -y libsndfile1-dev
apt -y update && apt install -y libsndfile1-dev
pip install --upgrade pip
pip install --upgrade pip
pip install .[sklearn,testing,onnxruntime,sentencepiece,speech]
pip install .[sklearn,testing,onnxruntime,sentencepiece,speech
,deepspeed,fairscale
]
-
name
:
Are GPUs recognized by our DL frameworks
-
name
:
Are GPUs recognized by our DL frameworks
run
:
|
run
:
|
...
...
docs/source/main_classes/trainer.rst
View file @
269c9638
...
@@ -274,6 +274,14 @@ Install the library via pypi:
...
@@ -274,6 +274,14 @@ Install the library via pypi:
pip install fairscale
pip install fairscale
or via ``transformers``'
``
extras
``:
..
code
-
block
::
bash
pip
install
transformers
[
fairscale
]
(
will
become
available
starting
from
``
transformers
==
4.6.0
``)
or
find
more
details
on
`
the
FairScale
's GitHub page <https://github.com/facebookresearch/fairscale/#installation>`__.
or
find
more
details
on
`
the
FairScale
's GitHub page <https://github.com/facebookresearch/fairscale/#installation>`__.
If you'
re
still
struggling
with
the
build
,
first
make
sure
to
read
:
ref
:`
zero
-
install
-
notes
`.
If you'
re
still
struggling
with
the
build
,
first
make
sure
to
read
:
ref
:`
zero
-
install
-
notes
`.
...
@@ -419,6 +427,14 @@ Install the library via pypi:
...
@@ -419,6 +427,14 @@ Install the library via pypi:
pip
install
deepspeed
pip
install
deepspeed
or
via
``
transformers
``
' ``extras``:
.. code-block:: bash
pip install transformers[deepspeed]
(will become available starting from ``transformers==4.6.0``)
or find more details on `the DeepSpeed'
s
GitHub
page
<
https
://
github
.
com
/
microsoft
/
deepspeed
#
installation
>`
__
and
or find more details on `the DeepSpeed'
s
GitHub
page
<
https
://
github
.
com
/
microsoft
/
deepspeed
#
installation
>`
__
and
`
advanced
install
<
https
://
www
.
deepspeed
.
ai
/
tutorials
/
advanced
-
install
/>`
__
.
`
advanced
install
<
https
://
www
.
deepspeed
.
ai
/
tutorials
/
advanced
-
install
/>`
__
.
...
@@ -525,7 +541,7 @@ Here is an example of running ``run_translation.py`` under DeepSpeed deploying a
...
@@ -525,7 +541,7 @@ Here is an example of running ``run_translation.py`` under DeepSpeed deploying a
..
code
-
block
::
bash
..
code
-
block
::
bash
deepspeed
examples
/
seq2seq
/
run_translation
.
py
\
deepspeed
examples
/
seq2seq
/
run_translation
.
py
\
--
deepspeed
examples
/
tests
/
deepspeed
/
ds_config
.
json
\
--
deepspeed
tests
/
deepspeed
/
ds_config
.
json
\
--
model_name_or_path
t5
-
small
--
per_device_train_batch_size
1
\
--
model_name_or_path
t5
-
small
--
per_device_train_batch_size
1
\
--
output_dir
output_dir
--
overwrite_output_dir
--
fp16
\
--
output_dir
output_dir
--
overwrite_output_dir
--
fp16
\
--
do_train
--
max_train_samples
500
--
num_train_epochs
1
\
--
do_train
--
max_train_samples
500
--
num_train_epochs
1
\
...
@@ -550,7 +566,7 @@ To deploy DeepSpeed with one GPU adjust the :class:`~transformers.Trainer` comma
...
@@ -550,7 +566,7 @@ To deploy DeepSpeed with one GPU adjust the :class:`~transformers.Trainer` comma
..
code
-
block
::
bash
..
code
-
block
::
bash
deepspeed
--
num_gpus
=
1
examples
/
seq2seq
/
run_translation
.
py
\
deepspeed
--
num_gpus
=
1
examples
/
seq2seq
/
run_translation
.
py
\
--
deepspeed
examples
/
tests
/
deepspeed
/
ds_config
.
json
\
--
deepspeed
tests
/
deepspeed
/
ds_config
.
json
\
--
model_name_or_path
t5
-
small
--
per_device_train_batch_size
1
\
--
model_name_or_path
t5
-
small
--
per_device_train_batch_size
1
\
--
output_dir
output_dir
--
overwrite_output_dir
--
fp16
\
--
output_dir
output_dir
--
overwrite_output_dir
--
fp16
\
--
do_train
--
max_train_samples
500
--
num_train_epochs
1
\
--
do_train
--
max_train_samples
500
--
num_train_epochs
1
\
...
...
docs/source/testing.rst
View file @
269c9638
..
..
Copyright 2020 The HuggingFace Team. All rights reserved.
Copyright 2020 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
...
@@ -388,7 +388,7 @@ For a single or a group of tests via ``pytest`` (after ``pip install pytest-pspe
...
@@ -388,7 +388,7 @@ For a single or a group of tests via ``pytest`` (after ``pip install pytest-pspe
.. code-block:: bash
.. code-block:: bash
pytest --pspec tests/test_optimization.py
pytest --pspec tests/test_optimization.py
...
@@ -672,7 +672,7 @@ and it will list:
...
@@ -672,7 +672,7 @@ and it will list:
test_this2.py::test_floor[integer-1-1.0]
test_this2.py::test_floor[integer-1-1.0]
test_this2.py::test_floor[negative--1.5--2.0]
test_this2.py::test_floor[negative--1.5--2.0]
test_this2.py::test_floor[large fraction-1.6-1]
test_this2.py::test_floor[large fraction-1.6-1]
So now you can run just the specific test:
So now you can run just the specific test:
...
@@ -795,6 +795,23 @@ leave any data in there.
...
@@ -795,6 +795,23 @@ leave any data in there.
otherwise.
otherwise.
Temporary sys.path override
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
If you need to temporary override ``sys.path`` to import from another test for example, you can use the
``ExtendSysPath`` context manager. Example:
.. code-block:: python
import os
from transformers.testing_utils import ExtendSysPath
bindir = os.path.abspath(os.path.dirname(__file__))
with ExtendSysPath(f"{bindir}/.."):
from test_trainer import TrainerIntegrationCommon # noqa
Skipping tests
Skipping tests
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...
...
examples/language-modeling/run_mlm.py
View file @
269c9638
...
@@ -422,7 +422,12 @@ def main():
...
@@ -422,7 +422,12 @@ def main():
# Data collator
# Data collator
# This one will take care of randomly masking the tokens.
# This one will take care of randomly masking the tokens.
data_collator
=
DataCollatorForLanguageModeling
(
tokenizer
=
tokenizer
,
mlm_probability
=
data_args
.
mlm_probability
)
pad_to_multiple_of_8
=
data_args
.
line_by_line
and
training_args
.
fp16
and
not
data_args
.
pad_to_max_length
data_collator
=
DataCollatorForLanguageModeling
(
tokenizer
=
tokenizer
,
mlm_probability
=
data_args
.
mlm_probability
,
pad_to_multiple_of
=
8
if
pad_to_multiple_of_8
else
None
,
)
# Initialize our Trainer
# Initialize our Trainer
trainer
=
Trainer
(
trainer
=
Trainer
(
...
...
setup.py
View file @
269c9638
...
@@ -19,7 +19,7 @@ To create the package for pypi.
...
@@ -19,7 +19,7 @@ To create the package for pypi.
1. Run `make pre-release` (or `make pre-patch` for a patch release) then run `make fix-copies` to fix the index of the
1. Run `make pre-release` (or `make pre-patch` for a patch release) then run `make fix-copies` to fix the index of the
documentation.
documentation.
2. Run Tests for Amazon Sagemaker. The documentation is located in `./tests/sagemaker/README.md`, otherwise @philschmid.
2. Run Tests for Amazon Sagemaker. The documentation is located in `./tests/sagemaker/README.md`, otherwise @philschmid.
3. Unpin specific versions from setup.py that use a git install.
3. Unpin specific versions from setup.py that use a git install.
...
@@ -85,11 +85,14 @@ if stale_egg_info.exists():
...
@@ -85,11 +85,14 @@ if stale_egg_info.exists():
# 1. all dependencies should be listed here with their version requirements if any
# 1. all dependencies should be listed here with their version requirements if any
# 2. once modified, run: `make deps_table_update` to update src/transformers/dependency_versions_table.py
# 2. once modified, run: `make deps_table_update` to update src/transformers/dependency_versions_table.py
_deps
=
[
_deps
=
[
"Pillow"
,
"black>=20.8b1"
,
"black>=20.8b1"
,
"cookiecutter==1.7.2"
,
"cookiecutter==1.7.2"
,
"dataclasses"
,
"dataclasses"
,
"datasets"
,
"datasets"
,
"deepspeed>0.3.13"
,
"docutils==0.16.0"
,
"docutils==0.16.0"
,
"fairscale>0.3"
,
"faiss-cpu"
,
"faiss-cpu"
,
"fastapi"
,
"fastapi"
,
"filelock"
,
"filelock"
,
...
@@ -102,13 +105,13 @@ _deps = [
...
@@ -102,13 +105,13 @@ _deps = [
"jax>=0.2.8"
,
"jax>=0.2.8"
,
"jaxlib>=0.1.59"
,
"jaxlib>=0.1.59"
,
"keras2onnx"
,
"keras2onnx"
,
"nltk"
,
"numpy>=1.17"
,
"numpy>=1.17"
,
"onnxconverter-common"
,
"onnxconverter-common"
,
"onnxruntime-tools>=1.4.2"
,
"onnxruntime-tools>=1.4.2"
,
"onnxruntime>=1.4.0"
,
"onnxruntime>=1.4.0"
,
"packaging"
,
"packaging"
,
"parameterized"
,
"parameterized"
,
"Pillow"
,
"protobuf"
,
"protobuf"
,
"psutil"
,
"psutil"
,
"pydantic"
,
"pydantic"
,
...
@@ -119,15 +122,18 @@ _deps = [
...
@@ -119,15 +122,18 @@ _deps = [
"recommonmark"
,
"recommonmark"
,
"regex!=2019.12.17"
,
"regex!=2019.12.17"
,
"requests"
,
"requests"
,
"rouge-score"
,
"sacrebleu>=1.4.12"
,
"sacremoses"
,
"sacremoses"
,
"sagemaker>=2.31.0"
,
"scikit-learn"
,
"scikit-learn"
,
"sentencepiece==0.1.91"
,
"sentencepiece==0.1.91"
,
"soundfile"
,
"soundfile"
,
"sphinx-copybutton"
,
"sphinx-copybutton"
,
"sphinx-markdown-tables"
,
"sphinx-markdown-tables"
,
"sphinx-rtd-theme==0.4.3"
,
# sphinx-rtd-theme==0.5.0 introduced big changes in the style.
"sphinx-rtd-theme==0.4.3"
,
# sphinx-rtd-theme==0.5.0 introduced big changes in the style.
"sphinxext-opengraph==0.4.1"
,
"sphinx==3.2.1"
,
"sphinx==3.2.1"
,
"sphinxext-opengraph==0.4.1"
,
"starlette"
,
"starlette"
,
"tensorflow-cpu>=2.3"
,
"tensorflow-cpu>=2.3"
,
"tensorflow>=2.3"
,
"tensorflow>=2.3"
,
...
@@ -139,7 +145,6 @@ _deps = [
...
@@ -139,7 +145,6 @@ _deps = [
"unidic>=1.0.2"
,
"unidic>=1.0.2"
,
"unidic_lite>=1.0.7"
,
"unidic_lite>=1.0.7"
,
"uvicorn"
,
"uvicorn"
,
"sagemaker>=2.31.0"
,
]
]
...
@@ -230,6 +235,8 @@ extras["onnx"] = deps_list("onnxconverter-common", "keras2onnx") + extras["onnxr
...
@@ -230,6 +235,8 @@ extras["onnx"] = deps_list("onnxconverter-common", "keras2onnx") + extras["onnxr
extras
[
"modelcreation"
]
=
deps_list
(
"cookiecutter"
)
extras
[
"modelcreation"
]
=
deps_list
(
"cookiecutter"
)
extras
[
"sagemaker"
]
=
deps_list
(
"sagemaker"
)
extras
[
"sagemaker"
]
=
deps_list
(
"sagemaker"
)
extras
[
"deepspeed"
]
=
deps_list
(
"deepspeed"
)
extras
[
"fairscale"
]
=
deps_list
(
"fairscale"
)
extras
[
"serving"
]
=
deps_list
(
"pydantic"
,
"uvicorn"
,
"fastapi"
,
"starlette"
)
extras
[
"serving"
]
=
deps_list
(
"pydantic"
,
"uvicorn"
,
"fastapi"
,
"starlette"
)
extras
[
"speech"
]
=
deps_list
(
"soundfile"
,
"torchaudio"
)
extras
[
"speech"
]
=
deps_list
(
"soundfile"
,
"torchaudio"
)
...
@@ -238,20 +245,12 @@ extras["vision"] = deps_list("Pillow")
...
@@ -238,20 +245,12 @@ extras["vision"] = deps_list("Pillow")
extras
[
"sentencepiece"
]
=
deps_list
(
"sentencepiece"
,
"protobuf"
)
extras
[
"sentencepiece"
]
=
deps_list
(
"sentencepiece"
,
"protobuf"
)
extras
[
"testing"
]
=
(
extras
[
"testing"
]
=
(
deps_list
(
deps_list
(
"pytest"
,
"pytest-xdist"
,
"timeout-decorator"
,
"parameterized"
,
"psutil"
,
"datasets"
,
"pytest-sugar"
,
"black"
"pytest"
,
"pytest-xdist"
,
"timeout-decorator"
,
"parameterized"
,
"psutil"
,
"datasets"
,
"pytest-sugar"
,
"black"
,
"sacrebleu"
,
"rouge-score"
,
"nltk"
)
)
+
extras
[
"retrieval"
]
+
extras
[
"retrieval"
]
+
extras
[
"modelcreation"
]
+
extras
[
"modelcreation"
]
)
)
extras
[
"docs"
]
=
deps_list
(
"docutils"
,
"recommonmark"
,
"sphinx"
,
"sphinx-markdown-tables"
,
"sphinx-rtd-theme"
,
"sphinx-copybutton"
,
"sphinxext-opengraph"
,
)
extras
[
"quality"
]
=
deps_list
(
"black"
,
"isort"
,
"flake8"
)
extras
[
"quality"
]
=
deps_list
(
"black"
,
"isort"
,
"flake8"
)
extras
[
"all"
]
=
(
extras
[
"all"
]
=
(
...
@@ -264,12 +263,24 @@ extras["all"] = (
...
@@ -264,12 +263,24 @@ extras["all"] = (
+
extras
[
"vision"
]
+
extras
[
"vision"
]
)
)
extras
[
"docs_specific"
]
=
deps_list
(
"docutils"
,
"recommonmark"
,
"sphinx"
,
"sphinx-markdown-tables"
,
"sphinx-rtd-theme"
,
"sphinx-copybutton"
,
"sphinxext-opengraph"
,
)
# "docs" needs "all" to resolve all the references
extras
[
"docs"
]
=
extras
[
"all"
]
+
extras
[
"docs_specific"
]
extras
[
"dev"
]
=
(
extras
[
"dev"
]
=
(
extras
[
"all"
]
extras
[
"all"
]
+
extras
[
"testing"
]
+
extras
[
"testing"
]
+
extras
[
"quality"
]
+
extras
[
"quality"
]
+
extras
[
"ja"
]
+
extras
[
"ja"
]
+
extras
[
"docs"
]
+
extras
[
"docs
_specific
"
]
+
extras
[
"sklearn"
]
+
extras
[
"sklearn"
]
+
extras
[
"modelcreation"
]
+
extras
[
"modelcreation"
]
)
)
...
...
src/transformers/data/data_collator.py
View file @
269c9638
...
@@ -192,7 +192,7 @@ class DataCollatorForTokenClassification:
...
@@ -192,7 +192,7 @@ class DataCollatorForTokenClassification:
return
batch
return
batch
def
_collate_batch
(
examples
,
tokenizer
):
def
_collate_batch
(
examples
,
tokenizer
,
pad_to_multiple_of
:
Optional
[
int
]
=
None
):
"""Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
"""Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
# Tensorize if necessary.
# Tensorize if necessary.
if
isinstance
(
examples
[
0
],
(
list
,
tuple
)):
if
isinstance
(
examples
[
0
],
(
list
,
tuple
)):
...
@@ -201,7 +201,7 @@ def _collate_batch(examples, tokenizer):
...
@@ -201,7 +201,7 @@ def _collate_batch(examples, tokenizer):
# Check if padding is necessary.
# Check if padding is necessary.
length_of_first
=
examples
[
0
].
size
(
0
)
length_of_first
=
examples
[
0
].
size
(
0
)
are_tensors_same_length
=
all
(
x
.
size
(
0
)
==
length_of_first
for
x
in
examples
)
are_tensors_same_length
=
all
(
x
.
size
(
0
)
==
length_of_first
for
x
in
examples
)
if
are_tensors_same_length
:
if
are_tensors_same_length
and
(
pad_to_multiple_of
is
None
or
length_of_first
%
pad_to_multiple_of
==
0
)
:
return
torch
.
stack
(
examples
,
dim
=
0
)
return
torch
.
stack
(
examples
,
dim
=
0
)
# If yes, check if we have a `pad_token`.
# If yes, check if we have a `pad_token`.
...
@@ -213,6 +213,8 @@ def _collate_batch(examples, tokenizer):
...
@@ -213,6 +213,8 @@ def _collate_batch(examples, tokenizer):
# Creating the full tensor and filling it with our data.
# Creating the full tensor and filling it with our data.
max_length
=
max
(
x
.
size
(
0
)
for
x
in
examples
)
max_length
=
max
(
x
.
size
(
0
)
for
x
in
examples
)
if
pad_to_multiple_of
is
not
None
and
(
max_length
%
pad_to_multiple_of
!=
0
):
max_length
=
((
max_length
//
pad_to_multiple_of
)
+
1
)
*
pad_to_multiple_of
result
=
examples
[
0
].
new_full
([
len
(
examples
),
max_length
],
tokenizer
.
pad_token_id
)
result
=
examples
[
0
].
new_full
([
len
(
examples
),
max_length
],
tokenizer
.
pad_token_id
)
for
i
,
example
in
enumerate
(
examples
):
for
i
,
example
in
enumerate
(
examples
):
if
tokenizer
.
padding_side
==
"right"
:
if
tokenizer
.
padding_side
==
"right"
:
...
@@ -311,6 +313,8 @@ class DataCollatorForLanguageModeling:
...
@@ -311,6 +313,8 @@ class DataCollatorForLanguageModeling:
non-masked tokens and the value to predict for the masked token.
non-masked tokens and the value to predict for the masked token.
mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
The probability with which to (randomly) mask tokens in the input, when :obj:`mlm` is set to :obj:`True`.
The probability with which to (randomly) mask tokens in the input, when :obj:`mlm` is set to :obj:`True`.
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
.. note::
.. note::
...
@@ -323,6 +327,7 @@ class DataCollatorForLanguageModeling:
...
@@ -323,6 +327,7 @@ class DataCollatorForLanguageModeling:
tokenizer
:
PreTrainedTokenizerBase
tokenizer
:
PreTrainedTokenizerBase
mlm
:
bool
=
True
mlm
:
bool
=
True
mlm_probability
:
float
=
0.15
mlm_probability
:
float
=
0.15
pad_to_multiple_of
:
Optional
[
int
]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
self
.
mlm
and
self
.
tokenizer
.
mask_token
is
None
:
if
self
.
mlm
and
self
.
tokenizer
.
mask_token
is
None
:
...
@@ -336,9 +341,9 @@ class DataCollatorForLanguageModeling:
...
@@ -336,9 +341,9 @@ class DataCollatorForLanguageModeling:
)
->
Dict
[
str
,
torch
.
Tensor
]:
)
->
Dict
[
str
,
torch
.
Tensor
]:
# Handle dict or lists with proper padding and conversion to tensor.
# Handle dict or lists with proper padding and conversion to tensor.
if
isinstance
(
examples
[
0
],
(
dict
,
BatchEncoding
)):
if
isinstance
(
examples
[
0
],
(
dict
,
BatchEncoding
)):
batch
=
self
.
tokenizer
.
pad
(
examples
,
return_tensors
=
"pt"
)
batch
=
self
.
tokenizer
.
pad
(
examples
,
return_tensors
=
"pt"
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
else
:
else
:
batch
=
{
"input_ids"
:
_collate_batch
(
examples
,
self
.
tokenizer
)}
batch
=
{
"input_ids"
:
_collate_batch
(
examples
,
self
.
tokenizer
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)}
# If special token mask has been preprocessed, pop it from the dict.
# If special token mask has been preprocessed, pop it from the dict.
special_tokens_mask
=
batch
.
pop
(
"special_tokens_mask"
,
None
)
special_tokens_mask
=
batch
.
pop
(
"special_tokens_mask"
,
None
)
...
...
src/transformers/dependency_versions_check.py
View file @
269c9638
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
import
sys
import
sys
from
.dependency_versions_table
import
deps
from
.dependency_versions_table
import
deps
from
.utils.versions
import
require_version_core
from
.utils.versions
import
require_version
,
require_version_core
# define which module versions we always want to check at run time
# define which module versions we always want to check at run time
...
@@ -41,3 +41,7 @@ for pkg in pkgs_to_check_at_runtime:
...
@@ -41,3 +41,7 @@ for pkg in pkgs_to_check_at_runtime:
require_version_core
(
deps
[
pkg
])
require_version_core
(
deps
[
pkg
])
else
:
else
:
raise
ValueError
(
f
"can't find
{
pkg
}
in
{
deps
.
keys
()
}
, check dependency_versions_table.py"
)
raise
ValueError
(
f
"can't find
{
pkg
}
in
{
deps
.
keys
()
}
, check dependency_versions_table.py"
)
def
dep_version_check
(
pkg
,
hint
=
None
):
require_version
(
deps
[
pkg
],
hint
)
src/transformers/dependency_versions_table.py
View file @
269c9638
...
@@ -2,11 +2,14 @@
...
@@ -2,11 +2,14 @@
# 1. modify the `_deps` dict in setup.py
# 1. modify the `_deps` dict in setup.py
# 2. run `make deps_table_update``
# 2. run `make deps_table_update``
deps
=
{
deps
=
{
"Pillow"
:
"Pillow"
,
"black"
:
"black>=20.8b1"
,
"black"
:
"black>=20.8b1"
,
"cookiecutter"
:
"cookiecutter==1.7.2"
,
"cookiecutter"
:
"cookiecutter==1.7.2"
,
"dataclasses"
:
"dataclasses"
,
"dataclasses"
:
"dataclasses"
,
"datasets"
:
"datasets"
,
"datasets"
:
"datasets"
,
"deepspeed"
:
"deepspeed>0.3.13"
,
"docutils"
:
"docutils==0.16.0"
,
"docutils"
:
"docutils==0.16.0"
,
"fairscale"
:
"fairscale>0.3"
,
"faiss-cpu"
:
"faiss-cpu"
,
"faiss-cpu"
:
"faiss-cpu"
,
"fastapi"
:
"fastapi"
,
"fastapi"
:
"fastapi"
,
"filelock"
:
"filelock"
,
"filelock"
:
"filelock"
,
...
@@ -19,13 +22,13 @@ deps = {
...
@@ -19,13 +22,13 @@ deps = {
"jax"
:
"jax>=0.2.8"
,
"jax"
:
"jax>=0.2.8"
,
"jaxlib"
:
"jaxlib>=0.1.59"
,
"jaxlib"
:
"jaxlib>=0.1.59"
,
"keras2onnx"
:
"keras2onnx"
,
"keras2onnx"
:
"keras2onnx"
,
"nltk"
:
"nltk"
,
"numpy"
:
"numpy>=1.17"
,
"numpy"
:
"numpy>=1.17"
,
"onnxconverter-common"
:
"onnxconverter-common"
,
"onnxconverter-common"
:
"onnxconverter-common"
,
"onnxruntime-tools"
:
"onnxruntime-tools>=1.4.2"
,
"onnxruntime-tools"
:
"onnxruntime-tools>=1.4.2"
,
"onnxruntime"
:
"onnxruntime>=1.4.0"
,
"onnxruntime"
:
"onnxruntime>=1.4.0"
,
"packaging"
:
"packaging"
,
"packaging"
:
"packaging"
,
"parameterized"
:
"parameterized"
,
"parameterized"
:
"parameterized"
,
"Pillow"
:
"Pillow"
,
"protobuf"
:
"protobuf"
,
"protobuf"
:
"protobuf"
,
"psutil"
:
"psutil"
,
"psutil"
:
"psutil"
,
"pydantic"
:
"pydantic"
,
"pydantic"
:
"pydantic"
,
...
@@ -36,15 +39,18 @@ deps = {
...
@@ -36,15 +39,18 @@ deps = {
"recommonmark"
:
"recommonmark"
,
"recommonmark"
:
"recommonmark"
,
"regex"
:
"regex!=2019.12.17"
,
"regex"
:
"regex!=2019.12.17"
,
"requests"
:
"requests"
,
"requests"
:
"requests"
,
"rouge-score"
:
"rouge-score"
,
"sacrebleu"
:
"sacrebleu>=1.4.12"
,
"sacremoses"
:
"sacremoses"
,
"sacremoses"
:
"sacremoses"
,
"sagemaker"
:
"sagemaker>=2.31.0"
,
"scikit-learn"
:
"scikit-learn"
,
"scikit-learn"
:
"scikit-learn"
,
"sentencepiece"
:
"sentencepiece==0.1.91"
,
"sentencepiece"
:
"sentencepiece==0.1.91"
,
"soundfile"
:
"soundfile"
,
"soundfile"
:
"soundfile"
,
"sphinx-copybutton"
:
"sphinx-copybutton"
,
"sphinx-copybutton"
:
"sphinx-copybutton"
,
"sphinx-markdown-tables"
:
"sphinx-markdown-tables"
,
"sphinx-markdown-tables"
:
"sphinx-markdown-tables"
,
"sphinx-rtd-theme"
:
"sphinx-rtd-theme==0.4.3"
,
"sphinx-rtd-theme"
:
"sphinx-rtd-theme==0.4.3"
,
"sphinxext-opengraph"
:
"sphinxext-opengraph==0.4.1"
,
"sphinx"
:
"sphinx==3.2.1"
,
"sphinx"
:
"sphinx==3.2.1"
,
"sphinxext-opengraph"
:
"sphinxext-opengraph==0.4.1"
,
"starlette"
:
"starlette"
,
"starlette"
:
"starlette"
,
"tensorflow-cpu"
:
"tensorflow-cpu>=2.3"
,
"tensorflow-cpu"
:
"tensorflow-cpu>=2.3"
,
"tensorflow"
:
"tensorflow>=2.3"
,
"tensorflow"
:
"tensorflow>=2.3"
,
...
@@ -56,5 +62,4 @@ deps = {
...
@@ -56,5 +62,4 @@ deps = {
"unidic"
:
"unidic>=1.0.2"
,
"unidic"
:
"unidic>=1.0.2"
,
"unidic_lite"
:
"unidic_lite>=1.0.7"
,
"unidic_lite"
:
"unidic_lite>=1.0.7"
,
"uvicorn"
:
"uvicorn"
,
"uvicorn"
:
"uvicorn"
,
"sagemaker"
:
"sagemaker>=2.31.0"
,
}
}
src/transformers/integrations.py
View file @
269c9638
...
@@ -24,8 +24,8 @@ import tempfile
...
@@ -24,8 +24,8 @@ import tempfile
from
copy
import
deepcopy
from
copy
import
deepcopy
from
pathlib
import
Path
from
pathlib
import
Path
from
.dependency_versions_check
import
dep_version_check
from
.utils
import
logging
from
.utils
import
logging
from
.utils.versions
import
require_version
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
...
@@ -324,7 +324,7 @@ def deepspeed_parse_config(ds_config):
...
@@ -324,7 +324,7 @@ def deepspeed_parse_config(ds_config):
If it's already a dict, return a copy of it, so that we can freely modify it.
If it's already a dict, return a copy of it, so that we can freely modify it.
"""
"""
require
_version
(
"deepspeed
>0.3.13
"
)
dep
_version
_check
(
"deepspeed"
)
if
isinstance
(
ds_config
,
dict
):
if
isinstance
(
ds_config
,
dict
):
# Don't modify user's data should they want to reuse it (e.g. in tests), because once we
# Don't modify user's data should they want to reuse it (e.g. in tests), because once we
...
@@ -604,9 +604,11 @@ class TensorBoardCallback(TrainerCallback):
...
@@ -604,9 +604,11 @@ class TensorBoardCallback(TrainerCallback):
self
.
tb_writer
.
add_hparams
(
args
.
to_sanitized_dict
(),
metric_dict
=
{})
self
.
tb_writer
.
add_hparams
(
args
.
to_sanitized_dict
(),
metric_dict
=
{})
def
on_log
(
self
,
args
,
state
,
control
,
logs
=
None
,
**
kwargs
):
def
on_log
(
self
,
args
,
state
,
control
,
logs
=
None
,
**
kwargs
):
if
state
.
is_world_process_zero
:
if
not
state
.
is_world_process_zero
:
if
self
.
tb_writer
is
None
:
return
self
.
_init_summary_writer
(
args
)
if
self
.
tb_writer
is
None
:
self
.
_init_summary_writer
(
args
)
if
self
.
tb_writer
is
not
None
:
if
self
.
tb_writer
is
not
None
:
logs
=
rewrite_logs
(
logs
)
logs
=
rewrite_logs
(
logs
)
...
...
src/transformers/modeling_flax_utils.py
View file @
269c9638
...
@@ -387,6 +387,7 @@ class FlaxPreTrainedModel(ABC):
...
@@ -387,6 +387,7 @@ class FlaxPreTrainedModel(ABC):
# get abs dir
# get abs dir
save_directory
=
os
.
path
.
abspath
(
save_directory
)
save_directory
=
os
.
path
.
abspath
(
save_directory
)
# save config as well
# save config as well
self
.
config
.
architectures
=
[
self
.
__class__
.
__name__
[
4
:]]
self
.
config
.
save_pretrained
(
save_directory
)
self
.
config
.
save_pretrained
(
save_directory
)
# save model
# save model
...
...
src/transformers/modeling_tf_utils.py
View file @
269c9638
...
@@ -1037,6 +1037,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
...
@@ -1037,6 +1037,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
logger
.
info
(
f
"Saved model created in
{
saved_model_dir
}
"
)
logger
.
info
(
f
"Saved model created in
{
saved_model_dir
}
"
)
# Save configuration file
# Save configuration file
self
.
config
.
architectures
=
[
self
.
__class__
.
__name__
[
2
:]]
self
.
config
.
save_pretrained
(
save_directory
)
self
.
config
.
save_pretrained
(
save_directory
)
# If we save using the predefined names, we can load using `from_pretrained`
# If we save using the predefined names, we can load using `from_pretrained`
...
...
src/transformers/models/auto/__init__.py
View file @
269c9638
...
@@ -22,6 +22,7 @@ from ...file_utils import _BaseLazyModule, is_flax_available, is_tf_available, i
...
@@ -22,6 +22,7 @@ from ...file_utils import _BaseLazyModule, is_flax_available, is_tf_available, i
_import_structure
=
{
_import_structure
=
{
"auto_factory"
:
[
"get_values"
],
"configuration_auto"
:
[
"ALL_PRETRAINED_CONFIG_ARCHIVE_MAP"
,
"CONFIG_MAPPING"
,
"MODEL_NAMES_MAPPING"
,
"AutoConfig"
],
"configuration_auto"
:
[
"ALL_PRETRAINED_CONFIG_ARCHIVE_MAP"
,
"CONFIG_MAPPING"
,
"MODEL_NAMES_MAPPING"
,
"AutoConfig"
],
"feature_extraction_auto"
:
[
"FEATURE_EXTRACTOR_MAPPING"
,
"AutoFeatureExtractor"
],
"feature_extraction_auto"
:
[
"FEATURE_EXTRACTOR_MAPPING"
,
"AutoFeatureExtractor"
],
"tokenization_auto"
:
[
"TOKENIZER_MAPPING"
,
"AutoTokenizer"
],
"tokenization_auto"
:
[
"TOKENIZER_MAPPING"
,
"AutoTokenizer"
],
...
@@ -104,6 +105,7 @@ if is_flax_available():
...
@@ -104,6 +105,7 @@ if is_flax_available():
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
.auto_factory
import
get_values
from
.configuration_auto
import
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
,
CONFIG_MAPPING
,
MODEL_NAMES_MAPPING
,
AutoConfig
from
.configuration_auto
import
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
,
CONFIG_MAPPING
,
MODEL_NAMES_MAPPING
,
AutoConfig
from
.feature_extraction_auto
import
FEATURE_EXTRACTOR_MAPPING
,
AutoFeatureExtractor
from
.feature_extraction_auto
import
FEATURE_EXTRACTOR_MAPPING
,
AutoFeatureExtractor
from
.tokenization_auto
import
TOKENIZER_MAPPING
,
AutoTokenizer
from
.tokenization_auto
import
TOKENIZER_MAPPING
,
AutoTokenizer
...
...
src/transformers/models/auto/auto_factory.py
View file @
269c9638
...
@@ -328,6 +328,26 @@ FROM_PRETRAINED_FLAX_DOCSTRING = """
...
@@ -328,6 +328,26 @@ FROM_PRETRAINED_FLAX_DOCSTRING = """
"""
"""
def
_get_model_class
(
config
,
model_mapping
):
supported_models
=
model_mapping
[
type
(
config
)]
if
not
isinstance
(
supported_models
,
(
list
,
tuple
)):
return
supported_models
name_to_model
=
{
model
.
__name__
:
model
for
model
in
supported_models
}
architectures
=
getattr
(
config
,
"architectures"
,
[])
for
arch
in
architectures
:
if
arch
in
name_to_model
:
return
name_to_model
[
arch
]
elif
f
"TF
{
arch
}
"
in
name_to_model
:
return
name_to_model
[
f
"TF
{
arch
}
"
]
elif
f
"Flax
{
arch
}
"
in
name_to_model
:
return
name_to_model
[
f
"Flax
{
arch
}
"
]
# If not architecture is set in the config or match the supported models, the first element of the tuple is the
# defaults.
return
supported_models
[
0
]
class
_BaseAutoModelClass
:
class
_BaseAutoModelClass
:
# Base class for auto models.
# Base class for auto models.
_model_mapping
=
None
_model_mapping
=
None
...
@@ -341,7 +361,8 @@ class _BaseAutoModelClass:
...
@@ -341,7 +361,8 @@ class _BaseAutoModelClass:
def
from_config
(
cls
,
config
,
**
kwargs
):
def
from_config
(
cls
,
config
,
**
kwargs
):
if
type
(
config
)
in
cls
.
_model_mapping
.
keys
():
if
type
(
config
)
in
cls
.
_model_mapping
.
keys
():
return
cls
.
_model_mapping
[
type
(
config
)](
config
,
**
kwargs
)
model_class
=
_get_model_class
(
config
,
cls
.
_model_mapping
)
return
model_class
(
config
,
**
kwargs
)
raise
ValueError
(
raise
ValueError
(
f
"Unrecognized configuration class
{
config
.
__class__
}
for this kind of AutoModel:
{
cls
.
__name__
}
.
\n
"
f
"Unrecognized configuration class
{
config
.
__class__
}
for this kind of AutoModel:
{
cls
.
__name__
}
.
\n
"
f
"Model type should be one of
{
', '
.
join
(
c
.
__name__
for
c
in
cls
.
_model_mapping
.
keys
())
}
."
f
"Model type should be one of
{
', '
.
join
(
c
.
__name__
for
c
in
cls
.
_model_mapping
.
keys
())
}
."
...
@@ -356,9 +377,8 @@ class _BaseAutoModelClass:
...
@@ -356,9 +377,8 @@ class _BaseAutoModelClass:
)
)
if
type
(
config
)
in
cls
.
_model_mapping
.
keys
():
if
type
(
config
)
in
cls
.
_model_mapping
.
keys
():
return
cls
.
_model_mapping
[
type
(
config
)].
from_pretrained
(
model_class
=
_get_model_class
(
config
,
cls
.
_model_mapping
)
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
return
model_class
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
)
raise
ValueError
(
raise
ValueError
(
f
"Unrecognized configuration class
{
config
.
__class__
}
for this kind of AutoModel:
{
cls
.
__name__
}
.
\n
"
f
"Unrecognized configuration class
{
config
.
__class__
}
for this kind of AutoModel:
{
cls
.
__name__
}
.
\n
"
f
"Model type should be one of
{
', '
.
join
(
c
.
__name__
for
c
in
cls
.
_model_mapping
.
keys
())
}
."
f
"Model type should be one of
{
', '
.
join
(
c
.
__name__
for
c
in
cls
.
_model_mapping
.
keys
())
}
."
...
@@ -418,3 +438,14 @@ def auto_class_factory(name, model_mapping, checkpoint_for_example="bert-base-ca
...
@@ -418,3 +438,14 @@ def auto_class_factory(name, model_mapping, checkpoint_for_example="bert-base-ca
from_pretrained
=
replace_list_option_in_docstrings
(
model_mapping
)(
from_pretrained
)
from_pretrained
=
replace_list_option_in_docstrings
(
model_mapping
)(
from_pretrained
)
new_class
.
from_pretrained
=
classmethod
(
from_pretrained
)
new_class
.
from_pretrained
=
classmethod
(
from_pretrained
)
return
new_class
return
new_class
def
get_values
(
model_mapping
):
result
=
[]
for
model
in
model_mapping
.
values
():
if
isinstance
(
model
,
(
list
,
tuple
)):
result
+=
list
(
model
)
else
:
result
.
append
(
model
)
return
result
src/transformers/models/auto/configuration_auto.py
View file @
269c9638
...
@@ -247,29 +247,38 @@ MODEL_NAMES_MAPPING = OrderedDict(
...
@@ -247,29 +247,38 @@ MODEL_NAMES_MAPPING = OrderedDict(
)
)
def
_get_class_name
(
model_class
):
if
isinstance
(
model_class
,
(
list
,
tuple
)):
return
" or "
.
join
([
f
":class:`~transformers.
{
c
.
__name__
}
`"
for
c
in
model_class
])
return
f
":class:`~transformers.
{
model_class
.
__name__
}
`"
def
_list_model_options
(
indent
,
config_to_class
=
None
,
use_model_types
=
True
):
def
_list_model_options
(
indent
,
config_to_class
=
None
,
use_model_types
=
True
):
if
config_to_class
is
None
and
not
use_model_types
:
if
config_to_class
is
None
and
not
use_model_types
:
raise
ValueError
(
"Using `use_model_types=False` requires a `config_to_class` dictionary."
)
raise
ValueError
(
"Using `use_model_types=False` requires a `config_to_class` dictionary."
)
if
use_model_types
:
if
use_model_types
:
if
config_to_class
is
None
:
if
config_to_class
is
None
:
model_type_to_name
=
{
model_type
:
config
.
__name__
for
model_type
,
config
in
CONFIG_MAPPING
.
items
()}
model_type_to_name
=
{
model_type
:
f
":class:`~transformers.
{
config
.
__name__
}
`"
for
model_type
,
config
in
CONFIG_MAPPING
.
items
()
}
else
:
else
:
model_type_to_name
=
{
model_type_to_name
=
{
model_type
:
config_to_class
[
config
]
.
__name__
model_type
:
_get_class_name
(
config_to_class
[
config
]
)
for
model_type
,
config
in
CONFIG_MAPPING
.
items
()
for
model_type
,
config
in
CONFIG_MAPPING
.
items
()
if
config
in
config_to_class
if
config
in
config_to_class
}
}
lines
=
[
lines
=
[
f
"
{
indent
}
- **
{
model_type
}
** --
:class:`~transformers.
{
model_type_to_name
[
model_type
]
}
`
(
{
MODEL_NAMES_MAPPING
[
model_type
]
}
model)"
f
"
{
indent
}
- **
{
model_type
}
** --
{
model_type_to_name
[
model_type
]
}
(
{
MODEL_NAMES_MAPPING
[
model_type
]
}
model)"
for
model_type
in
sorted
(
model_type_to_name
.
keys
())
for
model_type
in
sorted
(
model_type_to_name
.
keys
())
]
]
else
:
else
:
config_to_name
=
{
config
.
__name__
:
clas
.
_
_name
__
for
config
,
clas
in
config_to_class
.
items
()}
config_to_name
=
{
config
.
__name__
:
_get_
clas
s
_name
(
clas
)
for
config
,
clas
in
config_to_class
.
items
()}
config_to_model_name
=
{
config_to_model_name
=
{
config
.
__name__
:
MODEL_NAMES_MAPPING
[
model_type
]
for
model_type
,
config
in
CONFIG_MAPPING
.
items
()
config
.
__name__
:
MODEL_NAMES_MAPPING
[
model_type
]
for
model_type
,
config
in
CONFIG_MAPPING
.
items
()
}
}
lines
=
[
lines
=
[
f
"
{
indent
}
- :class:`~transformers.
{
config_name
}
` configuration class:
:class:`~transformers.
{
config_to_name
[
config_name
]
}
`
(
{
config_to_model_name
[
config_name
]
}
model)"
f
"
{
indent
}
- :class:`~transformers.
{
config_name
}
` configuration class:
{
config_to_name
[
config_name
]
}
(
{
config_to_model_name
[
config_name
]
}
model)"
for
config_name
in
sorted
(
config_to_name
.
keys
())
for
config_name
in
sorted
(
config_to_name
.
keys
())
]
]
return
"
\n
"
.
join
(
lines
)
return
"
\n
"
.
join
(
lines
)
...
...
src/transformers/models/auto/modeling_auto.py
View file @
269c9638
...
@@ -124,6 +124,7 @@ from ..flaubert.modeling_flaubert import (
...
@@ -124,6 +124,7 @@ from ..flaubert.modeling_flaubert import (
)
)
from
..fsmt.modeling_fsmt
import
FSMTForConditionalGeneration
,
FSMTModel
from
..fsmt.modeling_fsmt
import
FSMTForConditionalGeneration
,
FSMTModel
from
..funnel.modeling_funnel
import
(
from
..funnel.modeling_funnel
import
(
FunnelBaseModel
,
FunnelForMaskedLM
,
FunnelForMaskedLM
,
FunnelForMultipleChoice
,
FunnelForMultipleChoice
,
FunnelForPreTraining
,
FunnelForPreTraining
,
...
@@ -377,7 +378,7 @@ MODEL_MAPPING = OrderedDict(
...
@@ -377,7 +378,7 @@ MODEL_MAPPING = OrderedDict(
(
CTRLConfig
,
CTRLModel
),
(
CTRLConfig
,
CTRLModel
),
(
ElectraConfig
,
ElectraModel
),
(
ElectraConfig
,
ElectraModel
),
(
ReformerConfig
,
ReformerModel
),
(
ReformerConfig
,
ReformerModel
),
(
FunnelConfig
,
FunnelModel
),
(
FunnelConfig
,
(
FunnelModel
,
FunnelBaseModel
)
),
(
LxmertConfig
,
LxmertModel
),
(
LxmertConfig
,
LxmertModel
),
(
BertGenerationConfig
,
BertGenerationEncoder
),
(
BertGenerationConfig
,
BertGenerationEncoder
),
(
DebertaConfig
,
DebertaModel
),
(
DebertaConfig
,
DebertaModel
),
...
...
src/transformers/models/auto/modeling_tf_auto.py
View file @
269c9638
...
@@ -91,6 +91,7 @@ from ..flaubert.modeling_tf_flaubert import (
...
@@ -91,6 +91,7 @@ from ..flaubert.modeling_tf_flaubert import (
TFFlaubertWithLMHeadModel
,
TFFlaubertWithLMHeadModel
,
)
)
from
..funnel.modeling_tf_funnel
import
(
from
..funnel.modeling_tf_funnel
import
(
TFFunnelBaseModel
,
TFFunnelForMaskedLM
,
TFFunnelForMaskedLM
,
TFFunnelForMultipleChoice
,
TFFunnelForMultipleChoice
,
TFFunnelForPreTraining
,
TFFunnelForPreTraining
,
...
@@ -242,7 +243,7 @@ TF_MODEL_MAPPING = OrderedDict(
...
@@ -242,7 +243,7 @@ TF_MODEL_MAPPING = OrderedDict(
(
XLMConfig
,
TFXLMModel
),
(
XLMConfig
,
TFXLMModel
),
(
CTRLConfig
,
TFCTRLModel
),
(
CTRLConfig
,
TFCTRLModel
),
(
ElectraConfig
,
TFElectraModel
),
(
ElectraConfig
,
TFElectraModel
),
(
FunnelConfig
,
TFFunnelModel
),
(
FunnelConfig
,
(
TFFunnelModel
,
TFFunnelBaseModel
)
),
(
DPRConfig
,
TFDPRQuestionEncoder
),
(
DPRConfig
,
TFDPRQuestionEncoder
),
(
MPNetConfig
,
TFMPNetModel
),
(
MPNetConfig
,
TFMPNetModel
),
(
BartConfig
,
TFBartModel
),
(
BartConfig
,
TFBartModel
),
...
...
src/transformers/testing_utils.py
View file @
269c9638
...
@@ -24,6 +24,7 @@ import unittest
...
@@ -24,6 +24,7 @@ import unittest
from
distutils.util
import
strtobool
from
distutils.util
import
strtobool
from
io
import
StringIO
from
io
import
StringIO
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Iterator
,
Union
from
.file_utils
import
(
from
.file_utils
import
(
is_datasets_available
,
is_datasets_available
,
...
@@ -621,6 +622,27 @@ class CaptureLogger:
...
@@ -621,6 +622,27 @@ class CaptureLogger:
return
f
"captured:
{
self
.
out
}
\n
"
return
f
"captured:
{
self
.
out
}
\n
"
@
contextlib
.
contextmanager
# adapted from https://stackoverflow.com/a/64789046/9201239
def
ExtendSysPath
(
path
:
Union
[
str
,
os
.
PathLike
])
->
Iterator
[
None
]:
"""
Temporary add given path to `sys.path`.
Usage ::
with ExtendSysPath('/path/to/dir'):
mymodule = importlib.import_module('mymodule')
"""
path
=
os
.
fspath
(
path
)
try
:
sys
.
path
.
insert
(
0
,
path
)
yield
finally
:
sys
.
path
.
remove
(
path
)
class
TestCasePlus
(
unittest
.
TestCase
):
class
TestCasePlus
(
unittest
.
TestCase
):
"""
"""
This class extends `unittest.TestCase` with additional features.
This class extends `unittest.TestCase` with additional features.
...
...
src/transformers/trainer.py
View file @
269c9638
...
@@ -54,6 +54,7 @@ from torch.utils.data.distributed import DistributedSampler
...
@@ -54,6 +54,7 @@ from torch.utils.data.distributed import DistributedSampler
from
torch.utils.data.sampler
import
RandomSampler
,
SequentialSampler
from
torch.utils.data.sampler
import
RandomSampler
,
SequentialSampler
from
.data.data_collator
import
DataCollator
,
DataCollatorWithPadding
,
default_data_collator
from
.data.data_collator
import
DataCollator
,
DataCollatorWithPadding
,
default_data_collator
from
.dependency_versions_check
import
dep_version_check
from
.file_utils
import
(
from
.file_utils
import
(
WEIGHTS_NAME
,
WEIGHTS_NAME
,
is_apex_available
,
is_apex_available
,
...
@@ -139,17 +140,14 @@ if is_torch_tpu_available():
...
@@ -139,17 +140,14 @@ if is_torch_tpu_available():
import
torch_xla.distributed.parallel_loader
as
pl
import
torch_xla.distributed.parallel_loader
as
pl
if
is_fairscale_available
():
if
is_fairscale_available
():
dep_version_check
(
"fairscale"
)
import
fairscale
import
fairscale
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FullyShardedDDP
from
fairscale.nn.data_parallel
import
ShardedDataParallel
as
ShardedDDP
from
fairscale.nn.data_parallel
import
ShardedDataParallel
as
ShardedDDP
from
fairscale.nn.wrap
import
auto_wrap
from
fairscale.optim
import
OSS
from
fairscale.optim
import
OSS
from
fairscale.optim.grad_scaler
import
ShardedGradScaler
from
fairscale.optim.grad_scaler
import
ShardedGradScaler
if
version
.
parse
(
fairscale
.
__version__
)
>=
version
.
parse
(
"0.3"
):
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FullyShardedDDP
from
fairscale.nn.wrap
import
auto_wrap
else
:
FullyShardedDDP
=
None
if
is_sagemaker_dp_enabled
():
if
is_sagemaker_dp_enabled
():
import
smdistributed.dataparallel.torch.distributed
as
dist
import
smdistributed.dataparallel.torch.distributed
as
dist
from
smdistributed.dataparallel.torch.parallel.distributed
import
DistributedDataParallel
as
DDP
from
smdistributed.dataparallel.torch.parallel.distributed
import
DistributedDataParallel
as
DDP
...
...
src/transformers/training_args.py
View file @
269c9638
...
@@ -531,6 +531,12 @@ class TrainingArguments:
...
@@ -531,6 +531,12 @@ class TrainingArguments:
)
)
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then).
# This needs to happen before any call to self.device or self.n_gpu.
env_local_rank
=
int
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
-
1
))
if
env_local_rank
!=
-
1
and
env_local_rank
!=
self
.
local_rank
:
self
.
local_rank
=
env_local_rank
# expand paths, if not os.makedirs("~/bar") will make directory
# expand paths, if not os.makedirs("~/bar") will make directory
# in the current directory instead of the actual home
# in the current directory instead of the actual home
# see https://github.com/huggingface/transformers/issues/10628
# see https://github.com/huggingface/transformers/issues/10628
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment