Commit 74bd02ec authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'main' into next-best-lm/merge-rope-main

parents e1f1aa06 f6d36d03
......@@ -35,6 +35,7 @@ unit_tests:
stage: test
script: &selene-test-launcher-script
- echo "Running selene resume from checkpoint test. "
- echo "In case of error check ${SELENE_ADLR_CI_PATH}/${CI_PIPELINE_ID}/${CI_JOB_NAME}/results directory for result logs."
- pwd
- export BUILD_DIR=`pwd`
- export RUN_NAME=resume_${RUN_MODEL}_tp${TP_SIZE}_pp${PP_SIZE}_${NUM_NODES}nodes
......@@ -66,6 +67,7 @@ unit_tests:
- echo "Finished job"
- source $PYTHON_VIRTUAL_ENV
- pytest $BUILD_DIR/tests/functional_tests/python_test_utils/test_resume_checkpoint_pipeline.py
- if [ $? -ne 0 ]; then echo "Pytest failed. See ${SELENE_ADLR_CI_PATH}/${CI_PIPELINE_ID}/${CI_JOB_NAME}/results directory for result logs."; fi
- echo "Completed the job"
rules:
- if: $TEST_LEVEL =~ $TESTS_TO_RUN_ON_THIS_COMMIT || $CI_JOB_NAME =~ $TESTS_TO_RUN_ON_THIS_COMMIT || $CI_JOB_NAME =~ $TEST_REGEX_ON_THIS_COMMIT
......@@ -82,11 +84,13 @@ unit_tests:
stage: test
script: &selene-test-launcher-script
- echo "Running selene test"
- echo "In case of error check ${SELENE_ADLR_CI_PATH}/${CI_PIPELINE_ID}/${CI_JOB_NAME}/results directory for result logs."
- echo "$CI_MERGE_REQUEST_APPROVED"
- pwd
- export BUILD_DIR=`pwd`
- export RUN_NAME=${RUN_MODEL}_tp${TP_SIZE}_pp${PP_SIZE}_${NUM_NODES}nodes_${MAX_STEPS}steps
- export TP_SIZE PP_SIZE NUM_NODES MAX_STEPS VP_SIZE
- export MBS GBS
- export DATA_DIR=$DATA_DIR
- echo "Run name is $RUN_NAME"
- mkdir -p $SELENE_ADLR_CI_PATH/$CI_PIPELINE_ID/$RUN_NAME/checkpoints
......@@ -100,7 +104,7 @@ unit_tests:
- export RESULTS_DIR=$BASE_DIR/results
- export CHECKPOINTS_DIR=$BASE_DIR/checkpoints
- echo "Submitting job"
- sbatch_submission=`sbatch $BUILD_DIR/tests/functional_tests/test_scripts/$RUN_MODEL/sbatch_${RUN_MODEL}_distributed_test.sh --export=BASE_DIR,BUILD_DIR,DATA_DIR,TP_SIZE,PP_SIZE,NUM_NODES,MAX_STEPS,VP_SIZE`
- sbatch_submission=`sbatch $BUILD_DIR/tests/functional_tests/test_scripts/$RUN_MODEL/sbatch_${RUN_MODEL}_distributed_test.sh --export=BASE_DIR,BUILD_DIR,DATA_DIR,TP_SIZE,PP_SIZE,NUM_NODES,MAX_STEPS,VP_SIZE,MBS,GBS`
- export SLURM_JOBID=$(echo $sbatch_submission| grep 'Submitted batch job' | awk '{ print $4 }');
- bash $BUILD_DIR/tests/functional_tests/shell_test_utils/jobwait.sh $SLURM_JOBID
- \[ ! -z ${SLURM_JOBID} \] && echo -e " --------------------------------------------------\n"
......@@ -167,6 +171,19 @@ train.gpt3.345m_tp1_pp2_1node_50steps:
TIME_LIMIT: "20:00"
TEST_LEVEL: L0
train.gpt3.345m_tp1_pp4_1node_50steps:
<<: *selene-test-launcher
variables:
<<: [*VARS]
RUN_MODEL: gpt3
TP_SIZE: 1
PP_SIZE: 4
VP_SIZE: 1
NUM_NODES: 1
MAX_STEPS: 50
TIME_LIMIT: "20:00"
TEST_LEVEL: L0
resume.checkpoint.gpt3.345m_tp1_pp2_1node:
<<: *selene-test-resume-checkpoint-launcher
variables:
......@@ -245,8 +262,8 @@ cleanup.selene:
variables:
<<: [*VARS]
script:
- NUM_CLEANUP=`find ${SELENE_ADLR_CI_PATH}/* -type d -ctime +20 | wc -l`
- find ${SELENE_ADLR_CI_PATH}/* -type d -ctime +20 | xargs rm -rf
- NUM_CLEANUP=`find ${SELENE_ADLR_CI_PATH}/* -type d -ctime +20 | grep -v data | wc -l`
- find ${SELENE_ADLR_CI_PATH}/* -type d -ctime +20 | grep -v data | xargs rm -rf
- echo "Finished cleaning $NUM_CLEANUP directories older than 20 days everything in Selene"
allow_failure: true
rules:
......
......@@ -250,6 +250,14 @@ def validate_args(args, defaults={}):
if args.ffn_hidden_size is None:
args.ffn_hidden_size = 4 * args.hidden_size
if args.swiglu:
# reduce the dimnesion for MLP since projections happens on
# two linear layers. this keeps the number of paramters in
# the same ballpark as the counterpart with 4*h size
# we keep it a multiple of 64, which means the actual tensor size
# will be a multiple of 64 / tp_size
args.ffn_hidden_size = int((4 * args.hidden_size * 2 / 3) / 64) * 64
if args.kv_channels is None:
assert args.hidden_size % args.num_attention_heads == 0
args.kv_channels = args.hidden_size // args.num_attention_heads
......@@ -349,6 +357,10 @@ def validate_args(args, defaults={}):
"Using async gradient all reduce requires setting the environment "
"variable CUDA_DEVICE_MAX_CONNECTIONS to 1")
# Disable bias gelu fusion if we are disabling bias altogether
if not args.add_bias_linear:
args.bias_gelu_fusion = False
# Load retro args.
if args.retro_workdir:
retro_args_path = get_retro_args_path(args.retro_workdir)
......@@ -530,6 +542,10 @@ def _add_network_size_args(parser):
help='Use OpenAIs GeLU implementation. This option'
'should not be used unless for backward compatibility'
'reasons.')
group.add_argument('--squared-relu', action='store_true',
help='Use squared relu activation instead of default gelu')
group.add_argument('--swiglu', action='store_true',
help='Use gated linear units and SiLU activation instead of default gelu')
group.add_argument('--onnx-safe', type=bool, required=False,
help='Use workarounds for known problems with '
'Torch ONNX exporter')
......@@ -736,6 +752,9 @@ def _add_training_args(parser):
group.add_argument('--use-flash-attn', action='store_true',
help='use FlashAttention implementation of attention. '
'https://arxiv.org/abs/2205.14135')
group.add_argument('--disable-bias-linear', action='store_false',
help='Disable bias in the linear layers',
dest='add_bias_linear')
group.add_argument('--optimizer', type=str, default='adam',
choices=['adam', 'sgd'],
help='Optimizer function')
......
......@@ -593,7 +593,7 @@ def forward_backward_pipelining_with_interleaving(*,
if not forward_only:
if all_warmup_microbatches:
output_tensor_grads[num_model_chunks-1].append(
p2p_communication.recv_backward(tensor_shape, timers=timers))
p2p_communication.recv_backward(tensor_shape, dtype=dtype, timers=timers))
for k in range(num_microbatches_remaining, total_num_microbatches):
input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
......
......@@ -88,28 +88,45 @@ class ParallelMLP(MegatronModule):
super(ParallelMLP, self).__init__()
args = get_args()
self.add_bias = args.add_bias_linear
# Project to 4h.
# Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
args.ffn_hidden_size,
args.ffn_hidden_size * 2 if args.swiglu else args.ffn_hidden_size,
bias=self.add_bias,
gather_output=False,
init_method=init_method,
skip_bias_add=True,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
**_args_to_kwargs())
self.bias_gelu_fusion = args.bias_gelu_fusion
self.activation_func = F.gelu
self.bias_gelu_fusion = False
self.activation_func = None
self.swiglu = args.swiglu
if args.openai_gelu:
self.activation_func = openai_gelu
elif args.onnx_safe:
self.activation_func = erf_gelu
elif args.swiglu:
def swiglu(x):
x = torch.chunk(x, 2, dim=-1)
return F.silu(x[0]) * x[1]
self.activation_func = swiglu
elif args.squared_relu:
def squared_relu(x):
return torch.pow(F.relu(x), 2)
self.activation_func = squared_relu
else:
self.bias_gelu_fusion = args.bias_gelu_fusion
self.activation_func = F.gelu
# Project back to h.
self.dense_4h_to_h = tensor_parallel.RowParallelLinear(
args.ffn_hidden_size,
args.hidden_size,
bias=self.add_bias,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True,
......@@ -121,11 +138,13 @@ class ParallelMLP(MegatronModule):
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
if self.bias_gelu_fusion:
intermediate_parallel = \
bias_gelu_impl(intermediate_parallel, bias_parallel)
assert self.add_bias is True
assert self.activation_func == F.gelu
intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
else:
intermediate_parallel = \
self.activation_func(intermediate_parallel + bias_parallel)
if bias_parallel is not None:
intermediate_parallel = intermediate_parallel + bias_parallel
intermediate_parallel = self.activation_func(intermediate_parallel)
# [s, b, h]
output, output_bias = self.dense_4h_to_h(intermediate_parallel)
......@@ -403,6 +422,7 @@ class ParallelAttention(MegatronModule):
self.query_key_value = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
3 * projection_size,
bias=args.add_bias_linear,
gather_output=False,
init_method=init_method,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
......@@ -412,6 +432,7 @@ class ParallelAttention(MegatronModule):
self.query = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
projection_size,
bias=args.add_bias_linear,
gather_output=False,
init_method=init_method,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
......@@ -421,6 +442,7 @@ class ParallelAttention(MegatronModule):
self.key_value = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
2 * projection_size,
bias=args.add_bias_linear,
gather_output=False,
init_method=init_method,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
......@@ -439,6 +461,7 @@ class ParallelAttention(MegatronModule):
self.dense = tensor_parallel.RowParallelLinear(
projection_size,
args.hidden_size,
bias=args.add_bias_linear,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True,
......@@ -632,7 +655,9 @@ class ParallelAttention(MegatronModule):
def bias_dropout_add(x, bias, residual, prob, training):
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
if bias is not None:
x = x + bias
out = torch.nn.functional.dropout(x, p=prob, training=training)
out = residual + out
return out
......@@ -767,10 +792,12 @@ class ParallelTransformerLayer(MegatronModule):
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
if attention_bias is not None:
attention_bias = attention_bias.expand_as(residual)
with self.bias_dropout_add_exec_handler():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
attention_bias,
residual,
self.hidden_dropout)
else:
......@@ -793,10 +820,13 @@ class ParallelTransformerLayer(MegatronModule):
else:
residual = layernorm_input
if attention_bias is not None:
attention_bias = attention_bias.expand_as(residual)
with self.bias_dropout_add_exec_handler():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
attention_bias,
residual,
self.hidden_dropout)
......@@ -813,10 +843,12 @@ class ParallelTransformerLayer(MegatronModule):
residual = layernorm_input
if self.drop_path is None:
if mlp_bias is not None:
mlp_bias = mlp_bias.expand_as(residual)
with self.bias_dropout_add_exec_handler():
output = bias_dropout_add_func(
mlp_output,
mlp_bias.expand_as(residual),
mlp_bias,
residual,
self.hidden_dropout)
......@@ -831,7 +863,9 @@ class ParallelTransformerLayer(MegatronModule):
keep_graph = True)
else:
out = torch.nn.functional.dropout(mlp_output + mlp_bias,
if mlp_bias is not None:
mlp_output = mlp_output + mlp_bias
out = torch.nn.functional.dropout(mlp_output,
p=self.hidden_dropout,
training=self.training)
output = residual + self.drop_path(out)
......@@ -1085,9 +1119,10 @@ class ParallelTransformer(MegatronModule):
"""Forward method with activation checkpointing."""
def custom(start, end, is_transformer_engine=False):
def custom_forward(*args, **kwargs):
x_, *args = args
for index in range(start, end):
layer = self._get_layer(index)
x_ = layer(*args, **kwargs)
x_ = layer(x_, *args, **kwargs)
return x_
def custom_forward_transformer_engine(*args, **kwargs):
return custom_forward(*args, is_first_microbatch=is_first_microbatch, **kwargs)
......
......@@ -25,6 +25,7 @@ def read_tb_logs_as_list(path, summary_name):
ea.Reload()
summary = ea.Scalars(summary_name)
summary_list = [round(x.value, 5) for x in summary]
print(f'\nObtained the following list for {summary_name} ------------------')
print(summary_list)
return summary_list
raise FileNotFoundError(f"File not found matching: {path}/events*")
......
......@@ -34,6 +34,7 @@ def read_tb_logs_as_list(path, summary_name):
ea.Reload()
summary = ea.Scalars(summary_name)
summary_list = [round(x.value, 5) for x in summary]
print(f'\nObtained the following list for {summary_name} ------------------')
print(summary_list)
return summary_list
raise FileNotFoundError(f"File not found matching: {path}/events*")
......@@ -53,9 +54,11 @@ class TestCIPipeline:
raise FileNotFoundError("Expected data is none")
expected = self.expected[loss_type]
expected_list = expected["values"]
print(expected_list)
actual_list = read_tb_logs_as_list(LOGS_DIR, loss_type)
assert actual_list is not None, f"No TensorBoard events file was found in the logs for {loss_type}."
for i, step in enumerate(range(expected["start_step"], expected["end_step"], expected["step_interval"])):
print(f"Checking step {step} against expected {i}")
if test_type == TypeOfTest.APPROX:
assert actual_list[step] == pytest.approx(expected=expected_list[i], rel=self.margin_loss), f"{self.job_name} : The loss at step {step} should be approximately {expected_list[i]} but it is {actual_list[step]}."
else:
......
{"lm loss": {"start_step": 0, "end_step": 45, "step_interval": 5, "values": [10.7947, 10.85294, 10.87058, 10.83388, 10.83012, 10.78726, 10.56378, 10.57311, 10.48692]}, "num-zeros": {"start_step": 0, "end_step": 29, "step_interval": 5, "values": [2452.0, 2818.0, 2036.0, 2662.0, 2651.0, 2422.0]}, "iteration_timing_avg": 0.1187023333333333}
......@@ -7,7 +7,9 @@ TP_SIZE=$4
PP_SIZE=$5
NNODES=$6
MAX_STEPS=$7
VP_SIZE=$8
MBS=$9
GBS=${10}
GPUS_PER_NODE=8
# Change for multinode config
MASTER_ADDR=localhost
......@@ -30,8 +32,8 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
--log-validation-ppl-to-tensorboard \
--log-timers-to-tensorboard \
--tensorboard-dir ${TENSORBOARD_DIR} \
--micro-batch-size 4 \
--global-batch-size 32 \
--micro-batch-size ${MBS:-4} \
--global-batch-size ${GBS:-32} \
--seq-length 1024 \
--max-position-embeddings 1024 \
--train-iters $MAX_STEPS \
......@@ -57,5 +59,6 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
--eval-iters 10 \
--tensor-model-parallel-size $TP_SIZE \
--pipeline-model-parallel-size $PP_SIZE \
${VP_SIZE:+--num-layers-per-virtual-pipeline-stage "$VP_SIZE"} \
--no-gradient-accumulation-fusion \
--fp16
......@@ -13,4 +13,4 @@ TENSORBOARD_DIR=/workspace/logs
srun --output $BASE_DIR/results/slurm-%j.out --error $BASE_DIR/results/slurm-%j.out --container-image gitlab-master.nvidia.com/dl/dgx/pytorch:21.12-py3-devel --container-mounts $BASE_DIR/logs:/workspace/logs,$BASE_DIR/checkpoints:/workspace/checkpoints,$BUILD_DIR:/workspace/megatron-lm,$DATA_DIR:/workspace/data --no-container-mount-home bash -c "
ls
cd /workspace/megatron-lm
./tests/functional_tests/test_scripts/gpt3/pretrain_gpt3_distributed_test.sh $DATA_PATH $CHECKPOINT_PATH $TENSORBOARD_DIR $TP_SIZE $PP_SIZE $NUM_NODES $MAX_STEPS"
\ No newline at end of file
./tests/functional_tests/test_scripts/gpt3/pretrain_gpt3_distributed_test.sh $DATA_PATH $CHECKPOINT_PATH $TENSORBOARD_DIR $TP_SIZE $PP_SIZE $NUM_NODES $MAX_STEPS $VP_SIZE $MBS $GBS"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment