Unverified Commit 1cd1181d authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

check `model_parallel` is initialized in `build_model` (#1248)

parent 7ec8ed67
...@@ -23,7 +23,7 @@ from . import pyprof ...@@ -23,7 +23,7 @@ from . import pyprof
from . import transformer from . import transformer
# Logging utilities mainly for apex.transformer module # Logging utilities for apex.transformer module
class RankInfoFormatter(logging.Formatter): class RankInfoFormatter(logging.Formatter):
def format(self, record): def format(self, record):
......
...@@ -82,11 +82,11 @@ def build_model( ...@@ -82,11 +82,11 @@ def build_model(
set_defaults_if_not_set_tensor_model_parallel_attributes(param) set_defaults_if_not_set_tensor_model_parallel_attributes(param)
# Print number of parameters. # Print number of parameters.
if parallel_state.get_data_parallel_rank() == 0: if parallel_state.model_parallel_is_initialized() and parallel_state.get_data_parallel_rank() == 0:
msg = " > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}".format( msg = " > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}".format(
parallel_state.get_tensor_model_parallel_rank(), parallel_state.get_tensor_model_parallel_rank(),
parallel_state.get_pipeline_model_parallel_rank(), parallel_state.get_pipeline_model_parallel_rank(),
sum([sum([p.nelement() for p in model_module.parameters()]) for model_module in model]) _calc_number_of_params(model),
) )
print(msg, flush=True) print(msg, flush=True)
...@@ -108,6 +108,11 @@ def build_model( ...@@ -108,6 +108,11 @@ def build_model(
return model return model
def _calc_number_of_params(model: List[torch.nn.Module]) -> int:
assert isinstance(model, list)
return sum([sum([p.nelement() for p in model_module.parameters()]) for model_module in model])
def _get_params_for_weight_decay_optimization( def _get_params_for_weight_decay_optimization(
model: Union[torch.nn.Module, List[torch.nn.Module]], model: Union[torch.nn.Module, List[torch.nn.Module]],
) -> Dict[str, torch.nn.Parameter]: ) -> Dict[str, torch.nn.Parameter]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment