Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
6a68502d
Commit
6a68502d
authored
Dec 08, 2020
by
mohammad
Committed by
Deepak Narayanan
Dec 19, 2020
Browse files
Minor fixes for batch size rampup
parent
de0b70a0
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
4 deletions
+24
-4
megatron/global_vars.py
megatron/global_vars.py
+24
-4
No files found.
megatron/global_vars.py
View file @
6a68502d
...
@@ -120,9 +120,25 @@ def _build_num_microbatches_calculator(args):
...
@@ -120,9 +120,25 @@ def _build_num_microbatches_calculator(args):
num_micro_batches
),
flush
=
True
)
num_micro_batches
),
flush
=
True
)
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
ConstantNumMicroBatches
(
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
ConstantNumMicroBatches
(
num_micro_batches
)
num_micro_batches
)
return
raise
Exception
(
'should not be here.'
)
else
:
assert
len
(
args
.
rampup_batch_size
)
==
3
,
'expected the following '
\
'format: --rampup-batch-size <start batch size> '
\
'<batch size incerement> <ramp-up samples>'
start_batch_size
=
int
(
args
.
rampup_batch_size
[
0
])
batch_size_increment
=
int
(
args
.
rampup_batch_size
[
1
])
ramup_samples
=
int
(
args
.
rampup_batch_size
[
2
])
if
args
.
rank
==
0
:
print
(
'will use batch size rampup starting from global batch '
'size {} to global batch size {} with batch size increments '
'{} over {} samples.'
.
format
(
start_batch_size
,
args
.
global_batch_size
,
batch_size_increment
,
ramup_samples
),
flush
=
True
)
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
RampupBatchsizeNumMicroBatches
(
start_batch_size
,
batch_size_increment
,
ramup_samples
,
args
.
global_batch_size
,
args
.
micro_batch_size
,
args
.
data_parallel_size
)
class
NumMicroBatchesCalculator
(
ABC
):
class
NumMicroBatchesCalculator
(
ABC
):
...
@@ -143,10 +159,10 @@ class NumMicroBatchesCalculator(ABC):
...
@@ -143,10 +159,10 @@ class NumMicroBatchesCalculator(ABC):
class
ConstantNumMicroBatches
(
NumMicroBatchesCalculator
):
class
ConstantNumMicroBatches
(
NumMicroBatchesCalculator
):
def
__init__
(
self
,
num_micro_batches
=
1
):
def
__init__
(
self
,
num_micro_batches
=
1
):
super
(
ConstantNumMicroBatches
,
self
).
__init__
(
'constant: {}'
.
format
(
num_micro_batches
))
assert
num_micro_batches
>=
1
assert
num_micro_batches
>=
1
self
.
num_micro_batches
=
num_micro_batches
self
.
num_micro_batches
=
num_micro_batches
super
(
ConstantNumMicroBatches
,
self
).
__init__
(
'constant: {}'
.
format
(
self
.
num_micro_batches
))
def
update
(
self
,
consumed_samples
):
def
update
(
self
,
consumed_samples
):
pass
pass
...
@@ -172,6 +188,10 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
...
@@ -172,6 +188,10 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
data_parallel_size: data parallel size.
data_parallel_size: data parallel size.
"""
"""
super
(
RampupBatchsizeNumMicroBatches
,
self
).
__init__
(
'batch size ramup: {}, {}, {}'
.
format
(
start_batch_size
,
batch_size_increment
,
ramup_samples
))
self
.
micro_batch_size
=
micro_batch_size
self
.
micro_batch_size
=
micro_batch_size
self
.
data_parallel_size
=
data_parallel_size
self
.
data_parallel_size
=
data_parallel_size
self
.
micro_batch_times_data_parallel_size
=
self
.
micro_batch_size
*
\
self
.
micro_batch_times_data_parallel_size
=
self
.
micro_batch_size
*
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment