Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
de0b70a0
Commit
de0b70a0
authored
Dec 08, 2020
by
mohammad
Committed by
Deepak Narayanan
Dec 19, 2020
Browse files
Support for ramping up the batch size
parent
c30ba0f7
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
67 additions
and
4 deletions
+67
-4
megatron/global_vars.py
megatron/global_vars.py
+67
-4
No files found.
megatron/global_vars.py
View file @
de0b70a0
...
@@ -17,10 +17,12 @@
...
@@ -17,10 +17,12 @@
from
abc
import
ABC
from
abc
import
ABC
from
abc
import
abstractmethod
from
abc
import
abstractmethod
import
math
import
os
import
os
import
sys
import
sys
import
time
import
time
import
numpy
as
np
import
torch
import
torch
from
megatron.tokenizer
import
build_tokenizer
from
megatron.tokenizer
import
build_tokenizer
...
@@ -127,11 +129,11 @@ class NumMicroBatchesCalculator(ABC):
...
@@ -127,11 +129,11 @@ class NumMicroBatchesCalculator(ABC):
def
__init__
(
self
,
name
):
def
__init__
(
self
,
name
):
self
.
name
=
name
self
.
name
=
name
self
.
num_micro_batches
=
None
super
(
NumMicroBatchesCalculator
,
self
).
__init__
()
super
(
NumMicroBatchesCalculator
,
self
).
__init__
()
@
abstractmethod
def
get
(
self
):
def
get
(
self
):
pas
s
return
self
.
num_micro_batche
s
@
abstractmethod
@
abstractmethod
def
update
(
self
,
consumed_samples
):
def
update
(
self
,
consumed_samples
):
...
@@ -149,9 +151,70 @@ class ConstantNumMicroBatches(NumMicroBatchesCalculator):
...
@@ -149,9 +151,70 @@ class ConstantNumMicroBatches(NumMicroBatchesCalculator):
def
update
(
self
,
consumed_samples
):
def
update
(
self
,
consumed_samples
):
pass
pass
def
get
(
self
):
return
self
.
num_micro_batches
class
RampupBatchsizeNumMicroBatches
(
NumMicroBatchesCalculator
):
def
__init__
(
self
,
start_batch_size
,
batch_size_increment
,
ramup_samples
,
global_batch_size
,
micro_batch_size
,
data_parallel_size
):
"""Batch size ramp up.
Over
steps = (global-batch-size - start-batch-size) / batch_size_increment
increment batch size from start-batch-size to global-batch-size using
rampup-samples / steps
samples.
Arguments:
start_batch_size: global batch size to start with
batch_size_increment: global batch size increments
ramup_samples: number of samples to use ramp up global
batch size from `start_batch_size` to `global_batch_size`
global_batch_size: global batch size post rampup
micro_batch_size: micro batch size
data_parallel_size: data parallel size.
"""
self
.
micro_batch_size
=
micro_batch_size
self
.
data_parallel_size
=
data_parallel_size
self
.
micro_batch_times_data_parallel_size
=
self
.
micro_batch_size
*
\
self
.
data_parallel_size
assert
self
.
micro_batch_times_data_parallel_size
>
0
assert
start_batch_size
>
0
self
.
start_batch_size
=
start_batch_size
assert
global_batch_size
>
0
self
.
global_batch_size
=
global_batch_size
diff_batch_size
=
self
.
global_batch_size
-
self
.
start_batch_size
assert
diff_batch_size
>=
0
assert
batch_size_increment
>
0
self
.
batch_size_increment
=
batch_size_increment
assert
diff_batch_size
%
batch_size_increment
==
0
,
'expected '
\
'global batch size interval ({}) to be divisible by global batch '
\
'size increment ({})'
.
format
(
diff_batch_size
,
batch_size_increment
)
num_increments
=
diff_batch_size
//
self
.
batch_size_increment
assert
ramup_samples
>=
0
self
.
rampup_samples_per_increment
=
ramup_samples
/
num_increments
# Initialize number of microbatches.
self
.
update
(
0
)
def
update
(
self
,
consumed_samples
):
steps
=
int
(
consumed_samples
/
self
.
rampup_samples_per_increment
)
current_global_batch_size
=
self
.
start_batch_size
+
\
steps
*
self
.
batch_size_increment
current_global_batch_size
=
min
(
current_global_batch_size
,
self
.
global_batch_size
)
assert
current_global_batch_size
%
\
self
.
micro_batch_times_data_parallel_size
==
0
,
'current global '
\
'batch size ({}) is not divisible by micro-batch-size ({}) times'
\
'data parallel size ({})'
.
format
(
current_global_batch_size
,
self
.
micro_batch_size
,
self
.
data_parallel_size
)
self
.
num_micro_batches
=
current_global_batch_size
//
\
self
.
micro_batch_times_data_parallel_size
def
_build_tokenizer
(
args
):
def
_build_tokenizer
(
args
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment