Commit c43d19ab authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'main' into untie_embeddings

parents 716da5d8 f6d36d03
...@@ -35,6 +35,7 @@ unit_tests: ...@@ -35,6 +35,7 @@ unit_tests:
stage: test stage: test
script: &selene-test-launcher-script script: &selene-test-launcher-script
- echo "Running selene resume from checkpoint test. " - echo "Running selene resume from checkpoint test. "
- echo "In case of error check ${SELENE_ADLR_CI_PATH}/${CI_PIPELINE_ID}/${CI_JOB_NAME}/results directory for result logs."
- pwd - pwd
- export BUILD_DIR=`pwd` - export BUILD_DIR=`pwd`
- export RUN_NAME=resume_${RUN_MODEL}_tp${TP_SIZE}_pp${PP_SIZE}_${NUM_NODES}nodes - export RUN_NAME=resume_${RUN_MODEL}_tp${TP_SIZE}_pp${PP_SIZE}_${NUM_NODES}nodes
...@@ -66,6 +67,7 @@ unit_tests: ...@@ -66,6 +67,7 @@ unit_tests:
- echo "Finished job" - echo "Finished job"
- source $PYTHON_VIRTUAL_ENV - source $PYTHON_VIRTUAL_ENV
- pytest $BUILD_DIR/tests/functional_tests/python_test_utils/test_resume_checkpoint_pipeline.py - pytest $BUILD_DIR/tests/functional_tests/python_test_utils/test_resume_checkpoint_pipeline.py
- if [ $? -ne 0 ]; then echo "Pytest failed. See ${SELENE_ADLR_CI_PATH}/${CI_PIPELINE_ID}/${CI_JOB_NAME}/results directory for result logs."; fi
- echo "Completed the job" - echo "Completed the job"
rules: rules:
- if: $TEST_LEVEL =~ $TESTS_TO_RUN_ON_THIS_COMMIT || $CI_JOB_NAME =~ $TESTS_TO_RUN_ON_THIS_COMMIT || $CI_JOB_NAME =~ $TEST_REGEX_ON_THIS_COMMIT - if: $TEST_LEVEL =~ $TESTS_TO_RUN_ON_THIS_COMMIT || $CI_JOB_NAME =~ $TESTS_TO_RUN_ON_THIS_COMMIT || $CI_JOB_NAME =~ $TEST_REGEX_ON_THIS_COMMIT
...@@ -82,6 +84,7 @@ unit_tests: ...@@ -82,6 +84,7 @@ unit_tests:
stage: test stage: test
script: &selene-test-launcher-script script: &selene-test-launcher-script
- echo "Running selene test" - echo "Running selene test"
- echo "In case of error check ${SELENE_ADLR_CI_PATH}/${CI_PIPELINE_ID}/${CI_JOB_NAME}/results directory for result logs."
- echo "$CI_MERGE_REQUEST_APPROVED" - echo "$CI_MERGE_REQUEST_APPROVED"
- pwd - pwd
- export BUILD_DIR=`pwd` - export BUILD_DIR=`pwd`
...@@ -259,8 +262,8 @@ cleanup.selene: ...@@ -259,8 +262,8 @@ cleanup.selene:
variables: variables:
<<: [*VARS] <<: [*VARS]
script: script:
- NUM_CLEANUP=`find ${SELENE_ADLR_CI_PATH}/* -type d -ctime +20 | wc -l` - NUM_CLEANUP=`find ${SELENE_ADLR_CI_PATH}/* -type d -ctime +20 | grep -v data | wc -l`
- find ${SELENE_ADLR_CI_PATH}/* -type d -ctime +20 | xargs rm -rf - find ${SELENE_ADLR_CI_PATH}/* -type d -ctime +20 | grep -v data | xargs rm -rf
- echo "Finished cleaning $NUM_CLEANUP directories older than 20 days everything in Selene" - echo "Finished cleaning $NUM_CLEANUP directories older than 20 days everything in Selene"
allow_failure: true allow_failure: true
rules: rules:
......
...@@ -250,6 +250,14 @@ def validate_args(args, defaults={}): ...@@ -250,6 +250,14 @@ def validate_args(args, defaults={}):
if args.ffn_hidden_size is None: if args.ffn_hidden_size is None:
args.ffn_hidden_size = 4 * args.hidden_size args.ffn_hidden_size = 4 * args.hidden_size
if args.swiglu:
# reduce the dimnesion for MLP since projections happens on
# two linear layers. this keeps the number of paramters in
# the same ballpark as the counterpart with 4*h size
# we keep it a multiple of 64, which means the actual tensor size
# will be a multiple of 64 / tp_size
args.ffn_hidden_size = int((4 * args.hidden_size * 2 / 3) / 64) * 64
if args.kv_channels is None: if args.kv_channels is None:
assert args.hidden_size % args.num_attention_heads == 0 assert args.hidden_size % args.num_attention_heads == 0
args.kv_channels = args.hidden_size // args.num_attention_heads args.kv_channels = args.hidden_size // args.num_attention_heads
...@@ -349,6 +357,10 @@ def validate_args(args, defaults={}): ...@@ -349,6 +357,10 @@ def validate_args(args, defaults={}):
"Using async gradient all reduce requires setting the environment " "Using async gradient all reduce requires setting the environment "
"variable CUDA_DEVICE_MAX_CONNECTIONS to 1") "variable CUDA_DEVICE_MAX_CONNECTIONS to 1")
# Disable bias gelu fusion if we are disabling bias altogether
if not args.add_bias_linear:
args.bias_gelu_fusion = False
# Load retro args. # Load retro args.
if args.retro_workdir: if args.retro_workdir:
retro_args_path = get_retro_args_path(args.retro_workdir) retro_args_path = get_retro_args_path(args.retro_workdir)
...@@ -522,6 +534,10 @@ def _add_network_size_args(parser): ...@@ -522,6 +534,10 @@ def _add_network_size_args(parser):
help='Use OpenAIs GeLU implementation. This option' help='Use OpenAIs GeLU implementation. This option'
'should not be used unless for backward compatibility' 'should not be used unless for backward compatibility'
'reasons.') 'reasons.')
group.add_argument('--squared-relu', action='store_true',
help='Use squared relu activation instead of default gelu')
group.add_argument('--swiglu', action='store_true',
help='Use gated linear units and SiLU activation instead of default gelu')
group.add_argument('--onnx-safe', type=bool, required=False, group.add_argument('--onnx-safe', type=bool, required=False,
help='Use workarounds for known problems with ' help='Use workarounds for known problems with '
'Torch ONNX exporter') 'Torch ONNX exporter')
...@@ -730,6 +746,9 @@ def _add_training_args(parser): ...@@ -730,6 +746,9 @@ def _add_training_args(parser):
group.add_argument('--use-flash-attn', action='store_true', group.add_argument('--use-flash-attn', action='store_true',
help='use FlashAttention implementation of attention. ' help='use FlashAttention implementation of attention. '
'https://arxiv.org/abs/2205.14135') 'https://arxiv.org/abs/2205.14135')
group.add_argument('--disable-bias-linear', action='store_false',
help='Disable bias in the linear layers',
dest='add_bias_linear')
group.add_argument('--optimizer', type=str, default='adam', group.add_argument('--optimizer', type=str, default='adam',
choices=['adam', 'sgd'], choices=['adam', 'sgd'],
help='Optimizer function') help='Optimizer function')
......
...@@ -87,28 +87,45 @@ class ParallelMLP(MegatronModule): ...@@ -87,28 +87,45 @@ class ParallelMLP(MegatronModule):
super(ParallelMLP, self).__init__() super(ParallelMLP, self).__init__()
args = get_args() args = get_args()
self.add_bias = args.add_bias_linear
# Project to 4h. # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear( self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
args.hidden_size, args.hidden_size,
args.ffn_hidden_size, args.ffn_hidden_size * 2 if args.swiglu else args.ffn_hidden_size,
bias=self.add_bias,
gather_output=False, gather_output=False,
init_method=init_method, init_method=init_method,
skip_bias_add=True, skip_bias_add=True,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce, async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
**_args_to_kwargs()) **_args_to_kwargs())
self.bias_gelu_fusion = args.bias_gelu_fusion self.bias_gelu_fusion = False
self.activation_func = F.gelu self.activation_func = None
self.swiglu = args.swiglu
if args.openai_gelu: if args.openai_gelu:
self.activation_func = openai_gelu self.activation_func = openai_gelu
elif args.onnx_safe: elif args.onnx_safe:
self.activation_func = erf_gelu self.activation_func = erf_gelu
elif args.swiglu:
def swiglu(x):
x = torch.chunk(x, 2, dim=-1)
return F.silu(x[0]) * x[1]
self.activation_func = swiglu
elif args.squared_relu:
def squared_relu(x):
return torch.pow(F.relu(x), 2)
self.activation_func = squared_relu
else:
self.bias_gelu_fusion = args.bias_gelu_fusion
self.activation_func = F.gelu
# Project back to h. # Project back to h.
self.dense_4h_to_h = tensor_parallel.RowParallelLinear( self.dense_4h_to_h = tensor_parallel.RowParallelLinear(
args.ffn_hidden_size, args.ffn_hidden_size,
args.hidden_size, args.hidden_size,
bias=self.add_bias,
input_is_parallel=True, input_is_parallel=True,
init_method=output_layer_init_method, init_method=output_layer_init_method,
skip_bias_add=True, skip_bias_add=True,
...@@ -120,11 +137,13 @@ class ParallelMLP(MegatronModule): ...@@ -120,11 +137,13 @@ class ParallelMLP(MegatronModule):
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states) intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
if self.bias_gelu_fusion: if self.bias_gelu_fusion:
intermediate_parallel = \ assert self.add_bias is True
bias_gelu_impl(intermediate_parallel, bias_parallel) assert self.activation_func == F.gelu
intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
else: else:
intermediate_parallel = \ if bias_parallel is not None:
self.activation_func(intermediate_parallel + bias_parallel) intermediate_parallel = intermediate_parallel + bias_parallel
intermediate_parallel = self.activation_func(intermediate_parallel)
# [s, b, h] # [s, b, h]
output, output_bias = self.dense_4h_to_h(intermediate_parallel) output, output_bias = self.dense_4h_to_h(intermediate_parallel)
...@@ -402,6 +421,7 @@ class ParallelAttention(MegatronModule): ...@@ -402,6 +421,7 @@ class ParallelAttention(MegatronModule):
self.query_key_value = tensor_parallel.ColumnParallelLinear( self.query_key_value = tensor_parallel.ColumnParallelLinear(
args.hidden_size, args.hidden_size,
3 * projection_size, 3 * projection_size,
bias=args.add_bias_linear,
gather_output=False, gather_output=False,
init_method=init_method, init_method=init_method,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce, async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
...@@ -411,6 +431,7 @@ class ParallelAttention(MegatronModule): ...@@ -411,6 +431,7 @@ class ParallelAttention(MegatronModule):
self.query = tensor_parallel.ColumnParallelLinear( self.query = tensor_parallel.ColumnParallelLinear(
args.hidden_size, args.hidden_size,
projection_size, projection_size,
bias=args.add_bias_linear,
gather_output=False, gather_output=False,
init_method=init_method, init_method=init_method,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce, async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
...@@ -420,6 +441,7 @@ class ParallelAttention(MegatronModule): ...@@ -420,6 +441,7 @@ class ParallelAttention(MegatronModule):
self.key_value = tensor_parallel.ColumnParallelLinear( self.key_value = tensor_parallel.ColumnParallelLinear(
args.hidden_size, args.hidden_size,
2 * projection_size, 2 * projection_size,
bias=args.add_bias_linear,
gather_output=False, gather_output=False,
init_method=init_method, init_method=init_method,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce, async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
...@@ -438,6 +460,7 @@ class ParallelAttention(MegatronModule): ...@@ -438,6 +460,7 @@ class ParallelAttention(MegatronModule):
self.dense = tensor_parallel.RowParallelLinear( self.dense = tensor_parallel.RowParallelLinear(
projection_size, projection_size,
args.hidden_size, args.hidden_size,
bias=args.add_bias_linear,
input_is_parallel=True, input_is_parallel=True,
init_method=output_layer_init_method, init_method=output_layer_init_method,
skip_bias_add=True, skip_bias_add=True,
...@@ -585,7 +608,9 @@ class ParallelAttention(MegatronModule): ...@@ -585,7 +608,9 @@ class ParallelAttention(MegatronModule):
def bias_dropout_add(x, bias, residual, prob, training): def bias_dropout_add(x, bias, residual, prob, training):
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
out = torch.nn.functional.dropout(x + bias, p=prob, training=training) if bias is not None:
x = x + bias
out = torch.nn.functional.dropout(x, p=prob, training=training)
out = residual + out out = residual + out
return out return out
...@@ -719,10 +744,12 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -719,10 +744,12 @@ class ParallelTransformerLayer(MegatronModule):
else: else:
bias_dropout_add_func = get_bias_dropout_add(self.training) bias_dropout_add_func = get_bias_dropout_add(self.training)
if attention_bias is not None:
attention_bias = attention_bias.expand_as(residual)
with self.bias_dropout_add_exec_handler(): with self.bias_dropout_add_exec_handler():
layernorm_input = bias_dropout_add_func( layernorm_input = bias_dropout_add_func(
attention_output, attention_output,
attention_bias.expand_as(residual), attention_bias,
residual, residual,
self.hidden_dropout) self.hidden_dropout)
else: else:
...@@ -745,10 +772,13 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -745,10 +772,13 @@ class ParallelTransformerLayer(MegatronModule):
else: else:
residual = layernorm_input residual = layernorm_input
if attention_bias is not None:
attention_bias = attention_bias.expand_as(residual)
with self.bias_dropout_add_exec_handler(): with self.bias_dropout_add_exec_handler():
layernorm_input = bias_dropout_add_func( layernorm_input = bias_dropout_add_func(
attention_output, attention_output,
attention_bias.expand_as(residual), attention_bias,
residual, residual,
self.hidden_dropout) self.hidden_dropout)
...@@ -765,10 +795,12 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -765,10 +795,12 @@ class ParallelTransformerLayer(MegatronModule):
residual = layernorm_input residual = layernorm_input
if self.drop_path is None: if self.drop_path is None:
if mlp_bias is not None:
mlp_bias = mlp_bias.expand_as(residual)
with self.bias_dropout_add_exec_handler(): with self.bias_dropout_add_exec_handler():
output = bias_dropout_add_func( output = bias_dropout_add_func(
mlp_output, mlp_output,
mlp_bias.expand_as(residual), mlp_bias,
residual, residual,
self.hidden_dropout) self.hidden_dropout)
...@@ -783,7 +815,9 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -783,7 +815,9 @@ class ParallelTransformerLayer(MegatronModule):
keep_graph = True) keep_graph = True)
else: else:
out = torch.nn.functional.dropout(mlp_output + mlp_bias, if mlp_bias is not None:
mlp_output = mlp_output + mlp_bias
out = torch.nn.functional.dropout(mlp_output,
p=self.hidden_dropout, p=self.hidden_dropout,
training=self.training) training=self.training)
output = residual + self.drop_path(out) output = residual + self.drop_path(out)
......
...@@ -25,6 +25,7 @@ def read_tb_logs_as_list(path, summary_name): ...@@ -25,6 +25,7 @@ def read_tb_logs_as_list(path, summary_name):
ea.Reload() ea.Reload()
summary = ea.Scalars(summary_name) summary = ea.Scalars(summary_name)
summary_list = [round(x.value, 5) for x in summary] summary_list = [round(x.value, 5) for x in summary]
print(f'\nObtained the following list for {summary_name} ------------------')
print(summary_list) print(summary_list)
return summary_list return summary_list
raise FileNotFoundError(f"File not found matching: {path}/events*") raise FileNotFoundError(f"File not found matching: {path}/events*")
......
...@@ -34,6 +34,7 @@ def read_tb_logs_as_list(path, summary_name): ...@@ -34,6 +34,7 @@ def read_tb_logs_as_list(path, summary_name):
ea.Reload() ea.Reload()
summary = ea.Scalars(summary_name) summary = ea.Scalars(summary_name)
summary_list = [round(x.value, 5) for x in summary] summary_list = [round(x.value, 5) for x in summary]
print(f'\nObtained the following list for {summary_name} ------------------')
print(summary_list) print(summary_list)
return summary_list return summary_list
raise FileNotFoundError(f"File not found matching: {path}/events*") raise FileNotFoundError(f"File not found matching: {path}/events*")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment