Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
269c9638
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "45ab8bf60e5c2af912006035f5568be92c0c99c9"
Commit
269c9638
authored
Apr 08, 2021
by
Sylvain Gugger
Browse files
Merge branch 'master' of github.com:huggingface/transformers
parents
d31c7b10
c2e0fd52
Changes
46
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
192 additions
and
55 deletions
+192
-55
.circleci/config.yml
.circleci/config.yml
+2
-2
.github/workflows/self-scheduled.yml
.github/workflows/self-scheduled.yml
+2
-2
docs/source/main_classes/trainer.rst
docs/source/main_classes/trainer.rst
+18
-2
docs/source/testing.rst
docs/source/testing.rst
+20
-3
examples/language-modeling/run_mlm.py
examples/language-modeling/run_mlm.py
+6
-1
setup.py
setup.py
+26
-15
src/transformers/data/data_collator.py
src/transformers/data/data_collator.py
+9
-4
src/transformers/dependency_versions_check.py
src/transformers/dependency_versions_check.py
+5
-1
src/transformers/dependency_versions_table.py
src/transformers/dependency_versions_table.py
+8
-3
src/transformers/integrations.py
src/transformers/integrations.py
+7
-5
src/transformers/modeling_flax_utils.py
src/transformers/modeling_flax_utils.py
+1
-0
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+1
-0
src/transformers/models/auto/__init__.py
src/transformers/models/auto/__init__.py
+2
-0
src/transformers/models/auto/auto_factory.py
src/transformers/models/auto/auto_factory.py
+35
-4
src/transformers/models/auto/configuration_auto.py
src/transformers/models/auto/configuration_auto.py
+14
-5
src/transformers/models/auto/modeling_auto.py
src/transformers/models/auto/modeling_auto.py
+2
-1
src/transformers/models/auto/modeling_tf_auto.py
src/transformers/models/auto/modeling_tf_auto.py
+2
-1
src/transformers/testing_utils.py
src/transformers/testing_utils.py
+22
-0
src/transformers/trainer.py
src/transformers/trainer.py
+4
-6
src/transformers/training_args.py
src/transformers/training_args.py
+6
-0
No files found.
.circleci/config.yml
View file @
269c9638
...
@@ -348,7 +348,7 @@ jobs:
...
@@ -348,7 +348,7 @@ jobs:
-
v0.4-{{ checksum "setup.py" }}
-
v0.4-{{ checksum "setup.py" }}
-
run
:
sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
-
run
:
sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
-
run
:
pip install --upgrade pip
-
run
:
pip install --upgrade pip
-
run
:
pip install ."[
all,
docs]"
-
run
:
pip install ."[docs]"
-
save_cache
:
-
save_cache
:
key
:
v0.4-build_doc-{{ checksum "setup.py" }}
key
:
v0.4-build_doc-{{ checksum "setup.py" }}
paths
:
paths
:
...
@@ -370,7 +370,7 @@ jobs:
...
@@ -370,7 +370,7 @@ jobs:
keys
:
keys
:
-
v0.4-deploy_doc-{{ checksum "setup.py" }}
-
v0.4-deploy_doc-{{ checksum "setup.py" }}
-
v0.4-{{ checksum "setup.py" }}
-
v0.4-{{ checksum "setup.py" }}
-
run
:
pip install ."[
all,
docs]"
-
run
:
pip install ."[docs]"
-
save_cache
:
-
save_cache
:
key
:
v0.4-deploy_doc-{{ checksum "setup.py" }}
key
:
v0.4-deploy_doc-{{ checksum "setup.py" }}
paths
:
paths
:
...
...
.github/workflows/self-scheduled.yml
View file @
269c9638
...
@@ -33,7 +33,7 @@ jobs:
...
@@ -33,7 +33,7 @@ jobs:
run
:
|
run
:
|
apt -y update && apt install -y libsndfile1-dev
apt -y update && apt install -y libsndfile1-dev
pip install --upgrade pip
pip install --upgrade pip
pip install .[sklearn,testing,onnxruntime,sentencepiece,speech]
pip install .[sklearn,testing,onnxruntime,sentencepiece,speech
,deepspeed
]
-
name
:
Are GPUs recognized by our DL frameworks
-
name
:
Are GPUs recognized by our DL frameworks
run
:
|
run
:
|
...
@@ -155,7 +155,7 @@ jobs:
...
@@ -155,7 +155,7 @@ jobs:
run
:
|
run
:
|
apt -y update && apt install -y libsndfile1-dev
apt -y update && apt install -y libsndfile1-dev
pip install --upgrade pip
pip install --upgrade pip
pip install .[sklearn,testing,onnxruntime,sentencepiece,speech]
pip install .[sklearn,testing,onnxruntime,sentencepiece,speech
,deepspeed,fairscale
]
-
name
:
Are GPUs recognized by our DL frameworks
-
name
:
Are GPUs recognized by our DL frameworks
run
:
|
run
:
|
...
...
docs/source/main_classes/trainer.rst
View file @
269c9638
...
@@ -274,6 +274,14 @@ Install the library via pypi:
...
@@ -274,6 +274,14 @@ Install the library via pypi:
pip install fairscale
pip install fairscale
or via ``transformers``'
``
extras
``:
..
code
-
block
::
bash
pip
install
transformers
[
fairscale
]
(
will
become
available
starting
from
``
transformers
==
4.6.0
``)
or
find
more
details
on
`
the
FairScale
's GitHub page <https://github.com/facebookresearch/fairscale/#installation>`__.
or
find
more
details
on
`
the
FairScale
's GitHub page <https://github.com/facebookresearch/fairscale/#installation>`__.
If you'
re
still
struggling
with
the
build
,
first
make
sure
to
read
:
ref
:`
zero
-
install
-
notes
`.
If you'
re
still
struggling
with
the
build
,
first
make
sure
to
read
:
ref
:`
zero
-
install
-
notes
`.
...
@@ -419,6 +427,14 @@ Install the library via pypi:
...
@@ -419,6 +427,14 @@ Install the library via pypi:
pip
install
deepspeed
pip
install
deepspeed
or
via
``
transformers
``
' ``extras``:
.. code-block:: bash
pip install transformers[deepspeed]
(will become available starting from ``transformers==4.6.0``)
or find more details on `the DeepSpeed'
s
GitHub
page
<
https
://
github
.
com
/
microsoft
/
deepspeed
#
installation
>`
__
and
or find more details on `the DeepSpeed'
s
GitHub
page
<
https
://
github
.
com
/
microsoft
/
deepspeed
#
installation
>`
__
and
`
advanced
install
<
https
://
www
.
deepspeed
.
ai
/
tutorials
/
advanced
-
install
/>`
__
.
`
advanced
install
<
https
://
www
.
deepspeed
.
ai
/
tutorials
/
advanced
-
install
/>`
__
.
...
@@ -525,7 +541,7 @@ Here is an example of running ``run_translation.py`` under DeepSpeed deploying a
...
@@ -525,7 +541,7 @@ Here is an example of running ``run_translation.py`` under DeepSpeed deploying a
..
code
-
block
::
bash
..
code
-
block
::
bash
deepspeed
examples
/
seq2seq
/
run_translation
.
py
\
deepspeed
examples
/
seq2seq
/
run_translation
.
py
\
--
deepspeed
examples
/
tests
/
deepspeed
/
ds_config
.
json
\
--
deepspeed
tests
/
deepspeed
/
ds_config
.
json
\
--
model_name_or_path
t5
-
small
--
per_device_train_batch_size
1
\
--
model_name_or_path
t5
-
small
--
per_device_train_batch_size
1
\
--
output_dir
output_dir
--
overwrite_output_dir
--
fp16
\
--
output_dir
output_dir
--
overwrite_output_dir
--
fp16
\
--
do_train
--
max_train_samples
500
--
num_train_epochs
1
\
--
do_train
--
max_train_samples
500
--
num_train_epochs
1
\
...
@@ -550,7 +566,7 @@ To deploy DeepSpeed with one GPU adjust the :class:`~transformers.Trainer` comma
...
@@ -550,7 +566,7 @@ To deploy DeepSpeed with one GPU adjust the :class:`~transformers.Trainer` comma
..
code
-
block
::
bash
..
code
-
block
::
bash
deepspeed
--
num_gpus
=
1
examples
/
seq2seq
/
run_translation
.
py
\
deepspeed
--
num_gpus
=
1
examples
/
seq2seq
/
run_translation
.
py
\
--
deepspeed
examples
/
tests
/
deepspeed
/
ds_config
.
json
\
--
deepspeed
tests
/
deepspeed
/
ds_config
.
json
\
--
model_name_or_path
t5
-
small
--
per_device_train_batch_size
1
\
--
model_name_or_path
t5
-
small
--
per_device_train_batch_size
1
\
--
output_dir
output_dir
--
overwrite_output_dir
--
fp16
\
--
output_dir
output_dir
--
overwrite_output_dir
--
fp16
\
--
do_train
--
max_train_samples
500
--
num_train_epochs
1
\
--
do_train
--
max_train_samples
500
--
num_train_epochs
1
\
...
...
docs/source/testing.rst
View file @
269c9638
..
..
Copyright 2020 The HuggingFace Team. All rights reserved.
Copyright 2020 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
...
@@ -388,7 +388,7 @@ For a single or a group of tests via ``pytest`` (after ``pip install pytest-pspe
...
@@ -388,7 +388,7 @@ For a single or a group of tests via ``pytest`` (after ``pip install pytest-pspe
.. code-block:: bash
.. code-block:: bash
pytest --pspec tests/test_optimization.py
pytest --pspec tests/test_optimization.py
...
@@ -672,7 +672,7 @@ and it will list:
...
@@ -672,7 +672,7 @@ and it will list:
test_this2.py::test_floor[integer-1-1.0]
test_this2.py::test_floor[integer-1-1.0]
test_this2.py::test_floor[negative--1.5--2.0]
test_this2.py::test_floor[negative--1.5--2.0]
test_this2.py::test_floor[large fraction-1.6-1]
test_this2.py::test_floor[large fraction-1.6-1]
So now you can run just the specific test:
So now you can run just the specific test:
...
@@ -795,6 +795,23 @@ leave any data in there.
...
@@ -795,6 +795,23 @@ leave any data in there.
otherwise.
otherwise.
Temporary sys.path override
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
If you need to temporary override ``sys.path`` to import from another test for example, you can use the
``ExtendSysPath`` context manager. Example:
.. code-block:: python
import os
from transformers.testing_utils import ExtendSysPath
bindir = os.path.abspath(os.path.dirname(__file__))
with ExtendSysPath(f"{bindir}/.."):
from test_trainer import TrainerIntegrationCommon # noqa
Skipping tests
Skipping tests
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...
...
examples/language-modeling/run_mlm.py
View file @
269c9638
...
@@ -422,7 +422,12 @@ def main():
...
@@ -422,7 +422,12 @@ def main():
# Data collator
# Data collator
# This one will take care of randomly masking the tokens.
# This one will take care of randomly masking the tokens.
data_collator
=
DataCollatorForLanguageModeling
(
tokenizer
=
tokenizer
,
mlm_probability
=
data_args
.
mlm_probability
)
pad_to_multiple_of_8
=
data_args
.
line_by_line
and
training_args
.
fp16
and
not
data_args
.
pad_to_max_length
data_collator
=
DataCollatorForLanguageModeling
(
tokenizer
=
tokenizer
,
mlm_probability
=
data_args
.
mlm_probability
,
pad_to_multiple_of
=
8
if
pad_to_multiple_of_8
else
None
,
)
# Initialize our Trainer
# Initialize our Trainer
trainer
=
Trainer
(
trainer
=
Trainer
(
...
...
setup.py
View file @
269c9638
...
@@ -19,7 +19,7 @@ To create the package for pypi.
...
@@ -19,7 +19,7 @@ To create the package for pypi.
1. Run `make pre-release` (or `make pre-patch` for a patch release) then run `make fix-copies` to fix the index of the
1. Run `make pre-release` (or `make pre-patch` for a patch release) then run `make fix-copies` to fix the index of the
documentation.
documentation.
2. Run Tests for Amazon Sagemaker. The documentation is located in `./tests/sagemaker/README.md`, otherwise @philschmid.
2. Run Tests for Amazon Sagemaker. The documentation is located in `./tests/sagemaker/README.md`, otherwise @philschmid.
3. Unpin specific versions from setup.py that use a git install.
3. Unpin specific versions from setup.py that use a git install.
...
@@ -85,11 +85,14 @@ if stale_egg_info.exists():
...
@@ -85,11 +85,14 @@ if stale_egg_info.exists():
# 1. all dependencies should be listed here with their version requirements if any
# 1. all dependencies should be listed here with their version requirements if any
# 2. once modified, run: `make deps_table_update` to update src/transformers/dependency_versions_table.py
# 2. once modified, run: `make deps_table_update` to update src/transformers/dependency_versions_table.py
_deps
=
[
_deps
=
[
"Pillow"
,
"black>=20.8b1"
,
"black>=20.8b1"
,
"cookiecutter==1.7.2"
,
"cookiecutter==1.7.2"
,
"dataclasses"
,
"dataclasses"
,
"datasets"
,
"datasets"
,
"deepspeed>0.3.13"
,
"docutils==0.16.0"
,
"docutils==0.16.0"
,
"fairscale>0.3"
,
"faiss-cpu"
,
"faiss-cpu"
,
"fastapi"
,
"fastapi"
,
"filelock"
,
"filelock"
,
...
@@ -102,13 +105,13 @@ _deps = [
...
@@ -102,13 +105,13 @@ _deps = [
"jax>=0.2.8"
,
"jax>=0.2.8"
,
"jaxlib>=0.1.59"
,
"jaxlib>=0.1.59"
,
"keras2onnx"
,
"keras2onnx"
,
"nltk"
,
"numpy>=1.17"
,
"numpy>=1.17"
,
"onnxconverter-common"
,
"onnxconverter-common"
,
"onnxruntime-tools>=1.4.2"
,
"onnxruntime-tools>=1.4.2"
,
"onnxruntime>=1.4.0"
,
"onnxruntime>=1.4.0"
,
"packaging"
,
"packaging"
,
"parameterized"
,
"parameterized"
,
"Pillow"
,
"protobuf"
,
"protobuf"
,
"psutil"
,
"psutil"
,
"pydantic"
,
"pydantic"
,
...
@@ -119,15 +122,18 @@ _deps = [
...
@@ -119,15 +122,18 @@ _deps = [
"recommonmark"
,
"recommonmark"
,
"regex!=2019.12.17"
,
"regex!=2019.12.17"
,
"requests"
,
"requests"
,
"rouge-score"
,
"sacrebleu>=1.4.12"
,
"sacremoses"
,
"sacremoses"
,
"sagemaker>=2.31.0"
,
"scikit-learn"
,
"scikit-learn"
,
"sentencepiece==0.1.91"
,
"sentencepiece==0.1.91"
,
"soundfile"
,
"soundfile"
,
"sphinx-copybutton"
,
"sphinx-copybutton"
,
"sphinx-markdown-tables"
,
"sphinx-markdown-tables"
,
"sphinx-rtd-theme==0.4.3"
,
# sphinx-rtd-theme==0.5.0 introduced big changes in the style.
"sphinx-rtd-theme==0.4.3"
,
# sphinx-rtd-theme==0.5.0 introduced big changes in the style.
"sphinxext-opengraph==0.4.1"
,
"sphinx==3.2.1"
,
"sphinx==3.2.1"
,
"sphinxext-opengraph==0.4.1"
,
"starlette"
,
"starlette"
,
"tensorflow-cpu>=2.3"
,
"tensorflow-cpu>=2.3"
,
"tensorflow>=2.3"
,
"tensorflow>=2.3"
,
...
@@ -139,7 +145,6 @@ _deps = [
...
@@ -139,7 +145,6 @@ _deps = [
"unidic>=1.0.2"
,
"unidic>=1.0.2"
,
"unidic_lite>=1.0.7"
,
"unidic_lite>=1.0.7"
,
"uvicorn"
,
"uvicorn"
,
"sagemaker>=2.31.0"
,
]
]
...
@@ -230,6 +235,8 @@ extras["onnx"] = deps_list("onnxconverter-common", "keras2onnx") + extras["onnxr
...
@@ -230,6 +235,8 @@ extras["onnx"] = deps_list("onnxconverter-common", "keras2onnx") + extras["onnxr
extras
[
"modelcreation"
]
=
deps_list
(
"cookiecutter"
)
extras
[
"modelcreation"
]
=
deps_list
(
"cookiecutter"
)
extras
[
"sagemaker"
]
=
deps_list
(
"sagemaker"
)
extras
[
"sagemaker"
]
=
deps_list
(
"sagemaker"
)
extras
[
"deepspeed"
]
=
deps_list
(
"deepspeed"
)
extras
[
"fairscale"
]
=
deps_list
(
"fairscale"
)
extras
[
"serving"
]
=
deps_list
(
"pydantic"
,
"uvicorn"
,
"fastapi"
,
"starlette"
)
extras
[
"serving"
]
=
deps_list
(
"pydantic"
,
"uvicorn"
,
"fastapi"
,
"starlette"
)
extras
[
"speech"
]
=
deps_list
(
"soundfile"
,
"torchaudio"
)
extras
[
"speech"
]
=
deps_list
(
"soundfile"
,
"torchaudio"
)
...
@@ -238,20 +245,12 @@ extras["vision"] = deps_list("Pillow")
...
@@ -238,20 +245,12 @@ extras["vision"] = deps_list("Pillow")
extras
[
"sentencepiece"
]
=
deps_list
(
"sentencepiece"
,
"protobuf"
)
extras
[
"sentencepiece"
]
=
deps_list
(
"sentencepiece"
,
"protobuf"
)
extras
[
"testing"
]
=
(
extras
[
"testing"
]
=
(
deps_list
(
deps_list
(
"pytest"
,
"pytest-xdist"
,
"timeout-decorator"
,
"parameterized"
,
"psutil"
,
"datasets"
,
"pytest-sugar"
,
"black"
"pytest"
,
"pytest-xdist"
,
"timeout-decorator"
,
"parameterized"
,
"psutil"
,
"datasets"
,
"pytest-sugar"
,
"black"
,
"sacrebleu"
,
"rouge-score"
,
"nltk"
)
)
+
extras
[
"retrieval"
]
+
extras
[
"retrieval"
]
+
extras
[
"modelcreation"
]
+
extras
[
"modelcreation"
]
)
)
extras
[
"docs"
]
=
deps_list
(
"docutils"
,
"recommonmark"
,
"sphinx"
,
"sphinx-markdown-tables"
,
"sphinx-rtd-theme"
,
"sphinx-copybutton"
,
"sphinxext-opengraph"
,
)
extras
[
"quality"
]
=
deps_list
(
"black"
,
"isort"
,
"flake8"
)
extras
[
"quality"
]
=
deps_list
(
"black"
,
"isort"
,
"flake8"
)
extras
[
"all"
]
=
(
extras
[
"all"
]
=
(
...
@@ -264,12 +263,24 @@ extras["all"] = (
...
@@ -264,12 +263,24 @@ extras["all"] = (
+
extras
[
"vision"
]
+
extras
[
"vision"
]
)
)
extras
[
"docs_specific"
]
=
deps_list
(
"docutils"
,
"recommonmark"
,
"sphinx"
,
"sphinx-markdown-tables"
,
"sphinx-rtd-theme"
,
"sphinx-copybutton"
,
"sphinxext-opengraph"
,
)
# "docs" needs "all" to resolve all the references
extras
[
"docs"
]
=
extras
[
"all"
]
+
extras
[
"docs_specific"
]
extras
[
"dev"
]
=
(
extras
[
"dev"
]
=
(
extras
[
"all"
]
extras
[
"all"
]
+
extras
[
"testing"
]
+
extras
[
"testing"
]
+
extras
[
"quality"
]
+
extras
[
"quality"
]
+
extras
[
"ja"
]
+
extras
[
"ja"
]
+
extras
[
"docs"
]
+
extras
[
"docs
_specific
"
]
+
extras
[
"sklearn"
]
+
extras
[
"sklearn"
]
+
extras
[
"modelcreation"
]
+
extras
[
"modelcreation"
]
)
)
...
...
src/transformers/data/data_collator.py
View file @
269c9638
...
@@ -192,7 +192,7 @@ class DataCollatorForTokenClassification:
...
@@ -192,7 +192,7 @@ class DataCollatorForTokenClassification:
return
batch
return
batch
def
_collate_batch
(
examples
,
tokenizer
):
def
_collate_batch
(
examples
,
tokenizer
,
pad_to_multiple_of
:
Optional
[
int
]
=
None
):
"""Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
"""Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
# Tensorize if necessary.
# Tensorize if necessary.
if
isinstance
(
examples
[
0
],
(
list
,
tuple
)):
if
isinstance
(
examples
[
0
],
(
list
,
tuple
)):
...
@@ -201,7 +201,7 @@ def _collate_batch(examples, tokenizer):
...
@@ -201,7 +201,7 @@ def _collate_batch(examples, tokenizer):
# Check if padding is necessary.
# Check if padding is necessary.
length_of_first
=
examples
[
0
].
size
(
0
)
length_of_first
=
examples
[
0
].
size
(
0
)
are_tensors_same_length
=
all
(
x
.
size
(
0
)
==
length_of_first
for
x
in
examples
)
are_tensors_same_length
=
all
(
x
.
size
(
0
)
==
length_of_first
for
x
in
examples
)
if
are_tensors_same_length
:
if
are_tensors_same_length
and
(
pad_to_multiple_of
is
None
or
length_of_first
%
pad_to_multiple_of
==
0
)
:
return
torch
.
stack
(
examples
,
dim
=
0
)
return
torch
.
stack
(
examples
,
dim
=
0
)
# If yes, check if we have a `pad_token`.
# If yes, check if we have a `pad_token`.
...
@@ -213,6 +213,8 @@ def _collate_batch(examples, tokenizer):
...
@@ -213,6 +213,8 @@ def _collate_batch(examples, tokenizer):
# Creating the full tensor and filling it with our data.
# Creating the full tensor and filling it with our data.
max_length
=
max
(
x
.
size
(
0
)
for
x
in
examples
)
max_length
=
max
(
x
.
size
(
0
)
for
x
in
examples
)
if
pad_to_multiple_of
is
not
None
and
(
max_length
%
pad_to_multiple_of
!=
0
):
max_length
=
((
max_length
//
pad_to_multiple_of
)
+
1
)
*
pad_to_multiple_of
result
=
examples
[
0
].
new_full
([
len
(
examples
),
max_length
],
tokenizer
.
pad_token_id
)
result
=
examples
[
0
].
new_full
([
len
(
examples
),
max_length
],
tokenizer
.
pad_token_id
)
for
i
,
example
in
enumerate
(
examples
):
for
i
,
example
in
enumerate
(
examples
):
if
tokenizer
.
padding_side
==
"right"
:
if
tokenizer
.
padding_side
==
"right"
:
...
@@ -311,6 +313,8 @@ class DataCollatorForLanguageModeling:
...
@@ -311,6 +313,8 @@ class DataCollatorForLanguageModeling:
non-masked tokens and the value to predict for the masked token.
non-masked tokens and the value to predict for the masked token.
mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
The probability with which to (randomly) mask tokens in the input, when :obj:`mlm` is set to :obj:`True`.
The probability with which to (randomly) mask tokens in the input, when :obj:`mlm` is set to :obj:`True`.
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
.. note::
.. note::
...
@@ -323,6 +327,7 @@ class DataCollatorForLanguageModeling:
...
@@ -323,6 +327,7 @@ class DataCollatorForLanguageModeling:
tokenizer
:
PreTrainedTokenizerBase
tokenizer
:
PreTrainedTokenizerBase
mlm
:
bool
=
True
mlm
:
bool
=
True
mlm_probability
:
float
=
0.15
mlm_probability
:
float
=
0.15
pad_to_multiple_of
:
Optional
[
int
]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
self
.
mlm
and
self
.
tokenizer
.
mask_token
is
None
:
if
self
.
mlm
and
self
.
tokenizer
.
mask_token
is
None
:
...
@@ -336,9 +341,9 @@ class DataCollatorForLanguageModeling:
...
@@ -336,9 +341,9 @@ class DataCollatorForLanguageModeling:
)
->
Dict
[
str
,
torch
.
Tensor
]:
)
->
Dict
[
str
,
torch
.
Tensor
]:
# Handle dict or lists with proper padding and conversion to tensor.
# Handle dict or lists with proper padding and conversion to tensor.
if
isinstance
(
examples
[
0
],
(
dict
,
BatchEncoding
)):
if
isinstance
(
examples
[
0
],
(
dict
,
BatchEncoding
)):
batch
=
self
.
tokenizer
.
pad
(
examples
,
return_tensors
=
"pt"
)
batch
=
self
.
tokenizer
.
pad
(
examples
,
return_tensors
=
"pt"
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
else
:
else
:
batch
=
{
"input_ids"
:
_collate_batch
(
examples
,
self
.
tokenizer
)}
batch
=
{
"input_ids"
:
_collate_batch
(
examples
,
self
.
tokenizer
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)}
# If special token mask has been preprocessed, pop it from the dict.
# If special token mask has been preprocessed, pop it from the dict.
special_tokens_mask
=
batch
.
pop
(
"special_tokens_mask"
,
None
)
special_tokens_mask
=
batch
.
pop
(
"special_tokens_mask"
,
None
)
...
...
src/transformers/dependency_versions_check.py
View file @
269c9638
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
import
sys
import
sys
from
.dependency_versions_table
import
deps
from
.dependency_versions_table
import
deps
from
.utils.versions
import
require_version_core
from
.utils.versions
import
require_version
,
require_version_core
# define which module versions we always want to check at run time
# define which module versions we always want to check at run time
...
@@ -41,3 +41,7 @@ for pkg in pkgs_to_check_at_runtime:
...
@@ -41,3 +41,7 @@ for pkg in pkgs_to_check_at_runtime:
require_version_core
(
deps
[
pkg
])
require_version_core
(
deps
[
pkg
])
else
:
else
:
raise
ValueError
(
f
"can't find
{
pkg
}
in
{
deps
.
keys
()
}
, check dependency_versions_table.py"
)
raise
ValueError
(
f
"can't find
{
pkg
}
in
{
deps
.
keys
()
}
, check dependency_versions_table.py"
)
def
dep_version_check
(
pkg
,
hint
=
None
):
require_version
(
deps
[
pkg
],
hint
)
src/transformers/dependency_versions_table.py
View file @
269c9638
...
@@ -2,11 +2,14 @@
...
@@ -2,11 +2,14 @@
# 1. modify the `_deps` dict in setup.py
# 1. modify the `_deps` dict in setup.py
# 2. run `make deps_table_update``
# 2. run `make deps_table_update``
deps
=
{
deps
=
{
"Pillow"
:
"Pillow"
,
"black"
:
"black>=20.8b1"
,
"black"
:
"black>=20.8b1"
,
"cookiecutter"
:
"cookiecutter==1.7.2"
,
"cookiecutter"
:
"cookiecutter==1.7.2"
,
"dataclasses"
:
"dataclasses"
,
"dataclasses"
:
"dataclasses"
,
"datasets"
:
"datasets"
,
"datasets"
:
"datasets"
,
"deepspeed"
:
"deepspeed>0.3.13"
,
"docutils"
:
"docutils==0.16.0"
,
"docutils"
:
"docutils==0.16.0"
,
"fairscale"
:
"fairscale>0.3"
,
"faiss-cpu"
:
"faiss-cpu"
,
"faiss-cpu"
:
"faiss-cpu"
,
"fastapi"
:
"fastapi"
,
"fastapi"
:
"fastapi"
,
"filelock"
:
"filelock"
,
"filelock"
:
"filelock"
,
...
@@ -19,13 +22,13 @@ deps = {
...
@@ -19,13 +22,13 @@ deps = {
"jax"
:
"jax>=0.2.8"
,
"jax"
:
"jax>=0.2.8"
,
"jaxlib"
:
"jaxlib>=0.1.59"
,
"jaxlib"
:
"jaxlib>=0.1.59"
,
"keras2onnx"
:
"keras2onnx"
,
"keras2onnx"
:
"keras2onnx"
,
"nltk"
:
"nltk"
,
"numpy"
:
"numpy>=1.17"
,
"numpy"
:
"numpy>=1.17"
,
"onnxconverter-common"
:
"onnxconverter-common"
,
"onnxconverter-common"
:
"onnxconverter-common"
,
"onnxruntime-tools"
:
"onnxruntime-tools>=1.4.2"
,
"onnxruntime-tools"
:
"onnxruntime-tools>=1.4.2"
,
"onnxruntime"
:
"onnxruntime>=1.4.0"
,
"onnxruntime"
:
"onnxruntime>=1.4.0"
,
"packaging"
:
"packaging"
,
"packaging"
:
"packaging"
,
"parameterized"
:
"parameterized"
,
"parameterized"
:
"parameterized"
,
"Pillow"
:
"Pillow"
,
"protobuf"
:
"protobuf"
,
"protobuf"
:
"protobuf"
,
"psutil"
:
"psutil"
,
"psutil"
:
"psutil"
,
"pydantic"
:
"pydantic"
,
"pydantic"
:
"pydantic"
,
...
@@ -36,15 +39,18 @@ deps = {
...
@@ -36,15 +39,18 @@ deps = {
"recommonmark"
:
"recommonmark"
,
"recommonmark"
:
"recommonmark"
,
"regex"
:
"regex!=2019.12.17"
,
"regex"
:
"regex!=2019.12.17"
,
"requests"
:
"requests"
,
"requests"
:
"requests"
,
"rouge-score"
:
"rouge-score"
,
"sacrebleu"
:
"sacrebleu>=1.4.12"
,
"sacremoses"
:
"sacremoses"
,
"sacremoses"
:
"sacremoses"
,
"sagemaker"
:
"sagemaker>=2.31.0"
,
"scikit-learn"
:
"scikit-learn"
,
"scikit-learn"
:
"scikit-learn"
,
"sentencepiece"
:
"sentencepiece==0.1.91"
,
"sentencepiece"
:
"sentencepiece==0.1.91"
,
"soundfile"
:
"soundfile"
,
"soundfile"
:
"soundfile"
,
"sphinx-copybutton"
:
"sphinx-copybutton"
,
"sphinx-copybutton"
:
"sphinx-copybutton"
,
"sphinx-markdown-tables"
:
"sphinx-markdown-tables"
,
"sphinx-markdown-tables"
:
"sphinx-markdown-tables"
,
"sphinx-rtd-theme"
:
"sphinx-rtd-theme==0.4.3"
,
"sphinx-rtd-theme"
:
"sphinx-rtd-theme==0.4.3"
,
"sphinxext-opengraph"
:
"sphinxext-opengraph==0.4.1"
,
"sphinx"
:
"sphinx==3.2.1"
,
"sphinx"
:
"sphinx==3.2.1"
,
"sphinxext-opengraph"
:
"sphinxext-opengraph==0.4.1"
,
"starlette"
:
"starlette"
,
"starlette"
:
"starlette"
,
"tensorflow-cpu"
:
"tensorflow-cpu>=2.3"
,
"tensorflow-cpu"
:
"tensorflow-cpu>=2.3"
,
"tensorflow"
:
"tensorflow>=2.3"
,
"tensorflow"
:
"tensorflow>=2.3"
,
...
@@ -56,5 +62,4 @@ deps = {
...
@@ -56,5 +62,4 @@ deps = {
"unidic"
:
"unidic>=1.0.2"
,
"unidic"
:
"unidic>=1.0.2"
,
"unidic_lite"
:
"unidic_lite>=1.0.7"
,
"unidic_lite"
:
"unidic_lite>=1.0.7"
,
"uvicorn"
:
"uvicorn"
,
"uvicorn"
:
"uvicorn"
,
"sagemaker"
:
"sagemaker>=2.31.0"
,
}
}
src/transformers/integrations.py
View file @
269c9638
...
@@ -24,8 +24,8 @@ import tempfile
...
@@ -24,8 +24,8 @@ import tempfile
from
copy
import
deepcopy
from
copy
import
deepcopy
from
pathlib
import
Path
from
pathlib
import
Path
from
.dependency_versions_check
import
dep_version_check
from
.utils
import
logging
from
.utils
import
logging
from
.utils.versions
import
require_version
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
...
@@ -324,7 +324,7 @@ def deepspeed_parse_config(ds_config):
...
@@ -324,7 +324,7 @@ def deepspeed_parse_config(ds_config):
If it's already a dict, return a copy of it, so that we can freely modify it.
If it's already a dict, return a copy of it, so that we can freely modify it.
"""
"""
require
_version
(
"deepspeed
>0.3.13
"
)
dep
_version
_check
(
"deepspeed"
)
if
isinstance
(
ds_config
,
dict
):
if
isinstance
(
ds_config
,
dict
):
# Don't modify user's data should they want to reuse it (e.g. in tests), because once we
# Don't modify user's data should they want to reuse it (e.g. in tests), because once we
...
@@ -604,9 +604,11 @@ class TensorBoardCallback(TrainerCallback):
...
@@ -604,9 +604,11 @@ class TensorBoardCallback(TrainerCallback):
self
.
tb_writer
.
add_hparams
(
args
.
to_sanitized_dict
(),
metric_dict
=
{})
self
.
tb_writer
.
add_hparams
(
args
.
to_sanitized_dict
(),
metric_dict
=
{})
def
on_log
(
self
,
args
,
state
,
control
,
logs
=
None
,
**
kwargs
):
def
on_log
(
self
,
args
,
state
,
control
,
logs
=
None
,
**
kwargs
):
if
state
.
is_world_process_zero
:
if
not
state
.
is_world_process_zero
:
if
self
.
tb_writer
is
None
:
return
self
.
_init_summary_writer
(
args
)
if
self
.
tb_writer
is
None
:
self
.
_init_summary_writer
(
args
)
if
self
.
tb_writer
is
not
None
:
if
self
.
tb_writer
is
not
None
:
logs
=
rewrite_logs
(
logs
)
logs
=
rewrite_logs
(
logs
)
...
...
src/transformers/modeling_flax_utils.py
View file @
269c9638
...
@@ -387,6 +387,7 @@ class FlaxPreTrainedModel(ABC):
...
@@ -387,6 +387,7 @@ class FlaxPreTrainedModel(ABC):
# get abs dir
# get abs dir
save_directory
=
os
.
path
.
abspath
(
save_directory
)
save_directory
=
os
.
path
.
abspath
(
save_directory
)
# save config as well
# save config as well
self
.
config
.
architectures
=
[
self
.
__class__
.
__name__
[
4
:]]
self
.
config
.
save_pretrained
(
save_directory
)
self
.
config
.
save_pretrained
(
save_directory
)
# save model
# save model
...
...
src/transformers/modeling_tf_utils.py
View file @
269c9638
...
@@ -1037,6 +1037,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
...
@@ -1037,6 +1037,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
logger
.
info
(
f
"Saved model created in
{
saved_model_dir
}
"
)
logger
.
info
(
f
"Saved model created in
{
saved_model_dir
}
"
)
# Save configuration file
# Save configuration file
self
.
config
.
architectures
=
[
self
.
__class__
.
__name__
[
2
:]]
self
.
config
.
save_pretrained
(
save_directory
)
self
.
config
.
save_pretrained
(
save_directory
)
# If we save using the predefined names, we can load using `from_pretrained`
# If we save using the predefined names, we can load using `from_pretrained`
...
...
src/transformers/models/auto/__init__.py
View file @
269c9638
...
@@ -22,6 +22,7 @@ from ...file_utils import _BaseLazyModule, is_flax_available, is_tf_available, i
...
@@ -22,6 +22,7 @@ from ...file_utils import _BaseLazyModule, is_flax_available, is_tf_available, i
_import_structure
=
{
_import_structure
=
{
"auto_factory"
:
[
"get_values"
],
"configuration_auto"
:
[
"ALL_PRETRAINED_CONFIG_ARCHIVE_MAP"
,
"CONFIG_MAPPING"
,
"MODEL_NAMES_MAPPING"
,
"AutoConfig"
],
"configuration_auto"
:
[
"ALL_PRETRAINED_CONFIG_ARCHIVE_MAP"
,
"CONFIG_MAPPING"
,
"MODEL_NAMES_MAPPING"
,
"AutoConfig"
],
"feature_extraction_auto"
:
[
"FEATURE_EXTRACTOR_MAPPING"
,
"AutoFeatureExtractor"
],
"feature_extraction_auto"
:
[
"FEATURE_EXTRACTOR_MAPPING"
,
"AutoFeatureExtractor"
],
"tokenization_auto"
:
[
"TOKENIZER_MAPPING"
,
"AutoTokenizer"
],
"tokenization_auto"
:
[
"TOKENIZER_MAPPING"
,
"AutoTokenizer"
],
...
@@ -104,6 +105,7 @@ if is_flax_available():
...
@@ -104,6 +105,7 @@ if is_flax_available():
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
.auto_factory
import
get_values
from
.configuration_auto
import
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
,
CONFIG_MAPPING
,
MODEL_NAMES_MAPPING
,
AutoConfig
from
.configuration_auto
import
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
,
CONFIG_MAPPING
,
MODEL_NAMES_MAPPING
,
AutoConfig
from
.feature_extraction_auto
import
FEATURE_EXTRACTOR_MAPPING
,
AutoFeatureExtractor
from
.feature_extraction_auto
import
FEATURE_EXTRACTOR_MAPPING
,
AutoFeatureExtractor
from
.tokenization_auto
import
TOKENIZER_MAPPING
,
AutoTokenizer
from
.tokenization_auto
import
TOKENIZER_MAPPING
,
AutoTokenizer
...
...
src/transformers/models/auto/auto_factory.py
View file @
269c9638
...
@@ -328,6 +328,26 @@ FROM_PRETRAINED_FLAX_DOCSTRING = """
...
@@ -328,6 +328,26 @@ FROM_PRETRAINED_FLAX_DOCSTRING = """
"""
"""
def
_get_model_class
(
config
,
model_mapping
):
supported_models
=
model_mapping
[
type
(
config
)]
if
not
isinstance
(
supported_models
,
(
list
,
tuple
)):
return
supported_models
name_to_model
=
{
model
.
__name__
:
model
for
model
in
supported_models
}
architectures
=
getattr
(
config
,
"architectures"
,
[])
for
arch
in
architectures
:
if
arch
in
name_to_model
:
return
name_to_model
[
arch
]
elif
f
"TF
{
arch
}
"
in
name_to_model
:
return
name_to_model
[
f
"TF
{
arch
}
"
]
elif
f
"Flax
{
arch
}
"
in
name_to_model
:
return
name_to_model
[
f
"Flax
{
arch
}
"
]
# If not architecture is set in the config or match the supported models, the first element of the tuple is the
# defaults.
return
supported_models
[
0
]
class
_BaseAutoModelClass
:
class
_BaseAutoModelClass
:
# Base class for auto models.
# Base class for auto models.
_model_mapping
=
None
_model_mapping
=
None
...
@@ -341,7 +361,8 @@ class _BaseAutoModelClass:
...
@@ -341,7 +361,8 @@ class _BaseAutoModelClass:
def
from_config
(
cls
,
config
,
**
kwargs
):
def
from_config
(
cls
,
config
,
**
kwargs
):
if
type
(
config
)
in
cls
.
_model_mapping
.
keys
():
if
type
(
config
)
in
cls
.
_model_mapping
.
keys
():
return
cls
.
_model_mapping
[
type
(
config
)](
config
,
**
kwargs
)
model_class
=
_get_model_class
(
config
,
cls
.
_model_mapping
)
return
model_class
(
config
,
**
kwargs
)
raise
ValueError
(
raise
ValueError
(
f
"Unrecognized configuration class
{
config
.
__class__
}
for this kind of AutoModel:
{
cls
.
__name__
}
.
\n
"
f
"Unrecognized configuration class
{
config
.
__class__
}
for this kind of AutoModel:
{
cls
.
__name__
}
.
\n
"
f
"Model type should be one of
{
', '
.
join
(
c
.
__name__
for
c
in
cls
.
_model_mapping
.
keys
())
}
."
f
"Model type should be one of
{
', '
.
join
(
c
.
__name__
for
c
in
cls
.
_model_mapping
.
keys
())
}
."
...
@@ -356,9 +377,8 @@ class _BaseAutoModelClass:
...
@@ -356,9 +377,8 @@ class _BaseAutoModelClass:
)
)
if
type
(
config
)
in
cls
.
_model_mapping
.
keys
():
if
type
(
config
)
in
cls
.
_model_mapping
.
keys
():
return
cls
.
_model_mapping
[
type
(
config
)].
from_pretrained
(
model_class
=
_get_model_class
(
config
,
cls
.
_model_mapping
)
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
return
model_class
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
config
=
config
,
**
kwargs
)
)
raise
ValueError
(
raise
ValueError
(
f
"Unrecognized configuration class
{
config
.
__class__
}
for this kind of AutoModel:
{
cls
.
__name__
}
.
\n
"
f
"Unrecognized configuration class
{
config
.
__class__
}
for this kind of AutoModel:
{
cls
.
__name__
}
.
\n
"
f
"Model type should be one of
{
', '
.
join
(
c
.
__name__
for
c
in
cls
.
_model_mapping
.
keys
())
}
."
f
"Model type should be one of
{
', '
.
join
(
c
.
__name__
for
c
in
cls
.
_model_mapping
.
keys
())
}
."
...
@@ -418,3 +438,14 @@ def auto_class_factory(name, model_mapping, checkpoint_for_example="bert-base-ca
...
@@ -418,3 +438,14 @@ def auto_class_factory(name, model_mapping, checkpoint_for_example="bert-base-ca
from_pretrained
=
replace_list_option_in_docstrings
(
model_mapping
)(
from_pretrained
)
from_pretrained
=
replace_list_option_in_docstrings
(
model_mapping
)(
from_pretrained
)
new_class
.
from_pretrained
=
classmethod
(
from_pretrained
)
new_class
.
from_pretrained
=
classmethod
(
from_pretrained
)
return
new_class
return
new_class
def
get_values
(
model_mapping
):
result
=
[]
for
model
in
model_mapping
.
values
():
if
isinstance
(
model
,
(
list
,
tuple
)):
result
+=
list
(
model
)
else
:
result
.
append
(
model
)
return
result
src/transformers/models/auto/configuration_auto.py
View file @
269c9638
...
@@ -247,29 +247,38 @@ MODEL_NAMES_MAPPING = OrderedDict(
...
@@ -247,29 +247,38 @@ MODEL_NAMES_MAPPING = OrderedDict(
)
)
def
_get_class_name
(
model_class
):
if
isinstance
(
model_class
,
(
list
,
tuple
)):
return
" or "
.
join
([
f
":class:`~transformers.
{
c
.
__name__
}
`"
for
c
in
model_class
])
return
f
":class:`~transformers.
{
model_class
.
__name__
}
`"
def
_list_model_options
(
indent
,
config_to_class
=
None
,
use_model_types
=
True
):
def
_list_model_options
(
indent
,
config_to_class
=
None
,
use_model_types
=
True
):
if
config_to_class
is
None
and
not
use_model_types
:
if
config_to_class
is
None
and
not
use_model_types
:
raise
ValueError
(
"Using `use_model_types=False` requires a `config_to_class` dictionary."
)
raise
ValueError
(
"Using `use_model_types=False` requires a `config_to_class` dictionary."
)
if
use_model_types
:
if
use_model_types
:
if
config_to_class
is
None
:
if
config_to_class
is
None
:
model_type_to_name
=
{
model_type
:
config
.
__name__
for
model_type
,
config
in
CONFIG_MAPPING
.
items
()}
model_type_to_name
=
{
model_type
:
f
":class:`~transformers.
{
config
.
__name__
}
`"
for
model_type
,
config
in
CONFIG_MAPPING
.
items
()
}
else
:
else
:
model_type_to_name
=
{
model_type_to_name
=
{
model_type
:
config_to_class
[
config
]
.
__name__
model_type
:
_get_class_name
(
config_to_class
[
config
]
)
for
model_type
,
config
in
CONFIG_MAPPING
.
items
()
for
model_type
,
config
in
CONFIG_MAPPING
.
items
()
if
config
in
config_to_class
if
config
in
config_to_class
}
}
lines
=
[
lines
=
[
f
"
{
indent
}
- **
{
model_type
}
** --
:class:`~transformers.
{
model_type_to_name
[
model_type
]
}
`
(
{
MODEL_NAMES_MAPPING
[
model_type
]
}
model)"
f
"
{
indent
}
- **
{
model_type
}
** --
{
model_type_to_name
[
model_type
]
}
(
{
MODEL_NAMES_MAPPING
[
model_type
]
}
model)"
for
model_type
in
sorted
(
model_type_to_name
.
keys
())
for
model_type
in
sorted
(
model_type_to_name
.
keys
())
]
]
else
:
else
:
config_to_name
=
{
config
.
__name__
:
clas
.
_
_name
__
for
config
,
clas
in
config_to_class
.
items
()}
config_to_name
=
{
config
.
__name__
:
_get_
clas
s
_name
(
clas
)
for
config
,
clas
in
config_to_class
.
items
()}
config_to_model_name
=
{
config_to_model_name
=
{
config
.
__name__
:
MODEL_NAMES_MAPPING
[
model_type
]
for
model_type
,
config
in
CONFIG_MAPPING
.
items
()
config
.
__name__
:
MODEL_NAMES_MAPPING
[
model_type
]
for
model_type
,
config
in
CONFIG_MAPPING
.
items
()
}
}
lines
=
[
lines
=
[
f
"
{
indent
}
- :class:`~transformers.
{
config_name
}
` configuration class:
:class:`~transformers.
{
config_to_name
[
config_name
]
}
`
(
{
config_to_model_name
[
config_name
]
}
model)"
f
"
{
indent
}
- :class:`~transformers.
{
config_name
}
` configuration class:
{
config_to_name
[
config_name
]
}
(
{
config_to_model_name
[
config_name
]
}
model)"
for
config_name
in
sorted
(
config_to_name
.
keys
())
for
config_name
in
sorted
(
config_to_name
.
keys
())
]
]
return
"
\n
"
.
join
(
lines
)
return
"
\n
"
.
join
(
lines
)
...
...
src/transformers/models/auto/modeling_auto.py
View file @
269c9638
...
@@ -124,6 +124,7 @@ from ..flaubert.modeling_flaubert import (
...
@@ -124,6 +124,7 @@ from ..flaubert.modeling_flaubert import (
)
)
from
..fsmt.modeling_fsmt
import
FSMTForConditionalGeneration
,
FSMTModel
from
..fsmt.modeling_fsmt
import
FSMTForConditionalGeneration
,
FSMTModel
from
..funnel.modeling_funnel
import
(
from
..funnel.modeling_funnel
import
(
FunnelBaseModel
,
FunnelForMaskedLM
,
FunnelForMaskedLM
,
FunnelForMultipleChoice
,
FunnelForMultipleChoice
,
FunnelForPreTraining
,
FunnelForPreTraining
,
...
@@ -377,7 +378,7 @@ MODEL_MAPPING = OrderedDict(
...
@@ -377,7 +378,7 @@ MODEL_MAPPING = OrderedDict(
(
CTRLConfig
,
CTRLModel
),
(
CTRLConfig
,
CTRLModel
),
(
ElectraConfig
,
ElectraModel
),
(
ElectraConfig
,
ElectraModel
),
(
ReformerConfig
,
ReformerModel
),
(
ReformerConfig
,
ReformerModel
),
(
FunnelConfig
,
FunnelModel
),
(
FunnelConfig
,
(
FunnelModel
,
FunnelBaseModel
)
),
(
LxmertConfig
,
LxmertModel
),
(
LxmertConfig
,
LxmertModel
),
(
BertGenerationConfig
,
BertGenerationEncoder
),
(
BertGenerationConfig
,
BertGenerationEncoder
),
(
DebertaConfig
,
DebertaModel
),
(
DebertaConfig
,
DebertaModel
),
...
...
src/transformers/models/auto/modeling_tf_auto.py
View file @
269c9638
...
@@ -91,6 +91,7 @@ from ..flaubert.modeling_tf_flaubert import (
...
@@ -91,6 +91,7 @@ from ..flaubert.modeling_tf_flaubert import (
TFFlaubertWithLMHeadModel
,
TFFlaubertWithLMHeadModel
,
)
)
from
..funnel.modeling_tf_funnel
import
(
from
..funnel.modeling_tf_funnel
import
(
TFFunnelBaseModel
,
TFFunnelForMaskedLM
,
TFFunnelForMaskedLM
,
TFFunnelForMultipleChoice
,
TFFunnelForMultipleChoice
,
TFFunnelForPreTraining
,
TFFunnelForPreTraining
,
...
@@ -242,7 +243,7 @@ TF_MODEL_MAPPING = OrderedDict(
...
@@ -242,7 +243,7 @@ TF_MODEL_MAPPING = OrderedDict(
(
XLMConfig
,
TFXLMModel
),
(
XLMConfig
,
TFXLMModel
),
(
CTRLConfig
,
TFCTRLModel
),
(
CTRLConfig
,
TFCTRLModel
),
(
ElectraConfig
,
TFElectraModel
),
(
ElectraConfig
,
TFElectraModel
),
(
FunnelConfig
,
TFFunnelModel
),
(
FunnelConfig
,
(
TFFunnelModel
,
TFFunnelBaseModel
)
),
(
DPRConfig
,
TFDPRQuestionEncoder
),
(
DPRConfig
,
TFDPRQuestionEncoder
),
(
MPNetConfig
,
TFMPNetModel
),
(
MPNetConfig
,
TFMPNetModel
),
(
BartConfig
,
TFBartModel
),
(
BartConfig
,
TFBartModel
),
...
...
src/transformers/testing_utils.py
View file @
269c9638
...
@@ -24,6 +24,7 @@ import unittest
...
@@ -24,6 +24,7 @@ import unittest
from
distutils.util
import
strtobool
from
distutils.util
import
strtobool
from
io
import
StringIO
from
io
import
StringIO
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Iterator
,
Union
from
.file_utils
import
(
from
.file_utils
import
(
is_datasets_available
,
is_datasets_available
,
...
@@ -621,6 +622,27 @@ class CaptureLogger:
...
@@ -621,6 +622,27 @@ class CaptureLogger:
return
f
"captured:
{
self
.
out
}
\n
"
return
f
"captured:
{
self
.
out
}
\n
"
@
contextlib
.
contextmanager
# adapted from https://stackoverflow.com/a/64789046/9201239
def
ExtendSysPath
(
path
:
Union
[
str
,
os
.
PathLike
])
->
Iterator
[
None
]:
"""
Temporary add given path to `sys.path`.
Usage ::
with ExtendSysPath('/path/to/dir'):
mymodule = importlib.import_module('mymodule')
"""
path
=
os
.
fspath
(
path
)
try
:
sys
.
path
.
insert
(
0
,
path
)
yield
finally
:
sys
.
path
.
remove
(
path
)
class
TestCasePlus
(
unittest
.
TestCase
):
class
TestCasePlus
(
unittest
.
TestCase
):
"""
"""
This class extends `unittest.TestCase` with additional features.
This class extends `unittest.TestCase` with additional features.
...
...
src/transformers/trainer.py
View file @
269c9638
...
@@ -54,6 +54,7 @@ from torch.utils.data.distributed import DistributedSampler
...
@@ -54,6 +54,7 @@ from torch.utils.data.distributed import DistributedSampler
from
torch.utils.data.sampler
import
RandomSampler
,
SequentialSampler
from
torch.utils.data.sampler
import
RandomSampler
,
SequentialSampler
from
.data.data_collator
import
DataCollator
,
DataCollatorWithPadding
,
default_data_collator
from
.data.data_collator
import
DataCollator
,
DataCollatorWithPadding
,
default_data_collator
from
.dependency_versions_check
import
dep_version_check
from
.file_utils
import
(
from
.file_utils
import
(
WEIGHTS_NAME
,
WEIGHTS_NAME
,
is_apex_available
,
is_apex_available
,
...
@@ -139,17 +140,14 @@ if is_torch_tpu_available():
...
@@ -139,17 +140,14 @@ if is_torch_tpu_available():
import
torch_xla.distributed.parallel_loader
as
pl
import
torch_xla.distributed.parallel_loader
as
pl
if
is_fairscale_available
():
if
is_fairscale_available
():
dep_version_check
(
"fairscale"
)
import
fairscale
import
fairscale
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FullyShardedDDP
from
fairscale.nn.data_parallel
import
ShardedDataParallel
as
ShardedDDP
from
fairscale.nn.data_parallel
import
ShardedDataParallel
as
ShardedDDP
from
fairscale.nn.wrap
import
auto_wrap
from
fairscale.optim
import
OSS
from
fairscale.optim
import
OSS
from
fairscale.optim.grad_scaler
import
ShardedGradScaler
from
fairscale.optim.grad_scaler
import
ShardedGradScaler
if
version
.
parse
(
fairscale
.
__version__
)
>=
version
.
parse
(
"0.3"
):
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FullyShardedDDP
from
fairscale.nn.wrap
import
auto_wrap
else
:
FullyShardedDDP
=
None
if
is_sagemaker_dp_enabled
():
if
is_sagemaker_dp_enabled
():
import
smdistributed.dataparallel.torch.distributed
as
dist
import
smdistributed.dataparallel.torch.distributed
as
dist
from
smdistributed.dataparallel.torch.parallel.distributed
import
DistributedDataParallel
as
DDP
from
smdistributed.dataparallel.torch.parallel.distributed
import
DistributedDataParallel
as
DDP
...
...
src/transformers/training_args.py
View file @
269c9638
...
@@ -531,6 +531,12 @@ class TrainingArguments:
...
@@ -531,6 +531,12 @@ class TrainingArguments:
)
)
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then).
# This needs to happen before any call to self.device or self.n_gpu.
env_local_rank
=
int
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
-
1
))
if
env_local_rank
!=
-
1
and
env_local_rank
!=
self
.
local_rank
:
self
.
local_rank
=
env_local_rank
# expand paths, if not os.makedirs("~/bar") will make directory
# expand paths, if not os.makedirs("~/bar") will make directory
# in the current directory instead of the actual home
# in the current directory instead of the actual home
# see https://github.com/huggingface/transformers/issues/10628
# see https://github.com/huggingface/transformers/issues/10628
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment