"vscode:/vscode.git/clone" did not exist on "9ad98c87c314877541187724a620c81332339a87"
Commit 9019bbf4 authored by mohammad's avatar mohammad Committed by Deepak Narayanan
Browse files

Rename --batch-size to --micro-batch-size and drop in-minibatch from...

Rename --batch-size to --micro-batch-size and drop in-minibatch from --num-micro-batches-in-minibatch
parent a84a5fa0
...@@ -63,8 +63,6 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -63,8 +63,6 @@ def parse_args(extra_args_provider=None, defaults={},
if "ring_exchange" not in dir(torch.distributed): if "ring_exchange" not in dir(torch.distributed):
raise Exception('PyTorch with torch.distributed.ring_exchange needed ' raise Exception('PyTorch with torch.distributed.ring_exchange needed '
'to run pipeline MP!') 'to run pipeline MP!')
if args.num_microbatches_in_minibatch is None:
args.num_microbatches_in_minibatch = 1
if args.rank == 0: if args.rank == 0:
print('using world size: {}, tensor-model-parallel size: {}, pipeline-model-parallel size: {} '.format( print('using world size: {}, tensor-model-parallel size: {}, pipeline-model-parallel size: {} '.format(
args.world_size, args.tensor_model_parallel_size, args.pipeline_model_parallel_size)) args.world_size, args.tensor_model_parallel_size, args.pipeline_model_parallel_size))
...@@ -212,11 +210,11 @@ def _add_regularization_args(parser): ...@@ -212,11 +210,11 @@ def _add_regularization_args(parser):
def _add_training_args(parser): def _add_training_args(parser):
group = parser.add_argument_group(title='training') group = parser.add_argument_group(title='training')
group.add_argument('--batch-size', type=int, default=None, group.add_argument('--micro-batch-size', type=int, default=None,
help='Batch size per model instance (local batch size). ' help='Batch size per model instance (local batch size). '
'Global batch size is local batch size times data ' 'Global batch size is local batch size times data '
'parallel size.') 'parallel size.')
group.add_argument('--num-microbatches-in-minibatch', type=int, default=None, group.add_argument('--num-microbatches', type=int, default=1,
help='Number of microbatches in minibatch') help='Number of microbatches in minibatch')
group.add_argument('--checkpoint-activations', action='store_true', group.add_argument('--checkpoint-activations', action='store_true',
help='Checkpoint activation to allow for training ' help='Checkpoint activation to allow for training '
......
...@@ -30,7 +30,7 @@ def build_pretraining_data_loader(dataset, consumed_samples): ...@@ -30,7 +30,7 @@ def build_pretraining_data_loader(dataset, consumed_samples):
args = get_args() args = get_args()
world_size = mpu.get_data_parallel_world_size() world_size = mpu.get_data_parallel_world_size()
global_batch_size = args.batch_size * world_size global_batch_size = args.micro_batch_size * world_size
# Megatron sampler # Megatron sampler
batch_sampler = MegatronPretrainingSampler( batch_sampler = MegatronPretrainingSampler(
......
...@@ -9,15 +9,15 @@ from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_co ...@@ -9,15 +9,15 @@ from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_co
from megatron import get_args, get_tokenizer, print_rank_0, mpu from megatron import get_args, get_tokenizer, print_rank_0, mpu
def get_one_epoch_dataloader(dataset, batch_size=None): def get_one_epoch_dataloader(dataset, micro_batch_size=None):
"""Specifically one epoch to be used in an indexing job.""" """Specifically one epoch to be used in an indexing job."""
args = get_args() args = get_args()
world_size = mpu.get_data_parallel_world_size() world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank() rank = mpu.get_data_parallel_rank()
if batch_size is None: if micro_batch_size is None:
batch_size = args.batch_size micro_batch_size = args.micro_batch_size
global_batch_size = batch_size * world_size global_batch_size = micro_batch_size * world_size
num_workers = args.num_workers num_workers = args.num_workers
sampler = torch.utils.data.SequentialSampler(dataset) sampler = torch.utils.data.SequentialSampler(dataset)
......
...@@ -80,7 +80,7 @@ __global__ void scaled_masked_softmax_warp_forward( ...@@ -80,7 +80,7 @@ __global__ void scaled_masked_softmax_warp_forward(
const input_t *src, const input_t *src,
const uint8_t *mask, const uint8_t *mask,
const acc_t scale, const acc_t scale,
int batch_size, int micro_batch_size,
int stride, int stride,
int element_count, int element_count,
int pad_batches) int pad_batches)
...@@ -102,9 +102,9 @@ __global__ void scaled_masked_softmax_warp_forward( ...@@ -102,9 +102,9 @@ __global__ void scaled_masked_softmax_warp_forward(
pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
} }
// batch_size might not be a multiple of WARP_BATCH. Check how // micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP. // many batches have to computed within this WARP.
int local_batches = batch_size - first_batch; int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH) if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH; local_batches = WARP_BATCH;
...@@ -184,7 +184,7 @@ __global__ void scaled_masked_softmax_warp_backward( ...@@ -184,7 +184,7 @@ __global__ void scaled_masked_softmax_warp_backward(
input_t *grad, input_t *grad,
const input_t *output, const input_t *output,
acc_t scale, acc_t scale,
int batch_size, int micro_batch_size,
int stride, int stride,
int element_count) int element_count)
{ {
...@@ -199,9 +199,9 @@ __global__ void scaled_masked_softmax_warp_backward( ...@@ -199,9 +199,9 @@ __global__ void scaled_masked_softmax_warp_backward(
// gridDim/blockIdx = (seq_len, attn_heads, batches) // gridDim/blockIdx = (seq_len, attn_heads, batches)
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
// batch_size might not be a multiple of WARP_BATCH. Check how // micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP. // many batches have to computed within this WARP.
int local_batches = batch_size - first_batch; int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH) if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH; local_batches = WARP_BATCH;
......
...@@ -79,7 +79,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( ...@@ -79,7 +79,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
output_t *dst, output_t *dst,
const input_t *src, const input_t *src,
const acc_t scale, const acc_t scale,
int batch_size, int micro_batch_size,
int stride, int stride,
int element_count) int element_count)
{ {
...@@ -94,9 +94,9 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( ...@@ -94,9 +94,9 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
int local_seq = blockIdx.x + 1; int local_seq = blockIdx.x + 1;
int warp_iteration_limit = (local_seq + WARP_SIZE - 1)/WARP_SIZE; int warp_iteration_limit = (local_seq + WARP_SIZE - 1)/WARP_SIZE;
// batch_size might not be a multiple of WARP_BATCH. Check how // micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP. // many batches have to computed within this WARP.
int local_batches = batch_size - first_batch; int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH) if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH; local_batches = WARP_BATCH;
...@@ -173,7 +173,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward( ...@@ -173,7 +173,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
input_t *grad, input_t *grad,
const input_t *output, const input_t *output,
acc_t scale, acc_t scale,
int batch_size, int micro_batch_size,
int stride, int stride,
int element_count) int element_count)
{ {
...@@ -187,9 +187,9 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward( ...@@ -187,9 +187,9 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
int local_seq = blockIdx.x + 1; int local_seq = blockIdx.x + 1;
// batch_size might not be a multiple of WARP_BATCH. Check how // micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP. // many batches have to computed within this WARP.
int local_batches = batch_size - first_batch; int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH) if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH; local_batches = WARP_BATCH;
......
...@@ -45,7 +45,7 @@ def init_checkpointed_activations_memory_buffer(): ...@@ -45,7 +45,7 @@ def init_checkpointed_activations_memory_buffer():
"""Initializ the memory buffer for the checkpointed activations.""" """Initializ the memory buffer for the checkpointed activations."""
args = get_args() args = get_args()
per_layer = args.batch_size * args.max_position_embeddings * \ per_layer = args.micro_batch_size * args.max_position_embeddings * \
args.hidden_size // args.model_parallel_size args.hidden_size // args.model_parallel_size
assert args.num_layers % args.checkpoint_num_layers == 0, \ assert args.num_layers % args.checkpoint_num_layers == 0, \
'number of layers is not divisible by checkpoint-num-layers' 'number of layers is not divisible by checkpoint-num-layers'
......
...@@ -138,7 +138,7 @@ def get_model(model_provider_func): ...@@ -138,7 +138,7 @@ def get_model(model_provider_func):
model = FP16_Module(model) model = FP16_Module(model)
# Wrap model for distributed training.""" # Wrap model for distributed training."""
if args.num_microbatches_in_minibatch > 1: if args.num_microbatches > 1:
assert args.DDP_impl == 'local' assert args.DDP_impl == 'local'
if args.DDP_impl == 'torch': if args.DDP_impl == 'torch':
...@@ -246,7 +246,7 @@ def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward) ...@@ -246,7 +246,7 @@ def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward)
# if needed. # if needed.
tensor_recv_prev = None tensor_recv_prev = None
tensor_recv_next = None tensor_recv_next = None
tensor_shape = (args.seq_length, args.batch_size, args.hidden_size) tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
if recv_forward: if recv_forward:
tensor_recv_prev = torch.empty(tensor_shape, tensor_recv_prev = torch.empty(tensor_shape,
requires_grad=True, requires_grad=True,
...@@ -315,7 +315,7 @@ def forward_step_with_communication(forward_step_func, data_iterator, model, ...@@ -315,7 +315,7 @@ def forward_step_with_communication(forward_step_func, data_iterator, model,
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
loss, loss_reduced = output_tensor loss, loss_reduced = output_tensor
output_tensor = loss / args.num_microbatches_in_minibatch output_tensor = loss / args.num_microbatches
losses_reduced.append(loss_reduced) losses_reduced.append(loss_reduced)
else: else:
timers('forward-send').start() timers('forward-send').start()
...@@ -375,7 +375,7 @@ def forward_and_backward_steps_with_communication(forward_step_func, data_iterat ...@@ -375,7 +375,7 @@ def forward_and_backward_steps_with_communication(forward_step_func, data_iterat
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
loss, loss_reduced = output_tensor loss, loss_reduced = output_tensor
output_tensor = loss / args.num_microbatches_in_minibatch output_tensor = loss / args.num_microbatches
output_tensor_grad = None output_tensor_grad = None
losses_reduced.append(loss_reduced) losses_reduced.append(loss_reduced)
else: else:
...@@ -419,10 +419,10 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model, ...@@ -419,10 +419,10 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
args = get_args() args = get_args()
losses_reduced = [] losses_reduced = []
for i in range(args.num_microbatches_in_minibatch): for i in range(args.num_microbatches):
timers('forward-compute').start() timers('forward-compute').start()
loss, loss_reduced = forward_step_func(data_iterator, model, input_tensor=None) loss, loss_reduced = forward_step_func(data_iterator, model, input_tensor=None)
output_tensor = loss / args.num_microbatches_in_minibatch output_tensor = loss / args.num_microbatches
losses_reduced.append(loss_reduced) losses_reduced.append(loss_reduced)
timers('forward-compute').stop() timers('forward-compute').stop()
...@@ -441,15 +441,15 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model, ...@@ -441,15 +441,15 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
args = get_args() args = get_args()
# Compute number of warmup microbatches. # Compute number of warmup microbatches.
num_microbatches_in_minibatch = args.num_microbatches_in_minibatch num_microbatches = args.num_microbatches
num_warmup_microbatches = \ num_warmup_microbatches = \
(mpu.get_pipeline_model_parallel_world_size() - (mpu.get_pipeline_model_parallel_world_size() -
mpu.get_pipeline_model_parallel_rank() - 1) mpu.get_pipeline_model_parallel_rank() - 1)
num_warmup_microbatches = min( num_warmup_microbatches = min(
num_warmup_microbatches, num_warmup_microbatches,
num_microbatches_in_minibatch) num_microbatches)
num_microbatches_in_minibatch_remaining = \ num_microbatches_remaining = \
num_microbatches_in_minibatch - num_warmup_microbatches num_microbatches - num_warmup_microbatches
input_tensors = [] input_tensors = []
output_tensors = [] output_tensors = []
...@@ -465,7 +465,7 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model, ...@@ -465,7 +465,7 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
# Before running 1F1B, need to receive first forward tensor. # Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to # If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here. # receive this tensor here.
if num_microbatches_in_minibatch_remaining > 0: if num_microbatches_remaining > 0:
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
input_tensor = None input_tensor = None
else: else:
...@@ -477,8 +477,8 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model, ...@@ -477,8 +477,8 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
timers('forward-recv').stop() timers('forward-recv').stop()
# Run 1F1B. # Run 1F1B.
for i in range(num_microbatches_in_minibatch_remaining): for i in range(num_microbatches_remaining):
last_iteration = (i == (num_microbatches_in_minibatch_remaining - 1)) last_iteration = (i == (num_microbatches_remaining - 1))
input_tensor = \ input_tensor = \
forward_and_backward_steps_with_communication(forward_step_func, data_iterator, model, forward_and_backward_steps_with_communication(forward_step_func, data_iterator, model,
optimizer, optimizer,
...@@ -702,8 +702,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -702,8 +702,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
lr_scheduler) lr_scheduler)
iteration += 1 iteration += 1
args.consumed_train_samples += mpu.get_data_parallel_world_size() * \ args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
args.batch_size * \ args.micro_batch_size * \
args.num_microbatches_in_minibatch args.num_microbatches
# Logging. # Logging.
loss_scale = None loss_scale = None
...@@ -761,7 +761,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False): ...@@ -761,7 +761,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
print_rank_0('Evaluating iter {}/{}'.format(iteration, print_rank_0('Evaluating iter {}/{}'.format(iteration,
args.eval_iters)) args.eval_iters))
for _ in range(args.num_microbatches_in_minibatch): for _ in range(args.num_microbatches):
if not mpu.is_pipeline_first_stage(): if not mpu.is_pipeline_first_stage():
input_tensor, _ = communicate( input_tensor, _ = communicate(
tensor_send_next=None, tensor_send_next=None,
...@@ -788,13 +788,13 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False): ...@@ -788,13 +788,13 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
recv_backward=False) recv_backward=False)
args.consumed_valid_samples += mpu.get_data_parallel_world_size() \ args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
* args.batch_size \ * args.micro_batch_size \
* args.num_microbatches_in_minibatch * args.num_microbatches
# Move model back to the train mode. # Move model back to the train mode.
model.train() model.train()
for key in total_loss_dict: for key in total_loss_dict:
total_loss_dict[key] /= args.eval_iters * args.num_microbatches_in_minibatch total_loss_dict[key] /= args.eval_iters * args.num_microbatches
return total_loss_dict return total_loss_dict
...@@ -834,7 +834,7 @@ def build_train_valid_test_data_iterators( ...@@ -834,7 +834,7 @@ def build_train_valid_test_data_iterators(
# Rank and global batch size. # Rank and global batch size.
data_parallel_size = mpu.get_data_parallel_world_size() data_parallel_size = mpu.get_data_parallel_world_size()
global_batch_size = args.batch_size * data_parallel_size * args.num_microbatches_in_minibatch global_batch_size = args.micro_batch_size * data_parallel_size * args.num_microbatches
# Backward compatibility, assume fixed batch size. # Backward compatibility, assume fixed batch size.
if args.iteration > 0 and args.consumed_train_samples == 0: if args.iteration > 0 and args.consumed_train_samples == 0:
args.consumed_train_samples = args.iteration * global_batch_size args.consumed_train_samples = args.iteration * global_batch_size
......
...@@ -98,11 +98,11 @@ def get_ltor_masks_and_position_ids(data, ...@@ -98,11 +98,11 @@ def get_ltor_masks_and_position_ids(data,
"""Build masks and position id for left to right model.""" """Build masks and position id for left to right model."""
# Extract batch size and sequence length. # Extract batch size and sequence length.
batch_size, seq_length = data.size() micro_batch_size, seq_length = data.size()
# Attention mask (lower triangular). # Attention mask (lower triangular).
if reset_attention_mask: if reset_attention_mask:
att_mask_batch = batch_size att_mask_batch = micro_batch_size
else: else:
att_mask_batch = 1 att_mask_batch = 1
attention_mask = torch.tril(torch.ones( attention_mask = torch.tril(torch.ones(
...@@ -124,7 +124,7 @@ def get_ltor_masks_and_position_ids(data, ...@@ -124,7 +124,7 @@ def get_ltor_masks_and_position_ids(data,
if reset_position_ids or reset_attention_mask: if reset_position_ids or reset_attention_mask:
# Loop through the batches: # Loop through the batches:
for b in range(batch_size): for b in range(micro_batch_size):
# Find indecies where EOD token is. # Find indecies where EOD token is.
eod_index = position_ids[b, data[b] == eod_token] eod_index = position_ids[b, data[b] == eod_token]
......
...@@ -87,8 +87,8 @@ def forward_step(data_iterator, model, input_tensor): ...@@ -87,8 +87,8 @@ def forward_step(data_iterator, model, input_tensor):
# Forward model. # Forward model.
query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask) query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask)
local_batch_size = query_logits.shape[0] micro_batch_size = query_logits.shape[0]
global_batch_size = dist.get_world_size() * local_batch_size # recall we assert that tensor_model_parallel_size == 1 global_batch_size = dist.get_world_size() * micro_batch_size # recall we assert that tensor_model_parallel_size == 1
all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits) all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
all_block_logits = AllgatherFromDataParallelRegion.apply(block_logits) all_block_logits = AllgatherFromDataParallelRegion.apply(block_logits)
......
...@@ -37,7 +37,7 @@ def accuracy_func_provider(single_dataset_provider): ...@@ -37,7 +37,7 @@ def accuracy_func_provider(single_dataset_provider):
for datapath in datapaths: for datapath in datapaths:
dataset = single_dataset_provider(datapath) dataset = single_dataset_provider(datapath)
dataloader = build_data_loader( dataloader = build_data_loader(
dataset, args.batch_size, num_workers=args.num_workers, dataset, args.micro_batch_size, num_workers=args.num_workers,
drop_last=(mpu.get_data_parallel_world_size() > 1)) drop_last=(mpu.get_data_parallel_world_size() > 1))
dataloaders.append((dataset.dataset_name, dataloader)) dataloaders.append((dataset.dataset_name, dataloader))
......
...@@ -71,7 +71,7 @@ def _cross_entropy_forward_step(batch, model): ...@@ -71,7 +71,7 @@ def _cross_entropy_forward_step(batch, model):
return loss, {'lm loss': averaged_loss[0]} return loss, {'lm loss': averaged_loss[0]}
def build_data_loader(dataset, batch_size, num_workers, drop_last): def build_data_loader(dataset, micro_batch_size, num_workers, drop_last):
"""Data loader. Note that batch-size is the local (per GPU) batch-size.""" """Data loader. Note that batch-size is the local (per GPU) batch-size."""
# Sampler. # Sampler.
...@@ -82,7 +82,7 @@ def build_data_loader(dataset, batch_size, num_workers, drop_last): ...@@ -82,7 +82,7 @@ def build_data_loader(dataset, batch_size, num_workers, drop_last):
# Data loader. Note that batch size is the per GPU batch size. # Data loader. Note that batch size is the per GPU batch size.
data_loader = torch.utils.data.DataLoader(dataset, data_loader = torch.utils.data.DataLoader(dataset,
batch_size=batch_size, batch_size=micro_batch_size,
sampler=sampler, sampler=sampler,
shuffle=False, shuffle=False,
num_workers=num_workers, num_workers=num_workers,
...@@ -109,14 +109,14 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset): ...@@ -109,14 +109,14 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
print_rank_0('building train and validation dataloaders ...') print_rank_0('building train and validation dataloaders ...')
# Training dataset. # Training dataset.
train_dataloader = build_data_loader(train_dataset, args.batch_size, train_dataloader = build_data_loader(train_dataset, args.micro_batch_size,
args.num_workers, not args.keep_last) args.num_workers, not args.keep_last)
# Set the training iterations. # Set the training iterations.
args.train_iters_per_epoch = len(train_dataloader) args.train_iters_per_epoch = len(train_dataloader)
args.train_iters = args.epochs * args.train_iters_per_epoch args.train_iters = args.epochs * args.train_iters_per_epoch
# Validation dataset. For this dataset, we do not need to set up # Validation dataset. For this dataset, we do not need to set up
# shuffling so we can just use a simple infinite loop. # shuffling so we can just use a simple infinite loop.
valid_dataloader_ = build_data_loader(valid_dataset, args.batch_size, valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size,
args.num_workers, not args.keep_last) args.num_workers, not args.keep_last)
valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_) valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)
......
...@@ -186,7 +186,7 @@ def main(): ...@@ -186,7 +186,7 @@ def main():
# Data stuff. # Data stuff.
dataset = build_dataset(args.task) dataset = build_dataset(args.task)
dataloader = build_data_loader(dataset, args.batch_size, dataloader = build_data_loader(dataset, args.micro_batch_size,
args.num_workers, drop_last=False) args.num_workers, drop_last=False)
# Run evaluation. # Run evaluation.
......
...@@ -86,7 +86,7 @@ def main(): ...@@ -86,7 +86,7 @@ def main():
# Generate samples. # Generate samples.
if args.num_samples == 0: if args.num_samples == 0:
args.batch_size = 1 args.micro_batch_size = 1
if args.sample_input_file != None: if args.sample_input_file != None:
generate_samples_input_from_file(model) generate_samples_input_from_file(model)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment