Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
68f50f34
Unverified
Commit
68f50f34
authored
Oct 03, 2022
by
Steven Liu
Committed by
GitHub
Oct 03, 2022
Browse files
Breakup export guide (#19271)
* split onnx and torchscript docs * make style * apply reviews
parent
18c06208
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
347 additions
and
327 deletions
+347
-327
docs/source/en/_toctree.yml
docs/source/en/_toctree.yml
+3
-1
docs/source/en/serialization.mdx
docs/source/en/serialization.mdx
+119
-326
docs/source/en/torchscript.mdx
docs/source/en/torchscript.mdx
+225
-0
No files found.
docs/source/en/_toctree.yml
View file @
68f50f34
...
...
@@ -33,7 +33,9 @@
-
local
:
converting_tensorflow_models
title
:
Converting from TensorFlow checkpoints
-
local
:
serialization
title
:
Export 🤗 Transformers models
title
:
Export to ONNX
-
local
:
torchscript
title
:
Export to TorchScript
-
local
:
troubleshooting
title
:
Troubleshoot
title
:
General usage
...
...
docs/source/en/serialization.mdx
View file @
68f50f34
...
...
@@ -10,36 +10,36 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific
language
governing
permissions
and
limitations
under
the
License
.
-->
#
Export
🤗
Transformers
Models
#
Export
to
ONNX
If
you
need
to
deploy
🤗
Transformers
models
in
production
environments
,
we
recommend
exporting
them
to
a
serialized
format
that
can
be
loaded
and
executed
on
specialized
runtimes
and
hardware
.
In
this
guide
,
we
'll show you how to
export 🤗 Transformers models in two widely used formats: ONNX and TorchScript
.
If
you
need
to
deploy
🤗
Transformers
models
in
production
environments
,
we
recommend
exporting
them
to
a
serialized
format
that
can
be
loaded
and
executed
on
specialized
runtimes
and
hardware
.
In
this
guide
,
we
'll show you how to
export 🤗 Transformers
models to [ONNX (Open Neural Network eXchange)](http://onnx.ai)
.
Once exported, a model can optimized for inference via techniques such as
quantization and pruning. If you are interested in optimizing your models to run
with maximum efficiency, check out the [🤗 Optimum
<Tip>
Once exported, a model can be optimized for inference via techniques such as
quantization and pruning. If you are interested in optimizing your models to run with
maximum efficiency, check out the [🤗 Optimum
library](https://github.com/huggingface/optimum).
## ONNX
</Tip>
The [ONNX (Open Neural Network eXchange)](http://onnx.ai) project is an open
standard that defines a common set of operators and a common file format to
represent deep learning models in a wide variety of frameworks, including
PyTorch and TensorFlow. When a model is exported to the ONNX format, these
operators are used to construct a computational graph (often called an
_intermediate representation_) which represents the flow of data through the
neural network.
ONNX is an open standard that defines a common set of operators and a common file format
to represent deep learning models in a wide variety of frameworks, including PyTorch and
TensorFlow. When a model is exported to the ONNX format, these operators are used to
construct a computational graph (often called an _intermediate representation_) which
represents the flow of data through the neural network.
By exposing a graph with standardized operators and data types, ONNX makes it
easy to
switch between frameworks. For example, a model trained in PyTorch can
be exported to
ONNX format and then imported in TensorFlow (and vice versa).
By exposing a graph with standardized operators and data types, ONNX makes it
easy to
switch between frameworks. For example, a model trained in PyTorch can
be exported to
ONNX format and then imported in TensorFlow (and vice versa).
🤗 Transformers provides a `transformers.onnx` package that enables
you to
convert model checkpoints to an ONNX graph by leveraging configuration objects.
These configuration objects come ready made for a number of model architectures,
and are
designed to be easily extendable to other architectures.
🤗 Transformers provides a
[
`transformers.onnx`
](main_classes/onnx)
package that enables
you to
convert model checkpoints to an ONNX graph by leveraging configuration objects.
These configuration objects come ready made for a number of model architectures,
and are
designed to be easily extendable to other architectures.
Ready-made configurations include the following architectures:
...
...
@@ -106,10 +106,10 @@ In the next two sections, we'll show you how to:
*
Export
a
supported
model
using
the
`
transformers
.
onnx
`
package
.
*
Export
a
custom
model
for
an
unsupported
architecture
.
##
#
Exporting
a
model
to
ONNX
##
Exporting
a
model
to
ONNX
To
export
a
🤗
Transformers
model
to
ONNX
,
you
'll first need to install some
extra
dependencies:
To
export
a
🤗
Transformers
model
to
ONNX
,
you
'll first need to install some
extra
dependencies:
```bash
pip install transformers[onnx]
...
...
@@ -141,7 +141,7 @@ Exporting a checkpoint using a ready-made configuration can be done as follows:
python -m transformers.onnx --model=distilbert-base-uncased onnx/
```
which
should s
how
the following logs:
You
should s
ee
the following logs:
```bash
Validating ONNX model...
...
...
@@ -152,13 +152,13 @@ Validating ONNX model...
All good, model saved at: onnx/model.onnx
```
This exports an ONNX graph of the checkpoint defined by the `--model` argument.
In this
example it is `distilbert-base-uncased`, but it can be any checkpoint on
the Hugging
Face Hub or one that'
s
stored
locally
.
This exports an ONNX graph of the checkpoint defined by the `--model` argument.
In this
example
,
it is `distilbert-base-uncased`, but it can be any checkpoint on
the Hugging
Face Hub or one that'
s
stored
locally
.
The
resulting
`
model
.
onnx
`
file
can
then
be
run
on
one
of
the
[
many
accelerators
](
https
://
onnx
.
ai
/
supported
-
tools
.
html
#
deployModel
)
that
support
the
ONNX
standard
.
For
example
,
we
can
load
and
run
the
model
with
[
ONNX
accelerators
](
https
://
onnx
.
ai
/
supported
-
tools
.
html
#
deployModel
)
that
support
the
ONNX
standard
.
For
example
,
we
can
load
and
run
the
model
with
[
ONNX
Runtime
](
https
://
onnxruntime
.
ai
/)
as
follows
:
```
python
...
...
@@ -172,9 +172,8 @@ Runtime](https://onnxruntime.ai/) as follows:
>>>
outputs
=
session
.
run
(
output_names
=[
"last_hidden_state"
],
input_feed
=
dict
(
inputs
))
```
The
required
output
names
(
i
.
e
.
`[
"last_hidden_state"
]`)
can
be
obtained
by
taking
a
look
at
the
ONNX
configuration
of
each
model
.
For
example
,
for
DistilBERT
we
have
:
The
required
output
names
(
like
`[
"last_hidden_state"
]`)
can
be
obtained
by
taking
a
look
at
the
ONNX
configuration
of
each
model
.
For
example
,
for
DistilBERT
we
have
:
```
python
>>>
from
transformers
.
models
.
distilbert
import
DistilBertConfig
,
DistilBertOnnxConfig
...
...
@@ -185,20 +184,19 @@ DistilBERT we have:
[
"last_hidden_state"
]
```
The
process
is
identical
for
TensorFlow
checkpoints
on
the
Hub
.
For
example
,
we
can
export
a
pure
TensorFlow
checkpoint
from
the
[
Keras
The
process
is
identical
for
TensorFlow
checkpoints
on
the
Hub
.
For
example
,
we
can
export
a
pure
TensorFlow
checkpoint
from
the
[
Keras
organization
](
https
://
huggingface
.
co
/
keras
-
io
)
as
follows
:
```
bash
python
-
m
transformers
.
onnx
--
model
=
keras
-
io
/
transformers
-
qa
onnx
/
```
To
export
a
model
that
's stored locally, you'
ll
need
to
have
the
model
's weights
and
tokenizer files stored in a directory. For example, we can load and save a
checkpoint as
follows:
To
export
a
model
that
's stored locally, you'
ll
need
to
have
the
model
's weights
and
tokenizer files stored in a directory. For example, we can load and save a
checkpoint as
follows:
<frameworkcontent>
<pt>
<frameworkcontent> <pt>
```python
>>> from transformers import AutoTokenizer, AutoModelForSequenceClassification
...
...
@@ -216,8 +214,7 @@ argument of the `transformers.onnx` package to the desired directory:
```bash
python -m transformers.onnx --model=local-pt-checkpoint onnx/
```
</pt>
<tf>
</pt> <tf>
```python
>>> from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
...
...
@@ -235,14 +232,13 @@ argument of the `transformers.onnx` package to the desired directory:
```bash
python -m transformers.onnx --model=local-tf-checkpoint onnx/
```
</tf>
</frameworkcontent>
</tf> </frameworkcontent>
##
#
Selecting features for different model t
opologie
s
## Selecting features for different model t
ask
s
Each ready-made configuration comes with a set of _features_ that enable you to
export
models for different types of
topologies or
tasks. As shown in the table
below, each feature is
associated with a different
a
uto
c
lass:
Each ready-made configuration comes with a set of _features_ that enable you to
export
models for different types of tasks. As shown in the table
below, each feature is
associated with a different
`A
uto
C
lass
`
:
| Feature | Auto Class |
| ------------------------------------ | ------------------------------------ |
...
...
@@ -255,7 +251,7 @@ below, each feature is associated with a different auto class:
| `token-classification` | `AutoModelForTokenClassification` |
For each configuration, you can find the list of supported features via the
`
FeaturesManager`. For example, for DistilBERT we have:
[`~transformers.onnx.
FeaturesManager`
]
. For example, for DistilBERT we have:
```python
>>> from transformers.onnx.features import FeaturesManager
...
...
@@ -266,15 +262,15 @@ For each configuration, you can find the list of supported features via the
```
You can then pass one of these features to the `--feature` argument in the
`transformers.onnx` package. For example, to export a text-classification model
we can
pick a fine-tuned model from the Hub and run:
`transformers.onnx` package. For example, to export a text-classification model
we can
pick a fine-tuned model from the Hub and run:
```bash
python -m transformers.onnx --model=distilbert-base-uncased-finetuned-sst-2-english \
--feature=sequence-classification onnx/
```
w
hi
ch will
display the following logs:
T
hi
s
display
s
the following logs:
```bash
Validating ONNX model...
...
...
@@ -285,37 +281,35 @@ Validating ONNX model...
All good, model saved at: onnx/model.onnx
```
Notice that in this case, the output names from the fine-tuned model are
`logits` instead of the `last_hidden_state` we saw with the
`distilbert-base-uncased` checkpoint earlier. This is expected since the
fine-tuned model has a sequence classification head.
Notice that in this case, the output names from the fine-tuned model are `logits`
instead of the `last_hidden_state` we saw with the `distilbert-base-uncased` checkpoint
earlier. This is expected since the fine-tuned model has a sequence classification head.
<Tip>
The features that have a `with-past` suffix (
e.g.
`causal-lm-with-past`)
correspond to model topologi
es with precomputed hidden states (key and values
in the attention blocks)
that can be used for fast autoregressive decoding.
The features that have a `with-past` suffix (
like
`causal-lm-with-past`)
correspond to
model class
es with precomputed hidden states (key and values
in the attention blocks)
that can be used for fast autoregressive decoding.
</Tip>
##
#
Exporting a model for an unsupported architecture
## Exporting a model for an unsupported architecture
If you wish to export a model whose architecture is not natively supported by
the
library, there are three main steps to follow:
If you wish to export a model whose architecture is not natively supported by
the
library, there are three main steps to follow:
1. Implement a custom ONNX configuration.
2. Export the model to ONNX.
3. Validate the outputs of the PyTorch and exported models.
In this section, we'
ll
look
at
how
DistilBERT
was
implemented
to
show
what
's
involved
with each step.
In this section, we'
ll
look
at
how
DistilBERT
was
implemented
to
show
what
's
involved
with each step.
###
#
Implementing a custom ONNX configuration
### Implementing a custom ONNX configuration
Let'
s
start
with
the
ONNX
configuration
object
.
We
provide
three
abstract
classes
that
you
should
inherit
from
,
depending
on
the
type
of
model
architecture
you
wish
to
export
:
Let'
s
start
with
the
ONNX
configuration
object
.
We
provide
three
abstract
classes
that
you
should
inherit
from
,
depending
on
the
type
of
model
architecture
you
wish
to
export
:
*
Encoder
-
based
models
inherit
from
[`~
onnx
.
config
.
OnnxConfig
`]
*
Decoder
-
based
models
inherit
from
[`~
onnx
.
config
.
OnnxConfigWithPast
`]
...
...
@@ -347,25 +341,24 @@ Since DistilBERT is an encoder-based model, its configuration inherits from
...
)
```
Every
configuration
object
must
implement
the
`
inputs
`
property
and
return
a
mapping
,
where
each
key
corresponds
to
an
expected
input
,
and
each
value
indicates
the
axis
of
that
input
.
For
DistilBERT
,
we
can
see
that
two
inputs
are
required
:
`
input_ids
`
and
`
attention_mask
`.
These
inputs
have
the
same
shape
of
`(
batch_size
,
sequence_length
)`
which
is
why
we
see
the
same
axes
used
in
the
configuration
.
Every
configuration
object
must
implement
the
`
inputs
`
property
and
return
a
mapping
,
where
each
key
corresponds
to
an
expected
input
,
and
each
value
indicates
the
axis
of
that
input
.
For
DistilBERT
,
we
can
see
that
two
inputs
are
required
:
`
input_ids
`
and
`
attention_mask
`.
These
inputs
have
the
same
shape
of
`(
batch_size
,
sequence_length
)`
which
is
why
we
see
the
same
axes
used
in
the
configuration
.
<
Tip
>
Notice
that
`
inputs
`
property
for
`
DistilBertOnnxConfig
`
returns
an
`
OrderedDict
`.
This
ensures
that
the
inputs
are
matched
with
their
relative
position
within
the
`
PreTrainedModel
.
forward
()`
method
when
tracing
the
graph
.
We
recommend
using
an
`
OrderedDict
`
for
the
`
inputs
`
and
`
outputs
`
properties
when
implementing
custom
ONNX
configurations
.
Notice
that
`
inputs
`
property
for
`
DistilBertOnnxConfig
`
returns
an
`
OrderedDict
`.
This
ensures
that
the
inputs
are
matched
with
their
relative
position
within
the
`
PreTrainedModel
.
forward
()`
method
when
tracing
the
graph
.
We
recommend
using
an
`
OrderedDict
`
for
the
`
inputs
`
and
`
outputs
`
properties
when
implementing
custom
ONNX
configurations
.
</
Tip
>
Once
you
have
implemented
an
ONNX
configuration
,
you
can
instantiate
it
by
providing
the
base
model
's configuration as follows:
Once
you
have
implemented
an
ONNX
configuration
,
you
can
instantiate
it
by
providing
the
base
model
's configuration as follows:
```python
>>> from transformers import AutoConfig
...
...
@@ -374,8 +367,8 @@ providing the base model's configuration as follows:
>>> onnx_config = DistilBertOnnxConfig(config)
```
The resulting object has several useful properties. For example you can view the
ONNX
operator set that will be used during the export:
The resulting object has several useful properties. For example
,
you can view the
ONNX
operator set that will be used during the export:
```python
>>> print(onnx_config.default_onnx_opset)
...
...
@@ -389,15 +382,14 @@ You can also view the outputs associated with the model as follows:
OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"})])
```
Notice that the outputs property follows the same structure as the inputs; it
returns an `OrderedDict` of named outputs and their shapes. The output structure
is linked to the choice of feature that the configuration is initialised with.
By default, the ONNX configuration is initialized with the `default` feature
that corresponds to exporting a model loaded with the `AutoModel` class. If you
want to export a different model topology, just provide a different feature to
the `task` argument when you initialize the ONNX configuration. For example, if
we wished to export DistilBERT with a sequence classification head, we could
use:
Notice that the outputs property follows the same structure as the inputs; it returns an
`OrderedDict` of named outputs and their shapes. The output structure is linked to the
choice of feature that the configuration is initialised with. By default, the ONNX
configuration is initialized with the `default` feature that corresponds to exporting a
model loaded with the `AutoModel` class. If you want to export a model for another task,
just provide a different feature to the `task` argument when you initialize the ONNX
configuration. For example, if we wished to export DistilBERT with a sequence
classification head, we could use:
```python
>>> from transformers import AutoConfig
...
...
@@ -410,18 +402,18 @@ OrderedDict([('logits', {0: 'batch'})])
<Tip>
All of the base properties and methods associated with [`~onnx.config.OnnxConfig`] and
the
other configuration classes can be overriden if needed. Check out
[`BartOnnxConfig`]
for an advanced example.
All of the base properties and methods associated with [`~onnx.config.OnnxConfig`] and
the
other configuration classes can be overriden if needed. Check out
[`BartOnnxConfig`]
for an advanced example.
</Tip>
###
#
Exporting the model
### Exporting the model
Once you have implemented the ONNX configuration, the next step is to export the
model.
Here we can use the `export()` function provided by the
`transformers.onnx` package.
This function expects the ONNX configuration, along
with the base model and tokenizer,
and the path to save the exported file:
Once you have implemented the ONNX configuration, the next step is to export the
model.
Here we can use the `export()` function provided by the
`transformers.onnx` package.
This function expects the ONNX configuration, along
with the base model and tokenizer,
and the path to save the exported file:
```python
>>> from pathlib import Path
...
...
@@ -436,10 +428,9 @@ with the base model and tokenizer, and the path to save the exported file:
>>> onnx_inputs, onnx_outputs = export(tokenizer, base_model, onnx_config, onnx_config.default_onnx_opset, onnx_path)
```
The `onnx_inputs` and `onnx_outputs` returned by the `export()` function are
lists of the keys defined in the `inputs` and `outputs` properties of the
configuration. Once the model is exported, you can test that the model is well
formed as follows:
The `onnx_inputs` and `onnx_outputs` returned by the `export()` function are lists of
the keys defined in the `inputs` and `outputs` properties of the configuration. Once the
model is exported, you can test that the model is well formed as follows:
```python
>>> import onnx
...
...
@@ -450,21 +441,20 @@ formed as follows:
<Tip>
If your model is larger than 2GB, you will see that many additional files are
created
during the export. This is _expected_ because ONNX uses [Protocol
Buffers](https://developers.google.com/protocol-buffers/) to store the model and
these
have a size limit of 2GB. See the [ONNX
documentation](https://github.com/onnx/onnx/blob/master/docs/ExternalData.md)
for
instructions on how to load models with external data.
If your model is larger than 2GB, you will see that many additional files are
created
during the export. This is _expected_ because ONNX uses [Protocol
Buffers](https://developers.google.com/protocol-buffers/) to store the model and
these
have a size limit of 2GB. See the [ONNX
documentation](https://github.com/onnx/onnx/blob/master/docs/ExternalData.md)
for
instructions on how to load models with external data.
</Tip>
###
#
Validating the model outputs
### Validating the model outputs
The final step is to validate that the outputs from the base and exported model
agree within some absolute tolerance. Here we can use the
`validate_model_outputs()` function provided by the `transformers.onnx` package
as follows:
The final step is to validate that the outputs from the base and exported model agree
within some absolute tolerance. Here we can use the `validate_model_outputs()` function
provided by the `transformers.onnx` package as follows:
```python
>>> from transformers.onnx import validate_model_outputs
...
...
@@ -474,220 +464,23 @@ as follows:
... )
```
This function uses the
`
OnnxConfig.generate_dummy_inputs
()
` method to
generate
inputs for the base and exported model, and the absolute tolerance can be
defined in the configuration. We generally find numerical agreement in the 1e-6
to 1e-4
range, although anything smaller than 1e-3 is likely to be OK.
This function uses the
[`~transformers.onnx.
OnnxConfig.generate_dummy_inputs`
]
method to
generate
inputs for the base and exported model, and the absolute tolerance can be
defined in the configuration. We generally find numerical agreement in the 1e-6
to 1e-4
range, although anything smaller than 1e-3 is likely to be OK.
##
#
Contributing a new configuration to 🤗 Transformers
## Contributing a new configuration to 🤗 Transformers
We are looking to expand the set of ready-made configurations and welcome
contributions
from the community! If you would like to contribute your addition
to the library, you
will need to:
We are looking to expand the set of ready-made configurations and welcome
contributions
from the community! If you would like to contribute your addition
to the library, you
will need to:
* Implement the ONNX configuration in the corresponding `configuration_<model_name>.py`
file
* Include the model architecture and corresponding features in [`~onnx.features.FeatureManager`]
* Include the model architecture and corresponding features in
[`~onnx.features.FeatureManager`]
* Add your model architecture to the tests in `test_onnx_v2.py`
Check out how the configuration for [IBERT was
contributed](https://github.com/huggingface/transformers/pull/14868/files) to
get an idea of what'
s
involved
.
##
TorchScript
<
Tip
>
This
is
the
very
beginning
of
our
experiments
with
TorchScript
and
we
are
still
exploring
its
capabilities
with
variable
-
input
-
size
models
.
It
is
a
focus
of
interest
to
us
and
we
will
deepen
our
analysis
in
upcoming
releases
,
with
more
code
examples
,
a
more
flexible
implementation
,
and
benchmarks
comparing
python
-
based
codes
with
compiled
TorchScript
.
</
Tip
>
According
to
Pytorch
's documentation: "TorchScript is a way to create serializable and optimizable models from PyTorch
code". Pytorch'
s
two
modules
[
JIT
and
TRACE
](
https
://
pytorch
.
org
/
docs
/
stable
/
jit
.
html
)
allow
the
developer
to
export
their
model
to
be
re
-
used
in
other
programs
,
such
as
efficiency
-
oriented
C
++
programs
.
We
have
provided
an
interface
that
allows
the
export
of
🤗
Transformers
models
to
TorchScript
so
that
they
can
be
reused
in
a
different
environment
than
a
Pytorch
-
based
python
program
.
Here
we
explain
how
to
export
and
use
our
models
using
TorchScript
.
Exporting
a
model
requires
two
things
:
-
a
forward
pass
with
dummy
inputs
.
-
model
instantiation
with
the
`
torchscript
`
flag
.
These
necessities
imply
several
things
developers
should
be
careful
about
.
These
are
detailed
below
.
###
TorchScript
flag
and
tied
weights
This
flag
is
necessary
because
most
of
the
language
models
in
this
repository
have
tied
weights
between
their
`
Embedding
`
layer
and
their
`
Decoding
`
layer
.
TorchScript
does
not
allow
the
export
of
models
that
have
tied
weights
,
therefore
it
is
necessary
to
untie
and
clone
the
weights
beforehand
.
This
implies
that
models
instantiated
with
the
`
torchscript
`
flag
have
their
`
Embedding
`
layer
and
`
Decoding
`
layer
separate
,
which
means
that
they
should
not
be
trained
down
the
line
.
Training
would
de
-
synchronize
the
two
layers
,
leading
to
unexpected
results
.
This
is
not
the
case
for
models
that
do
not
have
a
Language
Model
head
,
as
those
do
not
have
tied
weights
.
These
models
can
be
safely
exported
without
the
`
torchscript
`
flag
.
###
Dummy
inputs
and
standard
lengths
The
dummy
inputs
are
used
to
do
a
model
forward
pass
.
While
the
inputs
' values are propagating through the layers,
Pytorch keeps track of the different operations executed on each tensor. These recorded operations are then used to
create the "trace" of the model.
The trace is created relatively to the inputs'
dimensions
.
It
is
therefore
constrained
by
the
dimensions
of
the
dummy
input
,
and
will
not
work
for
any
other
sequence
length
or
batch
size
.
When
trying
with
a
different
size
,
an
error
such
as
:
`
The
expanded
size
of
the
tensor
(
3
)
must
match
the
existing
size
(
7
)
at
non
-
singleton
dimension
2
`
will
be
raised
.
It
is
therefore
recommended
to
trace
the
model
with
a
dummy
input
size
at
least
as
large
as
the
largest
input
that
will
be
fed
to
the
model
during
inference
.
Padding
can
be
performed
to
fill
the
missing
values
.
As
the
model
will
have
been
traced
with
a
large
input
size
however
,
the
dimensions
of
the
different
matrix
will
be
large
as
well
,
resulting
in
more
calculations
.
It
is
recommended
to
be
careful
of
the
total
number
of
operations
done
on
each
input
and
to
follow
performance
closely
when
exporting
varying
sequence
-
length
models
.
###
Using
TorchScript
in
Python
Below
is
an
example
,
showing
how
to
save
,
load
models
as
well
as
how
to
use
the
trace
for
inference
.
####
Saving
a
model
This
snippet
shows
how
to
use
TorchScript
to
export
a
`
BertModel
`.
Here
the
`
BertModel
`
is
instantiated
according
to
a
`
BertConfig
`
class
and
then
saved
to
disk
under
the
filename
`
traced_bert
.
pt
`
```
python
from
transformers
import
BertModel
,
BertTokenizer
,
BertConfig
import
torch
enc
=
BertTokenizer
.
from_pretrained
(
"bert-base-uncased"
)
#
Tokenizing
input
text
text
=
"[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text
=
enc
.
tokenize
(
text
)
#
Masking
one
of
the
input
tokens
masked_index
=
8
tokenized_text
[
masked_index
]
=
"[MASK]"
indexed_tokens
=
enc
.
convert_tokens_to_ids
(
tokenized_text
)
segments_ids
=
[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
,
1
,
1
,
1
]
#
Creating
a
dummy
input
tokens_tensor
=
torch
.
tensor
([
indexed_tokens
])
segments_tensors
=
torch
.
tensor
([
segments_ids
])
dummy_input
=
[
tokens_tensor
,
segments_tensors
]
#
Initializing
the
model
with
the
torchscript
flag
#
Flag
set
to
True
even
though
it
is
not
necessary
as
this
model
does
not
have
an
LM
Head
.
config
=
BertConfig
(
vocab_size_or_config_json_file
=
32000
,
hidden_size
=
768
,
num_hidden_layers
=
12
,
num_attention_heads
=
12
,
intermediate_size
=
3072
,
torchscript
=
True
,
)
#
Instantiating
the
model
model
=
BertModel
(
config
)
#
The
model
needs
to
be
in
evaluation
mode
model
.
eval
()
#
If
you
are
instantiating
the
model
with
*
from_pretrained
*
you
can
also
easily
set
the
TorchScript
flag
model
=
BertModel
.
from_pretrained
(
"bert-base-uncased"
,
torchscript
=
True
)
#
Creating
the
trace
traced_model
=
torch
.
jit
.
trace
(
model
,
[
tokens_tensor
,
segments_tensors
])
torch
.
jit
.
save
(
traced_model
,
"traced_bert.pt"
)
```
####
Loading
a
model
This
snippet
shows
how
to
load
the
`
BertModel
`
that
was
previously
saved
to
disk
under
the
name
`
traced_bert
.
pt
`.
We
are
re
-
using
the
previously
initialised
`
dummy_input
`.
```
python
loaded_model
=
torch
.
jit
.
load
(
"traced_bert.pt"
)
loaded_model
.
eval
()
all_encoder_layers
,
pooled_output
=
loaded_model
(*
dummy_input
)
```
####
Using
a
traced
model
for
inference
Using
the
traced
model
for
inference
is
as
simple
as
using
its
`
__call__
`
dunder
method
:
```
python
traced_model
(
tokens_tensor
,
segments_tensors
)
```
###
Deploying
HuggingFace
TorchScript
models
on
AWS
using
the
Neuron
SDK
AWS
introduced
the
[
Amazon
EC2
Inf1
](
https
://
aws
.
amazon
.
com
/
ec2
/
instance
-
types
/
inf1
/)
instance
family
for
low
cost
,
high
performance
machine
learning
inference
in
the
cloud
.
The
Inf1
instances
are
powered
by
the
AWS
Inferentia
chip
,
a
custom
-
built
hardware
accelerator
,
specializing
in
deep
learning
inferencing
workloads
.
[
AWS
Neuron
](
https
://
awsdocs
-
neuron
.
readthedocs
-
hosted
.
com
/
en
/
latest
/#)
is
the
SDK
for
Inferentia
that
supports
tracing
and
optimizing
transformers
models
for
deployment
on
Inf1
.
The
Neuron
SDK
provides
:
1.
Easy
-
to
-
use
API
with
one
line
of
code
change
to
trace
and
optimize
a
TorchScript
model
for
inference
in
the
cloud
.
2.
Out
of
the
box
performance
optimizations
for
[
improved
cost
-
performance
](
https
://
awsdocs
-
neuron
.
readthedocs
-
hosted
.
com
/
en
/
latest
/
neuron
-
guide
/
benchmark
/>)
3.
Support
for
HuggingFace
transformers
models
built
with
either
[
PyTorch
](
https
://
awsdocs
-
neuron
.
readthedocs
-
hosted
.
com
/
en
/
latest
/
src
/
examples
/
pytorch
/
bert_tutorial
/
tutorial_pretrained_bert
.
html
)
or
[
TensorFlow
](
https
://
awsdocs
-
neuron
.
readthedocs
-
hosted
.
com
/
en
/
latest
/
src
/
examples
/
tensorflow
/
huggingface_bert
/
huggingface_bert
.
html
).
####
Implications
Transformers
Models
based
on
the
[
BERT
(
Bidirectional
Encoder
Representations
from
Transformers
)](
https
://
huggingface
.
co
/
docs
/
transformers
/
main
/
model_doc
/
bert
)
architecture
,
or
its
variants
such
as
[
distilBERT
](
https
://
huggingface
.
co
/
docs
/
transformers
/
main
/
model_doc
/
distilbert
)
and
[
roBERTa
](
https
://
huggingface
.
co
/
docs
/
transformers
/
main
/
model_doc
/
roberta
)
will
run
best
on
Inf1
for
non
-
generative
tasks
such
as
Extractive
Question
Answering
,
Sequence
Classification
,
Token
Classification
.
Alternatively
,
text
generation
tasks
can
be
adapted
to
run
on
Inf1
,
according
to
this
[
AWS
Neuron
MarianMT
tutorial
](
https
://
awsdocs
-
neuron
.
readthedocs
-
hosted
.
com
/
en
/
latest
/
src
/
examples
/
pytorch
/
transformers
-
marianmt
.
html
).
More
information
about
models
that
can
be
converted
out
of
the
box
on
Inferentia
can
be
found
in
the
[
Model
Architecture
Fit
section
of
the
Neuron
documentation
](
https
://
awsdocs
-
neuron
.
readthedocs
-
hosted
.
com
/
en
/
latest
/
neuron
-
guide
/
models
/
models
-
inferentia
.
html
#
models
-
inferentia
).
####
Dependencies
Using
AWS
Neuron
to
convert
models
requires
the
following
dependencies
and
environment
:
*
A
[
Neuron
SDK
environment
](
https
://
awsdocs
-
neuron
.
readthedocs
-
hosted
.
com
/
en
/
latest
/
neuron
-
guide
/
neuron
-
frameworks
/
pytorch
-
neuron
/
index
.
html
#
installation
-
guide
),
which
comes
pre
-
configured
on
[
AWS
Deep
Learning
AMI
](
https
://
docs
.
aws
.
amazon
.
com
/
dlami
/
latest
/
devguide
/
tutorial
-
inferentia
-
launching
.
html
).
####
Converting
a
Model
for
AWS
Neuron
Using
the
same
script
as
in
[
Using
TorchScript
in
Python
](
https
://
huggingface
.
co
/
docs
/
transformers
/
main
/
en
/
serialization
#
using
-
torchscript
-
in
-
python
)
to
trace
a
"BertModel"
,
you
import
`
torch
.
neuron
`
framework
extension
to
access
the
components
of
the
Neuron
SDK
through
a
Python
API
.
```
python
from
transformers
import
BertModel
,
BertTokenizer
,
BertConfig
import
torch
import
torch
.
neuron
```
And
only
modify
the
tracing
line
of
code
from
:
```
python
torch
.
jit
.
trace
(
model
,
[
tokens_tensor
,
segments_tensors
])
```
to
:
```
python
torch
.
neuron
.
trace
(
model
,
[
token_tensor
,
segments_tensors
])
```
This
change
enables
Neuron
SDK
to
trace
the
model
and
optimize
it
to
run
in
Inf1
instances
.
To
learn
more
about
AWS
Neuron
SDK
features
,
tools
,
example
tutorials
and
latest
updates
,
please
see
the
[
AWS
NeuronSDK
documentation
](
https
://
awsdocs
-
neuron
.
readthedocs
-
hosted
.
com
/
en
/
latest
/
index
.
html
).
contributed](https://github.com/huggingface/transformers/pull/14868/files) to get an
idea of what'
s
involved
.
\ No newline at end of file
docs/source/en/torchscript.mdx
0 → 100644
View file @
68f50f34
<
!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed
under
the
Apache
License
,
Version
2.0
(
the
"License"
);
you
may
not
use
this
file
except
in
compliance
with
the
License
.
You
may
obtain
a
copy
of
the
License
at
http
://
www
.
apache
.
org
/
licenses
/
LICENSE
-
2.0
Unless
required
by
applicable
law
or
agreed
to
in
writing
,
software
distributed
under
the
License
is
distributed
on
an
"AS IS"
BASIS
,
WITHOUT
WARRANTIES
OR
CONDITIONS
OF
ANY
KIND
,
either
express
or
implied
.
See
the
License
for
the
specific
language
governing
permissions
and
limitations
under
the
License
.
-->
#
Export
to
TorchScript
<
Tip
>
This
is
the
very
beginning
of
our
experiments
with
TorchScript
and
we
are
still
exploring
its
capabilities
with
variable
-
input
-
size
models
.
It
is
a
focus
of
interest
to
us
and
we
will
deepen
our
analysis
in
upcoming
releases
,
with
more
code
examples
,
a
more
flexible
implementation
,
and
benchmarks
comparing
Python
-
based
codes
with
compiled
TorchScript
.
</
Tip
>
According
to
the
[
TorchScript
documentation
](
https
://
pytorch
.
org
/
docs
/
stable
/
jit
.
html
):
>
TorchScript
is
a
way
to
create
serializable
and
optimizable
models
from
PyTorch
code
.
There
are
two
PyTorch
modules
,
[
JIT
and
TRACE
](
https
://
pytorch
.
org
/
docs
/
stable
/
jit
.
html
),
that
allow
developers
to
export
their
models
to
be
reused
in
other
programs
like
efficiency
-
oriented
C
++
programs
.
We
provide
an
interface
that
allows
you
to
export
🤗
Transformers
models
to
TorchScript
so
they
can
be
reused
in
a
different
environment
than
PyTorch
-
based
Python
programs
.
Here
,
we
explain
how
to
export
and
use
our
models
using
TorchScript
.
Exporting
a
model
requires
two
things
:
-
model
instantiation
with
the
`
torchscript
`
flag
-
a
forward
pass
with
dummy
inputs
These
necessities
imply
several
things
developers
should
be
careful
about
as
detailed
below
.
##
TorchScript
flag
and
tied
weights
The
`
torchscript
`
flag
is
necessary
because
most
of
the
🤗
Transformers
language
models
have
tied
weights
between
their
`
Embedding
`
layer
and
their
`
Decoding
`
layer
.
TorchScript
does
not
allow
you
to
export
models
that
have
tied
weights
,
so
it
is
necessary
to
untie
and
clone
the
weights
beforehand
.
Models
instantiated
with
the
`
torchscript
`
flag
have
their
`
Embedding
`
layer
and
`
Decoding
`
layer
separated
,
which
means
that
they
should
not
be
trained
down
the
line
.
Training
would
desynchronize
the
two
layers
,
leading
to
unexpected
results
.
This
is
not
the
case
for
models
that
do
not
have
a
language
model
head
,
as
those
do
not
have
tied
weights
.
These
models
can
be
safely
exported
without
the
`
torchscript
`
flag
.
##
Dummy
inputs
and
standard
lengths
The
dummy
inputs
are
used
for
a
models
forward
pass
.
While
the
inputs
' values are
propagated through the layers, PyTorch keeps track of the different operations executed
on each tensor. These recorded operations are then used to create the *trace* of the
model.
The trace is created relative to the inputs'
dimensions
.
It
is
therefore
constrained
by
the
dimensions
of
the
dummy
input
,
and
will
not
work
for
any
other
sequence
length
or
batch
size
.
When
trying
with
a
different
size
,
the
following
error
is
raised
:
```
`
The
expanded
size
of
the
tensor
(
3
)
must
match
the
existing
size
(
7
)
at
non
-
singleton
dimension
2
`
```
We
recommended
you
trace
the
model
with
a
dummy
input
size
at
least
as
large
as
the
largest
input
that
will
be
fed
to
the
model
during
inference
.
Padding
can
help
fill
the
missing
values
.
However
,
since
the
model
is
traced
with
a
larger
input
size
,
the
dimensions
of
the
matrix
will
also
be
large
,
resulting
in
more
calculations
.
Be
careful
of
the
total
number
of
operations
done
on
each
input
and
follow
the
performance
closely
when
exporting
varying
sequence
-
length
models
.
##
Using
TorchScript
in
Python
This
section
demonstrates
how
to
save
and
load
models
as
well
as
how
to
use
the
trace
for
inference
.
###
Saving
a
model
To
export
a
`
BertModel
`
with
TorchScript
,
instantiate
`
BertModel
`
from
the
`
BertConfig
`
class
and
then
save
it
to
disk
under
the
filename
`
traced_bert
.
pt
`:
```
python
from
transformers
import
BertModel
,
BertTokenizer
,
BertConfig
import
torch
enc
=
BertTokenizer
.
from_pretrained
(
"bert-base-uncased"
)
#
Tokenizing
input
text
text
=
"[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text
=
enc
.
tokenize
(
text
)
#
Masking
one
of
the
input
tokens
masked_index
=
8
tokenized_text
[
masked_index
]
=
"[MASK]"
indexed_tokens
=
enc
.
convert_tokens_to_ids
(
tokenized_text
)
segments_ids
=
[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
,
1
,
1
,
1
]
#
Creating
a
dummy
input
tokens_tensor
=
torch
.
tensor
([
indexed_tokens
])
segments_tensors
=
torch
.
tensor
([
segments_ids
])
dummy_input
=
[
tokens_tensor
,
segments_tensors
]
#
Initializing
the
model
with
the
torchscript
flag
#
Flag
set
to
True
even
though
it
is
not
necessary
as
this
model
does
not
have
an
LM
Head
.
config
=
BertConfig
(
vocab_size_or_config_json_file
=
32000
,
hidden_size
=
768
,
num_hidden_layers
=
12
,
num_attention_heads
=
12
,
intermediate_size
=
3072
,
torchscript
=
True
,
)
#
Instantiating
the
model
model
=
BertModel
(
config
)
#
The
model
needs
to
be
in
evaluation
mode
model
.
eval
()
#
If
you
are
instantiating
the
model
with
*
from_pretrained
*
you
can
also
easily
set
the
TorchScript
flag
model
=
BertModel
.
from_pretrained
(
"bert-base-uncased"
,
torchscript
=
True
)
#
Creating
the
trace
traced_model
=
torch
.
jit
.
trace
(
model
,
[
tokens_tensor
,
segments_tensors
])
torch
.
jit
.
save
(
traced_model
,
"traced_bert.pt"
)
```
###
Loading
a
model
Now
you
can
load
the
previously
saved
`
BertModel
`,
`
traced_bert
.
pt
`,
from
disk
and
use
it
on
the
previously
initialised
`
dummy_input
`:
```
python
loaded_model
=
torch
.
jit
.
load
(
"traced_bert.pt"
)
loaded_model
.
eval
()
all_encoder_layers
,
pooled_output
=
loaded_model
(*
dummy_input
)
```
###
Using
a
traced
model
for
inference
Use
the
traced
model
for
inference
by
using
its
`
__call__
`
dunder
method
:
```
python
traced_model
(
tokens_tensor
,
segments_tensors
)
```
##
Deploy
Hugging
Face
TorchScript
models
to
AWS
with
the
Neuron
SDK
AWS
introduced
the
[
Amazon
EC2
Inf1
](
https
://
aws
.
amazon
.
com
/
ec2
/
instance
-
types
/
inf1
/)
instance
family
for
low
cost
,
high
performance
machine
learning
inference
in
the
cloud
.
The
Inf1
instances
are
powered
by
the
AWS
Inferentia
chip
,
a
custom
-
built
hardware
accelerator
,
specializing
in
deep
learning
inferencing
workloads
.
[
AWS
Neuron
](
https
://
awsdocs
-
neuron
.
readthedocs
-
hosted
.
com
/
en
/
latest
/#)
is
the
SDK
for
Inferentia
that
supports
tracing
and
optimizing
transformers
models
for
deployment
on
Inf1
.
The
Neuron
SDK
provides
:
1.
Easy
-
to
-
use
API
with
one
line
of
code
change
to
trace
and
optimize
a
TorchScript
model
for
inference
in
the
cloud
.
2.
Out
of
the
box
performance
optimizations
for
[
improved
cost
-
performance
](
https
://
awsdocs
-
neuron
.
readthedocs
-
hosted
.
com
/
en
/
latest
/
neuron
-
guide
/
benchmark
/>).
3.
Support
for
Hugging
Face
transformers
models
built
with
either
[
PyTorch
](
https
://
awsdocs
-
neuron
.
readthedocs
-
hosted
.
com
/
en
/
latest
/
src
/
examples
/
pytorch
/
bert_tutorial
/
tutorial_pretrained_bert
.
html
)
or
[
TensorFlow
](
https
://
awsdocs
-
neuron
.
readthedocs
-
hosted
.
com
/
en
/
latest
/
src
/
examples
/
tensorflow
/
huggingface_bert
/
huggingface_bert
.
html
).
###
Implications
Transformers
models
based
on
the
[
BERT
(
Bidirectional
Encoder
Representations
from
Transformers
)](
https
://
huggingface
.
co
/
docs
/
transformers
/
main
/
model_doc
/
bert
)
architecture
,
or
its
variants
such
as
[
distilBERT
](
https
://
huggingface
.
co
/
docs
/
transformers
/
main
/
model_doc
/
distilbert
)
and
[
roBERTa
](
https
://
huggingface
.
co
/
docs
/
transformers
/
main
/
model_doc
/
roberta
)
run
best
on
Inf1
for
non
-
generative
tasks
such
as
extractive
question
answering
,
sequence
classification
,
and
token
classification
.
However
,
text
generation
tasks
can
still
be
adapted
to
run
on
Inf1
according
to
this
[
AWS
Neuron
MarianMT
tutorial
](
https
://
awsdocs
-
neuron
.
readthedocs
-
hosted
.
com
/
en
/
latest
/
src
/
examples
/
pytorch
/
transformers
-
marianmt
.
html
).
More
information
about
models
that
can
be
converted
out
of
the
box
on
Inferentia
can
be
found
in
the
[
Model
Architecture
Fit
](
https
://
awsdocs
-
neuron
.
readthedocs
-
hosted
.
com
/
en
/
latest
/
neuron
-
guide
/
models
/
models
-
inferentia
.
html
#
models
-
inferentia
)
section
of
the
Neuron
documentation
.
###
Dependencies
Using
AWS
Neuron
to
convert
models
requires
a
[
Neuron
SDK
environment
](
https
://
awsdocs
-
neuron
.
readthedocs
-
hosted
.
com
/
en
/
latest
/
neuron
-
guide
/
neuron
-
frameworks
/
pytorch
-
neuron
/
index
.
html
#
installation
-
guide
)
which
comes
preconfigured
on
[
AWS
Deep
Learning
AMI
](
https
://
docs
.
aws
.
amazon
.
com
/
dlami
/
latest
/
devguide
/
tutorial
-
inferentia
-
launching
.
html
).
###
Converting
a
model
for
AWS
Neuron
Convert
a
model
for
AWS
NEURON
using
the
same
code
from
[
Using
TorchScript
in
Python
](
serialization
#
using
-
torchscript
-
in
-
python
)
to
trace
a
`
BertModel
`.
Import
the
`
torch
.
neuron
`
framework
extension
to
access
the
components
of
the
Neuron
SDK
through
a
Python
API
:
```
python
from
transformers
import
BertModel
,
BertTokenizer
,
BertConfig
import
torch
import
torch
.
neuron
```
You
only
need
to
modify
the
following
line
:
```
diff
-
torch
.
jit
.
trace
(
model
,
[
tokens_tensor
,
segments_tensors
])
+
torch
.
neuron
.
trace
(
model
,
[
token_tensor
,
segments_tensors
])
```
This
enables
the
Neuron
SDK
to
trace
the
model
and
optimize
it
for
Inf1
instances
.
To
learn
more
about
AWS
Neuron
SDK
features
,
tools
,
example
tutorials
and
latest
updates
,
please
see
the
[
AWS
NeuronSDK
documentation
](
https
://
awsdocs
-
neuron
.
readthedocs
-
hosted
.
com
/
en
/
latest
/
index
.
html
).
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment