Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
00ac56ab
"vscode:/vscode.git/clone" did not exist on "b18547fdeb6b9e13bc304fd8ba2167663cafdbe9"
Commit
00ac56ab
authored
Dec 09, 2020
by
mohammad
Committed by
Deepak Narayanan
Dec 19, 2020
Browse files
Address Jared's comments
parent
22ab91bb
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
158 additions
and
132 deletions
+158
-132
megatron/arguments.py
megatron/arguments.py
+4
-2
megatron/global_vars.py
megatron/global_vars.py
+3
-130
megatron/microbatches.py
megatron/microbatches.py
+151
-0
No files found.
megatron/arguments.py
View file @
00ac56ab
...
@@ -264,9 +264,11 @@ def _add_training_args(parser):
...
@@ -264,9 +264,11 @@ def _add_training_args(parser):
group
.
add_argument
(
'--micro-batch-size'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--micro-batch-size'
,
type
=
int
,
default
=
None
,
help
=
'Batch size per model instance (local batch size). '
help
=
'Batch size per model instance (local batch size). '
'Global batch size is local batch size times data '
'Global batch size is local batch size times data '
'parallel size.'
)
'parallel size
times number of micro batches
.'
)
group
.
add_argument
(
'--global-batch-size'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--global-batch-size'
,
type
=
int
,
default
=
None
,
help
=
'Training batch size. If this value is None, then '
help
=
'Training batch size. If set, it should be a '
'multiple of micro-batch-size times data-parallel-size. '
'If this value is None, then '
'use micro-batch-size * data-parallel-size as the '
'use micro-batch-size * data-parallel-size as the '
'global batch size. This choice will result in 1 for '
'global batch size. This choice will result in 1 for '
'number of micro-batches.'
)
'number of micro-batches.'
)
...
...
megatron/global_vars.py
View file @
00ac56ab
...
@@ -15,18 +15,15 @@
...
@@ -15,18 +15,15 @@
"""Megatron global variables."""
"""Megatron global variables."""
from
abc
import
ABC
from
abc
import
abstractmethod
import
math
import
os
import
os
import
sys
import
sys
import
time
import
time
import
numpy
as
np
import
torch
import
torch
from
megatron.tokenizer
import
build_tokenizer
from
megatron.tokenizer
import
build_tokenizer
from
.arguments
import
parse_args
from
.arguments
import
parse_args
from
.microbatches
import
build_num_microbatches_calculator
_GLOBAL_ARGS
=
None
_GLOBAL_ARGS
=
None
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
None
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
None
...
@@ -104,132 +101,8 @@ def _build_num_microbatches_calculator(args):
...
@@ -104,132 +101,8 @@ def _build_num_microbatches_calculator(args):
_ensure_var_is_not_initialized
(
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
,
_ensure_var_is_not_initialized
(
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
,
'num microbatches calculator'
)
'num microbatches calculator'
)
# Constant num micro-batches.
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
build_num_microbatches_calculator
(
if
args
.
rampup_batch_size
is
None
:
args
)
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
ConstantNumMicroBatches
(
args
.
global_batch_size
,
args
.
micro_batch_size
,
args
.
data_parallel_size
)
if
args
.
rank
==
0
:
print
(
'setting number of micro-batches to constant {}'
.
format
(
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
get
()),
flush
=
True
)
else
:
assert
len
(
args
.
rampup_batch_size
)
==
3
,
'expected the following '
\
'format: --rampup-batch-size <start batch size> '
\
'<batch size incerement> <ramp-up samples>'
start_batch_size
=
int
(
args
.
rampup_batch_size
[
0
])
batch_size_increment
=
int
(
args
.
rampup_batch_size
[
1
])
ramup_samples
=
int
(
args
.
rampup_batch_size
[
2
])
if
args
.
rank
==
0
:
print
(
'will use batch size rampup starting from global batch '
'size {} to global batch size {} with batch size increments '
'{} over {} samples.'
.
format
(
start_batch_size
,
args
.
global_batch_size
,
batch_size_increment
,
ramup_samples
),
flush
=
True
)
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
RampupBatchsizeNumMicroBatches
(
start_batch_size
,
batch_size_increment
,
ramup_samples
,
args
.
global_batch_size
,
args
.
micro_batch_size
,
args
.
data_parallel_size
)
class
NumMicroBatchesCalculator
(
ABC
):
def
__init__
(
self
):
self
.
num_micro_batches
=
None
def
get
(
self
):
return
self
.
num_micro_batches
@
abstractmethod
def
update
(
self
,
consumed_samples
):
pass
class
ConstantNumMicroBatches
(
NumMicroBatchesCalculator
):
def
__init__
(
self
,
global_batch_size
,
micro_batch_size
,
data_parallel_size
):
micro_batch_times_data_parallel
=
micro_batch_size
*
\
data_parallel_size
assert
global_batch_size
%
micro_batch_times_data_parallel
==
0
,
\
'global batch size ({}) is not divisible by micro batch size ({})'
\
' times data parallel size ({})'
.
format
(
global_batch_size
,
micro_batch_size
,
data_parallel_size
)
self
.
num_micro_batches
=
global_batch_size
//
\
micro_batch_times_data_parallel
assert
self
.
num_micro_batches
>=
1
def
update
(
self
,
consumed_samples
):
pass
class
RampupBatchsizeNumMicroBatches
(
NumMicroBatchesCalculator
):
def
__init__
(
self
,
start_batch_size
,
batch_size_increment
,
ramup_samples
,
global_batch_size
,
micro_batch_size
,
data_parallel_size
):
"""Batch size ramp up.
Over
steps = (global-batch-size - start-batch-size) / batch_size_increment
increment batch size from start-batch-size to global-batch-size using
rampup-samples / steps
samples.
Arguments:
start_batch_size: global batch size to start with
batch_size_increment: global batch size increments
ramup_samples: number of samples to use ramp up global
batch size from `start_batch_size` to `global_batch_size`
global_batch_size: global batch size post rampup
micro_batch_size: micro batch size
data_parallel_size: data parallel size.
"""
self
.
micro_batch_size
=
micro_batch_size
self
.
data_parallel_size
=
data_parallel_size
self
.
micro_batch_times_data_parallel_size
=
self
.
micro_batch_size
*
\
self
.
data_parallel_size
assert
self
.
micro_batch_times_data_parallel_size
>
0
assert
start_batch_size
>
0
self
.
start_batch_size
=
start_batch_size
assert
global_batch_size
>
0
self
.
global_batch_size
=
global_batch_size
diff_batch_size
=
self
.
global_batch_size
-
self
.
start_batch_size
assert
diff_batch_size
>=
0
assert
batch_size_increment
>
0
self
.
batch_size_increment
=
batch_size_increment
assert
diff_batch_size
%
batch_size_increment
==
0
,
'expected '
\
'global batch size interval ({}) to be divisible by global batch '
\
'size increment ({})'
.
format
(
diff_batch_size
,
batch_size_increment
)
num_increments
=
diff_batch_size
//
self
.
batch_size_increment
self
.
ramup_samples
=
ramup_samples
assert
self
.
ramup_samples
>=
0
self
.
rampup_samples_per_increment
=
self
.
ramup_samples
/
num_increments
# Initialize number of microbatches.
self
.
update
(
0
)
def
update
(
self
,
consumed_samples
):
if
consumed_samples
>
self
.
ramup_samples
:
current_global_batch_size
=
self
.
global_batch_size
else
:
steps
=
int
(
consumed_samples
/
self
.
rampup_samples_per_increment
)
current_global_batch_size
=
self
.
start_batch_size
+
\
steps
*
self
.
batch_size_increment
assert
current_global_batch_size
<=
self
.
global_batch_size
assert
current_global_batch_size
%
\
self
.
micro_batch_times_data_parallel_size
==
0
,
'current global '
\
'batch size ({}) is not divisible by micro-batch-size ({}) times'
\
'data parallel size ({})'
.
format
(
current_global_batch_size
,
self
.
micro_batch_size
,
self
.
data_parallel_size
)
self
.
num_micro_batches
=
current_global_batch_size
//
\
self
.
micro_batch_times_data_parallel_size
def
_build_tokenizer
(
args
):
def
_build_tokenizer
(
args
):
...
...
megatron/microbatches.py
0 → 100644
View file @
00ac56ab
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron number of micro-batches calculators."""
from
abc
import
ABC
from
abc
import
abstractmethod
def
build_num_microbatches_calculator
(
args
):
# Constant num micro-batches.
if
args
.
rampup_batch_size
is
None
:
num_microbatches_calculator
=
ConstantNumMicroBatches
(
args
.
global_batch_size
,
args
.
micro_batch_size
,
args
.
data_parallel_size
)
if
args
.
rank
==
0
:
print
(
'setting number of micro-batches to constant {}'
.
format
(
num_microbatches_calculator
.
get
()),
flush
=
True
)
else
:
assert
len
(
args
.
rampup_batch_size
)
==
3
,
'expected the following '
\
'format: --rampup-batch-size <start batch size> '
\
'<batch size incerement> <ramp-up samples>'
start_batch_size
=
int
(
args
.
rampup_batch_size
[
0
])
batch_size_increment
=
int
(
args
.
rampup_batch_size
[
1
])
ramup_samples
=
int
(
args
.
rampup_batch_size
[
2
])
if
args
.
rank
==
0
:
print
(
'will use batch size rampup starting from global batch '
'size {} to global batch size {} with batch size increments '
'{} over {} samples.'
.
format
(
start_batch_size
,
args
.
global_batch_size
,
batch_size_increment
,
ramup_samples
),
flush
=
True
)
num_microbatches_calculator
=
RampupBatchsizeNumMicroBatches
(
start_batch_size
,
batch_size_increment
,
ramup_samples
,
args
.
global_batch_size
,
args
.
micro_batch_size
,
args
.
data_parallel_size
)
return
num_microbatches_calculator
class
NumMicroBatchesCalculator
(
ABC
):
def
__init__
(
self
):
self
.
num_micro_batches
=
None
def
get
(
self
):
return
self
.
num_micro_batches
@
abstractmethod
def
update
(
self
,
consumed_samples
):
pass
class
ConstantNumMicroBatches
(
NumMicroBatchesCalculator
):
def
__init__
(
self
,
global_batch_size
,
micro_batch_size
,
data_parallel_size
):
micro_batch_times_data_parallel
=
micro_batch_size
*
\
data_parallel_size
assert
global_batch_size
%
micro_batch_times_data_parallel
==
0
,
\
'global batch size ({}) is not divisible by micro batch size ({})'
\
' times data parallel size ({})'
.
format
(
global_batch_size
,
micro_batch_size
,
data_parallel_size
)
self
.
num_micro_batches
=
global_batch_size
//
\
micro_batch_times_data_parallel
assert
self
.
num_micro_batches
>=
1
def
update
(
self
,
consumed_samples
):
pass
class
RampupBatchsizeNumMicroBatches
(
NumMicroBatchesCalculator
):
def
__init__
(
self
,
start_batch_size
,
batch_size_increment
,
ramup_samples
,
global_batch_size
,
micro_batch_size
,
data_parallel_size
):
"""Batch size ramp up.
Over
steps = (global-batch-size - start-batch-size) / batch_size_increment
increment batch size from start-batch-size to global-batch-size using
rampup-samples / steps
samples.
Arguments:
start_batch_size: global batch size to start with
batch_size_increment: global batch size increments
ramup_samples: number of samples to use ramp up global
batch size from `start_batch_size` to `global_batch_size`
global_batch_size: global batch size post rampup
micro_batch_size: micro batch size
data_parallel_size: data parallel size.
"""
self
.
micro_batch_size
=
micro_batch_size
self
.
data_parallel_size
=
data_parallel_size
self
.
micro_batch_times_data_parallel_size
=
self
.
micro_batch_size
*
\
self
.
data_parallel_size
assert
self
.
micro_batch_times_data_parallel_size
>
0
assert
start_batch_size
>
0
self
.
start_batch_size
=
start_batch_size
assert
global_batch_size
>
0
self
.
global_batch_size
=
global_batch_size
diff_batch_size
=
self
.
global_batch_size
-
self
.
start_batch_size
assert
diff_batch_size
>=
0
assert
batch_size_increment
>
0
self
.
batch_size_increment
=
batch_size_increment
assert
diff_batch_size
%
batch_size_increment
==
0
,
'expected '
\
'global batch size interval ({}) to be divisible by global batch '
\
'size increment ({})'
.
format
(
diff_batch_size
,
batch_size_increment
)
num_increments
=
diff_batch_size
//
self
.
batch_size_increment
self
.
ramup_samples
=
ramup_samples
assert
self
.
ramup_samples
>=
0
self
.
rampup_samples_per_increment
=
self
.
ramup_samples
/
num_increments
# Initialize number of microbatches.
self
.
update
(
0
)
def
update
(
self
,
consumed_samples
):
if
consumed_samples
>
self
.
ramup_samples
:
current_global_batch_size
=
self
.
global_batch_size
else
:
steps
=
int
(
consumed_samples
/
self
.
rampup_samples_per_increment
)
current_global_batch_size
=
self
.
start_batch_size
+
\
steps
*
self
.
batch_size_increment
assert
current_global_batch_size
<=
self
.
global_batch_size
assert
current_global_batch_size
%
\
self
.
micro_batch_times_data_parallel_size
==
0
,
'current global '
\
'batch size ({}) is not divisible by micro-batch-size ({}) times'
\
'data parallel size ({})'
.
format
(
current_global_batch_size
,
self
.
micro_batch_size
,
self
.
data_parallel_size
)
self
.
num_micro_batches
=
current_global_batch_size
//
\
self
.
micro_batch_times_data_parallel_size
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment