Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
27e14f82
Commit
27e14f82
authored
Mar 28, 2020
by
Mohammad
Browse files
refactored training
parent
3f58649b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
139 additions
and
115 deletions
+139
-115
megatron/arguments.py
megatron/arguments.py
+2
-0
megatron/training.py
megatron/training.py
+101
-90
pretrain_bert.py
pretrain_bert.py
+18
-14
pretrain_gpt2.py
pretrain_gpt2.py
+18
-11
No files found.
megatron/arguments.py
View file @
27e14f82
...
...
@@ -234,6 +234,8 @@ def _add_mixed_precision_args(parser):
'attention-softmax-in-fp32 to true'
)
group
.
add_argument
(
'--attention-softmax-in-fp32'
,
action
=
'store_true'
,
help
=
'Run attention masking and softmax in fp32.'
)
group
.
add_argument
(
'--fp32-allreduce'
,
action
=
'store_true'
,
help
=
'All-reduce in fp32'
)
group
.
add_argument
(
'--hysteresis'
,
type
=
int
,
default
=
2
,
help
=
'hysteresis for dynamic loss scaling'
)
group
.
add_argument
(
'--loss-scale'
,
type
=
float
,
default
=
None
,
...
...
megatron/training.py
View file @
27e14f82
...
...
@@ -13,62 +13,57 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain utilities"""
"""Pretrain utilities
.
"""
from
datetime
import
datetime
import
math
import
sys
import
torch
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
apex.optimizers
import
FusedAdam
as
Adam
from
megatron.global_vars
import
get_args
from
megatron.global_vars
import
get_timers
from
megatron.global_vars
import
get_tensorboard_writer
from
megatron.global_vars
import
get_adlr_autoresume
from
megatron.initialize
import
initialize_megatron
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron
import
get_tensorboard_writer
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.fp16
import
FP16_Module
from
megatron.fp16
import
FP16_Optimizer
from
megatron.initialize
import
initialize_megatron
from
megatron.learning_rates
import
AnnealingLR
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
get_params_for_weight_decay_optimization
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.checkpointing
import
load_checkpoint
from
megatron
import
print_rank_0
from
megatron.utils
import
report_memory
from
megatron.checkpointing
import
save_checkpoint
def
run
(
top_level_message
,
train_val_test_data_provider
,
model_provider
,
forward_step_func
,
extra_args_provider
=
None
,
args_defaults
=
{}):
def
pretrain
(
train_val_test_data_provider
,
model_provider
,
forward_step_func
,
extra_args_provider
=
None
,
args_defaults
=
{}):
"""Main training program.
This function will run the followings in the order provided:
1)
get input arguments
.
2)
initialize distributed and seeds
.
1)
initialize Megatron
.
2)
setup model, optimizer and lr schedule using the model_provider
.
3) call train_val_test_data_provider to get train/val/test datasets.
4) setup model, optimizer and lr schedule using the model_provider.
5) train the modle using the forward_step_func.
4) train the modle using the forward_step_func.
Arguments:
top_level_message: a meesage to print at the top of the run.
train_val_test_data_provider: a function that takes `args` as input
and returns `train, val, test` dataloaders. Note that args are
passed and can be modified in case we need to use some parameters
later. For example, we can set vocab size using
args.vocab_size = ...
and later use this value in `model_provider`.
model_provider: a function that takes `args` and returns a vanilla
version of the model. By vanilla we mean a simple model on cpu
with no fp16 or ddp.
forward_step_func: a function that takes a `data iterator`, `model`,
`args`, and `timers` and returns a `loss` scalar with a dictionary
with key:values being the info we would like to monitor during
training, for example `lm-loss: value`. We also require that this
function add `batch generator` to the timers class.
train_val_test_data_provider: a function that builds datasets
and returns `train, val, test` dataloaders.
model_provider: a function that returns a vanilla version of the
model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
forward_step_func: a function that takes a `data iterator` and `model`,
and returns a `loss` scalar with a dictionary with key:values being
the info we would like to monitor during training, for example
`lm-loss: value`. We also require that this function add
`batch generator` to the timers class.
extra_args_provider: a function that takes a parser and adds arguments
to it. It is used for programs to add their own arguments.
args_defaults: a dictionary from argument-name to argument-value. It
to set already parse arguments.
"""
# Initalize and get arguments, timers, and Tensorboard writer.
...
...
@@ -76,36 +71,44 @@ def run(top_level_message, train_val_test_data_provider,
args_defaults
=
args_defaults
)
args
=
get_args
()
timers
=
get_timers
()
writer
=
get_tensorboard_writer
()
# Data stuff.
train_data
,
val_data
,
test_data
=
train_val_test_data_provider
(
args
)
# Model, optimizer, and learning rate.
model
,
optimizer
,
lr_scheduler
=
setup_model_and_optimizer
(
model_provider
,
args
)
timers
(
'model and optimizer'
).
start
()
model
,
optimizer
,
lr_scheduler
=
setup_model_and_optimizer
(
model_provider
)
timers
(
'model and optimizer'
).
stop
()
# Data stuff.
timers
(
'train/valid/test dataset'
).
start
()
train_data
,
val_data
,
test_data
=
train_val_test_data_provider
()
timers
(
'train/valid/test dataset'
).
stop
()
# Train, validation, and test data.
timers
(
'train/valid/test dataloader'
).
start
()
train_data_iterator
,
val_data_iterator
,
\
test_data_iterator
=
get_train_val_test_data_iterators
(
train_data
,
val_data
,
test_data
,
args
)
test_data
)
timers
(
'train/valid/test dataloader'
).
stop
()
# Print setup timing.
print_rank_0
(
'done with setups ...'
)
timers
.
log
([
'model and optimizer'
,
'train/valid/test dataset'
,
'train/valid/test dataloader'
])
print_rank_0
(
'training ...'
)
iteration
=
0
if
args
.
train_iters
>
0
:
if
args
.
do_train
:
iteration
,
_
=
train
(
forward_step_func
,
model
,
optimizer
,
lr_scheduler
,
train_data_iterator
,
val_data_iterator
,
timers
,
args
,
writer
)
iteration
,
_
=
train
(
forward_step_func
,
model
,
optimizer
,
lr_scheduler
,
train_data_iterator
,
val_data_iterator
)
if
args
.
do_valid
:
prefix
=
'the end of training for val data'
evaluate_and_print_results
(
prefix
,
forward_step_func
,
val_data_iterator
,
model
,
args
,
writer
,
iteration
,
timers
,
False
)
iteration
,
False
)
if
args
.
save
and
iteration
!=
0
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
...
...
@@ -115,14 +118,15 @@ def run(top_level_message, train_val_test_data_provider,
prefix
=
'the end of training for test data'
evaluate_and_print_results
(
prefix
,
forward_step_func
,
test_data_iterator
,
model
,
args
,
None
,
0
,
timers
,
True
)
0
,
True
)
def
get_model
(
model_provider_func
,
args
):
def
get_model
(
model_provider_func
):
"""Build the model."""
args
=
get_args
()
# Build model on cpu.
model
=
model_provider_func
(
args
)
model
=
model_provider_func
()
# Print number of parameters.
if
mpu
.
get_data_parallel_rank
()
==
0
:
...
...
@@ -140,26 +144,24 @@ def get_model(model_provider_func, args):
# Wrap model for distributed training."""
if
args
.
DDP_impl
==
'torch'
:
i
=
torch
.
cuda
.
current_device
()
args
.
DDP_type
=
torchDDP
model
=
args
.
DDP_type
(
model
,
device_ids
=
[
i
],
output_device
=
i
,
process_group
=
mpu
.
get_data_parallel_group
())
model
=
torchDDP
(
model
,
device_ids
=
[
i
],
output_device
=
i
,
process_group
=
mpu
.
get_data_parallel_group
())
return
model
if
args
.
DDP_impl
==
'local'
:
args
.
DDP_type
=
LocalDDP
model
=
args
.
DDP_type
(
model
)
model
=
LocalDDP
(
model
)
return
model
print_rank_0
(
'Unknown DDP implementation specified: {}. '
'Exiting.'
.
format
(
args
.
DDP_impl
))
exit
()
return
model
sys
.
exit
()
def
get_optimizer
(
model
,
args
):
def
get_optimizer
(
model
):
"""Set up the optimizer."""
args
=
get_args
()
# Build parameter groups (weight decay and non-decay).
while
isinstance
(
model
,
(
args
.
DDP_type
,
FP16_Module
)):
while
isinstance
(
model
,
(
torchDDP
,
LocalDDP
,
FP16_Module
)):
model
=
model
.
module
param_groups
=
get_params_for_weight_decay_optimization
(
model
)
...
...
@@ -170,8 +172,7 @@ def get_optimizer(model, args):
param
.
model_parallel
=
False
# Use Adam.
optimizer
=
Adam
(
param_groups
,
lr
=
args
.
lr
,
weight_decay
=
args
.
weight_decay
)
optimizer
=
Adam
(
param_groups
,
lr
=
args
.
lr
,
weight_decay
=
args
.
weight_decay
)
# Wrap into fp16 optimizer.
if
args
.
fp16
:
...
...
@@ -186,8 +187,9 @@ def get_optimizer(model, args):
return
optimizer
def
get_learning_rate_scheduler
(
optimizer
,
args
):
def
get_learning_rate_scheduler
(
optimizer
):
"""Build the learning rate scheduler."""
args
=
get_args
()
# Add linear learning rate scheduler.
if
args
.
lr_decay_iters
is
not
None
:
...
...
@@ -211,12 +213,13 @@ def get_learning_rate_scheduler(optimizer, args):
return
lr_scheduler
def
setup_model_and_optimizer
(
model_provider_func
,
args
):
def
setup_model_and_optimizer
(
model_provider_func
):
"""Setup model and optimizer."""
args
=
get_args
()
model
=
get_model
(
model_provider_func
,
args
)
optimizer
=
get_optimizer
(
model
,
args
)
lr_scheduler
=
get_learning_rate_scheduler
(
optimizer
,
args
)
model
=
get_model
(
model_provider_func
)
optimizer
=
get_optimizer
(
model
)
lr_scheduler
=
get_learning_rate_scheduler
(
optimizer
)
if
args
.
load
is
not
None
:
args
.
iteration
=
load_checkpoint
(
model
,
optimizer
,
lr_scheduler
)
...
...
@@ -226,8 +229,10 @@ def setup_model_and_optimizer(model_provider_func, args):
return
model
,
optimizer
,
lr_scheduler
def
backward_step
(
optimizer
,
model
,
loss
,
args
,
timers
):
def
backward_step
(
optimizer
,
model
,
loss
):
"""Backward step."""
args
=
get_args
()
timers
=
get_timers
()
# Backward pass.
optimizer
.
zero_grad
()
...
...
@@ -255,18 +260,20 @@ def backward_step(optimizer, model, loss, args, timers):
optimizer
.
clip_master_grads
(
args
.
clip_grad
)
def
train_step
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
lr_scheduler
,
args
,
timers
):
def
train_step
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
lr_scheduler
):
"""Single training step."""
args
=
get_args
()
timers
=
get_timers
()
# Forward model for one step.
timers
(
'forward'
).
start
()
loss
,
loss_reduced
=
forward_step_func
(
data_iterator
,
model
,
args
,
timers
)
loss
,
loss_reduced
=
forward_step_func
(
data_iterator
,
model
)
timers
(
'forward'
).
stop
()
# Calculate gradients, reduce across processes, and clip.
timers
(
'backward'
).
start
()
backward_step
(
optimizer
,
model
,
loss
,
args
,
timers
)
backward_step
(
optimizer
,
model
,
loss
)
timers
(
'backward'
).
stop
()
# Update parameters.
...
...
@@ -285,7 +292,11 @@ def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler,
def
training_log
(
loss_dict
,
total_loss_dict
,
learning_rate
,
iteration
,
loss_scale
,
report_memory_flag
,
writer
,
args
,
timers
):
loss_scale
,
report_memory_flag
):
"""Log training information such as losses, timing, ...."""
args
=
get_args
()
timers
=
get_timers
()
writer
=
get_tensorboard_writer
()
# Update losses.
for
key
in
loss_dict
:
...
...
@@ -341,8 +352,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
def
train
(
forward_step_func
,
model
,
optimizer
,
lr_scheduler
,
train_data_iterator
,
val_data_iterator
,
timers
,
args
,
writer
):
train_data_iterator
,
val_data_iterator
):
"""Train the model function."""
args
=
get_args
()
timers
=
get_timers
()
# Turn on training mode which enables dropout.
model
.
train
()
...
...
@@ -361,8 +374,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
train_data_iterator
,
model
,
optimizer
,
lr_scheduler
,
args
,
timers
)
lr_scheduler
)
skipped_iters
+=
skipped_iter
iteration
+=
1
...
...
@@ -370,8 +382,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
report_memory_flag
=
training_log
(
loss_dict
,
total_loss_dict
,
optimizer
.
param_groups
[
0
][
'lr'
],
iteration
,
optimizer
.
loss_scale
,
report_memory_flag
,
writer
,
args
,
timers
)
report_memory_flag
)
# Autoresume
if
(
iteration
%
args
.
adlr_autoresume_interval
==
0
)
and
\
...
...
@@ -389,23 +400,23 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
args
.
do_valid
:
prefix
=
'iteration {}'
.
format
(
iteration
)
evaluate_and_print_results
(
prefix
,
forward_step_func
,
val_data_iterator
,
model
,
args
,
writer
,
iteration
,
timers
,
False
)
val_data_iterator
,
model
,
iteration
,
False
)
if
args
.
exit_interval
and
iteration
%
args
.
exit_interval
==
0
:
torch
.
distributed
.
barrier
()
time_str
=
datetime
.
now
().
strftime
(
'%Y-%m-%d %H:%M:%S'
)
rank
=
torch
.
distributed
.
get_rank
()
print
(
'rank: {} | time: {} | exiting the program at
iteration {}'
.
format
(
rank
,
time_str
,
iteration
)
,
flush
=
True
)
exit
()
print
_rank_0
(
'rank: {} | time: {} | exiting the program at
'
'iteration {}'
.
format
(
rank
,
time_str
,
iteration
))
sys
.
exit
()
return
iteration
,
skipped_iters
def
evaluate
(
forward_step_func
,
data_iterator
,
model
,
args
,
timers
,
verbose
=
False
):
def
evaluate
(
forward_step_func
,
data_iterator
,
model
,
verbose
=
False
):
"""Evaluation."""
args
=
get_args
()
# Turn on evaluation mode which disables dropout.
model
.
eval
()
...
...
@@ -420,8 +431,7 @@ def evaluate(forward_step_func, data_iterator, model,
print_rank_0
(
'Evaluating iter {}/{}'
.
format
(
iteration
,
args
.
eval_iters
))
# Forward evaluation.
_
,
loss_dict
=
forward_step_func
(
data_iterator
,
model
,
args
,
timers
)
_
,
loss_dict
=
forward_step_func
(
data_iterator
,
model
)
# Reduce across processes.
for
key
in
loss_dict
:
total_loss_dict
[
key
]
=
total_loss_dict
.
get
(
key
,
0.
)
+
\
...
...
@@ -437,11 +447,11 @@ def evaluate(forward_step_func, data_iterator, model,
def
evaluate_and_print_results
(
prefix
,
forward_step_func
,
data_iterator
,
model
,
args
,
writer
,
iteration
,
timers
,
verbose
=
False
):
iteration
,
verbose
=
False
):
"""Helper function to evaluate and dump results on screen."""
total_loss_dict
=
evaluate
(
forward_step_func
,
data_iterator
,
model
,
args
,
timers
,
verbose
)
writer
=
get_tensorboard_writer
()
total_loss_dict
=
evaluate
(
forward_step_func
,
data_iterator
,
model
,
verbose
)
string
=
' validation loss at {} | '
.
format
(
prefix
)
for
key
in
total_loss_dict
:
string
+=
'{} value: {:.6E} | '
.
format
(
key
,
total_loss_dict
[
key
].
item
())
...
...
@@ -459,8 +469,9 @@ def evaluate_and_print_results(prefix, forward_step_func,
print_rank_0
(
'-'
*
length
)
def
get_train_val_test_data_iterators
(
train_data
,
val_data
,
test_data
,
args
):
def
get_train_val_test_data_iterators
(
train_data
,
val_data
,
test_data
):
"""Build train/validation/test iterators"""
args
=
get_args
()
# Shift the start iterations.
if
train_data
is
not
None
:
...
...
pretrain_bert.py
View file @
27e14f82
...
...
@@ -18,24 +18,28 @@
import
torch
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron.model
import
BertModel
from
megatron
import
print_rank_0
from
megatron.utils
import
reduce_losses
from
megatron.utils
import
vocab_size_with_padding
from
megatron.training
import
ru
n
from
megatron.training
import
pretrai
n
from
megatron.data.bert_dataset
import
build_train_valid_test_datasets
from
megatron.data_utils.samplers
import
DistributedBatchSampler
def
model_provider
(
args
):
def
model_provider
():
"""Build the model."""
args
=
get_args
()
print_rank_0
(
'building BERT model ...'
)
model
=
BertModel
(
num_layers
=
args
.
num_layers
,
vocab_size
=
args
.
vocab_size
,
vocab_size
=
args
.
padded_
vocab_size
,
hidden_size
=
args
.
hidden_size
,
num_attention_heads
=
args
.
num_attention_heads
,
embedding_dropout_prob
=
args
.
hidden_dropout
,
...
...
@@ -46,7 +50,7 @@ def model_provider(args):
checkpoint_num_layers
=
args
.
checkpoint_num_layers
,
add_binary_head
=
True
,
layernorm_epsilon
=
args
.
layernorm_epsilon
,
num_tokentypes
=
args
.
tokentype_size
,
num_tokentypes
=
2
,
parallel_output
=
True
,
apply_query_key_layer_scaling
=
args
.
apply_query_key_layer_scaling
,
attention_softmax_in_fp32
=
args
.
attention_softmax_in_fp32
)
...
...
@@ -54,19 +58,17 @@ def model_provider(args):
return
model
def
get_batch
(
data_iterator
,
timers
):
def
get_batch
(
data_iterator
):
# Items and their type.
keys
=
[
'text'
,
'types'
,
'labels'
,
'is_random'
,
'loss_mask'
,
'padding_mask'
]
datatype
=
torch
.
int64
# Broadcast data.
timers
(
'data loader'
).
start
()
if
data_iterator
is
not
None
:
data
=
next
(
data_iterator
)
else
:
data
=
None
timers
(
'data loader'
).
stop
()
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
# Unpack.
...
...
@@ -80,13 +82,14 @@ def get_batch(data_iterator, timers):
return
tokens
,
types
,
sentence_order
,
loss_mask
,
lm_labels
,
padding_mask
def
forward_step
(
data_iterator
,
model
,
args
,
timers
):
def
forward_step
(
data_iterator
,
model
):
"""Forward step."""
timers
=
get_timers
()
# Get the batch.
timers
(
'batch generator'
).
start
()
tokens
,
types
,
sentence_order
,
loss_mask
,
lm_labels
,
padding_mask
\
=
get_batch
(
data_iterator
,
timers
)
=
get_batch
(
data_iterator
)
timers
(
'batch generator'
).
stop
()
# Forward model.
...
...
@@ -108,9 +111,10 @@ def forward_step(data_iterator, model, args, timers):
return
loss
,
{
'lm loss'
:
reduced_losses
[
0
],
'sop loss'
:
reduced_losses
[
1
]}
def
get_train_val_test_data
(
args
):
def
get_train_val_test_data
():
"""Load the data on rank zero and boradcast number of tokens to all GPUS."""
args
=
get_args
()
(
train_data
,
valid_data
,
test_data
)
=
(
None
,
None
,
None
)
# Data loader only on rank 0 of each model parallel group.
...
...
@@ -202,6 +206,6 @@ if __name__ == "__main__":
'tokenizer_type': 'BertWordPieceLowerCase'})
exit()
'''
run
(
'P
retrain
BERT model'
,
get_train_val_test_data
,
model_provider
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
p
retrain
(
get_train_val_test_data
,
model_provider
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
pretrain_gpt2.py
View file @
27e14f82
...
...
@@ -17,6 +17,10 @@
import
torch
from
megatron
import
get_args
from
megatron
import
get_timers
from
configure_data
import
configure_data
from
gpt2_data_loader
import
make_gpt2_dataloaders
from
megatron
import
mpu
...
...
@@ -25,15 +29,16 @@ from megatron.utils import get_ltor_masks_and_position_ids
from
megatron
import
print_rank_0
from
megatron.utils
import
reduce_losses
from
megatron.utils
import
vocab_size_with_padding
from
megatron.training
import
ru
n
from
megatron.training
import
pretrai
n
def
model_provider
(
args
):
def
model_provider
():
"""Build the model."""
args
=
get_args
()
print_rank_0
(
'building GPT2 model ...'
)
model
=
GPT2Model
(
num_layers
=
args
.
num_layers
,
vocab_size
=
args
.
vocab_size
,
vocab_size
=
args
.
padded_
vocab_size
,
hidden_size
=
args
.
hidden_size
,
num_attention_heads
=
args
.
num_attention_heads
,
embedding_dropout_prob
=
args
.
hidden_dropout
,
...
...
@@ -50,20 +55,19 @@ def model_provider(args):
return
model
def
get_batch
(
data_iterator
,
args
,
timers
):
def
get_batch
(
data_iterator
):
"""Generate a batch"""
args
=
get_args
()
# Items and their type.
keys
=
[
'text'
]
datatype
=
torch
.
int64
# Broadcast data.
timers
(
'data loader'
).
start
()
if
data_iterator
is
not
None
:
data
=
next
(
data_iterator
)
else
:
data
=
None
timers
(
'data loader'
).
stop
()
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
# Unpack.
...
...
@@ -85,8 +89,10 @@ def get_batch(data_iterator, args, timers):
return
tokens
,
labels
,
loss_mask
,
attention_mask
,
position_ids
def
forward_step
(
data_iterator
,
model
,
args
,
timers
):
def
forward_step
(
data_iterator
,
model
):
"""Forward step."""
args
=
get_args
()
timers
=
get_timers
()
# Get the batch.
timers
(
'batch generator'
).
start
()
...
...
@@ -107,9 +113,10 @@ def forward_step(data_iterator, model, args, timers):
return
loss
,
{
'lm loss'
:
reduced_loss
[
0
]}
def
get_train_val_test_data
(
args
):
def
get_train_val_test_data
():
"""Load the data on rank zero and boradcast number of tokens to all GPUS."""
args
=
get_args
()
(
train_data
,
val_data
,
test_data
)
=
(
None
,
None
,
None
)
# Data loader only on rank 0 of each model parallel group.
...
...
@@ -162,5 +169,5 @@ def get_train_val_test_data(args):
if
__name__
==
"__main__"
:
run
(
'P
retrain
GPT-2 model'
,
get_train_val_test_data
,
model_provider
,
forward_step
)
p
retrain
(
get_train_val_test_data
,
model_provider
,
forward_step
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment