Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
60d51ef5
Unverified
Commit
60d51ef5
authored
Mar 17, 2023
by
Stas Bekman
Committed by
GitHub
Mar 17, 2023
Browse files
[trainer] param count for deepspeed zero3 (#22193)
[trainer] param count for zero3
parent
cf601b90
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
3 deletions
+20
-3
src/transformers/trainer.py
src/transformers/trainer.py
+2
-3
src/transformers/trainer_pt_utils.py
src/transformers/trainer_pt_utils.py
+18
-0
No files found.
src/transformers/trainer.py
View file @
60d51ef5
...
@@ -97,6 +97,7 @@ from .trainer_pt_utils import (
...
@@ -97,6 +97,7 @@ from .trainer_pt_utils import (
distributed_broadcast_scalars
,
distributed_broadcast_scalars
,
distributed_concat
,
distributed_concat
,
find_batch_size
,
find_batch_size
,
get_model_param_count
,
get_module_class_from_name
,
get_module_class_from_name
,
get_parameter_names
,
get_parameter_names
,
nested_concat
,
nested_concat
,
...
@@ -1744,9 +1745,7 @@ class Trainer:
...
@@ -1744,9 +1745,7 @@ class Trainer:
logger
.
info
(
f
" Total train batch size (w. parallel, distributed & accumulation) =
{
total_train_batch_size
}
"
)
logger
.
info
(
f
" Total train batch size (w. parallel, distributed & accumulation) =
{
total_train_batch_size
}
"
)
logger
.
info
(
f
" Gradient Accumulation steps =
{
args
.
gradient_accumulation_steps
}
"
)
logger
.
info
(
f
" Gradient Accumulation steps =
{
args
.
gradient_accumulation_steps
}
"
)
logger
.
info
(
f
" Total optimization steps =
{
max_steps
}
"
)
logger
.
info
(
f
" Total optimization steps =
{
max_steps
}
"
)
logger
.
info
(
logger
.
info
(
f
" Number of trainable parameters =
{
get_model_param_count
(
model
,
trainable_only
=
True
)
}
"
)
f
" Number of trainable parameters =
{
sum
(
p
.
numel
()
for
p
in
model
.
parameters
()
if
p
.
requires_grad
)
}
"
)
self
.
state
.
epoch
=
0
self
.
state
.
epoch
=
0
start_time
=
time
.
time
()
start_time
=
time
.
time
()
...
...
src/transformers/trainer_pt_utils.py
View file @
60d51ef5
...
@@ -35,6 +35,7 @@ from torch import nn
...
@@ -35,6 +35,7 @@ from torch import nn
from
torch.utils.data
import
Dataset
,
IterableDataset
,
RandomSampler
,
Sampler
from
torch.utils.data
import
Dataset
,
IterableDataset
,
RandomSampler
,
Sampler
from
torch.utils.data.distributed
import
DistributedSampler
from
torch.utils.data.distributed
import
DistributedSampler
from
.deepspeed
import
is_deepspeed_zero3_enabled
from
.tokenization_utils_base
import
BatchEncoding
from
.tokenization_utils_base
import
BatchEncoding
from
.utils
import
is_sagemaker_mp_enabled
,
is_torch_tpu_available
,
is_training_run_on_sagemaker
,
logging
from
.utils
import
is_sagemaker_mp_enabled
,
is_torch_tpu_available
,
is_training_run_on_sagemaker
,
logging
...
@@ -1032,6 +1033,23 @@ def save_state(self):
...
@@ -1032,6 +1033,23 @@ def save_state(self):
self
.
state
.
save_to_json
(
path
)
self
.
state
.
save_to_json
(
path
)
def
get_model_param_count
(
model
,
trainable_only
=
False
):
"""
Calculate model's total param count. If trainable_only is True then count only those requiring grads
"""
if
is_deepspeed_zero3_enabled
():
def
numel
(
p
):
return
p
.
ds_numel
else
:
def
numel
(
p
):
return
p
.
numel
()
return
sum
(
numel
(
p
)
for
p
in
model
.
parameters
()
if
not
trainable_only
or
p
.
requires_grad
)
def
get_parameter_names
(
model
,
forbidden_layer_types
):
def
get_parameter_names
(
model
,
forbidden_layer_types
):
"""
"""
Returns the names of the model parameters that are not inside a forbidden layer.
Returns the names of the model parameters that are not inside a forbidden layer.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment