Commit adbba962 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

Merge branch 'main' into vision_transformer

parents a75f1783 be473a5b
This diff is collapsed.
...@@ -28,11 +28,11 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ ...@@ -28,11 +28,11 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--num-layers 24 \ --num-layers 24 \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
--batch-size 8 \ --micro-batch-size 8 \
--checkpoint-activations \ --checkpoint-activations \
--lr 5.0e-5 \ --lr 5.0e-5 \
--lr-decay-style linear \ --lr-decay-style linear \
--warmup 0.065 \ --lr-warmup-fraction 0.065 \
--seq-length 512 \ --seq-length 512 \
--max-position-embeddings 512 \ --max-position-embeddings 512 \
--save-interval 500000 \ --save-interval 500000 \
......
...@@ -28,11 +28,11 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ ...@@ -28,11 +28,11 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--num-layers 24 \ --num-layers 24 \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
--batch-size 4 \ --micro-batch-size 4 \
--checkpoint-activations \ --checkpoint-activations \
--lr 1.0e-5 \ --lr 1.0e-5 \
--lr-decay-style linear \ --lr-decay-style linear \
--warmup 0.06 \ --lr-warmup-fraction 0.06 \
--seq-length 512 \ --seq-length 512 \
--max-position-embeddings 512 \ --max-position-embeddings 512 \
--save-interval 100000 \ --save-interval 100000 \
......
...@@ -9,24 +9,24 @@ python pretrain_bert.py \ ...@@ -9,24 +9,24 @@ python pretrain_bert.py \
--num-layers 24 \ --num-layers 24 \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
--batch-size 4 \ --micro-batch-size 4 \
--global-batch-size 8 \
--seq-length 512 \ --seq-length 512 \
--max-position-embeddings 512 \ --max-position-embeddings 512 \
--train-iters 2000000 \ --train-iters 2000000 \
--lr-decay-iters 990000 \
--save $CHECKPOINT_PATH \ --save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \ --load $CHECKPOINT_PATH \
--data-path $DATA_PATH \ --data-path $DATA_PATH \
--vocab-file bert-vocab.txt \ --vocab-file bert-vocab.txt \
--data-impl mmap \ --data-impl mmap \
--split 949,50,1 \ --split 949,50,1 \
--distributed-backend nccl \
--lr 0.0001 \ --lr 0.0001 \
--min-lr 0.00001 \ --min-lr 0.00001 \
--lr-decay-style linear \ --lr-decay-style linear \
--lr-decay-iters 990000 \ --lr-warmup-fraction .01 \
--weight-decay 1e-2 \ --weight-decay 1e-2 \
--clip-grad 1.0 \ --clip-grad 1.0 \
--warmup .01 \
--log-interval 100 \ --log-interval 100 \
--save-interval 10000 \ --save-interval 10000 \
--eval-interval 1000 \ --eval-interval 1000 \
......
...@@ -15,11 +15,11 @@ DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $ ...@@ -15,11 +15,11 @@ DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $
python -m torch.distributed.launch $DISTRIBUTED_ARGS \ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
pretrain_bert.py \ pretrain_bert.py \
--tensor-model-parallel-size 1 \
--num-layers 24 \ --num-layers 24 \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
--batch-size 4 \ --micro-batch-size 4 \
--global-batch-size 32 \
--seq-length 512 \ --seq-length 512 \
--max-position-embeddings 512 \ --max-position-embeddings 512 \
--train-iters 1000000 \ --train-iters 1000000 \
...@@ -36,7 +36,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \ ...@@ -36,7 +36,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
--lr-decay-iters 990000 \ --lr-decay-iters 990000 \
--weight-decay 1e-2 \ --weight-decay 1e-2 \
--clip-grad 1.0 \ --clip-grad 1.0 \
--warmup .01 \ --lr-warmup-fraction .01 \
--log-interval 100 \ --log-interval 100 \
--save-interval 10000 \ --save-interval 10000 \
--eval-interval 1000 \ --eval-interval 1000 \
......
...@@ -20,8 +20,8 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \ ...@@ -20,8 +20,8 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
--num-layers 24 \ --num-layers 24 \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
--batch-size 2 \ --micro-batch-size 2 \
--num-microbatches-in-minibatch 2 \ --global-batch-size 16 \
--seq-length 512 \ --seq-length 512 \
--max-position-embeddings 512 \ --max-position-embeddings 512 \
--train-iters 1000000 \ --train-iters 1000000 \
...@@ -38,7 +38,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \ ...@@ -38,7 +38,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
--lr-decay-iters 990000 \ --lr-decay-iters 990000 \
--weight-decay 1e-2 \ --weight-decay 1e-2 \
--clip-grad 1.0 \ --clip-grad 1.0 \
--warmup .01 \ --lr-warmup-fraction .01 \
--log-interval 100 \ --log-interval 100 \
--save-interval 10000 \ --save-interval 10000 \
--eval-interval 1000 \ --eval-interval 1000 \
......
...@@ -9,11 +9,12 @@ DATA_PATH=<Specify path and file prefix>_text_document ...@@ -9,11 +9,12 @@ DATA_PATH=<Specify path and file prefix>_text_document
CHECKPOINT_PATH=<Specify path> CHECKPOINT_PATH=<Specify path>
python pretrain_gpt2.py \ python pretrain_gpt.py \
--num-layers 24 \ --num-layers 24 \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
--batch-size 8 \ --micro-batch-size 4 \
--global-batch-size 8 \
--seq-length 1024 \ --seq-length 1024 \
--max-position-embeddings 1024 \ --max-position-embeddings 1024 \
--train-iters 500000 \ --train-iters 500000 \
...@@ -31,7 +32,7 @@ python pretrain_gpt2.py \ ...@@ -31,7 +32,7 @@ python pretrain_gpt2.py \
--lr-decay-style cosine \ --lr-decay-style cosine \
--weight-decay 1e-2 \ --weight-decay 1e-2 \
--clip-grad 1.0 \ --clip-grad 1.0 \
--warmup .01 \ --lr-warmup-fraction .01 \
--checkpoint-activations \ --checkpoint-activations \
--log-interval 100 \ --log-interval 100 \
--save-interval 10000 \ --save-interval 10000 \
......
#!/bin/bash
#SBATCH <SLURM OPTIONS> --nodes=128 --exclusive --ntasks-per-node=8 --job-name=megatron_gpt3_175b
DIR=`pwd`
DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'`
mkdir -p $DIR/logs
DATASET_1="<PATH TO THE FIRST DATASET>"
DATASET_2="<PATH TO THE SECOND DATASET>"
DATASET_3="<PATH TO THE THIRD DATASET>"
DATASET="0.2 ${DATASET_1} 0.3 ${DATASET_2} 0.5 ${DATASET_3}"
options=" \
--tensor-model-parallel-size 8 \
--pipeline-model-parallel-size 16 \
--num-layers 96 \
--hidden-size 12288 \
--num-attention-heads 96 \
--seq-length 2048 \
--max-position-embeddings 2048 \
--micro-batch-size 1 \
--global-batch-size 1536 \
--rampup-batch-size 16 16 5859375 \
--train-samples 146484375 \
--lr-decay-samples 126953125 \
--lr-warmup-samples 183105 \
--lr 6.0e-5 \
--min-lr 6.0e-6 \
--lr-decay-style cosine \
--log-interval 10 \
--eval-iters 40 \
--eval-interval 1000 \
--data-path ${DATASET} \
--vocab-file <PATH TO gpt-vocab.json> \
--merge-file <PATH TO gpt-merges.txt> \
--save-interval 1000 \
--save <PATH TO CHECKPOINTS DIRECTORY> \
--load <PATH TO CHECKPOINTS DIRECTORY> \
--split 98,2,0 \
--clip-grad 1.0 \
--weight-decay 0.1 \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--tensorboard-dir <TENSORBOARD DIRECTORY> \
--fp16 \
--checkpoint-activations "
run_cmd="python -u ${DIR}/pretrain_gpt.py $@ ${options}"
srun -l \
--container-image "nvcr.io/nvidia/pytorch:20.12-py3" \
--container-mounts "<DIRECTORIES TO MOUNT>" \
--output=$DIR/logs/%x_%j_$DATETIME.log sh -c "${run_cmd}"
set +x
...@@ -16,12 +16,12 @@ CHECKPOINT_PATH=<Specify path> ...@@ -16,12 +16,12 @@ CHECKPOINT_PATH=<Specify path>
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
python -m torch.distributed.launch $DISTRIBUTED_ARGS \ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
pretrain_gpt2.py \ pretrain_gpt.py \
--tensor-model-parallel-size 1 \
--num-layers 24 \ --num-layers 24 \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
--batch-size 8 \ --micro-batch-size 8 \
--global-batch-size 64 \
--seq-length 1024 \ --seq-length 1024 \
--max-position-embeddings 1024 \ --max-position-embeddings 1024 \
--train-iters 500000 \ --train-iters 500000 \
...@@ -39,7 +39,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \ ...@@ -39,7 +39,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
--min-lr 1.0e-5 \ --min-lr 1.0e-5 \
--weight-decay 1e-2 \ --weight-decay 1e-2 \
--clip-grad 1.0 \ --clip-grad 1.0 \
--warmup .01 \ --lr-warmup-fraction .01 \
--checkpoint-activations \ --checkpoint-activations \
--log-interval 100 \ --log-interval 100 \
--save-interval 10000 \ --save-interval 10000 \
......
...@@ -16,14 +16,14 @@ CHECKPOINT_PATH=<Specify path> ...@@ -16,14 +16,14 @@ CHECKPOINT_PATH=<Specify path>
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
python -m torch.distributed.launch $DISTRIBUTED_ARGS \ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
pretrain_gpt2.py \ pretrain_gpt.py \
--tensor-model-parallel-size 2 \ --tensor-model-parallel-size 2 \
--pipeline-model-parallel-size 2 \ --pipeline-model-parallel-size 2 \
--num-layers 24 \ --num-layers 24 \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
--batch-size 4 \ --micro-batch-size 4 \
--num-microbatches-in-minibatch 2 \ --global-batch-size 16 \
--seq-length 1024 \ --seq-length 1024 \
--max-position-embeddings 1024 \ --max-position-embeddings 1024 \
--train-iters 500000 \ --train-iters 500000 \
...@@ -41,7 +41,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \ ...@@ -41,7 +41,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
--min-lr 1.0e-5 \ --min-lr 1.0e-5 \
--weight-decay 1e-2 \ --weight-decay 1e-2 \
--clip-grad 1.0 \ --clip-grad 1.0 \
--warmup .01 \ --lr-warmup-fraction .01 \
--checkpoint-activations \ --checkpoint-activations \
--log-interval 100 \ --log-interval 100 \
--save-interval 10000 \ --save-interval 10000 \
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""GPT2 style dataset.""" """GPT style dataset."""
import os import os
import time import time
...@@ -107,7 +107,7 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -107,7 +107,7 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
if splits[index + 1] > splits[index]: if splits[index + 1] > splits[index]:
documents = np.arange(start=splits[index], stop=splits[index + 1], documents = np.arange(start=splits[index], stop=splits[index + 1],
step=1, dtype=np.int32) step=1, dtype=np.int32)
dataset = GPT2Dataset(name, data_prefix, dataset = GPTDataset(name, data_prefix,
documents, indexed_dataset, documents, indexed_dataset,
train_valid_test_num_samples[index], train_valid_test_num_samples[index],
seq_length, seed) seq_length, seed)
...@@ -136,7 +136,7 @@ def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): ...@@ -136,7 +136,7 @@ def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
return indexed_dataset return indexed_dataset
class GPT2Dataset(torch.utils.data.Dataset): class GPTDataset(torch.utils.data.Dataset):
def __init__(self, name, data_prefix, documents, indexed_dataset, def __init__(self, name, data_prefix, documents, indexed_dataset,
num_samples, seq_length, seed): num_samples, seq_length, seed):
......
...@@ -35,10 +35,11 @@ from .bert_model import (BertModel, ...@@ -35,10 +35,11 @@ from .bert_model import (BertModel,
BertModelFirstStage, BertModelFirstStage,
BertModelIntermediateStage, BertModelIntermediateStage,
BertModelLastStage) BertModelLastStage)
from .gpt2_model import (GPT2Model, from .realm_model import ICTBertModel
GPT2ModelFirstStage, from .gpt_model import (GPTModel,
GPT2ModelIntermediateStage, GPTModelFirstStage,
GPT2ModelLastStage) GPTModelIntermediateStage,
GPTModelLastStage)
from .language_model import get_language_model from .language_model import get_language_model
from .module import FP16Module from .module import FP16Module
from .realm_model import ICTBertModel from .realm_model import ICTBertModel
......
...@@ -56,11 +56,11 @@ def post_language_model_processing(lm_output, labels, logit_weights, ...@@ -56,11 +56,11 @@ def post_language_model_processing(lm_output, labels, logit_weights,
return loss return loss
class GPT2ModelBase(MegatronModule): class GPTModelBase(MegatronModule):
"""GPT-2 Language model.""" """GPT-2 Language model."""
def __init__(self, num_tokentypes=0, parallel_output=True): def __init__(self, num_tokentypes=0, parallel_output=True):
super(GPT2ModelBase, self).__init__() super(GPTModelBase, self).__init__()
args = get_args() args = get_args()
self.parallel_output = parallel_output self.parallel_output = parallel_output
...@@ -75,17 +75,17 @@ class GPT2ModelBase(MegatronModule): ...@@ -75,17 +75,17 @@ class GPT2ModelBase(MegatronModule):
self.initialize_word_embeddings(init_method_normal) self.initialize_word_embeddings(init_method_normal)
def forward(self, gpt2_model_input, attention_mask, labels=None, def forward(self, gpt_model_input, attention_mask, labels=None,
tokentype_ids=None, layer_past=None, get_key_value=False, tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None): forward_method_parallel_output=None):
kwargs = {'layer_past': layer_past, 'get_key_value': get_key_value} kwargs = {'layer_past': layer_past, 'get_key_value': get_key_value}
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
(input_ids, position_ids) = gpt2_model_input (input_ids, position_ids) = gpt_model_input
args = [input_ids, position_ids, attention_mask] args = [input_ids, position_ids, attention_mask]
kwargs['tokentype_ids'] = tokentype_ids kwargs['tokentype_ids'] = tokentype_ids
else: else:
args = [gpt2_model_input, attention_mask] args = [gpt_model_input, attention_mask]
lm_output = self.language_model(*args, **kwargs) lm_output = self.language_model(*args, **kwargs)
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
...@@ -124,17 +124,17 @@ class GPT2ModelBase(MegatronModule): ...@@ -124,17 +124,17 @@ class GPT2ModelBase(MegatronModule):
self.language_model.load_state_dict(state_dict, strict=strict) self.language_model.load_state_dict(state_dict, strict=strict)
class GPT2Model(GPT2ModelBase): class GPTModel(GPTModelBase):
def __init__(self, num_tokentypes=0, parallel_output=True): def __init__(self, num_tokentypes=0, parallel_output=True):
super(GPT2Model, self).__init__( super(GPTModel, self).__init__(
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
parallel_output=parallel_output) parallel_output=parallel_output)
def forward(self, input_ids, position_ids, attention_mask, labels=None, def forward(self, input_ids, position_ids, attention_mask, labels=None,
tokentype_ids=None, layer_past=None, get_key_value=False, tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None): forward_method_parallel_output=None):
return super(GPT2Model, self).forward( return super(GPTModel, self).forward(
(input_ids, position_ids), (input_ids, position_ids),
attention_mask, attention_mask,
labels=labels, labels=labels,
...@@ -144,15 +144,15 @@ class GPT2Model(GPT2ModelBase): ...@@ -144,15 +144,15 @@ class GPT2Model(GPT2ModelBase):
forward_method_parallel_output=forward_method_parallel_output) forward_method_parallel_output=forward_method_parallel_output)
class GPT2ModelFirstStage(GPT2ModelBase): class GPTModelFirstStage(GPTModelBase):
def __init__(self, num_tokentypes=0): def __init__(self, num_tokentypes=0):
super(GPT2ModelFirstStage, self).__init__( super(GPTModelFirstStage, self).__init__(
num_tokentypes=num_tokentypes) num_tokentypes=num_tokentypes)
def forward(self, input_ids, position_ids, attention_mask, def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False): tokentype_ids=None, layer_past=None, get_key_value=False):
return super(GPT2ModelFirstStage, self).forward( return super(GPTModelFirstStage, self).forward(
(input_ids, position_ids), (input_ids, position_ids),
attention_mask, attention_mask,
tokentype_ids=tokentype_ids, tokentype_ids=tokentype_ids,
...@@ -160,32 +160,32 @@ class GPT2ModelFirstStage(GPT2ModelBase): ...@@ -160,32 +160,32 @@ class GPT2ModelFirstStage(GPT2ModelBase):
get_key_value=get_key_value) get_key_value=get_key_value)
class GPT2ModelIntermediateStage(GPT2ModelBase): class GPTModelIntermediateStage(GPTModelBase):
def __init__(self, num_tokentypes=0): def __init__(self, num_tokentypes=0):
super(GPT2ModelIntermediateStage, self).__init__( super(GPTModelIntermediateStage, self).__init__(
num_tokentypes=num_tokentypes) num_tokentypes=num_tokentypes)
def forward(self, hidden_state, attention_mask, def forward(self, hidden_state, attention_mask,
layer_past=None, get_key_value=False): layer_past=None, get_key_value=False):
return super(GPT2ModelIntermediateStage, self).forward( return super(GPTModelIntermediateStage, self).forward(
hidden_state, hidden_state,
attention_mask, attention_mask,
layer_past=layer_past, layer_past=layer_past,
get_key_value=get_key_value) get_key_value=get_key_value)
class GPT2ModelLastStage(GPT2ModelBase): class GPTModelLastStage(GPTModelBase):
def __init__(self, num_tokentypes=0, parallel_output=True): def __init__(self, num_tokentypes=0, parallel_output=True):
super(GPT2ModelLastStage, self).__init__( super(GPTModelLastStage, self).__init__(
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
parallel_output=parallel_output) parallel_output=parallel_output)
def forward(self, hidden_state, attention_mask, labels=None, def forward(self, hidden_state, attention_mask, labels=None,
layer_past=None, get_key_value=False, layer_past=None, get_key_value=False,
forward_method_parallel_output=None): forward_method_parallel_output=None):
return super(GPT2ModelLastStage, self).forward( return super(GPTModelLastStage, self).forward(
hidden_state, hidden_state,
attention_mask, attention_mask,
labels=labels, labels=labels,
......
...@@ -929,6 +929,7 @@ def evaluate_and_print_results(prefix, forward_step_func, ...@@ -929,6 +929,7 @@ def evaluate_and_print_results(prefix, forward_step_func,
data_iterator, model, data_iterator, model,
iteration, verbose=False): iteration, verbose=False):
"""Helper function to evaluate and dump results on screen.""" """Helper function to evaluate and dump results on screen."""
args = get_args()
writer = get_tensorboard_writer() writer = get_tensorboard_writer()
total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose) total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose)
...@@ -937,11 +938,16 @@ def evaluate_and_print_results(prefix, forward_step_func, ...@@ -937,11 +938,16 @@ def evaluate_and_print_results(prefix, forward_step_func,
string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item()) string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
ppl = math.exp(min(20, total_loss_dict[key].item())) ppl = math.exp(min(20, total_loss_dict[key].item()))
string += '{} PPL: {:.6E} | '.format(key, ppl) string += '{} PPL: {:.6E} | '.format(key, ppl)
if writer and torch.distributed.get_rank() == 0: if writer and is_last_rank():
writer.add_scalar('{} value'.format(key), writer.add_scalar('{} value-validation'.format(key),
total_loss_dict[key].item(), total_loss_dict[key].item(),
iteration) iteration)
writer.add_scalar('{} ppl'.format(key), ppl, iteration) writer.add_scalar('{} ppl-validation'.format(key), ppl, iteration)
writer.add_scalar('{} value-validation vs samples'.format(key),
total_loss_dict[key].item(),
args.consumed_train_samples)
writer.add_scalar('{} ppl-validation vs samples'.format(key), ppl,
args.consumed_train_samples)
length = len(string) + 1 length = len(string) + 1
print_rank_last('-' * length) print_rank_last('-' * length)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment