Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
2cc3dac7
Commit
2cc3dac7
authored
Apr 01, 2023
by
Jared Casper
Browse files
Add swiglu and squared relu activations and ability to disable bias.
parent
b7f3c263
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
67 additions
and
14 deletions
+67
-14
megatron/arguments.py
megatron/arguments.py
+19
-0
megatron/model/transformer.py
megatron/model/transformer.py
+48
-14
No files found.
megatron/arguments.py
View file @
2cc3dac7
...
@@ -250,6 +250,14 @@ def validate_args(args, defaults={}):
...
@@ -250,6 +250,14 @@ def validate_args(args, defaults={}):
if
args
.
ffn_hidden_size
is
None
:
if
args
.
ffn_hidden_size
is
None
:
args
.
ffn_hidden_size
=
4
*
args
.
hidden_size
args
.
ffn_hidden_size
=
4
*
args
.
hidden_size
if
args
.
swiglu
:
# reduce the dimnesion for MLP since projections happens on
# two linear layers. this keeps the number of paramters in
# the same ballpark as the counterpart with 4*h size
# we keep it a multiple of 64, which means the actual tensor size
# will be a multiple of 64 / tp_size
args
.
ffn_hidden_size
=
int
((
4
*
args
.
hidden_size
*
2
/
3
)
/
64
)
*
64
if
args
.
kv_channels
is
None
:
if
args
.
kv_channels
is
None
:
assert
args
.
hidden_size
%
args
.
num_attention_heads
==
0
assert
args
.
hidden_size
%
args
.
num_attention_heads
==
0
args
.
kv_channels
=
args
.
hidden_size
//
args
.
num_attention_heads
args
.
kv_channels
=
args
.
hidden_size
//
args
.
num_attention_heads
...
@@ -349,6 +357,10 @@ def validate_args(args, defaults={}):
...
@@ -349,6 +357,10 @@ def validate_args(args, defaults={}):
"Using async gradient all reduce requires setting the environment "
"Using async gradient all reduce requires setting the environment "
"variable CUDA_DEVICE_MAX_CONNECTIONS to 1"
)
"variable CUDA_DEVICE_MAX_CONNECTIONS to 1"
)
# Disable bias gelu fusion if we are disabling bias altogether
if
not
args
.
add_bias_linear
:
args
.
bias_gelu_fusion
=
False
# Load retro args.
# Load retro args.
if
args
.
retro_workdir
:
if
args
.
retro_workdir
:
retro_args_path
=
get_retro_args_path
(
args
.
retro_workdir
)
retro_args_path
=
get_retro_args_path
(
args
.
retro_workdir
)
...
@@ -522,6 +534,10 @@ def _add_network_size_args(parser):
...
@@ -522,6 +534,10 @@ def _add_network_size_args(parser):
help
=
'Use OpenAIs GeLU implementation. This option'
help
=
'Use OpenAIs GeLU implementation. This option'
'should not be used unless for backward compatibility'
'should not be used unless for backward compatibility'
'reasons.'
)
'reasons.'
)
group
.
add_argument
(
'--squared-relu'
,
action
=
'store_true'
,
help
=
'Use squared relu activation instead of default gelu'
)
group
.
add_argument
(
'--swiglu'
,
action
=
'store_true'
,
help
=
'Use gated linear units and SiLU activation instead of default gelu'
)
group
.
add_argument
(
'--onnx-safe'
,
type
=
bool
,
required
=
False
,
group
.
add_argument
(
'--onnx-safe'
,
type
=
bool
,
required
=
False
,
help
=
'Use workarounds for known problems with '
help
=
'Use workarounds for known problems with '
'Torch ONNX exporter'
)
'Torch ONNX exporter'
)
...
@@ -728,6 +744,9 @@ def _add_training_args(parser):
...
@@ -728,6 +744,9 @@ def _add_training_args(parser):
group
.
add_argument
(
'--use-flash-attn'
,
action
=
'store_true'
,
group
.
add_argument
(
'--use-flash-attn'
,
action
=
'store_true'
,
help
=
'use FlashAttention implementation of attention. '
help
=
'use FlashAttention implementation of attention. '
'https://arxiv.org/abs/2205.14135'
)
'https://arxiv.org/abs/2205.14135'
)
group
.
add_argument
(
'--disable-bias-linear'
,
action
=
'store_false'
,
help
=
'Disable bias in the linear layers'
,
dest
=
'add_bias_linear'
)
group
.
add_argument
(
'--optimizer'
,
type
=
str
,
default
=
'adam'
,
group
.
add_argument
(
'--optimizer'
,
type
=
str
,
default
=
'adam'
,
choices
=
[
'adam'
,
'sgd'
],
choices
=
[
'adam'
,
'sgd'
],
help
=
'Optimizer function'
)
help
=
'Optimizer function'
)
...
...
megatron/model/transformer.py
View file @
2cc3dac7
...
@@ -86,28 +86,45 @@ class ParallelMLP(MegatronModule):
...
@@ -86,28 +86,45 @@ class ParallelMLP(MegatronModule):
super
(
ParallelMLP
,
self
).
__init__
()
super
(
ParallelMLP
,
self
).
__init__
()
args
=
get_args
()
args
=
get_args
()
self
.
add_bias
=
args
.
add_bias_linear
# Project to 4h.
# Project to 4h.
If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
self
.
dense_h_to_4h
=
tensor_parallel
.
ColumnParallelLinear
(
self
.
dense_h_to_4h
=
tensor_parallel
.
ColumnParallelLinear
(
args
.
hidden_size
,
args
.
hidden_size
,
args
.
ffn_hidden_size
,
args
.
ffn_hidden_size
*
2
if
args
.
swiglu
else
args
.
ffn_hidden_size
,
bias
=
self
.
add_bias
,
gather_output
=
False
,
gather_output
=
False
,
init_method
=
init_method
,
init_method
=
init_method
,
skip_bias_add
=
True
,
skip_bias_add
=
True
,
async_tensor_model_parallel_allreduce
=
args
.
async_tensor_model_parallel_allreduce
,
async_tensor_model_parallel_allreduce
=
args
.
async_tensor_model_parallel_allreduce
,
**
_args_to_kwargs
())
**
_args_to_kwargs
())
self
.
bias_gelu_fusion
=
args
.
bias_gelu_fusion
self
.
bias_gelu_fusion
=
False
self
.
activation_func
=
F
.
gelu
self
.
activation_func
=
None
self
.
swiglu
=
args
.
swiglu
if
args
.
openai_gelu
:
if
args
.
openai_gelu
:
self
.
activation_func
=
openai_gelu
self
.
activation_func
=
openai_gelu
elif
args
.
onnx_safe
:
elif
args
.
onnx_safe
:
self
.
activation_func
=
erf_gelu
self
.
activation_func
=
erf_gelu
elif
args
.
swiglu
:
def
swiglu
(
x
):
x
=
torch
.
chunk
(
x
,
2
,
dim
=-
1
)
return
F
.
silu
(
x
[
0
])
*
x
[
1
]
self
.
activation_func
=
swiglu
elif
args
.
squared_relu
:
def
squared_relu
(
x
):
return
torch
.
pow
(
F
.
relu
(
x
),
2
)
self
.
activation_func
=
squared_relu
else
:
self
.
bias_gelu_fusion
=
args
.
bias_gelu_fusion
self
.
activation_func
=
F
.
gelu
# Project back to h.
# Project back to h.
self
.
dense_4h_to_h
=
tensor_parallel
.
RowParallelLinear
(
self
.
dense_4h_to_h
=
tensor_parallel
.
RowParallelLinear
(
args
.
ffn_hidden_size
,
args
.
ffn_hidden_size
,
args
.
hidden_size
,
args
.
hidden_size
,
bias
=
self
.
add_bias
,
input_is_parallel
=
True
,
input_is_parallel
=
True
,
init_method
=
output_layer_init_method
,
init_method
=
output_layer_init_method
,
skip_bias_add
=
True
,
skip_bias_add
=
True
,
...
@@ -119,11 +136,13 @@ class ParallelMLP(MegatronModule):
...
@@ -119,11 +136,13 @@ class ParallelMLP(MegatronModule):
intermediate_parallel
,
bias_parallel
=
self
.
dense_h_to_4h
(
hidden_states
)
intermediate_parallel
,
bias_parallel
=
self
.
dense_h_to_4h
(
hidden_states
)
if
self
.
bias_gelu_fusion
:
if
self
.
bias_gelu_fusion
:
intermediate_parallel
=
\
assert
self
.
add_bias
is
True
bias_gelu_impl
(
intermediate_parallel
,
bias_parallel
)
assert
self
.
activation_func
==
F
.
gelu
intermediate_parallel
=
bias_gelu_impl
(
intermediate_parallel
,
bias_parallel
)
else
:
else
:
intermediate_parallel
=
\
if
self
.
add_bias
:
self
.
activation_func
(
intermediate_parallel
+
bias_parallel
)
intermediate_parallel
=
intermediate_parallel
+
bias_parallel
intermediate_parallel
=
self
.
activation_func
(
intermediate_parallel
)
# [s, b, h]
# [s, b, h]
output
,
output_bias
=
self
.
dense_4h_to_h
(
intermediate_parallel
)
output
,
output_bias
=
self
.
dense_4h_to_h
(
intermediate_parallel
)
...
@@ -401,6 +420,7 @@ class ParallelAttention(MegatronModule):
...
@@ -401,6 +420,7 @@ class ParallelAttention(MegatronModule):
self
.
query_key_value
=
tensor_parallel
.
ColumnParallelLinear
(
self
.
query_key_value
=
tensor_parallel
.
ColumnParallelLinear
(
args
.
hidden_size
,
args
.
hidden_size
,
3
*
projection_size
,
3
*
projection_size
,
bias
=
args
.
add_bias_linear
,
gather_output
=
False
,
gather_output
=
False
,
init_method
=
init_method
,
init_method
=
init_method
,
async_tensor_model_parallel_allreduce
=
args
.
async_tensor_model_parallel_allreduce
,
async_tensor_model_parallel_allreduce
=
args
.
async_tensor_model_parallel_allreduce
,
...
@@ -410,6 +430,7 @@ class ParallelAttention(MegatronModule):
...
@@ -410,6 +430,7 @@ class ParallelAttention(MegatronModule):
self
.
query
=
tensor_parallel
.
ColumnParallelLinear
(
self
.
query
=
tensor_parallel
.
ColumnParallelLinear
(
args
.
hidden_size
,
args
.
hidden_size
,
projection_size
,
projection_size
,
bias
=
args
.
add_bias_linear
,
gather_output
=
False
,
gather_output
=
False
,
init_method
=
init_method
,
init_method
=
init_method
,
async_tensor_model_parallel_allreduce
=
args
.
async_tensor_model_parallel_allreduce
,
async_tensor_model_parallel_allreduce
=
args
.
async_tensor_model_parallel_allreduce
,
...
@@ -419,6 +440,7 @@ class ParallelAttention(MegatronModule):
...
@@ -419,6 +440,7 @@ class ParallelAttention(MegatronModule):
self
.
key_value
=
tensor_parallel
.
ColumnParallelLinear
(
self
.
key_value
=
tensor_parallel
.
ColumnParallelLinear
(
args
.
hidden_size
,
args
.
hidden_size
,
2
*
projection_size
,
2
*
projection_size
,
bias
=
args
.
add_bias_linear
,
gather_output
=
False
,
gather_output
=
False
,
init_method
=
init_method
,
init_method
=
init_method
,
async_tensor_model_parallel_allreduce
=
args
.
async_tensor_model_parallel_allreduce
,
async_tensor_model_parallel_allreduce
=
args
.
async_tensor_model_parallel_allreduce
,
...
@@ -437,6 +459,7 @@ class ParallelAttention(MegatronModule):
...
@@ -437,6 +459,7 @@ class ParallelAttention(MegatronModule):
self
.
dense
=
tensor_parallel
.
RowParallelLinear
(
self
.
dense
=
tensor_parallel
.
RowParallelLinear
(
projection_size
,
projection_size
,
args
.
hidden_size
,
args
.
hidden_size
,
bias
=
args
.
add_bias_linear
,
input_is_parallel
=
True
,
input_is_parallel
=
True
,
init_method
=
output_layer_init_method
,
init_method
=
output_layer_init_method
,
skip_bias_add
=
True
,
skip_bias_add
=
True
,
...
@@ -584,7 +607,9 @@ class ParallelAttention(MegatronModule):
...
@@ -584,7 +607,9 @@ class ParallelAttention(MegatronModule):
def
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
training
):
def
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
training
):
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
out
=
torch
.
nn
.
functional
.
dropout
(
x
+
bias
,
p
=
prob
,
training
=
training
)
if
bias
is
not
None
:
x
=
x
+
bias
out
=
torch
.
nn
.
functional
.
dropout
(
x
,
p
=
prob
,
training
=
training
)
out
=
residual
+
out
out
=
residual
+
out
return
out
return
out
...
@@ -649,7 +674,7 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -649,7 +674,7 @@ class ParallelTransformerLayer(MegatronModule):
attention_type
=
AttnType
.
self_attn
,
attention_type
=
AttnType
.
self_attn
,
attn_mask_type
=
self_attn_mask_type
)
attn_mask_type
=
self_attn_mask_type
)
self
.
hidden_dropout
=
args
.
hidden_dropout
self
.
hidden_dropout
=
args
.
hidden_dropout
self
.
bias_dropout_fusion
=
args
.
bias_dropout_fusion
self
.
bias_dropout_fusion
=
args
.
bias_dropout_fusion
and
args
.
add_bias_linear
self
.
drop_path
=
DropPath
(
drop_path_rate
)
if
drop_path_rate
>
0.0
else
None
self
.
drop_path
=
DropPath
(
drop_path_rate
)
if
drop_path_rate
>
0.0
else
None
# Layernorm on the attention output
# Layernorm on the attention output
...
@@ -718,10 +743,12 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -718,10 +743,12 @@ class ParallelTransformerLayer(MegatronModule):
else
:
else
:
bias_dropout_add_func
=
get_bias_dropout_add
(
self
.
training
)
bias_dropout_add_func
=
get_bias_dropout_add
(
self
.
training
)
if
attention_bias
is
not
None
:
attention_bias
=
attention_bias
.
expand_as
(
residual
)
with
self
.
bias_dropout_add_exec_handler
():
with
self
.
bias_dropout_add_exec_handler
():
layernorm_input
=
bias_dropout_add_func
(
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_output
,
attention_bias
.
expand_as
(
residual
)
,
attention_bias
,
residual
,
residual
,
self
.
hidden_dropout
)
self
.
hidden_dropout
)
else
:
else
:
...
@@ -744,10 +771,13 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -744,10 +771,13 @@ class ParallelTransformerLayer(MegatronModule):
else
:
else
:
residual
=
layernorm_input
residual
=
layernorm_input
if
attention_bias
is
not
None
:
attention_bias
=
attention_bias
.
expand_as
(
residual
)
with
self
.
bias_dropout_add_exec_handler
():
with
self
.
bias_dropout_add_exec_handler
():
layernorm_input
=
bias_dropout_add_func
(
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_output
,
attention_bias
.
expand_as
(
residual
)
,
attention_bias
,
residual
,
residual
,
self
.
hidden_dropout
)
self
.
hidden_dropout
)
...
@@ -764,10 +794,12 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -764,10 +794,12 @@ class ParallelTransformerLayer(MegatronModule):
residual
=
layernorm_input
residual
=
layernorm_input
if
self
.
drop_path
is
None
:
if
self
.
drop_path
is
None
:
if
mlp_bias
is
not
None
:
mlp_bias
=
mlp_bias
.
expand_as
(
residual
)
with
self
.
bias_dropout_add_exec_handler
():
with
self
.
bias_dropout_add_exec_handler
():
output
=
bias_dropout_add_func
(
output
=
bias_dropout_add_func
(
mlp_output
,
mlp_output
,
mlp_bias
.
expand_as
(
residual
)
,
mlp_bias
,
residual
,
residual
,
self
.
hidden_dropout
)
self
.
hidden_dropout
)
...
@@ -782,7 +814,9 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -782,7 +814,9 @@ class ParallelTransformerLayer(MegatronModule):
keep_graph
=
True
)
keep_graph
=
True
)
else
:
else
:
out
=
torch
.
nn
.
functional
.
dropout
(
mlp_output
+
mlp_bias
,
if
mlp_bias
is
not
None
:
mlp_output
=
mlp_output
+
mlp_bias
out
=
torch
.
nn
.
functional
.
dropout
(
mlp_output
,
p
=
self
.
hidden_dropout
,
p
=
self
.
hidden_dropout
,
training
=
self
.
training
)
training
=
self
.
training
)
output
=
residual
+
self
.
drop_path
(
out
)
output
=
residual
+
self
.
drop_path
(
out
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment