Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
22ab91bb
Commit
22ab91bb
authored
Dec 08, 2020
by
mohammad
Committed by
Deepak Narayanan
Dec 19, 2020
Browse files
Sample based learning rate computation
parent
6a68502d
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
142 additions
and
62 deletions
+142
-62
megatron/arguments.py
megatron/arguments.py
+39
-4
megatron/global_vars.py
megatron/global_vars.py
+26
-31
megatron/learning_rates.py
megatron/learning_rates.py
+8
-11
megatron/training.py
megatron/training.py
+69
-16
No files found.
megatron/arguments.py
View file @
22ab91bb
...
@@ -125,6 +125,30 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -125,6 +125,30 @@ def parse_args(extra_args_provider=None, defaults={},
else
:
else
:
setattr
(
args
,
key
,
defaults
[
key
])
setattr
(
args
,
key
,
defaults
[
key
])
# Iteration-based training.
if
args
.
train_iters
:
# If we use iteration-based training, make sure the
# sample-based options are off.
assert
args
.
train_samples
is
None
,
\
'expected iteration-based training'
assert
args
.
lr_decay_samples
is
None
,
\
'expected iteration-based learning rate decay'
assert
args
.
lr_warmup_samples
==
0
,
\
'expected iteration-based learnig rate warmup'
assert
args
.
rampup_batch_size
is
None
,
\
'expected no batch-size rampup for iteration-based training'
# Sample-based training.
if
args
.
train_samples
:
# If we use sample-based training, make sure the
# iteration-based options are off.
assert
args
.
train_iters
is
None
,
\
'expected sample-based training'
assert
args
.
lr_decay_iters
is
None
,
\
'expected sample-based learning rate decay'
assert
args
.
lr_warmup_iters
==
0
,
\
'expected sample-based learnig rate warmup'
# Check required arguments.
# Check required arguments.
required_args
=
[
'num_layers'
,
'hidden_size'
,
'num_attention_heads'
,
required_args
=
[
'num_layers'
,
'hidden_size'
,
'num_attention_heads'
,
'max_position_embeddings'
]
'max_position_embeddings'
]
...
@@ -269,7 +293,12 @@ def _add_training_args(parser):
...
@@ -269,7 +293,12 @@ def _add_training_args(parser):
help
=
'chunk size (number of layers) for checkpointing.'
)
help
=
'chunk size (number of layers) for checkpointing.'
)
group
.
add_argument
(
'--train-iters'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--train-iters'
,
type
=
int
,
default
=
None
,
help
=
'Total number of iterations to train over all '
help
=
'Total number of iterations to train over all '
'training runs.'
)
'training runs. Note that either train-iters or '
'train-samples should be provided.'
)
group
.
add_argument
(
'--train-samples'
,
type
=
int
,
default
=
None
,
help
=
'Total number of samples to train over all '
'training runs. Note that either train-iters or '
'train-samples should be provided.'
)
group
.
add_argument
(
'--log-interval'
,
type
=
int
,
default
=
100
,
group
.
add_argument
(
'--log-interval'
,
type
=
int
,
default
=
100
,
help
=
'Report loss and timing interval.'
)
help
=
'Report loss and timing interval.'
)
group
.
add_argument
(
'--exit-interval'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--exit-interval'
,
type
=
int
,
default
=
None
,
...
@@ -319,12 +348,18 @@ def _add_learning_rate_args(parser):
...
@@ -319,12 +348,18 @@ def _add_learning_rate_args(parser):
group
.
add_argument
(
'--lr-decay-iters'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--lr-decay-iters'
,
type
=
int
,
default
=
None
,
help
=
'number of iterations to decay learning rate over,'
help
=
'number of iterations to decay learning rate over,'
' If None defaults to `--train-iters`'
)
' If None defaults to `--train-iters`'
)
group
.
add_argument
(
'--lr-decay-samples'
,
type
=
int
,
default
=
None
,
help
=
'number of samples to decay learning rate over,'
' If None defaults to `--train-samples`'
)
group
.
add_argument
(
'--lr-warmup-iters'
,
type
=
int
,
default
=
0
,
help
=
'number of iterations to linearly warmup '
'learning rate over.'
)
group
.
add_argument
(
'--lr-warmup-samples'
,
type
=
int
,
default
=
0
,
help
=
'number of samples to linearly warmup '
'learning rate over.'
)
group
.
add_argument
(
'--min-lr'
,
type
=
float
,
default
=
0.0
,
group
.
add_argument
(
'--min-lr'
,
type
=
float
,
default
=
0.0
,
help
=
'Minumum value for learning rate. The scheduler'
help
=
'Minumum value for learning rate. The scheduler'
'clip values below this threshold.'
)
'clip values below this threshold.'
)
group
.
add_argument
(
'--warmup'
,
type
=
float
,
default
=
0.01
,
help
=
'Percentage of total iterations to warmup on '
'(.01 = 1 percent of all training iters).'
)
group
.
add_argument
(
'--override-lr-scheduler'
,
action
=
'store_true'
,
group
.
add_argument
(
'--override-lr-scheduler'
,
action
=
'store_true'
,
help
=
'Reset the values of the scheduler (learning rate,'
help
=
'Reset the values of the scheduler (learning rate,'
'warmup iterations, minimum learning rate, maximum '
'warmup iterations, minimum learning rate, maximum '
...
...
megatron/global_vars.py
View file @
22ab91bb
...
@@ -106,20 +106,12 @@ def _build_num_microbatches_calculator(args):
...
@@ -106,20 +106,12 @@ def _build_num_microbatches_calculator(args):
# Constant num micro-batches.
# Constant num micro-batches.
if
args
.
rampup_batch_size
is
None
:
if
args
.
rampup_batch_size
is
None
:
micro_batch_times_data_parallel
=
args
.
micro_batch_size
*
\
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
ConstantNumMicroBatches
(
args
.
data_parallel_size
args
.
global_batch_size
,
args
.
micro_batch_size
,
assert
args
.
global_batch_size
%
micro_batch_times_data_parallel
==
0
,
\
'global batch size ({}) is not divisible by micro batch size ({})'
\
' times data parallel size ({})'
.
format
(
args
.
global_batch_size
,
args
.
micro_batch_size
,
args
.
data_parallel_size
)
args
.
data_parallel_size
)
num_micro_batches
=
args
.
global_batch_size
//
\
micro_batch_times_data_parallel
if
args
.
rank
==
0
:
if
args
.
rank
==
0
:
print
(
'setting number of micro-batches to constant {}'
.
format
(
print
(
'setting number of micro-batches to constant {}'
.
format
(
num_micro_batches
),
flush
=
True
)
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
get
()),
flush
=
True
)
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
ConstantNumMicroBatches
(
num_micro_batches
)
else
:
else
:
assert
len
(
args
.
rampup_batch_size
)
==
3
,
'expected the following '
\
assert
len
(
args
.
rampup_batch_size
)
==
3
,
'expected the following '
\
...
@@ -143,10 +135,8 @@ def _build_num_microbatches_calculator(args):
...
@@ -143,10 +135,8 @@ def _build_num_microbatches_calculator(args):
class
NumMicroBatchesCalculator
(
ABC
):
class
NumMicroBatchesCalculator
(
ABC
):
def
__init__
(
self
,
name
):
def
__init__
(
self
):
self
.
name
=
name
self
.
num_micro_batches
=
None
self
.
num_micro_batches
=
None
super
(
NumMicroBatchesCalculator
,
self
).
__init__
()
def
get
(
self
):
def
get
(
self
):
return
self
.
num_micro_batches
return
self
.
num_micro_batches
...
@@ -158,11 +148,17 @@ class NumMicroBatchesCalculator(ABC):
...
@@ -158,11 +148,17 @@ class NumMicroBatchesCalculator(ABC):
class
ConstantNumMicroBatches
(
NumMicroBatchesCalculator
):
class
ConstantNumMicroBatches
(
NumMicroBatchesCalculator
):
def
__init__
(
self
,
num_micro_batches
=
1
):
def
__init__
(
self
,
global_batch_size
,
micro_batch_size
,
data_parallel_size
):
super
(
ConstantNumMicroBatches
,
self
).
__init__
(
micro_batch_times_data_parallel
=
micro_batch_size
*
\
'constant: {}'
.
format
(
num_micro_batches
))
data_parallel_size
assert
num_micro_batches
>=
1
assert
global_batch_size
%
micro_batch_times_data_parallel
==
0
,
\
self
.
num_micro_batches
=
num_micro_batches
'global batch size ({}) is not divisible by micro batch size ({})'
\
' times data parallel size ({})'
.
format
(
global_batch_size
,
micro_batch_size
,
data_parallel_size
)
self
.
num_micro_batches
=
global_batch_size
//
\
micro_batch_times_data_parallel
assert
self
.
num_micro_batches
>=
1
def
update
(
self
,
consumed_samples
):
def
update
(
self
,
consumed_samples
):
pass
pass
...
@@ -188,10 +184,6 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
...
@@ -188,10 +184,6 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
data_parallel_size: data parallel size.
data_parallel_size: data parallel size.
"""
"""
super
(
RampupBatchsizeNumMicroBatches
,
self
).
__init__
(
'batch size ramup: {}, {}, {}'
.
format
(
start_batch_size
,
batch_size_increment
,
ramup_samples
))
self
.
micro_batch_size
=
micro_batch_size
self
.
micro_batch_size
=
micro_batch_size
self
.
data_parallel_size
=
data_parallel_size
self
.
data_parallel_size
=
data_parallel_size
self
.
micro_batch_times_data_parallel_size
=
self
.
micro_batch_size
*
\
self
.
micro_batch_times_data_parallel_size
=
self
.
micro_batch_size
*
\
...
@@ -212,8 +204,9 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
...
@@ -212,8 +204,9 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
'size increment ({})'
.
format
(
diff_batch_size
,
batch_size_increment
)
'size increment ({})'
.
format
(
diff_batch_size
,
batch_size_increment
)
num_increments
=
diff_batch_size
//
self
.
batch_size_increment
num_increments
=
diff_batch_size
//
self
.
batch_size_increment
assert
ramup_samples
>=
0
self
.
ramup_samples
=
ramup_samples
self
.
rampup_samples_per_increment
=
ramup_samples
/
num_increments
assert
self
.
ramup_samples
>=
0
self
.
rampup_samples_per_increment
=
self
.
ramup_samples
/
num_increments
# Initialize number of microbatches.
# Initialize number of microbatches.
self
.
update
(
0
)
self
.
update
(
0
)
...
@@ -221,11 +214,13 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
...
@@ -221,11 +214,13 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
def
update
(
self
,
consumed_samples
):
def
update
(
self
,
consumed_samples
):
if
consumed_samples
>
self
.
ramup_samples
:
current_global_batch_size
=
self
.
global_batch_size
else
:
steps
=
int
(
consumed_samples
/
self
.
rampup_samples_per_increment
)
steps
=
int
(
consumed_samples
/
self
.
rampup_samples_per_increment
)
current_global_batch_size
=
self
.
start_batch_size
+
\
current_global_batch_size
=
self
.
start_batch_size
+
\
steps
*
self
.
batch_size_increment
steps
*
self
.
batch_size_increment
current_global_batch_size
=
min
(
current_global_batch_size
,
assert
current_global_batch_size
<=
self
.
global_batch_size
self
.
global_batch_size
)
assert
current_global_batch_size
%
\
assert
current_global_batch_size
%
\
self
.
micro_batch_times_data_parallel_size
==
0
,
'current global '
\
self
.
micro_batch_times_data_parallel_size
==
0
,
'current global '
\
...
...
megatron/learning_rates.py
View file @
22ab91bb
...
@@ -23,8 +23,7 @@ class AnnealingLR(object):
...
@@ -23,8 +23,7 @@ class AnnealingLR(object):
"""Anneals the learning rate."""
"""Anneals the learning rate."""
def
__init__
(
self
,
optimizer
,
max_lr
,
min_lr
,
def
__init__
(
self
,
optimizer
,
max_lr
,
min_lr
,
warmup_steps
,
decay_steps
,
warmup_steps
,
decay_steps
,
decay_style
,
decay_style
,
num_steps
,
use_checkpoint_lr_scheduler
=
True
,
use_checkpoint_lr_scheduler
=
True
,
override_lr_scheduler
=
False
):
override_lr_scheduler
=
False
):
...
@@ -37,7 +36,7 @@ class AnnealingLR(object):
...
@@ -37,7 +36,7 @@ class AnnealingLR(object):
assert
self
.
max_lr
>=
self
.
min_lr
assert
self
.
max_lr
>=
self
.
min_lr
self
.
warmup_steps
=
warmup_steps
self
.
warmup_steps
=
warmup_steps
self
.
num_steps
=
num_steps
self
.
num_steps
=
0
self
.
decay_steps
=
decay_steps
self
.
decay_steps
=
decay_steps
assert
self
.
decay_steps
>
0
assert
self
.
decay_steps
>
0
assert
self
.
warmup_steps
<
self
.
decay_steps
assert
self
.
warmup_steps
<
self
.
decay_steps
...
@@ -51,7 +50,7 @@ class AnnealingLR(object):
...
@@ -51,7 +50,7 @@ class AnnealingLR(object):
'use-checkpoint are set.'
'use-checkpoint are set.'
# Set the learning rate
# Set the learning rate
self
.
step
(
step_num
=
self
.
num_steps
)
self
.
step
(
0
)
print_rank_0
(
'> learning rate decay style: {}'
.
format
(
self
.
decay_style
))
print_rank_0
(
'> learning rate decay style: {}'
.
format
(
self
.
decay_style
))
...
@@ -92,11 +91,9 @@ class AnnealingLR(object):
...
@@ -92,11 +91,9 @@ class AnnealingLR(object):
return
self
.
min_lr
+
coeff
*
delta_lr
return
self
.
min_lr
+
coeff
*
delta_lr
def
step
(
self
,
increment
=
1
,
step_num
=
None
):
def
step
(
self
,
increment
):
"""Set lr for all parameters groups."""
"""Set lr for all parameters groups."""
if
step_num
is
None
:
self
.
num_steps
+=
increment
step_num
=
self
.
num_steps
+
increment
self
.
num_steps
=
step_num
new_lr
=
self
.
get_lr
()
new_lr
=
self
.
get_lr
()
for
group
in
self
.
optimizer
.
param_groups
:
for
group
in
self
.
optimizer
.
param_groups
:
group
[
'lr'
]
=
new_lr
group
[
'lr'
]
=
new_lr
...
@@ -160,7 +157,7 @@ class AnnealingLR(object):
...
@@ -160,7 +157,7 @@ class AnnealingLR(object):
'decay style'
)
'decay style'
)
if
'num_iters'
in
sd
:
if
'num_iters'
in
sd
:
self
.
num_steps
=
sd
[
'num_iters'
]
num_steps
=
sd
[
'num_iters'
]
else
:
else
:
self
.
num_steps
=
sd
[
'num_steps'
]
num_steps
=
sd
[
'num_steps'
]
self
.
step
(
step_num
=
self
.
num_steps
)
self
.
step
(
increment
=
num_steps
)
megatron/training.py
View file @
22ab91bb
...
@@ -116,6 +116,37 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
...
@@ -116,6 +116,37 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
test_data_iterator
,
model
,
test_data_iterator
,
model
,
0
,
True
)
0
,
True
)
def
update_train_iters
(
args
):
# For iteration-based training, we don't need to do anything
if
args
.
train_iters
:
return
# Constant batch size with sample-based training.
if
args
.
rampup_batch_size
is
None
:
args
.
train_iters
=
args
.
train_samples
//
args
.
global_batch_size
else
:
# Sample based training with rampup batch size.
iterations
=
0
consumed_samples
=
0
# Rampup phase.
while
consumed_samples
<=
int
(
args
.
rampup_batch_size
[
2
]):
update_num_microbatches
(
consumed_samples
)
consumed_samples
+=
get_num_microbatches
()
*
\
args
.
micro_batch_size
*
\
args
.
data_parallel_size
iterations
+=
1
# Reset
update_num_microbatches
(
0
)
# Constant phase
# Note that we throw away any partial last batch.
iterations
+=
(
args
.
train_samples
-
consumed_samples
)
//
\
args
.
global_batch_size
args
.
train_iters
=
iterations
print_rank_0
(
'setting training iterations to {}'
.
format
(
args
.
train_iters
))
def
get_model
(
model_provider_func
):
def
get_model
(
model_provider_func
):
"""Build the model."""
"""Build the model."""
...
@@ -188,22 +219,33 @@ def get_learning_rate_scheduler(optimizer):
...
@@ -188,22 +219,33 @@ def get_learning_rate_scheduler(optimizer):
"""Build the learning rate scheduler."""
"""Build the learning rate scheduler."""
args
=
get_args
()
args
=
get_args
()
# Add linear learning rate scheduler.
# Iteration-based training.
if
args
.
lr_decay_iters
is
not
None
:
if
args
.
train_iters
:
num_iters
=
args
.
lr_decay_iters
if
args
.
lr_decay_iters
is
None
:
args
.
lr_decay_iters
=
args
.
train_iters
warmup_steps
=
args
.
lr_warmup_iters
*
args
.
global_batch_size
decay_steps
=
args
.
lr_decay_iters
*
args
.
global_batch_size
# Sample-based training.
elif
args
.
train_samples
:
# We need to set training iters for later use. Technically
# we need to adjust the training samples too (due to last
# batch being incomplete) but we leave it as is for now.
update_train_iters
(
args
)
if
args
.
lr_decay_samples
is
None
:
args
.
lr_decay_samples
=
args
.
train_samples
warmup_steps
=
args
.
lr_warmup_samples
decay_steps
=
args
.
lr_decay_samples
else
:
else
:
num_iters
=
args
.
train_iters
raise
Exception
(
num_iters
=
max
(
1
,
num_iters
)
'either train-iters or train-samples should be provided.'
)
init_step
=
0
warmup_iter
=
args
.
warmup
*
num_iters
lr_scheduler
=
AnnealingLR
(
lr_scheduler
=
AnnealingLR
(
optimizer
,
optimizer
,
max_lr
=
args
.
lr
,
max_lr
=
args
.
lr
,
min_lr
=
args
.
min_lr
,
min_lr
=
args
.
min_lr
,
warmup_steps
=
warmup_
i
te
r
,
warmup_steps
=
warmup_
s
te
ps
,
decay_steps
=
num_i
te
r
s
,
decay_steps
=
decay_s
te
p
s
,
decay_style
=
args
.
lr_decay_style
,
decay_style
=
args
.
lr_decay_style
,
num_steps
=
init_step
,
use_checkpoint_lr_scheduler
=
args
.
use_checkpoint_lr_scheduler
,
use_checkpoint_lr_scheduler
=
args
.
use_checkpoint_lr_scheduler
,
override_lr_scheduler
=
args
.
override_lr_scheduler
)
override_lr_scheduler
=
args
.
override_lr_scheduler
)
...
@@ -568,7 +610,10 @@ def train_step(forward_step_func, data_iterator,
...
@@ -568,7 +610,10 @@ def train_step(forward_step_func, data_iterator,
# Update learning rate.
# Update learning rate.
skipped_iter
=
0
skipped_iter
=
0
if
not
(
args
.
fp16
and
optimizer
.
overflow
):
if
not
(
args
.
fp16
and
optimizer
.
overflow
):
lr_scheduler
.
step
()
increment
=
get_num_microbatches
()
*
\
args
.
micro_batch_size
*
\
args
.
data_parallel_size
lr_scheduler
.
step
(
increment
=
increment
)
else
:
else
:
skipped_iter
=
1
skipped_iter
=
1
...
@@ -649,8 +694,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
...
@@ -649,8 +694,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
if
writer
and
torch
.
distributed
.
get_rank
()
==
0
:
if
writer
and
torch
.
distributed
.
get_rank
()
==
0
:
writer
.
add_scalar
(
'iteration_time'
,
writer
.
add_scalar
(
'iteration_time'
,
elapsed_time
/
args
.
log_interval
,
iteration
)
elapsed_time
/
args
.
log_interval
,
iteration
)
log_string
=
' iteration {:8d}/{:8d} |'
.
format
(
iteration
,
log_string
=
' iteration {:8d}/{:8d} |'
.
format
(
args
.
train_iters
)
iteration
,
args
.
train_iters
)
log_string
+=
' elapsed time per iteration (ms): {:.1f} |'
.
format
(
log_string
+=
' elapsed time per iteration (ms): {:.1f} |'
.
format
(
elapsed_time
*
1000.0
/
args
.
log_interval
)
elapsed_time
*
1000.0
/
args
.
log_interval
)
log_string
+=
' learning rate: {:.3E} |'
.
format
(
learning_rate
)
log_string
+=
' learning rate: {:.3E} |'
.
format
(
learning_rate
)
...
@@ -837,8 +882,12 @@ def build_train_valid_test_data_iterators(
...
@@ -837,8 +882,12 @@ def build_train_valid_test_data_iterators(
# Backward compatibility, assume fixed batch size.
# Backward compatibility, assume fixed batch size.
if
args
.
iteration
>
0
and
args
.
consumed_train_samples
==
0
:
if
args
.
iteration
>
0
and
args
.
consumed_train_samples
==
0
:
assert
args
.
train_samples
is
None
,
\
'only backward compatiblity support for iteration-based training'
args
.
consumed_train_samples
=
args
.
iteration
*
args
.
global_batch_size
args
.
consumed_train_samples
=
args
.
iteration
*
args
.
global_batch_size
if
args
.
iteration
>
0
and
args
.
consumed_valid_samples
==
0
:
if
args
.
iteration
>
0
and
args
.
consumed_valid_samples
==
0
:
assert
args
.
train_samples
is
None
,
\
'only backward compatiblity support for iteration-based training'
args
.
consumed_valid_samples
=
(
args
.
iteration
//
args
.
eval_interval
)
*
\
args
.
consumed_valid_samples
=
(
args
.
iteration
//
args
.
eval_interval
)
*
\
args
.
eval_iters
*
args
.
global_batch_size
args
.
eval_iters
*
args
.
global_batch_size
...
@@ -846,10 +895,14 @@ def build_train_valid_test_data_iterators(
...
@@ -846,10 +895,14 @@ def build_train_valid_test_data_iterators(
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
# Number of train/valid/test samples.
# Number of train/valid/test samples.
train_iters
=
args
.
train_iters
if
args
.
train_samples
:
eval_iters
=
(
train_iters
//
args
.
eval_interval
+
1
)
*
args
.
eval_iters
train_samples
=
args
.
train_samples
else
:
train_samples
=
args
.
train_iters
*
args
.
global_batch_size
eval_iters
=
(
args
.
train_iters
//
args
.
eval_interval
+
1
)
*
\
args
.
eval_iters
test_iters
=
args
.
eval_iters
test_iters
=
args
.
eval_iters
train_val_test_num_samples
=
[
train_
iters
*
args
.
global_batch_size
,
train_val_test_num_samples
=
[
train_
samples
,
eval_iters
*
args
.
global_batch_size
,
eval_iters
*
args
.
global_batch_size
,
test_iters
*
args
.
global_batch_size
]
test_iters
*
args
.
global_batch_size
]
print_rank_0
(
' > datasets target sizes (minimum size):'
)
print_rank_0
(
' > datasets target sizes (minimum size):'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment