Commit b938ec51 authored by Jared Casper's avatar Jared Casper
Browse files

Tasks seems to be working.

parent 3b91262e
...@@ -58,7 +58,7 @@ class MultipleChoice(MegatronModule): ...@@ -58,7 +58,7 @@ class MultipleChoice(MegatronModule):
init_method) init_method)
self._multichoice_head_key = 'multichoice_head' self._multichoice_head_key = 'multichoice_head'
def set_input_tensor(self, input_tensor) def set_input_tensor(self, input_tensor):
self.language_model.set_input_tensor(input_tensor) self.language_model.set_input_tensor(input_tensor)
def forward(self, model_input, attention_mask, tokentype_ids=None): def forward(self, model_input, attention_mask, tokentype_ids=None):
...@@ -127,4 +127,3 @@ class MultipleChoice(MegatronModule): ...@@ -127,4 +127,3 @@ class MultipleChoice(MegatronModule):
print_rank_last('***WARNING*** could not find {} in the checkpoint, ' print_rank_last('***WARNING*** could not find {} in the checkpoint, '
'initializing to random'.format( 'initializing to random'.format(
self._multichoice_head_key)) self._multichoice_head_key))
...@@ -24,6 +24,18 @@ from megatron import mpu ...@@ -24,6 +24,18 @@ from megatron import mpu
from megatron import p2p_communication from megatron import p2p_communication
def get_forward_backward_func():
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving
else:
forward_backward_func = forward_backward_pipelining_without_interleaving
else:
forward_backward_func = forward_backward_no_pipelining
return forward_backward_func
def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced): def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced):
"""Forward step for passed-in model. """Forward step for passed-in model.
...@@ -34,6 +46,7 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r ...@@ -34,6 +46,7 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
timers = get_timers() timers = get_timers()
timers('forward-compute').start() timers('forward-compute').start()
# TODO
model.module.module.set_input_tensor(input_tensor) model.module.module.set_input_tensor(input_tensor)
output_tensor, loss_func = forward_step_func(data_iterator, model) output_tensor, loss_func = forward_step_func(data_iterator, model)
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
......
...@@ -26,9 +26,8 @@ import torch.nn.functional as F ...@@ -26,9 +26,8 @@ import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
from megatron.training import communicate
from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids
from megatron.p2p_communication import recv_forward, send_forward
def get_batch(context_tokens): def get_batch(context_tokens):
"""Generate batch from context tokens.""" """Generate batch from context tokens."""
...@@ -395,55 +394,26 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids, ...@@ -395,55 +394,26 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
layer_past=None, get_key_value=None, layer_past=None, get_key_value=None,
forward_method_parallel_output=None): forward_method_parallel_output=None):
# Hidden size changes when not using recompute, need to tell communicate() # Hidden size changes when not using recompute, need to tell p2p_communicate
# the correct size # functions the correct size
args = get_args() args = get_args()
orig_seq_length = args.seq_length orig_seq_length = args.seq_length
args.seq_length = tokens.shape[1] args.seq_length = tokens.shape[1]
if not mpu.is_pipeline_first_stage(): input_tensor = recv_forward()
input_tensor, _ = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=True,
recv_backward=False)
else:
input_tensor = None
# Forward pass through the model. # Forward pass through the model.
if mpu.is_pipeline_first_stage(): model.set_input_tensor(input_tensor)
assert input_tensor is None
if mpu.is_pipeline_last_stage():
output_tensor = model(tokens, position_ids, attention_mask, output_tensor = model(tokens, position_ids, attention_mask,
tokentype_ids=tokentype_ids, tokentype_ids=tokentype_ids,
layer_past=layer_past, layer_past=layer_past,
get_key_value=get_key_value, get_key_value=get_key_value,
forward_method_parallel_output=forward_method_parallel_output) forward_method_parallel_output=forward_method_parallel_output)
else:
output_tensor = model(tokens, position_ids, attention_mask,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value)
elif mpu.is_pipeline_last_stage():
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask,
layer_past=layer_past,
get_key_value=get_key_value,
forward_method_parallel_output=forward_method_parallel_output)
else:
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask,
layer_past=layer_past,
get_key_value=get_key_value)
if get_key_value: if get_key_value:
output_tensor, layer_past = output_tensor output_tensor, layer_past = output_tensor
if not mpu.is_pipeline_last_stage(): send_forward(output_tensor)
communicate(tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_forward=False,
recv_backward=False)
args.seq_length = orig_seq_length args.seq_length = orig_seq_length
if get_key_value: if get_key_value:
......
...@@ -17,13 +17,14 @@ ...@@ -17,13 +17,14 @@
import os import os
import time import time
from functools import partial
import torch import torch
from megatron import get_args from megatron import get_args
from megatron import print_rank_last, is_last_rank from megatron import print_rank_last, is_last_rank
from megatron import mpu from megatron import mpu
from megatron.training import communicate from megatron.schedules import get_forward_backward_func
from tasks.finetune_utils import build_data_loader from tasks.finetune_utils import build_data_loader
from tasks.finetune_utils import process_batch from tasks.finetune_utils import process_batch
...@@ -38,7 +39,7 @@ def accuracy_func_provider(single_dataset_provider): ...@@ -38,7 +39,7 @@ def accuracy_func_provider(single_dataset_provider):
for datapath in datapaths: for datapath in datapaths:
dataset = single_dataset_provider(datapath) dataset = single_dataset_provider(datapath)
dataloader = build_data_loader( dataloader = build_data_loader(
dataset, args.micro_batch_size, num_workers=args.num_workers, dataset, args.orig_micro_batch_size, num_workers=args.num_workers,
drop_last=(mpu.get_data_parallel_world_size() > 1)) drop_last=(mpu.get_data_parallel_world_size() > 1))
dataloaders.append((dataset.dataset_name, dataloader)) dataloaders.append((dataset.dataset_name, dataloader))
...@@ -73,14 +74,61 @@ def accuracy_func_provider(single_dataset_provider): ...@@ -73,14 +74,61 @@ def accuracy_func_provider(single_dataset_provider):
return metrics_func return metrics_func
def calculate_correct_answers(name, model, dataloader, def calculate_correct_answers(name, model, dataloader,
epoch, output_predictions): epoch, output_predictions):
"""Calculate correct over total answers and return prediction if the """Calculate correct over total answers and return prediction if the
`output_predictions` is true.""" `output_predictions` is true."""
args = get_args() args = get_args()
forward_backward_func = get_forward_backward_func()
start_time = time.time() start_time = time.time()
model.eval() for m in model:
saved_batch_size = args.micro_batch_size m.eval()
saved_micro_batch_size = args.micro_batch_size
saved_global_batch_size = args.global_batch_size
ds = dataloader.dataset
if hasattr(ds, 'sample_multiplier'):
sample_multiplier = ds.sample_multiplier
else:
sample_multiplier = 1
micro_batch_size_times_data_parallel = args.orig_micro_batch_size * args.data_parallel_size
num_micro_batches = args.orig_global_batch_size // micro_batch_size_times_data_parallel
def loss_func(output_predictions, labels, output_tensor):
logits = output_tensor
loss_dict = {}
# Add output predictions.
if output_predictions:
assert False
loss_dict['softmaxes'] = torch.nn.Softmax(dim=-1)(
logits.float()).data.cpu().numpy().tolist()
loss_dict['labels'] = labels.data.cpu().numpy().tolist()
loss_dict['ids'] = batch['uid'].cpu().numpy().tolist()
# Compute the correct answers.
predicted = torch.argmax(logits, dim=-1)
corrects = (predicted == labels)
# Add to the counters.
loss_dict['total'] = labels.size(0)
loss_dict['correct'] = corrects.sum().item()
return 0, loss_dict
# defined inside to capture output_predictions
def correct_answers_forward_step(batch, model):
try:
batch_ = next(batch)
except BaseException:
batch_ = batch
tokens, types, labels, attention_mask = process_batch(batch_)
# Forward model.
args = get_args()
output_tensor = model(tokens, attention_mask, tokentype_ids=types)
return output_tensor, partial(loss_func, output_predictions, labels)
with torch.no_grad(): with torch.no_grad():
# For all the batches in the dataset. # For all the batches in the dataset.
total = 0 total = 0
...@@ -92,60 +140,30 @@ def calculate_correct_answers(name, model, dataloader, ...@@ -92,60 +140,30 @@ def calculate_correct_answers(name, model, dataloader,
labels = [] labels = []
ids = [] ids = []
for _, batch in enumerate(dataloader): for _, batch in enumerate(dataloader):
# Run the model forward.
tokens, types, labels_, attention_mask = process_batch(batch)
# For evaluation only mode we use drop_last = False to get all the # For evaluation only mode we use drop_last = False to get all the
# samples, which means we might not have a full batch, so we # samples, which means we might not have a full batch, so we
# adjust batch_size here to actual batch size of data # adjust batch_size here to actual batch size of data
actual_batch_size = len(labels_) actual_batch_size = len(batch['label'])
# ... applying sample_multiplier if necessary # ... applying sample_multiplier if necessary
ds = dataloader.dataset args.micro_batch_size = actual_batch_size * sample_multiplier
if hasattr(ds, 'sample_multiplier'): args.global_batch_size = actual_batch_size * sample_multiplier * num_micro_batches
actual_batch_size *= ds.sample_multiplier
args.micro_batch_size = actual_batch_size
if not mpu.is_pipeline_first_stage():
input_tensor, _ = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=True,
recv_backward=False)
else:
input_tensor = None
# Forward model. loss_dicts = forward_backward_func(correct_answers_forward_step, batch, model,
if mpu.is_pipeline_first_stage(): optimizer=None, timers=None, forward_only=True)
assert input_tensor is None
output_tensor = model(tokens, attention_mask, tokentype_ids=types)
else:
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask)
if mpu.is_pipeline_last_stage(): for loss_dict in loss_dicts:
logits = output_tensor
# Add output predictions.
if output_predictions: if output_predictions:
softmaxes.extend(torch.nn.Softmax(dim=-1)( softmaxes.extend(loss_dict['softmaxes'])
logits.float()).data.cpu().numpy().tolist()) labels.extend(loss_dict['labels'])
labels.extend(labels_.data.cpu().numpy().tolist()) ids.extend(loss_dict['ids'])
ids.extend(batch['uid'].cpu().numpy().tolist()) total += loss_dict['total']
# Compute the correct answers. correct += loss_dict['correct']
predicted = torch.argmax(logits, dim=-1)
corrects = (predicted == labels_)
# Add to the counters. for m in model:
total += labels_.size(0) m.train()
correct += corrects.sum().item() args.micro_batch_size = saved_micro_batch_size
else: args.global_batch_size = saved_global_batch_size
communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_forward=False,
recv_backward=False)
model.train()
args.micro_batch_size = saved_batch_size
# Reduce. # Reduce.
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
"""Finetune utilities.""" """Finetune utilities."""
from functools import partial
import torch import torch
from megatron import get_args from megatron import get_args
...@@ -46,7 +48,20 @@ def process_batch(batch): ...@@ -46,7 +48,20 @@ def process_batch(batch):
return tokens, types, labels, attention_mask return tokens, types, labels, attention_mask
def _cross_entropy_forward_step(batch, model, input_tensor): def cross_entropy_loss_func(labels, output_tensor):
logits = output_tensor
# Cross-entropy loss.
loss_func = torch.nn.CrossEntropyLoss()
loss = loss_func(logits.contiguous().float(), labels)
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
def _cross_entropy_forward_step(batch, model):
"""Simple forward step with cross-entropy loss.""" """Simple forward step with cross-entropy loss."""
timers = get_timers() timers = get_timers()
...@@ -60,25 +75,9 @@ def _cross_entropy_forward_step(batch, model, input_tensor): ...@@ -60,25 +75,9 @@ def _cross_entropy_forward_step(batch, model, input_tensor):
timers('batch-generator').stop() timers('batch-generator').stop()
# Forward model. # Forward model.
if mpu.is_pipeline_first_stage():
assert input_tensor is None
output_tensor = model(tokens, attention_mask, tokentype_ids=types) output_tensor = model(tokens, attention_mask, tokentype_ids=types)
else:
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask)
if mpu.is_pipeline_last_stage():
logits = output_tensor
# Cross-entropy loss.
loss_func = torch.nn.CrossEntropyLoss()
loss = loss_func(logits.contiguous().float(), labels)
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]} return output_tensor, partial(cross_entropy_loss_func, labels)
return output_tensor
def build_data_loader(dataset, micro_batch_size, num_workers, drop_last): def build_data_loader(dataset, micro_batch_size, num_workers, drop_last):
...@@ -135,6 +134,8 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset): ...@@ -135,6 +134,8 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
# This is necessary so pipeline transfers know what size they are # This is necessary so pipeline transfers know what size they are
# and the LR schedule, which is based on samples seen, gets set # and the LR schedule, which is based on samples seen, gets set
# correctly. # correctly.
args.orig_micro_batch_size = args.micro_batch_size
args.orig_global_batch_size = args.global_batch_size
if hasattr(train_dataset, 'sample_multiplier'): if hasattr(train_dataset, 'sample_multiplier'):
args.micro_batch_size *= train_dataset.sample_multiplier args.micro_batch_size *= train_dataset.sample_multiplier
args.global_batch_size *= train_dataset.sample_multiplier args.global_batch_size *= train_dataset.sample_multiplier
...@@ -149,7 +150,8 @@ def _train(model, optimizer, lr_scheduler, forward_step, ...@@ -149,7 +150,8 @@ def _train(model, optimizer, lr_scheduler, forward_step,
timers = get_timers() timers = get_timers()
# Turn on training mode which enables dropout. # Turn on training mode which enables dropout.
model.train() for m in model:
m.train()
# Tracking loss. # Tracking loss.
losses_dict_sum = {} losses_dict_sum = {}
...@@ -180,10 +182,8 @@ def _train(model, optimizer, lr_scheduler, forward_step, ...@@ -180,10 +182,8 @@ def _train(model, optimizer, lr_scheduler, forward_step,
start_iteration = 0 start_iteration = 0
# Train for one step. # Train for one step.
losses_dict, skipped_iter, grad_norm = train_step(forward_step, out = train_step(forward_step, batch, model, optimizer, lr_scheduler)
batch, model, losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = out
optimizer,
lr_scheduler)
iteration += 1 iteration += 1
# Logging. # Logging.
...@@ -195,7 +195,7 @@ def _train(model, optimizer, lr_scheduler, forward_step, ...@@ -195,7 +195,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
iteration, iteration,
optimizer.get_loss_scale().item(), optimizer.get_loss_scale().item(),
report_memory_flag, skipped_iter, report_memory_flag, skipped_iter,
grad_norm, params_norm) grad_norm, params_norm, num_zeros_in_grad)
# Autoresume # Autoresume
if args.adlr_autoresume and \ if args.adlr_autoresume and \
......
...@@ -19,7 +19,7 @@ from megatron import get_args ...@@ -19,7 +19,7 @@ from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
from megatron.model.classification import Classification, ClassificationFirstStage, ClassificationIntermediateStage, ClassificationLastStage from megatron.model.classification import Classification
from tasks.eval_utils import accuracy_func_provider from tasks.eval_utils import accuracy_func_provider
from tasks.finetune_utils import finetune from tasks.finetune_utils import finetune
...@@ -39,25 +39,14 @@ def glue_classification(num_classes, Dataset, ...@@ -39,25 +39,14 @@ def glue_classification(num_classes, Dataset,
return train_dataset, valid_dataset return train_dataset, valid_dataset
def model_provider(): def model_provider(pre_process=True, post_process=True):
"""Build the model.""" """Build the model."""
args = get_args() args = get_args()
print_rank_0('building classification model for {} ...'.format( print_rank_0('building classification model for {} ...'.format(
args.task)) args.task))
if mpu.get_pipeline_model_parallel_world_size() > 1: model = Classification(num_classes=num_classes, num_tokentypes=2,
# Determine model based on position of stage in pipeline. pre_process=pre_process, post_process=post_process)
if mpu.is_pipeline_first_stage():
model = ClassificationFirstStage(
num_classes=num_classes, num_tokentypes=2)
elif mpu.is_pipeline_last_stage():
model = ClassificationLastStage(
num_classes=num_classes, num_tokentypes=2)
else:
model = ClassificationIntermediateStage(
num_classes=num_classes, num_tokentypes=2)
else:
model = Classification(num_classes=num_classes, num_tokentypes=2)
return model return model
......
...@@ -70,6 +70,11 @@ if __name__ == '__main__': ...@@ -70,6 +70,11 @@ if __name__ == '__main__':
initialize_megatron(extra_args_provider=get_tasks_args) initialize_megatron(extra_args_provider=get_tasks_args)
args = get_args() args = get_args()
if args.num_layers_per_virtual_pipeline_stage is not None:
print("Interleaved pipeline schedule is not yet supported for downstream tasks.")
exit()
if args.task == 'RACE': if args.task == 'RACE':
from race.finetune import main from race.finetune import main
elif args.task in ['MNLI', 'QQP']: elif args.task in ['MNLI', 'QQP']:
......
...@@ -19,7 +19,7 @@ from megatron import get_args ...@@ -19,7 +19,7 @@ from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
from megatron.model.multiple_choice import MultipleChoice, MultipleChoiceFirstStage, MultipleChoiceIntermediateStage, MultipleChoiceLastStage from megatron.model.multiple_choice import MultipleChoice
from tasks.eval_utils import accuracy_func_provider from tasks.eval_utils import accuracy_func_provider
from tasks.finetune_utils import finetune from tasks.finetune_utils import finetune
from tasks.race.data import RaceDataset from tasks.race.data import RaceDataset
...@@ -38,20 +38,13 @@ def train_valid_datasets_provider(): ...@@ -38,20 +38,13 @@ def train_valid_datasets_provider():
return train_dataset, valid_dataset return train_dataset, valid_dataset
def model_provider(): def model_provider(pre_process=True, post_process=True):
"""Build the model.""" """Build the model."""
print_rank_0('building multichoice model for RACE ...') print_rank_0('building multichoice model for RACE ...')
if mpu.get_pipeline_model_parallel_world_size() > 1: model = MultipleChoice(num_tokentypes=2,
# Determine model based on position of stage in pipeline. pre_process=pre_process,
if mpu.is_pipeline_first_stage(): post_process=post_process)
model = MultipleChoiceFirstStage(num_tokentypes=2)
elif mpu.is_pipeline_last_stage():
model = MultipleChoiceLastStage(num_tokentypes=2)
else:
model = MultipleChoiceIntermediateStage(num_tokentypes=2)
else:
model = MultipleChoice(num_tokentypes=2)
return model return model
......
...@@ -25,8 +25,9 @@ from megatron import get_tokenizer ...@@ -25,8 +25,9 @@ from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
from megatron.model import GPTModel, GPTModelFirstStage, GPTModelLastStage, GPTModelIntermediateStage from megatron.model import GPTModel, GPTModelFirstStage, GPTModelLastStage, GPTModelIntermediateStage
from megatron.training import get_model, communicate from megatron.training import get_model
from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids
from megatron.p2p_communication import recv_forward, send_forward
from tasks.finetune_utils import build_data_loader from tasks.finetune_utils import build_data_loader
from .datasets import build_dataset from .datasets import build_dataset
...@@ -98,14 +99,7 @@ def forward_step(batch, model, eval_metric): ...@@ -98,14 +99,7 @@ def forward_step(batch, model, eval_metric):
args.micro_batch_size = len(labels) args.micro_batch_size = len(labels)
# Forward model. # Forward model.
if not mpu.is_pipeline_first_stage(): input_tensor = recv_forward()
input_tensor, _ = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=True,
recv_backward=False)
else:
input_tensor = None
# Forward pass through the model. # Forward pass through the model.
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
...@@ -118,12 +112,7 @@ def forward_step(batch, model, eval_metric): ...@@ -118,12 +112,7 @@ def forward_step(batch, model, eval_metric):
assert input_tensor is not None assert input_tensor is not None
output = model(input_tensor, attention_mask) output = model(input_tensor, attention_mask)
if not mpu.is_pipeline_last_stage(): send_forward(output)
communicate(tensor_send_next=output,
tensor_send_prev=None,
recv_forward=False,
recv_backward=False)
return None
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
# For loss, return the unreduced loss. # For loss, return the unreduced loss.
...@@ -214,6 +203,10 @@ def main(): ...@@ -214,6 +203,10 @@ def main():
"""Main program.""" """Main program."""
args = get_args() args = get_args()
if args.num_layers_per_virtual_pipeline_stage is not None:
print("Interleaved pipeline schedule is not yet supported for text generation.")
exit()
if args.task == 'LAMBADA': if args.task == 'LAMBADA':
eval_metric = 'accuracy' eval_metric = 'accuracy'
elif args.task == 'WIKITEXT103': elif args.task == 'WIKITEXT103':
...@@ -227,6 +220,9 @@ def main(): ...@@ -227,6 +220,9 @@ def main():
if args.load is not None: if args.load is not None:
_ = load_checkpoint(model, None, None) _ = load_checkpoint(model, None, None)
assert len(model) == 1, "Above condition should have caught this"
model = model[0]
# Data stuff. # Data stuff.
dataset = build_dataset(args.task) dataset = build_dataset(args.task)
dataloader = build_data_loader(dataset, args.micro_batch_size, dataloader = build_data_loader(dataset, args.micro_batch_size,
......
...@@ -96,12 +96,20 @@ def main(): ...@@ -96,12 +96,20 @@ def main():
'no_load_rng': True, 'no_load_rng': True,
'no_load_optim': True}) 'no_load_optim': True})
args = get_args()
if args.num_layers_per_virtual_pipeline_stage is not None:
print("Interleaved pipeline schedule is not yet supported for text generation.")
exit()
# Set up model and load checkpoint. # Set up model and load checkpoint.
model = get_model(model_provider) model = get_model(model_provider)
args = get_args()
if args.load is not None: if args.load is not None:
_ = load_checkpoint(model, None, None) _ = load_checkpoint(model, None, None)
assert len(model) == 1, "Above condition should have caught this"
model = model[0]
# Generate samples. # Generate samples.
if args.num_samples == 0: if args.num_samples == 0:
args.micro_batch_size = 1 args.micro_batch_size = 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment