Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
fac6718a
Commit
fac6718a
authored
Dec 02, 2020
by
Jared Casper
Browse files
Merge branch 'blendable_dataset' into 'megatron_sampler'
Blendable dataset See merge request ADLR/megatron-lm!178
parents
cebd3b8b
1eda0a17
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
355 additions
and
51 deletions
+355
-51
megatron/arguments.py
megatron/arguments.py
+11
-6
megatron/data/blendable_dataset.py
megatron/data/blendable_dataset.py
+75
-0
megatron/data/dataset_utils.py
megatron/data/dataset_utils.py
+74
-0
megatron/data/gpt2_dataset.py
megatron/data/gpt2_dataset.py
+42
-0
megatron/data/helpers.cpp
megatron/data/helpers.cpp
+64
-0
megatron/data/realm_dataset_utils.py
megatron/data/realm_dataset_utils.py
+2
-1
megatron/learning_rates.py
megatron/learning_rates.py
+82
-39
megatron/training.py
megatron/training.py
+5
-5
No files found.
megatron/arguments.py
View file @
fac6718a
...
@@ -136,14 +136,16 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -136,14 +136,16 @@ def parse_args(extra_args_provider=None, defaults={},
def
_print_args
(
args
):
def
_print_args
(
args
):
"""Print arguments."""
"""Print arguments."""
if
args
.
rank
==
0
:
if
args
.
rank
==
0
:
print
(
'-------------------- arguments --------------------'
,
flush
=
True
)
print
(
'------------------------ arguments ------------------------'
,
flush
=
True
)
str_list
=
[]
str_list
=
[]
for
arg
in
vars
(
args
):
for
arg
in
vars
(
args
):
dots
=
'.'
*
(
32
-
len
(
arg
))
dots
=
'.'
*
(
48
-
len
(
arg
))
str_list
.
append
(
' {} {} {}'
.
format
(
arg
,
dots
,
getattr
(
args
,
arg
)))
str_list
.
append
(
' {} {} {}'
.
format
(
arg
,
dots
,
getattr
(
args
,
arg
)))
for
arg
in
sorted
(
str_list
,
key
=
lambda
x
:
x
.
lower
()):
for
arg
in
sorted
(
str_list
,
key
=
lambda
x
:
x
.
lower
()):
print
(
arg
,
flush
=
True
)
print
(
arg
,
flush
=
True
)
print
(
'---------------- end of arguments ----------------'
,
flush
=
True
)
print
(
'-------------------- end of arguments ---------------------'
,
flush
=
True
)
def
_check_arg_is_not_none
(
args
,
arg
):
def
_check_arg_is_not_none
(
args
,
arg
):
...
@@ -278,7 +280,7 @@ def _add_learning_rate_args(parser):
...
@@ -278,7 +280,7 @@ def _add_learning_rate_args(parser):
'and initial warmup, the learing rate at each '
'and initial warmup, the learing rate at each '
'iteration would be different.'
)
'iteration would be different.'
)
group
.
add_argument
(
'--lr-decay-style'
,
type
=
str
,
default
=
'linear'
,
group
.
add_argument
(
'--lr-decay-style'
,
type
=
str
,
default
=
'linear'
,
choices
=
[
'constant'
,
'linear'
,
'cosine'
,
'exponential'
],
choices
=
[
'constant'
,
'linear'
,
'cosine'
],
help
=
'Learning rate decay function.'
)
help
=
'Learning rate decay function.'
)
group
.
add_argument
(
'--lr-decay-iters'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--lr-decay-iters'
,
type
=
int
,
default
=
None
,
help
=
'number of iterations to decay learning rate over,'
help
=
'number of iterations to decay learning rate over,'
...
@@ -400,8 +402,11 @@ def _add_validation_args(parser):
...
@@ -400,8 +402,11 @@ def _add_validation_args(parser):
def
_add_data_args
(
parser
):
def
_add_data_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'data and dataloader'
)
group
=
parser
.
add_argument_group
(
title
=
'data and dataloader'
)
group
.
add_argument
(
'--data-path'
,
type
=
str
,
default
=
None
,
group
.
add_argument
(
'--data-path'
,
nargs
=
'*'
,
default
=
None
,
help
=
'Path to combined dataset to split.'
)
help
=
'Path to the training dataset. Accepted format:'
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...'
)
group
.
add_argument
(
'--split'
,
type
=
str
,
default
=
'969, 30, 1'
,
group
.
add_argument
(
'--split'
,
type
=
str
,
default
=
'969, 30, 1'
,
help
=
'Comma-separated list of proportions for training,'
help
=
'Comma-separated list of proportions for training,'
' validation, and test split. For example the split '
' validation, and test split. For example the split '
...
...
megatron/data/blendable_dataset.py
0 → 100644
View file @
fac6718a
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Blendable dataset."""
import
time
import
numpy
as
np
import
torch
from
megatron
import
print_rank_0
from
megatron
import
mpu
class
BlendableDataset
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
datasets
,
weights
):
self
.
datasets
=
datasets
num_datasets
=
len
(
datasets
)
assert
num_datasets
==
len
(
weights
)
self
.
size
=
0
for
dataset
in
self
.
datasets
:
self
.
size
+=
len
(
dataset
)
# Normalize weights.
weights
=
np
.
array
(
weights
,
dtype
=
np
.
float64
)
sum_weights
=
np
.
sum
(
weights
)
assert
sum_weights
>
0.0
weights
/=
sum_weights
# Build indecies.
start_time
=
time
.
time
()
assert
num_datasets
<
255
self
.
dataset_index
=
np
.
zeros
(
self
.
size
,
dtype
=
np
.
uint8
)
self
.
dataset_sample_index
=
np
.
zeros
(
self
.
size
,
dtype
=
np
.
int64
)
if
torch
.
distributed
.
get_rank
()
==
0
:
from
megatron.data.dataset_utils
import
compile_helper
compile_helper
()
# Simple barrier
tmp
=
torch
.
cuda
.
LongTensor
([
1
])
torch
.
distributed
.
all_reduce
(
tmp
,
group
=
mpu
.
get_data_parallel_group
())
from
megatron.data
import
helpers
helpers
.
build_blending_indices
(
self
.
dataset_index
,
self
.
dataset_sample_index
,
weights
,
num_datasets
,
self
.
size
,
torch
.
distributed
.
get_rank
()
==
0
)
print_rank_0
(
'> elapsed time for building blendable dataset indices: '
'{:.2f} (sec)'
.
format
(
time
.
time
()
-
start_time
))
def
__len__
(
self
):
return
self
.
size
def
__getitem__
(
self
,
idx
):
dataset_idx
=
self
.
dataset_index
[
idx
]
sample_idx
=
self
.
dataset_sample_index
[
idx
]
return
self
.
datasets
[
dataset_idx
][
sample_idx
]
megatron/data/dataset_utils.py
View file @
fac6718a
...
@@ -18,11 +18,13 @@
...
@@ -18,11 +18,13 @@
# https://github.com/google-research/albert/blob/master/create_pretraining_data.py
# https://github.com/google-research/albert/blob/master/create_pretraining_data.py
# with some modifications.
# with some modifications.
import
math
import
time
import
time
import
collections
import
collections
import
numpy
as
np
import
numpy
as
np
from
megatron
import
get_args
,
print_rank_0
from
megatron
import
get_args
,
print_rank_0
from
megatron.data.blendable_dataset
import
BlendableDataset
from
megatron.data.indexed_dataset
import
make_dataset
as
make_indexed_dataset
from
megatron.data.indexed_dataset
import
make_dataset
as
make_indexed_dataset
DSET_TYPE_STD
=
'standard_bert'
DSET_TYPE_STD
=
'standard_bert'
...
@@ -31,6 +33,38 @@ DSET_TYPE_ICT = 'ict'
...
@@ -31,6 +33,38 @@ DSET_TYPE_ICT = 'ict'
DSET_TYPES
=
[
DSET_TYPE_ICT
,
DSET_TYPE_STD
]
DSET_TYPES
=
[
DSET_TYPE_ICT
,
DSET_TYPE_STD
]
def
get_datasets_weights_and_num_samples
(
data_prefix
,
train_valid_test_num_samples
):
# The data prefix should be in the format of:
# weight-1, data-prefix-1, weight-2, data-prefix-2, ..
assert
len
(
data_prefix
)
%
2
==
0
num_datasets
=
len
(
data_prefix
)
//
2
weights
=
[
0
]
*
num_datasets
prefixes
=
[
0
]
*
num_datasets
for
i
in
range
(
num_datasets
):
weights
[
i
]
=
float
(
data_prefix
[
2
*
i
])
prefixes
[
i
]
=
(
data_prefix
[
2
*
i
+
1
]).
strip
()
# Normalize weights
weight_sum
=
0.0
for
weight
in
weights
:
weight_sum
+=
weight
assert
weight_sum
>
0.0
weights
=
[
weight
/
weight_sum
for
weight
in
weights
]
# Add 0.5% (the 1.005 factor) so in case the bleding dataset does
# not uniformly distribute the number of samples, we still have
# samples left to feed to the network.
datasets_train_valid_test_num_samples
=
[]
for
weight
in
weights
:
datasets_train_valid_test_num_samples
.
append
(
[
int
(
math
.
ceil
(
val
*
weight
*
1.005
))
for
val
in
train_valid_test_num_samples
])
return
prefixes
,
weights
,
datasets_train_valid_test_num_samples
def
compile_helper
():
def
compile_helper
():
"""Compile helper function ar runtime. Make sure this
"""Compile helper function ar runtime. Make sure this
is invoked on a single process."""
is invoked on a single process."""
...
@@ -360,6 +394,46 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
...
@@ -360,6 +394,46 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
short_seq_prob
,
seed
,
skip_warmup
,
short_seq_prob
,
seed
,
skip_warmup
,
dataset_type
=
'standard_bert'
):
dataset_type
=
'standard_bert'
):
if
len
(
data_prefix
)
==
1
:
return
_build_train_valid_test_datasets
(
data_prefix
[
0
],
data_impl
,
splits_string
,
train_valid_test_num_samples
,
max_seq_length
,
masked_lm_prob
,
short_seq_prob
,
seed
,
skip_warmup
,
dataset_type
=
dataset_type
)
# Blending dataset.
# Parse the values.
output
=
get_datasets_weights_and_num_samples
(
data_prefix
,
train_valid_test_num_samples
)
prefixes
,
weights
,
datasets_train_valid_test_num_samples
=
output
# Build individual datasets.
train_datasets
=
[]
valid_datasets
=
[]
test_datasets
=
[]
for
i
in
range
(
len
(
prefixes
)):
train_ds
,
valid_ds
,
test_ds
=
_build_train_valid_test_datasets
(
prefixes
[
i
],
data_impl
,
splits_string
,
datasets_train_valid_test_num_samples
[
i
],
max_seq_length
,
masked_lm_prob
,
short_seq_prob
,
seed
,
skip_warmup
,
dataset_type
=
dataset_type
)
# Blend.
blending_train_dataset
=
BlendableDataset
(
train_datasets
,
weights
)
blending_valid_dataset
=
BlendableDataset
(
valid_datasets
,
weights
)
blending_test_dataset
=
BlendableDataset
(
test_datasets
,
weights
)
return
(
blending_train_dataset
,
blending_valid_dataset
,
blending_test_dataset
)
def
_build_train_valid_test_datasets
(
data_prefix
,
data_impl
,
splits_string
,
train_valid_test_num_samples
,
max_seq_length
,
masked_lm_prob
,
short_seq_prob
,
seed
,
skip_warmup
,
dataset_type
=
'standard_bert'
):
if
dataset_type
not
in
DSET_TYPES
:
if
dataset_type
not
in
DSET_TYPES
:
raise
ValueError
(
"Invalid dataset_type: "
,
dataset_type
)
raise
ValueError
(
"Invalid dataset_type: "
,
dataset_type
)
...
...
megatron/data/gpt2_dataset.py
View file @
fac6718a
...
@@ -22,6 +22,8 @@ import numpy as np
...
@@ -22,6 +22,8 @@ import numpy as np
import
torch
import
torch
from
megatron
import
mpu
,
print_rank_0
from
megatron
import
mpu
,
print_rank_0
from
megatron.data.blendable_dataset
import
BlendableDataset
from
megatron.data.dataset_utils
import
get_datasets_weights_and_num_samples
from
megatron.data.dataset_utils
import
get_train_valid_test_split_
from
megatron.data.dataset_utils
import
get_train_valid_test_split_
from
megatron.data.indexed_dataset
import
make_dataset
as
make_indexed_dataset
from
megatron.data.indexed_dataset
import
make_dataset
as
make_indexed_dataset
...
@@ -31,6 +33,46 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
...
@@ -31,6 +33,46 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
seq_length
,
seed
,
skip_warmup
):
seq_length
,
seed
,
skip_warmup
):
"""Build train, valid, and test datasets."""
"""Build train, valid, and test datasets."""
# Single dataset.
if
len
(
data_prefix
)
==
1
:
return
_build_train_valid_test_datasets
(
data_prefix
[
0
],
data_impl
,
splits_string
,
train_valid_test_num_samples
,
seq_length
,
seed
,
skip_warmup
)
# Blending dataset.
# Parse the values.
output
=
get_datasets_weights_and_num_samples
(
data_prefix
,
train_valid_test_num_samples
)
prefixes
,
weights
,
datasets_train_valid_test_num_samples
=
output
# Build individual datasets.
train_datasets
=
[]
valid_datasets
=
[]
test_datasets
=
[]
for
i
in
range
(
len
(
prefixes
)):
train_ds
,
valid_ds
,
test_ds
=
_build_train_valid_test_datasets
(
prefixes
[
i
],
data_impl
,
splits_string
,
datasets_train_valid_test_num_samples
[
i
],
seq_length
,
seed
,
skip_warmup
)
train_datasets
.
append
(
train_ds
)
valid_datasets
.
append
(
valid_ds
)
test_datasets
.
append
(
test_ds
)
# Blend.
blending_train_dataset
=
BlendableDataset
(
train_datasets
,
weights
)
blending_valid_dataset
=
BlendableDataset
(
valid_datasets
,
weights
)
blending_test_dataset
=
BlendableDataset
(
test_datasets
,
weights
)
return
(
blending_train_dataset
,
blending_valid_dataset
,
blending_test_dataset
)
def
_build_train_valid_test_datasets
(
data_prefix
,
data_impl
,
splits_string
,
train_valid_test_num_samples
,
seq_length
,
seed
,
skip_warmup
):
"""Build train, valid, and test datasets."""
# Indexed dataset.
# Indexed dataset.
indexed_dataset
=
get_indexed_dataset_
(
data_prefix
,
indexed_dataset
=
get_indexed_dataset_
(
data_prefix
,
data_impl
,
data_impl
,
...
...
megatron/data/helpers.cpp
View file @
fac6718a
...
@@ -33,6 +33,69 @@ using namespace std;
...
@@ -33,6 +33,69 @@ using namespace std;
const
int32_t
LONG_SENTENCE_LEN
=
512
;
const
int32_t
LONG_SENTENCE_LEN
=
512
;
void
build_blending_indices
(
py
::
array_t
<
uint8_t
>&
dataset_index
,
py
::
array_t
<
int64_t
>&
dataset_sample_index
,
const
py
::
array_t
<
double
>&
weights
,
const
int32_t
num_datasets
,
const
int64_t
size
,
const
bool
verbose
)
{
/* Given multiple datasets and a weighting array, build samples
such that it follows those wieghts.*/
if
(
verbose
)
{
std
::
cout
<<
"> building indices for blendable datasets ..."
<<
std
::
endl
;
}
// Get the pointer access without the checks.
auto
dataset_index_ptr
=
dataset_index
.
mutable_unchecked
<
1
>
();
auto
dataset_sample_index_ptr
=
dataset_sample_index
.
mutable_unchecked
<
1
>
();
auto
weights_ptr
=
weights
.
unchecked
<
1
>
();
// Initialize buffer for number of samples used for each dataset.
int64_t
current_samples
[
num_datasets
];
for
(
int64_t
i
=
0
;
i
<
num_datasets
;
++
i
)
{
current_samples
[
i
]
=
0
;
}
// For each sample:
for
(
int64_t
sample_idx
=
0
;
sample_idx
<
size
;
++
sample_idx
)
{
// Determine where the max error in sampling is happening.
auto
sample_idx_double
=
std
::
max
(
static_cast
<
double
>
(
sample_idx
),
1.0
);
int64_t
max_error_index
=
0
;
double
max_error
=
weights_ptr
[
0
]
*
sample_idx_double
-
static_cast
<
double
>
(
current_samples
[
0
]);
for
(
int64_t
dataset_idx
=
1
;
dataset_idx
<
num_datasets
;
++
dataset_idx
)
{
double
error
=
weights_ptr
[
dataset_idx
]
*
sample_idx_double
-
static_cast
<
double
>
(
current_samples
[
dataset_idx
]);
if
(
error
>
max_error
)
{
max_error
=
error
;
max_error_index
=
dataset_idx
;
}
}
// Populate the indices.
dataset_index_ptr
[
sample_idx
]
=
static_cast
<
uint8_t
>
(
max_error_index
);
dataset_sample_index_ptr
[
sample_idx
]
=
current_samples
[
max_error_index
];
// Update the total samples.
current_samples
[
max_error_index
]
+=
1
;
}
// print info
if
(
verbose
)
{
std
::
cout
<<
" > sample ratios:"
<<
std
::
endl
;
for
(
int64_t
dataset_idx
=
0
;
dataset_idx
<
num_datasets
;
++
dataset_idx
)
{
auto
ratio
=
static_cast
<
double
>
(
current_samples
[
dataset_idx
])
/
static_cast
<
double
>
(
size
);
std
::
cout
<<
" dataset "
<<
dataset_idx
<<
", input: "
<<
weights_ptr
[
dataset_idx
]
<<
", achieved: "
<<
ratio
<<
std
::
endl
;
}
}
}
py
::
array
build_sample_idx
(
const
py
::
array_t
<
int32_t
>&
sizes_
,
py
::
array
build_sample_idx
(
const
py
::
array_t
<
int32_t
>&
sizes_
,
const
py
::
array_t
<
int32_t
>&
doc_idx_
,
const
py
::
array_t
<
int32_t
>&
doc_idx_
,
const
int32_t
seq_length
,
const
int32_t
seq_length
,
...
@@ -640,4 +703,5 @@ PYBIND11_MODULE(helpers, m) {
...
@@ -640,4 +703,5 @@ PYBIND11_MODULE(helpers, m) {
m
.
def
(
"build_mapping"
,
&
build_mapping
);
m
.
def
(
"build_mapping"
,
&
build_mapping
);
m
.
def
(
"build_blocks_mapping"
,
&
build_blocks_mapping
);
m
.
def
(
"build_blocks_mapping"
,
&
build_blocks_mapping
);
m
.
def
(
"build_sample_idx"
,
&
build_sample_idx
);
m
.
def
(
"build_sample_idx"
,
&
build_sample_idx
);
m
.
def
(
"build_blending_indices"
,
&
build_blending_indices
);
}
}
megatron/data/realm_dataset_utils.py
View file @
fac6718a
...
@@ -6,7 +6,6 @@ import torch
...
@@ -6,7 +6,6 @@ import torch
from
megatron
import
mpu
,
print_rank_0
from
megatron
import
mpu
,
print_rank_0
from
megatron.data.dataset_utils
import
create_masked_lm_predictions
,
pad_and_convert_to_numpy
from
megatron.data.dataset_utils
import
create_masked_lm_predictions
,
pad_and_convert_to_numpy
from
megatron.data.samplers
import
DistributedBatchSampler
from
megatron
import
get_args
,
get_tokenizer
,
print_rank_0
,
mpu
from
megatron
import
get_args
,
get_tokenizer
,
print_rank_0
,
mpu
...
@@ -23,6 +22,8 @@ def get_one_epoch_dataloader(dataset, batch_size=None):
...
@@ -23,6 +22,8 @@ def get_one_epoch_dataloader(dataset, batch_size=None):
sampler
=
torch
.
utils
.
data
.
SequentialSampler
(
dataset
)
sampler
=
torch
.
utils
.
data
.
SequentialSampler
(
dataset
)
# importantly, drop_last must be False to get all the data.
# importantly, drop_last must be False to get all the data.
assert
False
,
'DistributedBatchSampler deprecated, change the implementation'
from
megatron.data.samplers
import
DistributedBatchSampler
batch_sampler
=
DistributedBatchSampler
(
sampler
,
batch_sampler
=
DistributedBatchSampler
(
sampler
,
batch_size
=
global_batch_size
,
batch_size
=
global_batch_size
,
drop_last
=
False
,
drop_last
=
False
,
...
...
megatron/learning_rates.py
View file @
fac6718a
...
@@ -19,77 +19,101 @@ import math
...
@@ -19,77 +19,101 @@ import math
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
class
AnnealingLR
(
object
):
class
AnnealingLR
(
object
):
"""Anneals the learning rate."""
"""Anneals the learning rate."""
def
__init__
(
self
,
optimizer
,
start
_lr
,
def
__init__
(
self
,
optimizer
,
max_lr
,
min
_lr
,
warmup_
i
te
r
,
total_i
te
r
s
,
warmup_
s
te
ps
,
decay_s
te
p
s
,
decay_style
,
last_iter
,
min_lr
=
0.0
,
decay_style
,
num_steps
,
use_checkpoint_lr_scheduler
=
True
,
use_checkpoint_lr_scheduler
=
True
,
override_lr_scheduler
=
False
):
override_lr_scheduler
=
False
):
# Class values.
# Class values.
self
.
optimizer
=
optimizer
self
.
optimizer
=
optimizer
self
.
start_lr
=
start_lr
self
.
max_lr
=
float
(
max_lr
)
self
.
min_lr
=
min_lr
self
.
min_lr
=
min_lr
self
.
warmup_iter
=
warmup_iter
assert
self
.
min_lr
>=
0.0
self
.
num_iters
=
last_iter
assert
self
.
max_lr
>=
self
.
min_lr
self
.
end_iter
=
total_iters
assert
self
.
end_iter
>
0
self
.
warmup_steps
=
warmup_steps
self
.
num_steps
=
num_steps
self
.
decay_steps
=
decay_steps
assert
self
.
decay_steps
>
0
assert
self
.
warmup_steps
<
self
.
decay_steps
self
.
decay_style
=
decay_style
self
.
decay_style
=
decay_style
self
.
override_lr_scheduler
=
override_lr_scheduler
self
.
override_lr_scheduler
=
override_lr_scheduler
self
.
use_checkpoint_lr_scheduler
=
use_checkpoint_lr_scheduler
self
.
use_checkpoint_lr_scheduler
=
use_checkpoint_lr_scheduler
if
self
.
override_lr_scheduler
:
if
self
.
override_lr_scheduler
:
assert
not
self
.
use_checkpoint_lr_scheduler
,
'both override and '
\
assert
not
self
.
use_checkpoint_lr_scheduler
,
'both override and '
\
'use-checkpoint are set.'
'use-checkpoint are set.'
# Set the learning rate
# Set the learning rate
self
.
step
(
self
.
num_
i
te
r
s
)
self
.
step
(
step_num
=
self
.
num_
s
te
p
s
)
print_rank_0
(
'> learning rate decay style: {}'
.
format
(
self
.
decay_style
))
print_rank_0
(
'> learning rate decay style: {}'
.
format
(
self
.
decay_style
))
def
get_lr
(
self
):
def
get_lr
(
self
):
"""Learning rate decay functions from:
"""Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
num_iters_
=
min
(
self
.
num_iters
,
self
.
end_iter
-
self
.
warmup_iter
)
# Use linear warmup for the initial part.
# Warmup.
if
self
.
warmup_steps
>
0
and
self
.
num_steps
<=
self
.
warmup_steps
:
if
self
.
warmup_iter
>
0
and
self
.
num_iters
<=
self
.
warmup_iter
:
return
self
.
max_lr
*
float
(
self
.
num_steps
)
/
\
return
float
(
self
.
start_lr
)
*
num_iters_
/
self
.
warmup_iter
float
(
self
.
warmup_steps
)
# If the learning rate is constant, just return the initial value.
if
self
.
decay_style
==
'constant'
:
return
self
.
max_lr
# For any steps larger than `self.decay_steps`, use `self.min_lr`.
if
self
.
num_steps
>
self
.
decay_steps
:
return
self
.
min_lr
# If we are done with the warmup period, use the decay style.
num_steps_
=
self
.
num_steps
-
self
.
warmup_steps
decay_steps_
=
self
.
decay_steps
-
self
.
warmup_steps
decay_ratio
=
float
(
num_steps_
)
/
float
(
decay_steps_
)
assert
decay_ratio
>=
0.0
assert
decay_ratio
<=
1.0
delta_lr
=
self
.
max_lr
-
self
.
min_lr
num_iters_
=
num_iters_
-
self
.
warmup_iter
if
self
.
decay_style
==
'linear'
:
if
self
.
decay_style
==
'linear'
:
lr
=
self
.
start_lr
*
(
self
.
end_iter
-
num_iters_
)
/
self
.
end_iter
coeff
=
(
1.0
-
decay_ratio
)
elif
self
.
decay_style
==
'cosine'
:
elif
self
.
decay_style
==
'cosine'
:
lr
=
self
.
start_lr
/
2.0
*
(
math
.
cos
(
coeff
=
0.5
*
(
math
.
cos
(
math
.
pi
*
decay_ratio
)
+
1.0
)
math
.
pi
*
num_iters_
/
self
.
end_iter
)
+
1
)
elif
self
.
decay_style
==
'exponential'
:
# exp(-0.693) = 1/2
lr
=
self
.
start_lr
*
math
.
exp
(
-
0.693
*
num_iters_
/
self
.
end_iter
)
else
:
else
:
lr
=
self
.
start_lr
raise
Exception
(
'{} decay style is not supported.'
.
format
(
return
max
(
lr
,
self
.
min_lr
)
self
.
decay_style
))
return
self
.
min_lr
+
coeff
*
delta_lr
def
step
(
self
,
step_num
=
None
):
def
step
(
self
,
increment
=
1
,
step_num
=
None
):
"""Set lr for all parameters groups."""
"""Set lr for all parameters groups."""
if
step_num
is
None
:
if
step_num
is
None
:
step_num
=
self
.
num_
i
te
r
s
+
1
step_num
=
self
.
num_
s
te
p
s
+
increment
self
.
num_
i
te
r
s
=
step_num
self
.
num_
s
te
p
s
=
step_num
new_lr
=
self
.
get_lr
()
new_lr
=
self
.
get_lr
()
for
group
in
self
.
optimizer
.
param_groups
:
for
group
in
self
.
optimizer
.
param_groups
:
group
[
'lr'
]
=
new_lr
group
[
'lr'
]
=
new_lr
def
state_dict
(
self
):
def
state_dict
(
self
):
state_dict
=
{
state_dict
=
{
'
start
_lr'
:
self
.
start
_lr
,
'
max
_lr'
:
self
.
max
_lr
,
'warmup_
i
te
r
'
:
self
.
warmup_
i
te
r
,
'warmup_
s
te
ps
'
:
self
.
warmup_
s
te
ps
,
'num_
i
te
r
s'
:
self
.
num_
i
te
r
s
,
'num_
s
te
p
s'
:
self
.
num_
s
te
p
s
,
'decay_style'
:
self
.
decay_style
,
'decay_style'
:
self
.
decay_style
,
'
end_iter'
:
self
.
end_iter
,
'
decay_steps'
:
self
.
decay_steps
,
'min_lr'
:
self
.
min_lr
'min_lr'
:
self
.
min_lr
}
}
return
state_dict
return
state_dict
def
_check_and_set
(
self
,
cls_value
,
sd_value
,
name
):
def
_check_and_set
(
self
,
cls_value
,
sd_value
,
name
):
"""Auxiliary function for checking the values in the checkpoint and
"""Auxiliary function for checking the values in the checkpoint and
setting them."""
setting them."""
...
@@ -104,20 +128,39 @@ class AnnealingLR(object):
...
@@ -104,20 +128,39 @@ class AnnealingLR(object):
name
))
name
))
return
sd_value
return
sd_value
def
load_state_dict
(
self
,
sd
):
def
load_state_dict
(
self
,
sd
):
self
.
start_lr
=
self
.
_check_and_set
(
self
.
start_lr
,
sd
[
'start_lr'
],
if
'start_lr'
in
sd
:
max_lr_
=
sd
[
'start_lr'
]
else
:
max_lr_
=
sd
[
'max_lr'
]
self
.
max_lr
=
self
.
_check_and_set
(
self
.
max_lr
,
max_lr_
,
'learning rate'
)
'learning rate'
)
self
.
min_lr
=
self
.
_check_and_set
(
self
.
min_lr
,
sd
[
'min_lr'
],
self
.
min_lr
=
self
.
_check_and_set
(
self
.
min_lr
,
sd
[
'min_lr'
],
'minimum learning rate'
)
'minimum learning rate'
)
self
.
warmup_iter
=
self
.
_check_and_set
(
self
.
warmup_iter
,
sd
[
'warmup_iter'
],
if
'warmup_iter'
in
sd
:
warmup_steps_
=
sd
[
'warmup_iter'
]
else
:
warmup_steps_
=
sd
[
'warmup_steps'
]
self
.
warmup_steps
=
self
.
_check_and_set
(
self
.
warmup_steps
,
warmup_steps_
,
'warmup iterations'
)
'warmup iterations'
)
self
.
end_iter
=
self
.
_check_and_set
(
self
.
end_iter
,
sd
[
'end_iter'
],
if
'end_iter'
in
sd
:
decay_steps_
=
sd
[
'end_iter'
]
else
:
decay_steps_
=
sd
[
'decay_steps'
]
self
.
decay_steps
=
self
.
_check_and_set
(
self
.
decay_steps
,
decay_steps_
,
'total number of iterations'
)
'total number of iterations'
)
self
.
decay_style
=
self
.
_check_and_set
(
self
.
decay_style
,
self
.
decay_style
=
self
.
_check_and_set
(
self
.
decay_style
,
sd
[
'decay_style'
],
sd
[
'decay_style'
],
'decay style'
)
'decay style'
)
self
.
num_iters
=
sd
[
'num_iters'
]
if
'num_iters'
in
sd
:
self
.
step
(
self
.
num_iters
)
self
.
num_steps
=
sd
[
'num_iters'
]
else
:
self
.
num_steps
=
sd
[
'num_steps'
]
self
.
step
(
step_num
=
self
.
num_steps
)
megatron/training.py
View file @
fac6718a
...
@@ -194,12 +194,12 @@ def get_learning_rate_scheduler(optimizer):
...
@@ -194,12 +194,12 @@ def get_learning_rate_scheduler(optimizer):
warmup_iter
=
args
.
warmup
*
num_iters
warmup_iter
=
args
.
warmup
*
num_iters
lr_scheduler
=
AnnealingLR
(
lr_scheduler
=
AnnealingLR
(
optimizer
,
optimizer
,
start_lr
=
args
.
lr
,
max_lr
=
args
.
lr
,
warmup_iter
=
warmup_iter
,
total_iters
=
num_iters
,
decay_style
=
args
.
lr_decay_style
,
last_iter
=
init_step
,
min_lr
=
args
.
min_lr
,
min_lr
=
args
.
min_lr
,
warmup_steps
=
warmup_iter
,
decay_steps
=
num_iters
,
decay_style
=
args
.
lr_decay_style
,
num_steps
=
init_step
,
use_checkpoint_lr_scheduler
=
args
.
use_checkpoint_lr_scheduler
,
use_checkpoint_lr_scheduler
=
args
.
use_checkpoint_lr_scheduler
,
override_lr_scheduler
=
args
.
override_lr_scheduler
)
override_lr_scheduler
=
args
.
override_lr_scheduler
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment