Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
8acbbe25
Commit
8acbbe25
authored
Jan 26, 2022
by
Vijay Korthikanti
Browse files
address review comments
parent
7e810e41
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
14 additions
and
4 deletions
+14
-4
megatron/arguments.py
megatron/arguments.py
+9
-4
megatron/optimizer/__init__.py
megatron/optimizer/__init__.py
+1
-0
megatron/training.py
megatron/training.py
+4
-0
No files found.
megatron/arguments.py
View file @
8acbbe25
...
...
@@ -246,9 +246,14 @@ def parse_args(extra_args_provider=None, defaults={},
assert
args
.
fp16
or
args
.
bf16
,
\
'residual connection in fp32 only supported when using fp16 or bf16.'
if
args
.
weight_decay
is
not
None
:
if
args
.
wd_incr_style
==
'constant'
:
assert
args
.
start_wd
is
None
assert
args
.
end_wd
is
None
args
.
start_wd
=
args
.
weight_decay
args
.
end_wd
=
args
.
weight_decay
else
:
assert
args
.
start_wd
is
not
None
assert
args
.
end_wd
is
not
None
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
...
...
@@ -399,11 +404,11 @@ def _add_regularization_args(parser):
help
=
'Dropout probability for hidden state transformer.'
)
group
.
add_argument
(
'--weight-decay'
,
type
=
float
,
default
=
0.01
,
help
=
'Weight decay coefficient for L2 regularization.'
)
group
.
add_argument
(
'--start-wd'
,
type
=
float
,
default
=
0.01
,
group
.
add_argument
(
'--start-wd'
,
type
=
float
,
help
=
'Initial weight decay coefficient for L2 regularization.'
)
group
.
add_argument
(
'--end-wd'
,
type
=
float
,
default
=
0.01
,
group
.
add_argument
(
'--end-wd'
,
type
=
float
,
help
=
'End of run weight decay coefficient for L2 regularization.'
)
group
.
add_argument
(
'--wd-incr-style'
,
type
=
str
,
default
=
'
linear
'
,
group
.
add_argument
(
'--wd-incr-style'
,
type
=
str
,
default
=
'
constant
'
,
choices
=
[
'constant'
,
'linear'
,
'cosine'
],
help
=
'Weight decay increment function.'
)
group
.
add_argument
(
'--clip-grad'
,
type
=
float
,
default
=
1.0
,
...
...
megatron/optimizer/__init__.py
View file @
8acbbe25
...
...
@@ -44,6 +44,7 @@ def get_param_groups(modules,
if
no_weight_decay_cond
is
not
None
:
no_wd
=
no_weight_decay_cond
(
name
,
param
)
else
:
# do not regularize biases nor Norm parameters
no_wd
=
name
.
endswith
(
".bias"
)
or
len
(
param
.
shape
)
==
1
if
scale_lr_cond
is
not
None
:
...
...
megatron/training.py
View file @
8acbbe25
...
...
@@ -87,6 +87,10 @@ def pretrain(train_valid_test_dataset_provider,
the info we would like to monitor during training, for example
`lm-loss: value`. We also require that this function add
`batch generator` to the timers class.
process_non_loss_data_func: a function to post process outputs of the
network. It can be used for dumping output tensors (e.g images) to
tensorboard. It takes `collected data`(list of tensors),
`current iteration index` and `tensorboard writer` as arguments.
extra_args_provider: a function that takes a parser and adds arguments
to it. It is used for programs to add their own arguments.
args_defaults: a dictionary from argument-name to argument-value. It
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment