"model/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "3b96a93672377129f2a2aafc447e79ef1ca48c5f"
Commit 6e9d5cb0 authored by mohammad's avatar mohammad Committed by Deepak Narayanan
Browse files

Move args writer to the beginning of training

parent 8a6e56b8
...@@ -79,8 +79,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, ...@@ -79,8 +79,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# Autoresume. # Autoresume.
_init_autoresume() _init_autoresume()
# Write arguments to tensorboard.
_write_args_to_tensorboard()
# No continuation function # No continuation function
return None return None
...@@ -154,13 +152,14 @@ def _set_random_seed(seed_): ...@@ -154,13 +152,14 @@ def _set_random_seed(seed_):
raise ValueError('Seed ({}) should be a positive integer.'.format(seed)) raise ValueError('Seed ({}) should be a positive integer.'.format(seed))
def _write_args_to_tensorboard(): def write_args_to_tensorboard():
"""Write arguments to tensorboard.""" """Write arguments to tensorboard."""
args = get_args() args = get_args()
writer = get_tensorboard_writer() writer = get_tensorboard_writer()
if writer: if writer:
for arg in vars(args): for arg in vars(args):
writer.add_text(arg, str(getattr(args, arg))) writer.add_text(arg, str(getattr(args, arg)),
global_step=args.iteration)
def _initialize_mem_buffs(): def _initialize_mem_buffs():
......
...@@ -41,6 +41,7 @@ from megatron.checkpointing import save_checkpoint ...@@ -41,6 +41,7 @@ from megatron.checkpointing import save_checkpoint
from megatron.fp16 import FP16_Module from megatron.fp16 import FP16_Module
from megatron.fp16 import FP16_Optimizer from megatron.fp16 import FP16_Optimizer
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard
from megatron.learning_rates import AnnealingLR from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization from megatron.model import get_params_for_weight_decay_optimization
...@@ -811,6 +812,9 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -811,6 +812,9 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
# Write args to tensorboard
write_args_to_tensorboard()
# Turn on training mode which enables dropout. # Turn on training mode which enables dropout.
model.train() model.train()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment