Unverified Commit 60d51ef5 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[trainer] param count for deepspeed zero3 (#22193)

[trainer] param count for zero3
parent cf601b90
......@@ -97,6 +97,7 @@ from .trainer_pt_utils import (
distributed_broadcast_scalars,
distributed_concat,
find_batch_size,
get_model_param_count,
get_module_class_from_name,
get_parameter_names,
nested_concat,
......@@ -1744,9 +1745,7 @@ class Trainer:
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps}")
logger.info(
f" Number of trainable parameters = {sum(p.numel() for p in model.parameters() if p.requires_grad)}"
)
logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True)}")
self.state.epoch = 0
start_time = time.time()
......
......@@ -35,6 +35,7 @@ from torch import nn
from torch.utils.data import Dataset, IterableDataset, RandomSampler, Sampler
from torch.utils.data.distributed import DistributedSampler
from .deepspeed import is_deepspeed_zero3_enabled
from .tokenization_utils_base import BatchEncoding
from .utils import is_sagemaker_mp_enabled, is_torch_tpu_available, is_training_run_on_sagemaker, logging
......@@ -1032,6 +1033,23 @@ def save_state(self):
self.state.save_to_json(path)
def get_model_param_count(model, trainable_only=False):
"""
Calculate model's total param count. If trainable_only is True then count only those requiring grads
"""
if is_deepspeed_zero3_enabled():
def numel(p):
return p.ds_numel
else:
def numel(p):
return p.numel()
return sum(numel(p) for p in model.parameters() if not trainable_only or p.requires_grad)
def get_parameter_names(model, forbidden_layer_types):
"""
Returns the names of the model parameters that are not inside a forbidden layer.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment