Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
84c0aa18
Commit
84c0aa18
authored
Jan 10, 2020
by
Julien Chaumond
Browse files
num_parameters helper
parent
331065e6
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
35 additions
and
2 deletions
+35
-2
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+17
-1
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+14
-1
tests/test_modeling_auto.py
tests/test_modeling_auto.py
+2
-0
tests/test_modeling_tf_auto.py
tests/test_modeling_tf_auto.py
+2
-0
No files found.
src/transformers/modeling_tf_utils.py
View file @
84c0aa18
...
...
@@ -20,6 +20,7 @@ import logging
import
os
import
h5py
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.python.keras.saving
import
hdf5_format
...
...
@@ -31,7 +32,22 @@ from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
logger
=
logging
.
getLogger
(
__name__
)
class
TFPreTrainedModel
(
tf
.
keras
.
Model
):
class
TFModelUtils
:
"""
A few utilities for `tf.keras.Model`s, to be used as a mixin.
"""
def
num_parameters
(
self
,
only_trainable
:
bool
=
False
)
->
int
:
"""
Get number of (optionally, trainable) parameters in the model.
"""
if
only_trainable
:
return
int
(
sum
(
np
.
prod
(
w
.
shape
.
as_list
())
for
w
in
self
.
trainable_variables
))
else
:
return
self
.
count_params
()
class
TFPreTrainedModel
(
tf
.
keras
.
Model
,
TFModelUtils
):
r
""" Base class for all TF models.
:class:`~transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
...
...
src/transformers/modeling_utils.py
View file @
84c0aa18
...
...
@@ -53,7 +53,20 @@ except ImportError:
return
input
class
PreTrainedModel
(
nn
.
Module
):
class
ModuleUtils
:
"""
A few utilities for torch.nn.Modules, to be used as a mixin.
"""
def
num_parameters
(
self
,
only_trainable
:
bool
=
False
)
->
int
:
"""
Get number of (optionally, trainable) parameters in the module.
"""
params
=
filter
(
lambda
x
:
x
.
requires_grad
,
self
.
parameters
())
if
only_trainable
else
self
.
parameters
()
return
sum
(
p
.
numel
()
for
p
in
params
)
class
PreTrainedModel
(
nn
.
Module
,
ModuleUtils
):
r
""" Base class for all models.
:class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
...
...
tests/test_modeling_auto.py
View file @
84c0aa18
...
...
@@ -100,3 +100,5 @@ class AutoModelTest(unittest.TestCase):
logging
.
basicConfig
(
level
=
logging
.
INFO
)
model
=
AutoModelWithLMHead
.
from_pretrained
(
SMALL_MODEL_IDENTIFIER
)
self
.
assertIsInstance
(
model
,
BertForMaskedLM
)
self
.
assertEqual
(
model
.
num_parameters
(),
14830
)
self
.
assertEqual
(
model
.
num_parameters
(
only_trainable
=
True
),
14830
)
tests/test_modeling_tf_auto.py
View file @
84c0aa18
...
...
@@ -99,3 +99,5 @@ class TFAutoModelTest(unittest.TestCase):
logging
.
basicConfig
(
level
=
logging
.
INFO
)
model
=
TFAutoModelWithLMHead
.
from_pretrained
(
SMALL_MODEL_IDENTIFIER
)
self
.
assertIsInstance
(
model
,
TFBertForMaskedLM
)
self
.
assertEqual
(
model
.
num_parameters
(),
14830
)
self
.
assertEqual
(
model
.
num_parameters
(
only_trainable
=
True
),
14830
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment