Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
cdde4338
Commit
cdde4338
authored
Apr 09, 2021
by
Mostofa Patwary
Browse files
Merge branch 'main' into main_retriver_merge_dpr
parents
10ff0607
8cfef1bf
Changes
28
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
88 additions
and
112 deletions
+88
-112
tasks/finetune_utils.py
tasks/finetune_utils.py
+34
-26
tasks/glue/finetune.py
tasks/glue/finetune.py
+4
-15
tasks/main.py
tasks/main.py
+5
-0
tasks/race/data.py
tasks/race/data.py
+2
-0
tasks/race/finetune.py
tasks/race/finetune.py
+5
-12
tasks/vision/finetune_utils.py
tasks/vision/finetune_utils.py
+1
-1
tasks/zeroshot_gpt/evaluate.py
tasks/zeroshot_gpt/evaluate.py
+24
-39
tools/generate_samples_gpt.py
tools/generate_samples_gpt.py
+13
-19
No files found.
tasks/finetune_utils.py
View file @
cdde4338
...
@@ -15,6 +15,8 @@
...
@@ -15,6 +15,8 @@
"""Finetune utilities."""
"""Finetune utilities."""
from
functools
import
partial
import
torch
import
torch
from
megatron
import
get_args
from
megatron
import
get_args
...
@@ -46,7 +48,20 @@ def process_batch(batch):
...
@@ -46,7 +48,20 @@ def process_batch(batch):
return
tokens
,
types
,
labels
,
attention_mask
return
tokens
,
types
,
labels
,
attention_mask
def
_cross_entropy_forward_step
(
batch
,
model
,
input_tensor
):
def
cross_entropy_loss_func
(
labels
,
output_tensor
):
logits
=
output_tensor
# Cross-entropy loss.
loss_func
=
torch
.
nn
.
CrossEntropyLoss
()
loss
=
loss_func
(
logits
.
contiguous
().
float
(),
labels
)
# Reduce loss for logging.
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
])
return
loss
,
{
'lm loss'
:
averaged_loss
[
0
]}
def
_cross_entropy_forward_step
(
batch
,
model
):
"""Simple forward step with cross-entropy loss."""
"""Simple forward step with cross-entropy loss."""
timers
=
get_timers
()
timers
=
get_timers
()
...
@@ -60,25 +75,9 @@ def _cross_entropy_forward_step(batch, model, input_tensor):
...
@@ -60,25 +75,9 @@ def _cross_entropy_forward_step(batch, model, input_tensor):
timers
(
'batch-generator'
).
stop
()
timers
(
'batch-generator'
).
stop
()
# Forward model.
# Forward model.
if
mpu
.
is_pipeline_first_stage
():
output_tensor
=
model
(
tokens
,
attention_mask
,
tokentype_ids
=
types
)
assert
input_tensor
is
None
output_tensor
=
model
(
tokens
,
attention_mask
,
tokentype_ids
=
types
)
else
:
assert
input_tensor
is
not
None
output_tensor
=
model
(
input_tensor
,
attention_mask
)
if
mpu
.
is_pipeline_last_stage
():
logits
=
output_tensor
# Cross-entropy loss.
loss_func
=
torch
.
nn
.
CrossEntropyLoss
()
loss
=
loss_func
(
logits
.
contiguous
().
float
(),
labels
)
# Reduce loss for logging.
return
output_tensor
,
partial
(
cross_entropy_loss_func
,
labels
)
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
])
return
loss
,
{
'lm loss'
:
averaged_loss
[
0
]}
return
output_tensor
def
build_data_loader
(
dataset
,
micro_batch_size
,
num_workers
,
drop_last
):
def
build_data_loader
(
dataset
,
micro_batch_size
,
num_workers
,
drop_last
):
...
@@ -135,7 +134,14 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
...
@@ -135,7 +134,14 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
# This is necessary so pipeline transfers know what size they are
# This is necessary so pipeline transfers know what size they are
# and the LR schedule, which is based on samples seen, gets set
# and the LR schedule, which is based on samples seen, gets set
# correctly.
# correctly.
args
.
orig_micro_batch_size
=
args
.
micro_batch_size
args
.
orig_global_batch_size
=
args
.
global_batch_size
if
hasattr
(
train_dataset
,
'sample_multiplier'
):
if
hasattr
(
train_dataset
,
'sample_multiplier'
):
# If our dataset as a sample_multiplier attribute that means
# each "sample" from the dataset actually has multiple samples
# that will collapse into the batch dimension (for example in
# the RACE dataset that has several options), we need to
# account for that when setting the micro batch size.
args
.
micro_batch_size
*=
train_dataset
.
sample_multiplier
args
.
micro_batch_size
*=
train_dataset
.
sample_multiplier
args
.
global_batch_size
*=
train_dataset
.
sample_multiplier
args
.
global_batch_size
*=
train_dataset
.
sample_multiplier
...
@@ -149,7 +155,8 @@ def _train(model, optimizer, lr_scheduler, forward_step,
...
@@ -149,7 +155,8 @@ def _train(model, optimizer, lr_scheduler, forward_step,
timers
=
get_timers
()
timers
=
get_timers
()
# Turn on training mode which enables dropout.
# Turn on training mode which enables dropout.
model
.
train
()
for
m
in
model
:
m
.
train
()
# Tracking loss.
# Tracking loss.
losses_dict_sum
=
{}
losses_dict_sum
=
{}
...
@@ -163,7 +170,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
...
@@ -163,7 +170,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
report_memory_flag
=
True
report_memory_flag
=
True
# For each remaining epoch
# For each remaining epoch
timers
(
'interval
time'
).
start
()
timers
(
'interval
-
time'
).
start
()
for
epoch
in
range
(
start_epoch
,
args
.
epochs
):
for
epoch
in
range
(
start_epoch
,
args
.
epochs
):
print_rank_0
(
'working on epoch {} ...'
.
format
(
epoch
+
1
))
print_rank_0
(
'working on epoch {} ...'
.
format
(
epoch
+
1
))
...
@@ -180,10 +187,8 @@ def _train(model, optimizer, lr_scheduler, forward_step,
...
@@ -180,10 +187,8 @@ def _train(model, optimizer, lr_scheduler, forward_step,
start_iteration
=
0
start_iteration
=
0
# Train for one step.
# Train for one step.
losses_dict
,
skipped_iter
,
grad_norm
=
train_step
(
forward_step
,
out
=
train_step
(
forward_step
,
batch
,
model
,
optimizer
,
lr_scheduler
)
batch
,
model
,
losses_dict
,
skipped_iter
,
grad_norm
,
num_zeros_in_grad
=
out
optimizer
,
lr_scheduler
)
iteration
+=
1
iteration
+=
1
# Logging.
# Logging.
...
@@ -195,7 +200,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
...
@@ -195,7 +200,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
iteration
,
iteration
,
optimizer
.
get_loss_scale
().
item
(),
optimizer
.
get_loss_scale
().
item
(),
report_memory_flag
,
skipped_iter
,
report_memory_flag
,
skipped_iter
,
grad_norm
,
params_norm
)
grad_norm
,
params_norm
,
num_zeros_in_grad
)
# Autoresume
# Autoresume
if
args
.
adlr_autoresume
and
\
if
args
.
adlr_autoresume
and
\
...
@@ -231,6 +236,9 @@ def finetune(train_valid_datasets_provider, model_provider,
...
@@ -231,6 +236,9 @@ def finetune(train_valid_datasets_provider, model_provider,
args
=
get_args
()
args
=
get_args
()
timers
=
get_timers
()
timers
=
get_timers
()
assert
args
.
rampup_batch_size
is
None
,
\
'batch size scaling is not supported for finetuning'
# Train and validation data loaders.
# Train and validation data loaders.
timers
(
'train/valid/test dataset/dataloder'
).
start
()
timers
(
'train/valid/test dataset/dataloder'
).
start
()
if
args
.
epochs
>
0
:
if
args
.
epochs
>
0
:
...
...
tasks/glue/finetune.py
View file @
cdde4338
...
@@ -19,7 +19,7 @@ from megatron import get_args
...
@@ -19,7 +19,7 @@ from megatron import get_args
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.model.classification
import
Classification
,
ClassificationFirstStage
,
ClassificationIntermediateStage
,
ClassificationLastStage
from
megatron.model.classification
import
Classification
from
tasks.eval_utils
import
accuracy_func_provider
from
tasks.eval_utils
import
accuracy_func_provider
from
tasks.finetune_utils
import
finetune
from
tasks.finetune_utils
import
finetune
...
@@ -39,25 +39,14 @@ def glue_classification(num_classes, Dataset,
...
@@ -39,25 +39,14 @@ def glue_classification(num_classes, Dataset,
return
train_dataset
,
valid_dataset
return
train_dataset
,
valid_dataset
def
model_provider
():
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
"""Build the model."""
args
=
get_args
()
args
=
get_args
()
print_rank_0
(
'building classification model for {} ...'
.
format
(
print_rank_0
(
'building classification model for {} ...'
.
format
(
args
.
task
))
args
.
task
))
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
model
=
Classification
(
num_classes
=
num_classes
,
num_tokentypes
=
2
,
# Determine model based on position of stage in pipeline.
pre_process
=
pre_process
,
post_process
=
post_process
)
if
mpu
.
is_pipeline_first_stage
():
model
=
ClassificationFirstStage
(
num_classes
=
num_classes
,
num_tokentypes
=
2
)
elif
mpu
.
is_pipeline_last_stage
():
model
=
ClassificationLastStage
(
num_classes
=
num_classes
,
num_tokentypes
=
2
)
else
:
model
=
ClassificationIntermediateStage
(
num_classes
=
num_classes
,
num_tokentypes
=
2
)
else
:
model
=
Classification
(
num_classes
=
num_classes
,
num_tokentypes
=
2
)
return
model
return
model
...
...
tasks/main.py
View file @
cdde4338
...
@@ -99,6 +99,11 @@ if __name__ == '__main__':
...
@@ -99,6 +99,11 @@ if __name__ == '__main__':
initialize_megatron
(
extra_args_provider
=
get_tasks_args
)
initialize_megatron
(
extra_args_provider
=
get_tasks_args
)
args
=
get_args
()
args
=
get_args
()
if
args
.
num_layers_per_virtual_pipeline_stage
is
not
None
:
print
(
"Interleaved pipeline schedule is not yet supported for downstream tasks."
)
exit
()
if
args
.
task
==
'RACE'
:
if
args
.
task
==
'RACE'
:
from
race.finetune
import
main
from
race.finetune
import
main
elif
args
.
task
in
[
'MNLI'
,
'QQP'
]:
elif
args
.
task
in
[
'MNLI'
,
'QQP'
]:
...
...
tasks/race/data.py
View file @
cdde4338
...
@@ -39,6 +39,8 @@ class RaceDataset(Dataset):
...
@@ -39,6 +39,8 @@ class RaceDataset(Dataset):
print_rank_0
(
' >> total number of samples: {}'
.
format
(
print_rank_0
(
' >> total number of samples: {}'
.
format
(
len
(
self
.
samples
)))
len
(
self
.
samples
)))
# This indicates that each "sample" has multiple samples that
# will collapse into batch dimension
self
.
sample_multiplier
=
NUM_CHOICES
self
.
sample_multiplier
=
NUM_CHOICES
def
__len__
(
self
):
def
__len__
(
self
):
...
...
tasks/race/finetune.py
View file @
cdde4338
...
@@ -19,7 +19,7 @@ from megatron import get_args
...
@@ -19,7 +19,7 @@ from megatron import get_args
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.model.multiple_choice
import
MultipleChoice
,
MultipleChoiceFirstStage
,
MultipleChoiceIntermediateStage
,
MultipleChoiceLastStage
from
megatron.model.multiple_choice
import
MultipleChoice
from
tasks.eval_utils
import
accuracy_func_provider
from
tasks.eval_utils
import
accuracy_func_provider
from
tasks.finetune_utils
import
finetune
from
tasks.finetune_utils
import
finetune
from
tasks.race.data
import
RaceDataset
from
tasks.race.data
import
RaceDataset
...
@@ -38,20 +38,13 @@ def train_valid_datasets_provider():
...
@@ -38,20 +38,13 @@ def train_valid_datasets_provider():
return
train_dataset
,
valid_dataset
return
train_dataset
,
valid_dataset
def
model_provider
():
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
"""Build the model."""
print_rank_0
(
'building multichoice model for RACE ...'
)
print_rank_0
(
'building multichoice model for RACE ...'
)
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
model
=
MultipleChoice
(
num_tokentypes
=
2
,
# Determine model based on position of stage in pipeline.
pre_process
=
pre_process
,
if
mpu
.
is_pipeline_first_stage
():
post_process
=
post_process
)
model
=
MultipleChoiceFirstStage
(
num_tokentypes
=
2
)
elif
mpu
.
is_pipeline_last_stage
():
model
=
MultipleChoiceLastStage
(
num_tokentypes
=
2
)
else
:
model
=
MultipleChoiceIntermediateStage
(
num_tokentypes
=
2
)
else
:
model
=
MultipleChoice
(
num_tokentypes
=
2
)
return
model
return
model
...
...
tasks/vision/finetune_utils.py
View file @
cdde4338
...
@@ -149,7 +149,7 @@ def _train(
...
@@ -149,7 +149,7 @@ def _train(
report_memory_flag
=
True
report_memory_flag
=
True
# For each remaining epoch
# For each remaining epoch
timers
(
"interval
time"
).
start
()
timers
(
"interval
-
time"
).
start
()
for
epoch
in
range
(
start_epoch
,
args
.
epochs
):
for
epoch
in
range
(
start_epoch
,
args
.
epochs
):
print_rank_0
(
"working on epoch {} ..."
.
format
(
epoch
+
1
))
print_rank_0
(
"working on epoch {} ..."
.
format
(
epoch
+
1
))
...
...
tasks/zeroshot_gpt/evaluate.py
View file @
cdde4338
...
@@ -24,19 +24,24 @@ from megatron import print_rank_0, is_last_rank
...
@@ -24,19 +24,24 @@ from megatron import print_rank_0, is_last_rank
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
load_checkpoint
from
megatron.model
import
GPTModel
,
GPTModelFirstStage
,
GPTModelLastStage
,
GPTModelIntermediateStage
from
megatron.model
import
GPTModel
from
megatron.training
import
get_model
,
communicate
from
megatron.training
import
get_model
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
get_ltor_masks_and_position_ids
,
unwrap_model
from
megatron.p2p_communication
import
recv_forward
,
send_forward
from
tasks.finetune_utils
import
build_data_loader
from
tasks.finetune_utils
import
build_data_loader
from
.datasets
import
build_dataset
from
.datasets
import
build_dataset
# These are needed to unwrap the model, would be nice to put these in megatron.utils if possible?
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
Float16Module
def
get_model_provider
(
eval_metric
):
def
get_model_provider
(
eval_metric
):
"""Based on evaluation metric set the parallel-output flag and
"""Based on evaluation metric set the parallel-output flag and
return the model provider."""
return the model provider."""
def
model_provider
():
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
"""Build the model."""
if
eval_metric
==
'loss'
:
if
eval_metric
==
'loss'
:
...
@@ -48,17 +53,8 @@ def get_model_provider(eval_metric):
...
@@ -48,17 +53,8 @@ def get_model_provider(eval_metric):
'is not supported.'
.
format
(
eval_metric
))
'is not supported.'
.
format
(
eval_metric
))
print_rank_0
(
'building GPT model ...'
)
print_rank_0
(
'building GPT model ...'
)
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
model
=
GPTModel
(
num_tokentypes
=
0
,
parallel_output
=
parallel_output
,
# Determine model based on position of stage in pipeline.
pre_process
=
pre_process
,
post_process
=
post_process
)
if
mpu
.
is_pipeline_first_stage
():
model
=
GPTModelFirstStage
(
num_tokentypes
=
0
)
elif
mpu
.
is_pipeline_last_stage
():
model
=
GPTModelLastStage
(
parallel_output
=
parallel_output
,
num_tokentypes
=
0
)
else
:
model
=
GPTModelIntermediateStage
(
num_tokentypes
=
0
)
else
:
model
=
GPTModel
(
num_tokentypes
=
0
,
parallel_output
=
parallel_output
)
return
model
return
model
...
@@ -97,33 +93,15 @@ def forward_step(batch, model, eval_metric):
...
@@ -97,33 +93,15 @@ def forward_step(batch, model, eval_metric):
args
=
get_args
()
args
=
get_args
()
args
.
micro_batch_size
=
len
(
labels
)
args
.
micro_batch_size
=
len
(
labels
)
# Forward model.
input_tensor
=
recv_forward
()
if
not
mpu
.
is_pipeline_first_stage
():
input_tensor
,
_
=
communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_forward
=
True
,
recv_backward
=
False
)
else
:
input_tensor
=
None
# Forward pass through the model.
# Forward pass through the model.
if
mpu
.
is_pipeline_first_stage
():
unwrapped_model
=
unwrap_model
(
assert
input_tensor
is
None
model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
if
mpu
.
is_pipeline_last_stage
():
unwrapped_model
.
set_input_tensor
(
input_tensor
)
output
=
model
(
tokens
,
position_ids
,
attention_mask
)
output
=
model
(
tokens
,
position_ids
,
attention_mask
)
else
:
output
=
model
(
tokens
,
position_ids
,
attention_mask
)
else
:
assert
input_tensor
is
not
None
output
=
model
(
input_tensor
,
attention_mask
)
if
not
mpu
.
is_pipeline_last_stage
():
send_forward
(
output
)
communicate
(
tensor_send_next
=
output
,
tensor_send_prev
=
None
,
recv_forward
=
False
,
recv_backward
=
False
)
return
None
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
# For loss, return the unreduced loss.
# For loss, return the unreduced loss.
...
@@ -214,6 +192,10 @@ def main():
...
@@ -214,6 +192,10 @@ def main():
"""Main program."""
"""Main program."""
args
=
get_args
()
args
=
get_args
()
if
args
.
num_layers_per_virtual_pipeline_stage
is
not
None
:
print
(
"Interleaved pipeline schedule is not yet supported for text generation."
)
exit
()
if
args
.
task
==
'LAMBADA'
:
if
args
.
task
==
'LAMBADA'
:
eval_metric
=
'accuracy'
eval_metric
=
'accuracy'
elif
args
.
task
==
'WIKITEXT103'
:
elif
args
.
task
==
'WIKITEXT103'
:
...
@@ -227,6 +209,9 @@ def main():
...
@@ -227,6 +209,9 @@ def main():
if
args
.
load
is
not
None
:
if
args
.
load
is
not
None
:
_
=
load_checkpoint
(
model
,
None
,
None
)
_
=
load_checkpoint
(
model
,
None
,
None
)
assert
len
(
model
)
==
1
,
"Above condition should have caught this"
model
=
model
[
0
]
# Data stuff.
# Data stuff.
dataset
=
build_dataset
(
args
.
task
)
dataset
=
build_dataset
(
args
.
task
)
dataloader
=
build_data_loader
(
dataset
,
args
.
micro_batch_size
,
dataloader
=
build_data_loader
(
dataset
,
args
.
micro_batch_size
,
...
...
tools/generate_samples_gpt.py
View file @
cdde4338
...
@@ -26,33 +26,19 @@ from megatron import get_tokenizer
...
@@ -26,33 +26,19 @@ from megatron import get_tokenizer
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
load_checkpoint
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
initialize_megatron
from
megatron.model
import
(
GPTModel
,
from
megatron.model
import
GPTModel
GPTModelFirstStage
,
GPTModelLastStage
,
GPTModelIntermediateStage
)
from
megatron.training
import
get_model
from
megatron.training
import
get_model
from
megatron.text_generation_utils
import
generate_and_write_samples_unconditional
from
megatron.text_generation_utils
import
generate_and_write_samples_unconditional
from
megatron.text_generation_utils
import
generate_samples_input_from_file
from
megatron.text_generation_utils
import
generate_samples_input_from_file
from
megatron.text_generation_utils
import
generate_samples_interactive
from
megatron.text_generation_utils
import
generate_samples_interactive
def
model_provider
():
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
"""Build the model."""
print_rank_0
(
'building GPT model ...'
)
print_rank_0
(
'building GPT model ...'
)
args
=
get_args
()
model
=
GPTModel
(
num_tokentypes
=
0
,
parallel_output
=
False
,
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
pre_process
=
pre_process
,
post_process
=
post_process
)
# Determine model based on position of stage in pipeline.
if
mpu
.
is_pipeline_first_stage
():
model
=
GPTModelFirstStage
(
num_tokentypes
=
0
)
elif
mpu
.
is_pipeline_last_stage
():
model
=
GPTModelLastStage
(
num_tokentypes
=
0
,
parallel_output
=
False
)
else
:
model
=
GPTModelIntermediateStage
(
num_tokentypes
=
0
)
else
:
model
=
GPTModel
(
num_tokentypes
=
0
,
parallel_output
=
False
)
return
model
return
model
...
@@ -96,12 +82,20 @@ def main():
...
@@ -96,12 +82,20 @@ def main():
'no_load_rng'
:
True
,
'no_load_rng'
:
True
,
'no_load_optim'
:
True
})
'no_load_optim'
:
True
})
args
=
get_args
()
if
args
.
num_layers_per_virtual_pipeline_stage
is
not
None
:
print
(
"Interleaved pipeline schedule is not yet supported for text generation."
)
exit
()
# Set up model and load checkpoint.
# Set up model and load checkpoint.
model
=
get_model
(
model_provider
)
model
=
get_model
(
model_provider
)
args
=
get_args
()
if
args
.
load
is
not
None
:
if
args
.
load
is
not
None
:
_
=
load_checkpoint
(
model
,
None
,
None
)
_
=
load_checkpoint
(
model
,
None
,
None
)
assert
len
(
model
)
==
1
,
"Above condition should have caught this"
model
=
model
[
0
]
# Generate samples.
# Generate samples.
if
args
.
num_samples
==
0
:
if
args
.
num_samples
==
0
:
args
.
micro_batch_size
=
1
args
.
micro_batch_size
=
1
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment