Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
53931b8b
Commit
53931b8b
authored
Jan 27, 2022
by
Vijay Korthikanti
Browse files
address review comments
parent
8acbbe25
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
24 additions
and
25 deletions
+24
-25
megatron/arguments.py
megatron/arguments.py
+10
-10
megatron/model/transformer.py
megatron/model/transformer.py
+11
-12
megatron/training.py
megatron/training.py
+3
-3
No files found.
megatron/arguments.py
View file @
53931b8b
...
@@ -246,14 +246,14 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -246,14 +246,14 @@ def parse_args(extra_args_provider=None, defaults={},
assert
args
.
fp16
or
args
.
bf16
,
\
assert
args
.
fp16
or
args
.
bf16
,
\
'residual connection in fp32 only supported when using fp16 or bf16.'
'residual connection in fp32 only supported when using fp16 or bf16.'
if
args
.
w
d
_incr_style
==
'constant'
:
if
args
.
w
eight_decay
_incr_style
==
'constant'
:
assert
args
.
start_w
d
is
None
assert
args
.
start_w
eight_decay
is
None
assert
args
.
end_w
d
is
None
assert
args
.
end_w
eight_decay
is
None
args
.
start_w
d
=
args
.
weight_decay
args
.
start_w
eight_decay
=
args
.
weight_decay
args
.
end_w
d
=
args
.
weight_decay
args
.
end_w
eight_decay
=
args
.
weight_decay
else
:
else
:
assert
args
.
start_w
d
is
not
None
assert
args
.
start_w
eight_decay
is
not
None
assert
args
.
end_w
d
is
not
None
assert
args
.
end_w
eight_decay
is
not
None
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
...
@@ -404,11 +404,11 @@ def _add_regularization_args(parser):
...
@@ -404,11 +404,11 @@ def _add_regularization_args(parser):
help
=
'Dropout probability for hidden state transformer.'
)
help
=
'Dropout probability for hidden state transformer.'
)
group
.
add_argument
(
'--weight-decay'
,
type
=
float
,
default
=
0.01
,
group
.
add_argument
(
'--weight-decay'
,
type
=
float
,
default
=
0.01
,
help
=
'Weight decay coefficient for L2 regularization.'
)
help
=
'Weight decay coefficient for L2 regularization.'
)
group
.
add_argument
(
'--start-w
d
'
,
type
=
float
,
group
.
add_argument
(
'--start-w
eight-decay
'
,
type
=
float
,
help
=
'Initial weight decay coefficient for L2 regularization.'
)
help
=
'Initial weight decay coefficient for L2 regularization.'
)
group
.
add_argument
(
'--end-w
d
'
,
type
=
float
,
group
.
add_argument
(
'--end-w
eight-decay
'
,
type
=
float
,
help
=
'End of run weight decay coefficient for L2 regularization.'
)
help
=
'End of run weight decay coefficient for L2 regularization.'
)
group
.
add_argument
(
'--w
d
-incr-style'
,
type
=
str
,
default
=
'constant'
,
group
.
add_argument
(
'--w
eight-decay
-incr-style'
,
type
=
str
,
default
=
'constant'
,
choices
=
[
'constant'
,
'linear'
,
'cosine'
],
choices
=
[
'constant'
,
'linear'
,
'cosine'
],
help
=
'Weight decay increment function.'
)
help
=
'Weight decay increment function.'
)
group
.
add_argument
(
'--clip-grad'
,
type
=
float
,
default
=
1.0
,
group
.
add_argument
(
'--clip-grad'
,
type
=
float
,
default
=
1.0
,
...
...
megatron/model/transformer.py
View file @
53931b8b
...
@@ -49,20 +49,20 @@ class DropPath(MegatronModule):
...
@@ -49,20 +49,20 @@ class DropPath(MegatronModule):
(when applied in main path of residual blocks).
(when applied in main path of residual blocks).
"""
"""
def
__init__
(
self
,
drop_prob
=
None
):
def
__init__
(
self
,
drop_prob
=
0.
):
super
(
DropPath
,
self
).
__init__
()
super
(
DropPath
,
self
).
__init__
()
self
.
drop_prob
=
drop_prob
self
.
drop_prob
=
drop_prob
def
forward
(
self
,
x
):
def
forward
(
self
,
hidden_state
):
if
self
.
drop_prob
==
0.
or
not
self
.
training
:
if
self
.
drop_prob
==
0.
or
not
self
.
training
:
return
x
return
hidden_state
keep_prob
=
1
-
self
.
drop_prob
keep_prob
=
1
-
self
.
drop_prob
# work with diff dim tensors, not just 2D ConvNets
# work with diff dim tensors, not just 2D ConvNets
shape
=
(
x
.
shape
[
0
],)
+
(
1
,)
*
(
x
.
ndim
-
1
)
shape
=
(
hidden_state
.
shape
[
0
],)
+
(
1
,)
*
(
hidden_state
.
ndim
-
1
)
random_tensor
=
keep_prob
+
\
random_tensor
=
keep_prob
+
\
torch
.
rand
(
shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
torch
.
rand
(
shape
,
dtype
=
hidden_state
.
dtype
,
device
=
hidden_state
.
device
)
random_tensor
.
floor_
()
# binarize
random_tensor
.
floor_
()
# binarize
output
=
x
.
div
(
keep_prob
)
*
random_tensor
output
=
hidden_state
.
div
(
keep_prob
)
*
random_tensor
return
output
return
output
...
@@ -437,7 +437,6 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -437,7 +437,6 @@ class ParallelTransformerLayer(MegatronModule):
super
(
ParallelTransformerLayer
,
self
).
__init__
()
super
(
ParallelTransformerLayer
,
self
).
__init__
()
self
.
layer_number
=
layer_number
self
.
layer_number
=
layer_number
self
.
layer_type
=
layer_type
self
.
layer_type
=
layer_type
self
.
drop_path_rate
=
drop_path_rate
self
.
apply_residual_connection_post_layernorm
\
self
.
apply_residual_connection_post_layernorm
\
=
args
.
apply_residual_connection_post_layernorm
=
args
.
apply_residual_connection_post_layernorm
...
@@ -460,7 +459,7 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -460,7 +459,7 @@ class ParallelTransformerLayer(MegatronModule):
attn_mask_type
=
self_attn_mask_type
)
attn_mask_type
=
self_attn_mask_type
)
self
.
hidden_dropout
=
args
.
hidden_dropout
self
.
hidden_dropout
=
args
.
hidden_dropout
self
.
bias_dropout_fusion
=
args
.
bias_dropout_fusion
self
.
bias_dropout_fusion
=
args
.
bias_dropout_fusion
self
.
drop_path
=
DropPath
(
drop_path_rate
)
self
.
drop_path
=
DropPath
(
drop_path_rate
)
if
drop_path_rate
>
0.0
else
None
# Layernorm on the attention output
# Layernorm on the attention output
self
.
post_attention_layernorm
=
LayerNorm
(
self
.
post_attention_layernorm
=
LayerNorm
(
...
@@ -504,7 +503,7 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -504,7 +503,7 @@ class ParallelTransformerLayer(MegatronModule):
else
:
else
:
residual
=
hidden_states
residual
=
hidden_states
if
self
.
drop_path
_rate
==
0.0
:
if
self
.
drop_path
is
None
:
# jit scripting for a nn.module (with dropout) is not
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# different nn.functional routines to account for varying
...
@@ -564,7 +563,7 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -564,7 +563,7 @@ class ParallelTransformerLayer(MegatronModule):
else
:
else
:
residual
=
layernorm_input
residual
=
layernorm_input
if
self
.
drop_path
_rate
==
0.0
:
if
self
.
drop_path
is
None
:
# re-enable torch grad to enable fused optimization.
# re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
with
torch
.
enable_grad
():
output
=
bias_dropout_add_func
(
output
=
bias_dropout_add_func
(
...
@@ -608,7 +607,7 @@ class ParallelTransformer(MegatronModule):
...
@@ -608,7 +607,7 @@ class ParallelTransformer(MegatronModule):
self
.
num_layers
=
mpu
.
get_num_layers
(
self
.
num_layers
=
mpu
.
get_num_layers
(
args
,
args
.
model_type
==
ModelType
.
encoder_and_decoder
)
args
,
args
.
model_type
==
ModelType
.
encoder_and_decoder
)
self
.
d
pr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
self
.
drop_path_rate
,
args
.
num_layers
)]
self
.
d
rop_path_rates
=
[
rate
.
item
()
for
rate
in
torch
.
linspace
(
0
,
self
.
drop_path_rate
,
args
.
num_layers
)]
# Transformer layers.
# Transformer layers.
def
build_layer
(
layer_number
):
def
build_layer
(
layer_number
):
...
@@ -618,7 +617,7 @@ class ParallelTransformer(MegatronModule):
...
@@ -618,7 +617,7 @@ class ParallelTransformer(MegatronModule):
layer_number
,
layer_number
,
layer_type
=
layer_type
,
layer_type
=
layer_type
,
self_attn_mask_type
=
self_attn_mask_type
,
self_attn_mask_type
=
self_attn_mask_type
,
drop_path_rate
=
self
.
d
p
r
[
layer_number
-
1
])
drop_path_rate
=
self
.
dr
op_path_rates
[
layer_number
-
1
])
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
assert
args
.
num_layers
%
args
.
virtual_pipeline_model_parallel_size
==
0
,
\
assert
args
.
num_layers
%
args
.
virtual_pipeline_model_parallel_size
==
0
,
\
'num_layers_per_stage must be divisible by '
\
'num_layers_per_stage must be divisible by '
\
...
...
megatron/training.py
View file @
53931b8b
...
@@ -341,9 +341,9 @@ def get_learning_rate_scheduler(optimizer):
...
@@ -341,9 +341,9 @@ def get_learning_rate_scheduler(optimizer):
warmup_steps
=
warmup_steps
,
warmup_steps
=
warmup_steps
,
decay_steps
=
decay_steps
,
decay_steps
=
decay_steps
,
decay_style
=
args
.
lr_decay_style
,
decay_style
=
args
.
lr_decay_style
,
start_wd
=
args
.
start_w
d
,
start_wd
=
args
.
start_w
eight_decay
,
end_wd
=
args
.
end_w
d
,
end_wd
=
args
.
end_w
eight_decay
,
wd_incr_style
=
args
.
w
d
_incr_style
,
wd_incr_style
=
args
.
w
eight_decay
_incr_style
,
use_checkpoint_lr_scheduler
=
args
.
use_checkpoint_lr_scheduler
,
use_checkpoint_lr_scheduler
=
args
.
use_checkpoint_lr_scheduler
,
override_lr_scheduler
=
args
.
override_lr_scheduler
)
override_lr_scheduler
=
args
.
override_lr_scheduler
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment