Unverified Commit 1cd1181d authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

check `model_parallel` is initialized in `build_model` (#1248)

parent 7ec8ed67
......@@ -23,7 +23,7 @@ from . import pyprof
from . import transformer
# Logging utilities mainly for apex.transformer module
# Logging utilities for apex.transformer module
class RankInfoFormatter(logging.Formatter):
def format(self, record):
......
......@@ -82,11 +82,11 @@ def build_model(
set_defaults_if_not_set_tensor_model_parallel_attributes(param)
# Print number of parameters.
if parallel_state.get_data_parallel_rank() == 0:
if parallel_state.model_parallel_is_initialized() and parallel_state.get_data_parallel_rank() == 0:
msg = " > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}".format(
parallel_state.get_tensor_model_parallel_rank(),
parallel_state.get_pipeline_model_parallel_rank(),
sum([sum([p.nelement() for p in model_module.parameters()]) for model_module in model])
_calc_number_of_params(model),
)
print(msg, flush=True)
......@@ -108,6 +108,11 @@ def build_model(
return model
def _calc_number_of_params(model: List[torch.nn.Module]) -> int:
assert isinstance(model, list)
return sum([sum([p.nelement() for p in model_module.parameters()]) for model_module in model])
def _get_params_for_weight_decay_optimization(
model: Union[torch.nn.Module, List[torch.nn.Module]],
) -> Dict[str, torch.nn.Parameter]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment