Unverified Commit 60d51ef5 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[trainer] param count for deepspeed zero3 (#22193)

[trainer] param count for zero3
parent cf601b90
...@@ -97,6 +97,7 @@ from .trainer_pt_utils import ( ...@@ -97,6 +97,7 @@ from .trainer_pt_utils import (
distributed_broadcast_scalars, distributed_broadcast_scalars,
distributed_concat, distributed_concat,
find_batch_size, find_batch_size,
get_model_param_count,
get_module_class_from_name, get_module_class_from_name,
get_parameter_names, get_parameter_names,
nested_concat, nested_concat,
...@@ -1744,9 +1745,7 @@ class Trainer: ...@@ -1744,9 +1745,7 @@ class Trainer:
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps}") logger.info(f" Total optimization steps = {max_steps}")
logger.info( logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True)}")
f" Number of trainable parameters = {sum(p.numel() for p in model.parameters() if p.requires_grad)}"
)
self.state.epoch = 0 self.state.epoch = 0
start_time = time.time() start_time = time.time()
......
...@@ -35,6 +35,7 @@ from torch import nn ...@@ -35,6 +35,7 @@ from torch import nn
from torch.utils.data import Dataset, IterableDataset, RandomSampler, Sampler from torch.utils.data import Dataset, IterableDataset, RandomSampler, Sampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from .deepspeed import is_deepspeed_zero3_enabled
from .tokenization_utils_base import BatchEncoding from .tokenization_utils_base import BatchEncoding
from .utils import is_sagemaker_mp_enabled, is_torch_tpu_available, is_training_run_on_sagemaker, logging from .utils import is_sagemaker_mp_enabled, is_torch_tpu_available, is_training_run_on_sagemaker, logging
...@@ -1032,6 +1033,23 @@ def save_state(self): ...@@ -1032,6 +1033,23 @@ def save_state(self):
self.state.save_to_json(path) self.state.save_to_json(path)
def get_model_param_count(model, trainable_only=False):
"""
Calculate model's total param count. If trainable_only is True then count only those requiring grads
"""
if is_deepspeed_zero3_enabled():
def numel(p):
return p.ds_numel
else:
def numel(p):
return p.numel()
return sum(numel(p) for p in model.parameters() if not trainable_only or p.requires_grad)
def get_parameter_names(model, forbidden_layer_types): def get_parameter_names(model, forbidden_layer_types):
""" """
Returns the names of the model parameters that are not inside a forbidden layer. Returns the names of the model parameters that are not inside a forbidden layer.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment