Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
00ac56ab
Commit
00ac56ab
authored
Dec 09, 2020
by
mohammad
Committed by
Deepak Narayanan
Dec 19, 2020
Browse files
Address Jared's comments
parent
22ab91bb
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
158 additions
and
132 deletions
+158
-132
megatron/arguments.py
megatron/arguments.py
+4
-2
megatron/global_vars.py
megatron/global_vars.py
+3
-130
megatron/microbatches.py
megatron/microbatches.py
+151
-0
No files found.
megatron/arguments.py
View file @
00ac56ab
...
...
@@ -264,9 +264,11 @@ def _add_training_args(parser):
group
.
add_argument
(
'--micro-batch-size'
,
type
=
int
,
default
=
None
,
help
=
'Batch size per model instance (local batch size). '
'Global batch size is local batch size times data '
'parallel size.'
)
'parallel size
times number of micro batches
.'
)
group
.
add_argument
(
'--global-batch-size'
,
type
=
int
,
default
=
None
,
help
=
'Training batch size. If this value is None, then '
help
=
'Training batch size. If set, it should be a '
'multiple of micro-batch-size times data-parallel-size. '
'If this value is None, then '
'use micro-batch-size * data-parallel-size as the '
'global batch size. This choice will result in 1 for '
'number of micro-batches.'
)
...
...
megatron/global_vars.py
View file @
00ac56ab
...
...
@@ -15,18 +15,15 @@
"""Megatron global variables."""
from
abc
import
ABC
from
abc
import
abstractmethod
import
math
import
os
import
sys
import
time
import
numpy
as
np
import
torch
from
megatron.tokenizer
import
build_tokenizer
from
.arguments
import
parse_args
from
.microbatches
import
build_num_microbatches_calculator
_GLOBAL_ARGS
=
None
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
None
...
...
@@ -104,132 +101,8 @@ def _build_num_microbatches_calculator(args):
_ensure_var_is_not_initialized
(
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
,
'num microbatches calculator'
)
# Constant num micro-batches.
if
args
.
rampup_batch_size
is
None
:
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
ConstantNumMicroBatches
(
args
.
global_batch_size
,
args
.
micro_batch_size
,
args
.
data_parallel_size
)
if
args
.
rank
==
0
:
print
(
'setting number of micro-batches to constant {}'
.
format
(
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
get
()),
flush
=
True
)
else
:
assert
len
(
args
.
rampup_batch_size
)
==
3
,
'expected the following '
\
'format: --rampup-batch-size <start batch size> '
\
'<batch size incerement> <ramp-up samples>'
start_batch_size
=
int
(
args
.
rampup_batch_size
[
0
])
batch_size_increment
=
int
(
args
.
rampup_batch_size
[
1
])
ramup_samples
=
int
(
args
.
rampup_batch_size
[
2
])
if
args
.
rank
==
0
:
print
(
'will use batch size rampup starting from global batch '
'size {} to global batch size {} with batch size increments '
'{} over {} samples.'
.
format
(
start_batch_size
,
args
.
global_batch_size
,
batch_size_increment
,
ramup_samples
),
flush
=
True
)
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
RampupBatchsizeNumMicroBatches
(
start_batch_size
,
batch_size_increment
,
ramup_samples
,
args
.
global_batch_size
,
args
.
micro_batch_size
,
args
.
data_parallel_size
)
class
NumMicroBatchesCalculator
(
ABC
):
def
__init__
(
self
):
self
.
num_micro_batches
=
None
def
get
(
self
):
return
self
.
num_micro_batches
@
abstractmethod
def
update
(
self
,
consumed_samples
):
pass
class
ConstantNumMicroBatches
(
NumMicroBatchesCalculator
):
def
__init__
(
self
,
global_batch_size
,
micro_batch_size
,
data_parallel_size
):
micro_batch_times_data_parallel
=
micro_batch_size
*
\
data_parallel_size
assert
global_batch_size
%
micro_batch_times_data_parallel
==
0
,
\
'global batch size ({}) is not divisible by micro batch size ({})'
\
' times data parallel size ({})'
.
format
(
global_batch_size
,
micro_batch_size
,
data_parallel_size
)
self
.
num_micro_batches
=
global_batch_size
//
\
micro_batch_times_data_parallel
assert
self
.
num_micro_batches
>=
1
def
update
(
self
,
consumed_samples
):
pass
class
RampupBatchsizeNumMicroBatches
(
NumMicroBatchesCalculator
):
def
__init__
(
self
,
start_batch_size
,
batch_size_increment
,
ramup_samples
,
global_batch_size
,
micro_batch_size
,
data_parallel_size
):
"""Batch size ramp up.
Over
steps = (global-batch-size - start-batch-size) / batch_size_increment
increment batch size from start-batch-size to global-batch-size using
rampup-samples / steps
samples.
Arguments:
start_batch_size: global batch size to start with
batch_size_increment: global batch size increments
ramup_samples: number of samples to use ramp up global
batch size from `start_batch_size` to `global_batch_size`
global_batch_size: global batch size post rampup
micro_batch_size: micro batch size
data_parallel_size: data parallel size.
"""
self
.
micro_batch_size
=
micro_batch_size
self
.
data_parallel_size
=
data_parallel_size
self
.
micro_batch_times_data_parallel_size
=
self
.
micro_batch_size
*
\
self
.
data_parallel_size
assert
self
.
micro_batch_times_data_parallel_size
>
0
assert
start_batch_size
>
0
self
.
start_batch_size
=
start_batch_size
assert
global_batch_size
>
0
self
.
global_batch_size
=
global_batch_size
diff_batch_size
=
self
.
global_batch_size
-
self
.
start_batch_size
assert
diff_batch_size
>=
0
assert
batch_size_increment
>
0
self
.
batch_size_increment
=
batch_size_increment
assert
diff_batch_size
%
batch_size_increment
==
0
,
'expected '
\
'global batch size interval ({}) to be divisible by global batch '
\
'size increment ({})'
.
format
(
diff_batch_size
,
batch_size_increment
)
num_increments
=
diff_batch_size
//
self
.
batch_size_increment
self
.
ramup_samples
=
ramup_samples
assert
self
.
ramup_samples
>=
0
self
.
rampup_samples_per_increment
=
self
.
ramup_samples
/
num_increments
# Initialize number of microbatches.
self
.
update
(
0
)
def
update
(
self
,
consumed_samples
):
if
consumed_samples
>
self
.
ramup_samples
:
current_global_batch_size
=
self
.
global_batch_size
else
:
steps
=
int
(
consumed_samples
/
self
.
rampup_samples_per_increment
)
current_global_batch_size
=
self
.
start_batch_size
+
\
steps
*
self
.
batch_size_increment
assert
current_global_batch_size
<=
self
.
global_batch_size
assert
current_global_batch_size
%
\
self
.
micro_batch_times_data_parallel_size
==
0
,
'current global '
\
'batch size ({}) is not divisible by micro-batch-size ({}) times'
\
'data parallel size ({})'
.
format
(
current_global_batch_size
,
self
.
micro_batch_size
,
self
.
data_parallel_size
)
self
.
num_micro_batches
=
current_global_batch_size
//
\
self
.
micro_batch_times_data_parallel_size
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
build_num_microbatches_calculator
(
args
)
def
_build_tokenizer
(
args
):
...
...
megatron/microbatches.py
0 → 100644
View file @
00ac56ab
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron number of micro-batches calculators."""
from
abc
import
ABC
from
abc
import
abstractmethod
def
build_num_microbatches_calculator
(
args
):
# Constant num micro-batches.
if
args
.
rampup_batch_size
is
None
:
num_microbatches_calculator
=
ConstantNumMicroBatches
(
args
.
global_batch_size
,
args
.
micro_batch_size
,
args
.
data_parallel_size
)
if
args
.
rank
==
0
:
print
(
'setting number of micro-batches to constant {}'
.
format
(
num_microbatches_calculator
.
get
()),
flush
=
True
)
else
:
assert
len
(
args
.
rampup_batch_size
)
==
3
,
'expected the following '
\
'format: --rampup-batch-size <start batch size> '
\
'<batch size incerement> <ramp-up samples>'
start_batch_size
=
int
(
args
.
rampup_batch_size
[
0
])
batch_size_increment
=
int
(
args
.
rampup_batch_size
[
1
])
ramup_samples
=
int
(
args
.
rampup_batch_size
[
2
])
if
args
.
rank
==
0
:
print
(
'will use batch size rampup starting from global batch '
'size {} to global batch size {} with batch size increments '
'{} over {} samples.'
.
format
(
start_batch_size
,
args
.
global_batch_size
,
batch_size_increment
,
ramup_samples
),
flush
=
True
)
num_microbatches_calculator
=
RampupBatchsizeNumMicroBatches
(
start_batch_size
,
batch_size_increment
,
ramup_samples
,
args
.
global_batch_size
,
args
.
micro_batch_size
,
args
.
data_parallel_size
)
return
num_microbatches_calculator
class
NumMicroBatchesCalculator
(
ABC
):
def
__init__
(
self
):
self
.
num_micro_batches
=
None
def
get
(
self
):
return
self
.
num_micro_batches
@
abstractmethod
def
update
(
self
,
consumed_samples
):
pass
class
ConstantNumMicroBatches
(
NumMicroBatchesCalculator
):
def
__init__
(
self
,
global_batch_size
,
micro_batch_size
,
data_parallel_size
):
micro_batch_times_data_parallel
=
micro_batch_size
*
\
data_parallel_size
assert
global_batch_size
%
micro_batch_times_data_parallel
==
0
,
\
'global batch size ({}) is not divisible by micro batch size ({})'
\
' times data parallel size ({})'
.
format
(
global_batch_size
,
micro_batch_size
,
data_parallel_size
)
self
.
num_micro_batches
=
global_batch_size
//
\
micro_batch_times_data_parallel
assert
self
.
num_micro_batches
>=
1
def
update
(
self
,
consumed_samples
):
pass
class
RampupBatchsizeNumMicroBatches
(
NumMicroBatchesCalculator
):
def
__init__
(
self
,
start_batch_size
,
batch_size_increment
,
ramup_samples
,
global_batch_size
,
micro_batch_size
,
data_parallel_size
):
"""Batch size ramp up.
Over
steps = (global-batch-size - start-batch-size) / batch_size_increment
increment batch size from start-batch-size to global-batch-size using
rampup-samples / steps
samples.
Arguments:
start_batch_size: global batch size to start with
batch_size_increment: global batch size increments
ramup_samples: number of samples to use ramp up global
batch size from `start_batch_size` to `global_batch_size`
global_batch_size: global batch size post rampup
micro_batch_size: micro batch size
data_parallel_size: data parallel size.
"""
self
.
micro_batch_size
=
micro_batch_size
self
.
data_parallel_size
=
data_parallel_size
self
.
micro_batch_times_data_parallel_size
=
self
.
micro_batch_size
*
\
self
.
data_parallel_size
assert
self
.
micro_batch_times_data_parallel_size
>
0
assert
start_batch_size
>
0
self
.
start_batch_size
=
start_batch_size
assert
global_batch_size
>
0
self
.
global_batch_size
=
global_batch_size
diff_batch_size
=
self
.
global_batch_size
-
self
.
start_batch_size
assert
diff_batch_size
>=
0
assert
batch_size_increment
>
0
self
.
batch_size_increment
=
batch_size_increment
assert
diff_batch_size
%
batch_size_increment
==
0
,
'expected '
\
'global batch size interval ({}) to be divisible by global batch '
\
'size increment ({})'
.
format
(
diff_batch_size
,
batch_size_increment
)
num_increments
=
diff_batch_size
//
self
.
batch_size_increment
self
.
ramup_samples
=
ramup_samples
assert
self
.
ramup_samples
>=
0
self
.
rampup_samples_per_increment
=
self
.
ramup_samples
/
num_increments
# Initialize number of microbatches.
self
.
update
(
0
)
def
update
(
self
,
consumed_samples
):
if
consumed_samples
>
self
.
ramup_samples
:
current_global_batch_size
=
self
.
global_batch_size
else
:
steps
=
int
(
consumed_samples
/
self
.
rampup_samples_per_increment
)
current_global_batch_size
=
self
.
start_batch_size
+
\
steps
*
self
.
batch_size_increment
assert
current_global_batch_size
<=
self
.
global_batch_size
assert
current_global_batch_size
%
\
self
.
micro_batch_times_data_parallel_size
==
0
,
'current global '
\
'batch size ({}) is not divisible by micro-batch-size ({}) times'
\
'data parallel size ({})'
.
format
(
current_global_batch_size
,
self
.
micro_batch_size
,
self
.
data_parallel_size
)
self
.
num_micro_batches
=
current_global_batch_size
//
\
self
.
micro_batch_times_data_parallel_size
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment