Commit 2cc3dac7 authored by Jared Casper's avatar Jared Casper
Browse files

Add swiglu and squared relu activations and ability to disable bias.

parent b7f3c263
...@@ -250,6 +250,14 @@ def validate_args(args, defaults={}): ...@@ -250,6 +250,14 @@ def validate_args(args, defaults={}):
if args.ffn_hidden_size is None: if args.ffn_hidden_size is None:
args.ffn_hidden_size = 4 * args.hidden_size args.ffn_hidden_size = 4 * args.hidden_size
if args.swiglu:
# reduce the dimnesion for MLP since projections happens on
# two linear layers. this keeps the number of paramters in
# the same ballpark as the counterpart with 4*h size
# we keep it a multiple of 64, which means the actual tensor size
# will be a multiple of 64 / tp_size
args.ffn_hidden_size = int((4 * args.hidden_size * 2 / 3) / 64) * 64
if args.kv_channels is None: if args.kv_channels is None:
assert args.hidden_size % args.num_attention_heads == 0 assert args.hidden_size % args.num_attention_heads == 0
args.kv_channels = args.hidden_size // args.num_attention_heads args.kv_channels = args.hidden_size // args.num_attention_heads
...@@ -349,6 +357,10 @@ def validate_args(args, defaults={}): ...@@ -349,6 +357,10 @@ def validate_args(args, defaults={}):
"Using async gradient all reduce requires setting the environment " "Using async gradient all reduce requires setting the environment "
"variable CUDA_DEVICE_MAX_CONNECTIONS to 1") "variable CUDA_DEVICE_MAX_CONNECTIONS to 1")
# Disable bias gelu fusion if we are disabling bias altogether
if not args.add_bias_linear:
args.bias_gelu_fusion = False
# Load retro args. # Load retro args.
if args.retro_workdir: if args.retro_workdir:
retro_args_path = get_retro_args_path(args.retro_workdir) retro_args_path = get_retro_args_path(args.retro_workdir)
...@@ -522,6 +534,10 @@ def _add_network_size_args(parser): ...@@ -522,6 +534,10 @@ def _add_network_size_args(parser):
help='Use OpenAIs GeLU implementation. This option' help='Use OpenAIs GeLU implementation. This option'
'should not be used unless for backward compatibility' 'should not be used unless for backward compatibility'
'reasons.') 'reasons.')
group.add_argument('--squared-relu', action='store_true',
help='Use squared relu activation instead of default gelu')
group.add_argument('--swiglu', action='store_true',
help='Use gated linear units and SiLU activation instead of default gelu')
group.add_argument('--onnx-safe', type=bool, required=False, group.add_argument('--onnx-safe', type=bool, required=False,
help='Use workarounds for known problems with ' help='Use workarounds for known problems with '
'Torch ONNX exporter') 'Torch ONNX exporter')
...@@ -728,6 +744,9 @@ def _add_training_args(parser): ...@@ -728,6 +744,9 @@ def _add_training_args(parser):
group.add_argument('--use-flash-attn', action='store_true', group.add_argument('--use-flash-attn', action='store_true',
help='use FlashAttention implementation of attention. ' help='use FlashAttention implementation of attention. '
'https://arxiv.org/abs/2205.14135') 'https://arxiv.org/abs/2205.14135')
group.add_argument('--disable-bias-linear', action='store_false',
help='Disable bias in the linear layers',
dest='add_bias_linear')
group.add_argument('--optimizer', type=str, default='adam', group.add_argument('--optimizer', type=str, default='adam',
choices=['adam', 'sgd'], choices=['adam', 'sgd'],
help='Optimizer function') help='Optimizer function')
......
...@@ -86,28 +86,45 @@ class ParallelMLP(MegatronModule): ...@@ -86,28 +86,45 @@ class ParallelMLP(MegatronModule):
super(ParallelMLP, self).__init__() super(ParallelMLP, self).__init__()
args = get_args() args = get_args()
self.add_bias = args.add_bias_linear
# Project to 4h. # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear( self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
args.hidden_size, args.hidden_size,
args.ffn_hidden_size, args.ffn_hidden_size * 2 if args.swiglu else args.ffn_hidden_size,
bias=self.add_bias,
gather_output=False, gather_output=False,
init_method=init_method, init_method=init_method,
skip_bias_add=True, skip_bias_add=True,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce, async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
**_args_to_kwargs()) **_args_to_kwargs())
self.bias_gelu_fusion = args.bias_gelu_fusion self.bias_gelu_fusion = False
self.activation_func = F.gelu self.activation_func = None
self.swiglu = args.swiglu
if args.openai_gelu: if args.openai_gelu:
self.activation_func = openai_gelu self.activation_func = openai_gelu
elif args.onnx_safe: elif args.onnx_safe:
self.activation_func = erf_gelu self.activation_func = erf_gelu
elif args.swiglu:
def swiglu(x):
x = torch.chunk(x, 2, dim=-1)
return F.silu(x[0]) * x[1]
self.activation_func = swiglu
elif args.squared_relu:
def squared_relu(x):
return torch.pow(F.relu(x), 2)
self.activation_func = squared_relu
else:
self.bias_gelu_fusion = args.bias_gelu_fusion
self.activation_func = F.gelu
# Project back to h. # Project back to h.
self.dense_4h_to_h = tensor_parallel.RowParallelLinear( self.dense_4h_to_h = tensor_parallel.RowParallelLinear(
args.ffn_hidden_size, args.ffn_hidden_size,
args.hidden_size, args.hidden_size,
bias=self.add_bias,
input_is_parallel=True, input_is_parallel=True,
init_method=output_layer_init_method, init_method=output_layer_init_method,
skip_bias_add=True, skip_bias_add=True,
...@@ -119,11 +136,13 @@ class ParallelMLP(MegatronModule): ...@@ -119,11 +136,13 @@ class ParallelMLP(MegatronModule):
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states) intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
if self.bias_gelu_fusion: if self.bias_gelu_fusion:
intermediate_parallel = \ assert self.add_bias is True
bias_gelu_impl(intermediate_parallel, bias_parallel) assert self.activation_func == F.gelu
intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
else: else:
intermediate_parallel = \ if self.add_bias:
self.activation_func(intermediate_parallel + bias_parallel) intermediate_parallel = intermediate_parallel + bias_parallel
intermediate_parallel = self.activation_func(intermediate_parallel)
# [s, b, h] # [s, b, h]
output, output_bias = self.dense_4h_to_h(intermediate_parallel) output, output_bias = self.dense_4h_to_h(intermediate_parallel)
...@@ -401,6 +420,7 @@ class ParallelAttention(MegatronModule): ...@@ -401,6 +420,7 @@ class ParallelAttention(MegatronModule):
self.query_key_value = tensor_parallel.ColumnParallelLinear( self.query_key_value = tensor_parallel.ColumnParallelLinear(
args.hidden_size, args.hidden_size,
3 * projection_size, 3 * projection_size,
bias=args.add_bias_linear,
gather_output=False, gather_output=False,
init_method=init_method, init_method=init_method,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce, async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
...@@ -410,6 +430,7 @@ class ParallelAttention(MegatronModule): ...@@ -410,6 +430,7 @@ class ParallelAttention(MegatronModule):
self.query = tensor_parallel.ColumnParallelLinear( self.query = tensor_parallel.ColumnParallelLinear(
args.hidden_size, args.hidden_size,
projection_size, projection_size,
bias=args.add_bias_linear,
gather_output=False, gather_output=False,
init_method=init_method, init_method=init_method,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce, async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
...@@ -419,6 +440,7 @@ class ParallelAttention(MegatronModule): ...@@ -419,6 +440,7 @@ class ParallelAttention(MegatronModule):
self.key_value = tensor_parallel.ColumnParallelLinear( self.key_value = tensor_parallel.ColumnParallelLinear(
args.hidden_size, args.hidden_size,
2 * projection_size, 2 * projection_size,
bias=args.add_bias_linear,
gather_output=False, gather_output=False,
init_method=init_method, init_method=init_method,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce, async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
...@@ -437,6 +459,7 @@ class ParallelAttention(MegatronModule): ...@@ -437,6 +459,7 @@ class ParallelAttention(MegatronModule):
self.dense = tensor_parallel.RowParallelLinear( self.dense = tensor_parallel.RowParallelLinear(
projection_size, projection_size,
args.hidden_size, args.hidden_size,
bias=args.add_bias_linear,
input_is_parallel=True, input_is_parallel=True,
init_method=output_layer_init_method, init_method=output_layer_init_method,
skip_bias_add=True, skip_bias_add=True,
...@@ -584,7 +607,9 @@ class ParallelAttention(MegatronModule): ...@@ -584,7 +607,9 @@ class ParallelAttention(MegatronModule):
def bias_dropout_add(x, bias, residual, prob, training): def bias_dropout_add(x, bias, residual, prob, training):
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
out = torch.nn.functional.dropout(x + bias, p=prob, training=training) if bias is not None:
x = x + bias
out = torch.nn.functional.dropout(x, p=prob, training=training)
out = residual + out out = residual + out
return out return out
...@@ -649,7 +674,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -649,7 +674,7 @@ class ParallelTransformerLayer(MegatronModule):
attention_type=AttnType.self_attn, attention_type=AttnType.self_attn,
attn_mask_type=self_attn_mask_type) attn_mask_type=self_attn_mask_type)
self.hidden_dropout = args.hidden_dropout self.hidden_dropout = args.hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion self.bias_dropout_fusion = args.bias_dropout_fusion and args.add_bias_linear
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None
# Layernorm on the attention output # Layernorm on the attention output
...@@ -718,10 +743,12 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -718,10 +743,12 @@ class ParallelTransformerLayer(MegatronModule):
else: else:
bias_dropout_add_func = get_bias_dropout_add(self.training) bias_dropout_add_func = get_bias_dropout_add(self.training)
if attention_bias is not None:
attention_bias = attention_bias.expand_as(residual)
with self.bias_dropout_add_exec_handler(): with self.bias_dropout_add_exec_handler():
layernorm_input = bias_dropout_add_func( layernorm_input = bias_dropout_add_func(
attention_output, attention_output,
attention_bias.expand_as(residual), attention_bias,
residual, residual,
self.hidden_dropout) self.hidden_dropout)
else: else:
...@@ -744,10 +771,13 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -744,10 +771,13 @@ class ParallelTransformerLayer(MegatronModule):
else: else:
residual = layernorm_input residual = layernorm_input
if attention_bias is not None:
attention_bias = attention_bias.expand_as(residual)
with self.bias_dropout_add_exec_handler(): with self.bias_dropout_add_exec_handler():
layernorm_input = bias_dropout_add_func( layernorm_input = bias_dropout_add_func(
attention_output, attention_output,
attention_bias.expand_as(residual), attention_bias,
residual, residual,
self.hidden_dropout) self.hidden_dropout)
...@@ -764,10 +794,12 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -764,10 +794,12 @@ class ParallelTransformerLayer(MegatronModule):
residual = layernorm_input residual = layernorm_input
if self.drop_path is None: if self.drop_path is None:
if mlp_bias is not None:
mlp_bias = mlp_bias.expand_as(residual)
with self.bias_dropout_add_exec_handler(): with self.bias_dropout_add_exec_handler():
output = bias_dropout_add_func( output = bias_dropout_add_func(
mlp_output, mlp_output,
mlp_bias.expand_as(residual), mlp_bias,
residual, residual,
self.hidden_dropout) self.hidden_dropout)
...@@ -782,7 +814,9 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -782,7 +814,9 @@ class ParallelTransformerLayer(MegatronModule):
keep_graph = True) keep_graph = True)
else: else:
out = torch.nn.functional.dropout(mlp_output + mlp_bias, if mlp_bias is not None:
mlp_output = mlp_output + mlp_bias
out = torch.nn.functional.dropout(mlp_output,
p=self.hidden_dropout, p=self.hidden_dropout,
training=self.training) training=self.training)
output = residual + self.drop_path(out) output = residual + self.drop_path(out)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment