Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
de0b70a0
Commit
de0b70a0
authored
Dec 08, 2020
by
mohammad
Committed by
Deepak Narayanan
Dec 19, 2020
Browse files
Support for ramping up the batch size
parent
c30ba0f7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
67 additions
and
4 deletions
+67
-4
megatron/global_vars.py
megatron/global_vars.py
+67
-4
No files found.
megatron/global_vars.py
View file @
de0b70a0
...
...
@@ -17,10 +17,12 @@
from
abc
import
ABC
from
abc
import
abstractmethod
import
math
import
os
import
sys
import
time
import
numpy
as
np
import
torch
from
megatron.tokenizer
import
build_tokenizer
...
...
@@ -127,11 +129,11 @@ class NumMicroBatchesCalculator(ABC):
def
__init__
(
self
,
name
):
self
.
name
=
name
self
.
num_micro_batches
=
None
super
(
NumMicroBatchesCalculator
,
self
).
__init__
()
@
abstractmethod
def
get
(
self
):
pas
s
return
self
.
num_micro_batche
s
@
abstractmethod
def
update
(
self
,
consumed_samples
):
...
...
@@ -149,9 +151,70 @@ class ConstantNumMicroBatches(NumMicroBatchesCalculator):
def
update
(
self
,
consumed_samples
):
pass
def
get
(
self
):
return
self
.
num_micro_batches
class
RampupBatchsizeNumMicroBatches
(
NumMicroBatchesCalculator
):
def
__init__
(
self
,
start_batch_size
,
batch_size_increment
,
ramup_samples
,
global_batch_size
,
micro_batch_size
,
data_parallel_size
):
"""Batch size ramp up.
Over
steps = (global-batch-size - start-batch-size) / batch_size_increment
increment batch size from start-batch-size to global-batch-size using
rampup-samples / steps
samples.
Arguments:
start_batch_size: global batch size to start with
batch_size_increment: global batch size increments
ramup_samples: number of samples to use ramp up global
batch size from `start_batch_size` to `global_batch_size`
global_batch_size: global batch size post rampup
micro_batch_size: micro batch size
data_parallel_size: data parallel size.
"""
self
.
micro_batch_size
=
micro_batch_size
self
.
data_parallel_size
=
data_parallel_size
self
.
micro_batch_times_data_parallel_size
=
self
.
micro_batch_size
*
\
self
.
data_parallel_size
assert
self
.
micro_batch_times_data_parallel_size
>
0
assert
start_batch_size
>
0
self
.
start_batch_size
=
start_batch_size
assert
global_batch_size
>
0
self
.
global_batch_size
=
global_batch_size
diff_batch_size
=
self
.
global_batch_size
-
self
.
start_batch_size
assert
diff_batch_size
>=
0
assert
batch_size_increment
>
0
self
.
batch_size_increment
=
batch_size_increment
assert
diff_batch_size
%
batch_size_increment
==
0
,
'expected '
\
'global batch size interval ({}) to be divisible by global batch '
\
'size increment ({})'
.
format
(
diff_batch_size
,
batch_size_increment
)
num_increments
=
diff_batch_size
//
self
.
batch_size_increment
assert
ramup_samples
>=
0
self
.
rampup_samples_per_increment
=
ramup_samples
/
num_increments
# Initialize number of microbatches.
self
.
update
(
0
)
def
update
(
self
,
consumed_samples
):
steps
=
int
(
consumed_samples
/
self
.
rampup_samples_per_increment
)
current_global_batch_size
=
self
.
start_batch_size
+
\
steps
*
self
.
batch_size_increment
current_global_batch_size
=
min
(
current_global_batch_size
,
self
.
global_batch_size
)
assert
current_global_batch_size
%
\
self
.
micro_batch_times_data_parallel_size
==
0
,
'current global '
\
'batch size ({}) is not divisible by micro-batch-size ({}) times'
\
'data parallel size ({})'
.
format
(
current_global_batch_size
,
self
.
micro_batch_size
,
self
.
data_parallel_size
)
self
.
num_micro_batches
=
current_global_batch_size
//
\
self
.
micro_batch_times_data_parallel_size
def
_build_tokenizer
(
args
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment