Commit 41a64613 authored by Rewon Child's avatar Rewon Child
Browse files

Merge main

parents 8676baca 83d26f03
...@@ -44,10 +44,13 @@ from megatron.initialize import initialize_megatron ...@@ -44,10 +44,13 @@ from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard from megatron.initialize import write_args_to_tensorboard
from megatron.learning_rates import AnnealingLR from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model.realm_model import ICTBertModel
from megatron.utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import unwrap_model
from megatron.data.data_samplers import build_pretraining_data_loader from megatron.data.data_samplers import build_pretraining_data_loader
from megatron.utils import calc_params_l2_norm from megatron.utils import calc_params_l2_norm
from megatron.schedules import forward_backward_no_pipelining
from megatron.schedules import forward_backward_pipelining_without_interleaving
from megatron.schedules import forward_backward_pipelining_with_interleaving
from megatron.utils import report_memory from megatron.utils import report_memory
...@@ -107,23 +110,32 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -107,23 +110,32 @@ def pretrain(train_valid_test_dataset_provider,
timers = get_timers() timers = get_timers()
# Model, optimizer, and learning rate. # Model, optimizer, and learning rate.
timers('model and optimizer').start() timers('model-and-optimizer-setup').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider) model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
timers('model and optimizer').stop() timers('model-and-optimizer-setup').stop()
print_datetime('after model, optimizer, and learning rate ' print_datetime('after model, optimizer, and learning rate '
'scheduler are built') 'scheduler are built')
# Data stuff. # Data stuff.
timers('train/valid/test data iterators').start() timers('train/valid/test-data-iterators-setup').start()
train_data_iterator, valid_data_iterator, test_data_iterator \ if args.virtual_pipeline_model_parallel_size is not None:
= build_train_valid_test_data_iterators( all_data_iterators = [
train_valid_test_dataset_provider) build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
timers('train/valid/test data iterators').stop() for _ in range(len(model))
]
train_data_iterator = [data_iterators[0] for data_iterators in all_data_iterators]
valid_data_iterator = [data_iterators[1] for data_iterators in all_data_iterators]
test_data_iterator = [data_iterators[2] for data_iterators in all_data_iterators]
else:
train_data_iterator, valid_data_iterator, test_data_iterator \
= build_train_valid_test_data_iterators(
train_valid_test_dataset_provider)
timers('train/valid/test-data-iterators-setup').stop()
print_datetime('after dataloaders are built') print_datetime('after dataloaders are built')
# Print setup timing. # Print setup timing.
print_rank_0('done with setups ...') print_rank_0('done with setup ...')
timers.log(['model and optimizer', 'train/valid/test data iterators']) timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup'])
print_rank_0('training ...') print_rank_0('training ...')
iteration = 0 iteration = 0
...@@ -185,13 +197,16 @@ def get_model(model_provider_func): ...@@ -185,13 +197,16 @@ def get_model(model_provider_func):
# Build model on cpu. # Build model on cpu.
model = model_provider_func() model = model_provider_func()
if not isinstance(model, list):
model = [model]
# Set tensor model parallel attributes if not set. # Set tensor model parallel attributes if not set.
# Only parameters that are already tensor model parallel have these # Only parameters that are already tensor model parallel have these
# attributes set for them. We should make sure the default attributes # attributes set for them. We should make sure the default attributes
# are set for all params so the optimizer can use them. # are set for all params so the optimizer can use them.
for param in model.parameters(): for model_module in model:
mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param) for param in model_module.parameters():
mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
# Print number of parameters. # Print number of parameters.
if mpu.get_data_parallel_rank() == 0: if mpu.get_data_parallel_rank() == 0:
...@@ -199,22 +214,25 @@ def get_model(model_provider_func): ...@@ -199,22 +214,25 @@ def get_model(model_provider_func):
'model parallel rank ({}, {}): {}'.format( 'model parallel rank ({}, {}): {}'.format(
mpu.get_tensor_model_parallel_rank(), mpu.get_tensor_model_parallel_rank(),
mpu.get_pipeline_model_parallel_rank(), mpu.get_pipeline_model_parallel_rank(),
sum([p.nelement() for p in model.parameters()])), flush=True) sum([sum([p.nelement() for p in model_module.parameters()])
for model_module in model])), flush=True)
# GPU allocation. # GPU allocation.
model.cuda(torch.cuda.current_device()) for model_module in model:
model_module.cuda(torch.cuda.current_device())
# Fp16 conversion. # Fp16 conversion.
if args.fp16: if args.fp16:
model = FP16Module(model) model = [FP16Module(model_module) for model_module in model]
if args.DDP_impl == 'torch': if args.DDP_impl == 'torch':
i = torch.cuda.current_device() i = torch.cuda.current_device()
model = torchDDP(model, device_ids=[i], output_device=i, model = [torchDDP(model_module, device_ids=[i], output_device=i,
process_group=mpu.get_data_parallel_group()) process_group=mpu.get_data_parallel_group())
for model_module in model]
return model return model
if args.DDP_impl == 'local': if args.DDP_impl == 'local':
model = LocalDDP(model) model = [LocalDDP(model_module) for model_module in model]
return model return model
raise NotImplementedError('Unknown DDP implementation specified: {}. ' raise NotImplementedError('Unknown DDP implementation specified: {}. '
...@@ -270,9 +288,8 @@ def setup_model_and_optimizer(model_provider_func): ...@@ -270,9 +288,8 @@ def setup_model_and_optimizer(model_provider_func):
model = get_model(model_provider_func) model = get_model(model_provider_func)
unwrapped_model = model unwrapped_model = unwrap_model(model,
while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16Module)): (torchDDP, LocalDDP, FP16Module))
unwrapped_model = unwrapped_model.module
optimizer = get_megatron_optimizer(unwrapped_model) optimizer = get_megatron_optimizer(unwrapped_model)
lr_scheduler = get_learning_rate_scheduler(optimizer) lr_scheduler = get_learning_rate_scheduler(optimizer)
...@@ -282,305 +299,31 @@ def setup_model_and_optimizer(model_provider_func): ...@@ -282,305 +299,31 @@ def setup_model_and_optimizer(model_provider_func):
# Extra barrier is added to make sure all ranks report the # Extra barrier is added to make sure all ranks report the
# max time. # max time.
torch.distributed.barrier() torch.distributed.barrier()
timers('load checkpoint').start() timers('load-checkpoint').start()
args.iteration = load_checkpoint(model, optimizer, lr_scheduler) args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
torch.distributed.barrier() torch.distributed.barrier()
timers('load checkpoint').stop() timers('load-checkpoint').stop()
timers.log(['load checkpoint']) timers.log(['load-checkpoint'])
else: else:
args.iteration = 0 args.iteration = 0
# We only support local DDP with multiple micro-batches. # We only support local DDP with multiple micro-batches.
if get_num_microbatches() > 1: if len(model) > 1:
assert args.DDP_impl == 'local'
if mpu.get_pipeline_model_parallel_world_size() > 1:
assert args.DDP_impl == 'local' assert args.DDP_impl == 'local'
# get model without FP16 and/or TorchDDP wrappers # get model without FP16 and/or TorchDDP wrappers
unwrapped_model = model if args.iteration == 0 and len(unwrapped_model) == 1 \
while hasattr(unwrapped_model, 'module'): and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'):
unwrapped_model = unwrapped_model.module print_rank_0("Initializing ICT from pretrained BERT model")
unwrapped_model[0].init_state_dict_from_bert()
if args.iteration == 0 and hasattr(unwrapped_model, if args.fp16:
'init_state_dict_from_bert'): optimizer.reload_model_params()
print("Initializing ICT from pretrained BERT model", flush=True)
unwrapped_model.init_state_dict_from_bert()
return model, optimizer, lr_scheduler return model, optimizer, lr_scheduler
def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward):
"""Communicate tensors between stages."""
args = get_args()
# Create placeholder tensors for receive in forward and backward directions
# if needed.
tensor_recv_prev = None
tensor_recv_next = None
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
dtype = args.params_dtype
if args.fp32_residual_connection:
dtype = torch.float
if recv_forward:
tensor_recv_prev = torch.empty(tensor_shape,
requires_grad=True,
device=torch.cuda.current_device(),
dtype=dtype)
if recv_backward:
tensor_recv_next = torch.empty(tensor_shape,
requires_grad=True,
device=torch.cuda.current_device(),
dtype=dtype)
# Send tensors in both the forward and backward directions as appropriate.
ops = []
if tensor_send_prev is not None:
send_prev_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_prev,
mpu.get_pipeline_model_parallel_prev_rank())
ops.append(send_prev_op)
if tensor_recv_prev is not None:
recv_prev_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_prev,
mpu.get_pipeline_model_parallel_prev_rank())
ops.append(recv_prev_op)
if tensor_send_next is not None:
send_next_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_next,
mpu.get_pipeline_model_parallel_next_rank())
ops.append(send_next_op)
if tensor_recv_next is not None:
recv_next_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_next,
mpu.get_pipeline_model_parallel_next_rank())
ops.append(recv_next_op)
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# Temporary workaround for batch_isend_irecv() race condition.
torch.cuda.synchronize()
return tensor_recv_prev, tensor_recv_next
def backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad):
"""Backward step."""
args = get_args()
timers = get_timers()
# Retain the grad on the input_tensor.
if input_tensor is not None:
input_tensor.retain_grad()
# Backward pass.
if output_tensor_grad is None:
output_tensor = optimizer.scale_loss(output_tensor)
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
# Collect the grad of the input_tensor.
input_tensor_grad = None
if input_tensor is not None:
input_tensor_grad = input_tensor.grad
return input_tensor_grad
def forward_step_with_communication(forward_step_func, data_iterator, model,
input_tensors, output_tensors,
losses_reduced, timers):
args = get_args()
if not mpu.is_pipeline_first_stage():
timers('forward-recv').start()
input_tensor, _ = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=True,
recv_backward=False)
timers('forward-recv').stop()
else:
input_tensor = None
# Forward model for one step.
timers('forward-compute').start()
output_tensor = forward_step_func(data_iterator, model, input_tensor)
timers('forward-compute').stop()
if mpu.is_pipeline_last_stage():
loss, loss_reduced = output_tensor
output_tensor = loss / get_num_microbatches()
losses_reduced.append(loss_reduced)
else:
timers('forward-send').start()
communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_forward=False,
recv_backward=False)
timers('forward-send').stop()
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
def backward_step_with_communication(optimizer, model, input_tensors, output_tensors, timers):
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
if mpu.is_pipeline_last_stage():
output_tensor_grad = None
else:
timers('backward-recv').start()
_, output_tensor_grad = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=False,
recv_backward=True)
timers('backward-recv').stop()
# Backward pass for one step.
timers('backward-compute').start()
input_grad_tensor = \
backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad)
timers('backward-compute').stop()
if not mpu.is_pipeline_first_stage():
timers('backward-send').start()
communicate(
tensor_send_next=None,
tensor_send_prev=input_grad_tensor,
recv_forward=False,
recv_backward=False)
timers('backward-send').stop()
def forward_and_backward_steps_with_communication(forward_step_func, data_iterator, model,
optimizer,
input_tensor, last_microbatch,
input_tensors, output_tensors,
losses_reduced, timers):
args = get_args()
# Forward model for one step.
timers('forward-compute').start()
output_tensor = forward_step_func(data_iterator, model, input_tensor)
timers('forward-compute').stop()
if mpu.is_pipeline_last_stage():
loss, loss_reduced = output_tensor
output_tensor = loss / get_num_microbatches()
output_tensor_grad = None
losses_reduced.append(loss_reduced)
else:
timers('forward-send-backward-recv').start()
_, output_tensor_grad = communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_forward=False,
recv_backward=True)
timers('forward-send-backward-recv').stop()
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
# Backward pass for one step.
timers('backward-compute').start()
input_grad_tensor = \
backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad)
timers('backward-compute').stop()
if not mpu.is_pipeline_first_stage():
timers('backward-send-forward-recv').start()
input_tensor, _ = communicate(
tensor_send_next=None,
tensor_send_prev=input_grad_tensor,
recv_forward=(not last_microbatch),
recv_backward=False)
timers('backward-send-forward-recv').stop()
else:
input_tensor = None
return input_tensor
def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
optimizer, timers):
"""Run forward and backward passes without inter-stage communication."""
args = get_args()
losses_reduced = []
for i in range(get_num_microbatches()):
timers('forward-compute').start()
loss, loss_reduced = forward_step_func(data_iterator, model, input_tensor=None)
output_tensor = loss / get_num_microbatches()
losses_reduced.append(loss_reduced)
timers('forward-compute').stop()
timers('backward-compute').start()
output_tensor_grad = None
backward_step(optimizer, model, input_tensor=None,
output_tensor=output_tensor, output_tensor_grad=None)
timers('backward-compute').stop()
return losses_reduced
def forward_backward_pipelining(forward_step_func, data_iterator, model,
optimizer, timers):
"""Run 1F1B schedule, with communication and warmup + cooldown microbatches as needed."""
args = get_args()
# Compute number of warmup microbatches.
num_microbatches = get_num_microbatches()
num_warmup_microbatches = \
(mpu.get_pipeline_model_parallel_world_size() -
mpu.get_pipeline_model_parallel_rank() - 1)
num_warmup_microbatches = min(
num_warmup_microbatches,
num_microbatches)
num_microbatches_remaining = \
num_microbatches - num_warmup_microbatches
input_tensors = []
output_tensors = []
losses_reduced = []
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
forward_step_with_communication(
forward_step_func, data_iterator, model,
input_tensors, output_tensors,
losses_reduced, timers)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_microbatches_remaining > 0:
if mpu.is_pipeline_first_stage():
input_tensor = None
else:
timers('forward-recv').start()
input_tensor, _ = communicate(tensor_send_next=None,
tensor_send_prev=None,
recv_forward=True,
recv_backward=False)
timers('forward-recv').stop()
# Run 1F1B.
for i in range(num_microbatches_remaining):
last_iteration = (i == (num_microbatches_remaining - 1))
input_tensor = \
forward_and_backward_steps_with_communication(forward_step_func, data_iterator, model,
optimizer,
input_tensor, last_iteration,
input_tensors, output_tensors,
losses_reduced, timers)
# Run cooldown backward passes.
for i in range(num_warmup_microbatches):
backward_step_with_communication(
optimizer, model, input_tensors, output_tensors, timers)
return losses_reduced
def train_step(forward_step_func, data_iterator, def train_step(forward_step_func, data_iterator,
model, optimizer, lr_scheduler): model, optimizer, lr_scheduler):
"""Single training step.""" """Single training step."""
...@@ -591,17 +334,25 @@ def train_step(forward_step_func, data_iterator, ...@@ -591,17 +334,25 @@ def train_step(forward_step_func, data_iterator,
optimizer.zero_grad() optimizer.zero_grad()
if mpu.get_pipeline_model_parallel_world_size() > 1: if mpu.get_pipeline_model_parallel_world_size() > 1:
losses_reduced = forward_backward_pipelining( if args.virtual_pipeline_model_parallel_size is not None:
forward_step_func, data_iterator, model, optimizer, timers) forward_backward_func = forward_backward_pipelining_with_interleaving
assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \
'number of microbatches is not divisible by pipeline-parallel ' \
'size when using interleaved schedule'
else:
forward_backward_func = forward_backward_pipelining_without_interleaving
else: else:
losses_reduced = forward_backward_no_pipelining( forward_backward_func = forward_backward_no_pipelining
forward_step_func, data_iterator, model, optimizer, timers) losses_reduced = forward_backward_func(
forward_step_func, data_iterator, model,
optimizer, timers, forward_only=False)
# All-reduce if needed. # All-reduce if needed.
if args.DDP_impl == 'local': if args.DDP_impl == 'local':
timers('backward-params-all-reduce').start() timers('backward-params-all-reduce').start()
model.allreduce_params(reduce_after=False, for model_module in model:
fp32_allreduce=args.fp32_allreduce) model_module.allreduce_params(reduce_after=False,
fp32_allreduce=args.fp32_allreduce)
timers('backward-params-all-reduce').stop() timers('backward-params-all-reduce').stop()
# All-reduce word_embeddings' grad across first and last stages to ensure # All-reduce word_embeddings' grad across first and last stages to ensure
...@@ -609,11 +360,15 @@ def train_step(forward_step_func, data_iterator, ...@@ -609,11 +360,15 @@ def train_step(forward_step_func, data_iterator,
# This should only run for models that support pipelined model parallelism # This should only run for models that support pipelined model parallelism
# (BERT and GPT-2). # (BERT and GPT-2).
timers('backward-embedding-all-reduce').start() timers('backward-embedding-all-reduce').start()
if (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) and \ if (mpu.is_pipeline_first_stage(ignore_virtual=True) or
mpu.is_pipeline_last_stage(ignore_virtual=True)) and \
mpu.get_pipeline_model_parallel_world_size() > 1: mpu.get_pipeline_model_parallel_world_size() > 1:
unwrapped_model = model if mpu.is_pipeline_first_stage(ignore_virtual=True):
while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16Module)): unwrapped_model = model[0]
unwrapped_model = unwrapped_model.module elif mpu.is_pipeline_last_stage(ignore_virtual=True):
unwrapped_model = model[-1]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, FP16Module))
if unwrapped_model.share_word_embeddings: if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight() word_embeddings_weight = unwrapped_model.word_embeddings_weight()
...@@ -623,11 +378,15 @@ def train_step(forward_step_func, data_iterator, ...@@ -623,11 +378,15 @@ def train_step(forward_step_func, data_iterator,
# Update parameters. # Update parameters.
timers('optimizer').start() timers('optimizer').start()
<<<<<<< HEAD
update_successfull, grad_norm, num_zeros = optimizer.step() update_successfull, grad_norm, num_zeros = optimizer.step()
=======
update_successful, grad_norm = optimizer.step()
>>>>>>> main
timers('optimizer').stop() timers('optimizer').stop()
# Update learning rate. # Update learning rate.
if update_successfull: if update_successful:
increment = get_num_microbatches() * \ increment = get_num_microbatches() * \
args.micro_batch_size * \ args.micro_batch_size * \
args.data_parallel_size args.data_parallel_size
...@@ -636,7 +395,7 @@ def train_step(forward_step_func, data_iterator, ...@@ -636,7 +395,7 @@ def train_step(forward_step_func, data_iterator,
else: else:
skipped_iter = 1 skipped_iter = 1
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage(ignore_virtual=True):
# Average loss across microbatches. # Average loss across microbatches.
loss_reduced = {} loss_reduced = {}
for key in losses_reduced[0]: for key in losses_reduced[0]:
...@@ -692,11 +451,12 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -692,11 +451,12 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
add_to_logging('forward-compute') add_to_logging('forward-compute')
add_to_logging('forward-recv') add_to_logging('forward-recv')
add_to_logging('forward-send') add_to_logging('forward-send')
add_to_logging('forward-send-backward-recv') add_to_logging('forward-backward-send-forward-backward-recv')
add_to_logging('backward-compute') add_to_logging('backward-compute')
add_to_logging('backward-recv') add_to_logging('backward-recv')
add_to_logging('backward-send') add_to_logging('backward-send')
add_to_logging('backward-send-forward-recv') add_to_logging('backward-send-forward-recv')
add_to_logging('backward-send-backward-recv')
add_to_logging('backward-params-all-reduce') add_to_logging('backward-params-all-reduce')
add_to_logging('backward-embedding-all-reduce') add_to_logging('backward-embedding-all-reduce')
add_to_logging('optimizer-copy-to-main-grad') add_to_logging('optimizer-copy-to-main-grad')
...@@ -749,7 +509,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -749,7 +509,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
normalizer=total_iterations) normalizer=total_iterations)
if iteration % args.log_interval == 0: if iteration % args.log_interval == 0:
elapsed_time = timers('interval time').elapsed() elapsed_time = timers('interval-time').elapsed()
elapsed_time_per_iteration = elapsed_time / total_iterations elapsed_time_per_iteration = elapsed_time / total_iterations
if writer and torch.distributed.get_rank() == 0: if writer and torch.distributed.get_rank() == 0:
if args.log_timers_to_tensorboard: if args.log_timers_to_tensorboard:
...@@ -800,11 +560,11 @@ def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler): ...@@ -800,11 +560,11 @@ def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
# Extra barrier is added to make sure # Extra barrier is added to make sure
# all ranks report the max time. # all ranks report the max time.
torch.distributed.barrier() torch.distributed.barrier()
timers('save checkpoint').start() timers('save-checkpoint').start()
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, lr_scheduler)
torch.distributed.barrier() torch.distributed.barrier()
timers('save checkpoint').stop() timers('save-checkpoint').stop()
timers.log(['save checkpoint']) timers.log(['save-checkpoint'])
def train(forward_step_func, model, optimizer, lr_scheduler, def train(forward_step_func, model, optimizer, lr_scheduler,
...@@ -817,7 +577,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -817,7 +577,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
write_args_to_tensorboard() write_args_to_tensorboard()
# Turn on training mode which enables dropout. # Turn on training mode which enables dropout.
model.train() for model_module in model:
model_module.train()
# Tracking loss. # Tracking loss.
total_loss_dict = {} total_loss_dict = {}
...@@ -825,7 +586,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -825,7 +586,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Iterations. # Iterations.
iteration = args.iteration iteration = args.iteration
timers('interval time').start() timers('interval-time').start()
print_datetime('before the start of training step') print_datetime('before the start of training step')
report_memory_flag = True report_memory_flag = True
while iteration < args.train_iters: while iteration < args.train_iters:
...@@ -906,7 +667,8 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False): ...@@ -906,7 +667,8 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
args = get_args() args = get_args()
# Turn on evaluation mode which disables dropout. # Turn on evaluation mode which disables dropout.
model.eval() for model_module in model:
model_module.eval()
total_loss_dict = {} total_loss_dict = {}
...@@ -918,37 +680,30 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False): ...@@ -918,37 +680,30 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
print_rank_0('Evaluating iter {}/{}'.format(iteration, print_rank_0('Evaluating iter {}/{}'.format(iteration,
args.eval_iters)) args.eval_iters))
for _ in range(get_num_microbatches()): if mpu.get_pipeline_model_parallel_world_size() > 1:
if not mpu.is_pipeline_first_stage(): if args.virtual_pipeline_model_parallel_size is not None:
input_tensor, _ = communicate( forward_backward_func = forward_backward_pipelining_with_interleaving
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=True,
recv_backward=False)
else: else:
input_tensor = None forward_backward_func = forward_backward_pipelining_without_interleaving
else:
# Forward evaluation. forward_backward_func = forward_backward_no_pipelining
output_tensor = forward_step_func(data_iterator, model, input_tensor) loss_dicts = forward_backward_func(
forward_step_func, data_iterator, model, optimizer=None,
if mpu.is_pipeline_last_stage(): timers=None, forward_only=True)
_, loss_dict = output_tensor
# Reduce across processes. if mpu.is_pipeline_last_stage(ignore_virtual=True):
# Reduce across processes.
for loss_dict in loss_dicts:
for key in loss_dict: for key in loss_dict:
total_loss_dict[key] = total_loss_dict.get(key, torch.cuda.FloatTensor([0.0])) + \ total_loss_dict[key] = total_loss_dict.get(
loss_dict[key] key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
else:
communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_forward=False,
recv_backward=False)
args.consumed_valid_samples += mpu.get_data_parallel_world_size() \ args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
* args.micro_batch_size \ * args.micro_batch_size \
* get_num_microbatches() * get_num_microbatches()
# Move model back to the train mode. # Move model back to the train mode.
model.train() for model_module in model:
model_module.train()
for key in total_loss_dict: for key in total_loss_dict:
total_loss_dict[key] /= args.eval_iters * get_num_microbatches() total_loss_dict[key] /= args.eval_iters * get_num_microbatches()
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import sys import sys
import torch import torch
from torch.nn.parallel import DistributedDataParallel as torchDDP
from apex.multi_tensor_apply import multi_tensor_applier from apex.multi_tensor_apply import multi_tensor_applier
import amp_C import amp_C
...@@ -26,11 +27,25 @@ from megatron import get_args ...@@ -26,11 +27,25 @@ from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_adlr_autoresume from megatron import get_adlr_autoresume
from megatron import mpu from megatron import mpu
from megatron.checkpointing import save_checkpoint
from megatron.model.module import param_is_not_shared from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
def unwrap_model(model, module_instances=(torchDDP)):
return_list = True
if not isinstance(model, list):
model = [model]
return_list = False
unwrapped_model = []
for model_module in model:
while isinstance(model_module, module_instances):
model_module = model_module.module
unwrapped_model.append(model_module)
if not return_list:
return unwrapped_model[0]
return unwrapped_model
def calc_params_l2_norm(model): def calc_params_l2_norm(model):
"""Calculate l2 norm of parameters """ """Calculate l2 norm of parameters """
# Remove duplicate params. # Remove duplicate params.
...@@ -106,6 +121,8 @@ def print_params_min_max_norm(optimizer, iteration): ...@@ -106,6 +121,8 @@ def print_params_min_max_norm(optimizer, iteration):
def check_adlr_autoresume_termination(iteration, model, def check_adlr_autoresume_termination(iteration, model,
optimizer, lr_scheduler): optimizer, lr_scheduler):
"""Check for autoresume signal and exit if it is received.""" """Check for autoresume signal and exit if it is received."""
from megatron.checkpointing import save_checkpoint
args = get_args() args = get_args()
autoresume = get_adlr_autoresume() autoresume = get_adlr_autoresume()
# Add barrier to ensure consistnecy. # Add barrier to ensure consistnecy.
......
...@@ -38,7 +38,7 @@ def model_provider(): ...@@ -38,7 +38,7 @@ def model_provider():
args = get_args() args = get_args()
num_tokentypes = 2 if args.bert_binary_head else 0 num_tokentypes = 2 if args.bert_binary_head else 0
if mpu.get_pipeline_model_parallel_world_size() > 1: def model_provider_pipelined():
# Determine model based on position of stage in pipeline. # Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
model = BertModelFirstStage( model = BertModelFirstStage(
...@@ -51,6 +51,17 @@ def model_provider(): ...@@ -51,6 +51,17 @@ def model_provider():
else: else:
model = BertModelIntermediateStage( model = BertModelIntermediateStage(
num_tokentypes=num_tokentypes) num_tokentypes=num_tokentypes)
return model
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None:
model = []
for i in range(args.virtual_pipeline_model_parallel_size):
mpu.set_virtual_pipeline_model_parallel_rank(i)
model.append(model_provider_pipelined())
else:
model = model_provider_pipelined()
else: else:
model = BertModel( model = BertModel(
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
...@@ -92,8 +103,8 @@ def forward_step(data_iterator, model, input_tensor): ...@@ -92,8 +103,8 @@ def forward_step(data_iterator, model, input_tensor):
# Get the batch. # Get the batch.
timers('batch-generator').start() timers('batch-generator').start()
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask \ tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch(
= get_batch(data_iterator) data_iterator)
timers('batch-generator').stop() timers('batch-generator').stop()
if not args.bert_binary_head: if not args.bert_binary_head:
......
...@@ -35,8 +35,8 @@ def model_provider(): ...@@ -35,8 +35,8 @@ def model_provider():
"""Build the model.""" """Build the model."""
print_rank_0('building GPT model ...') print_rank_0('building GPT model ...')
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1: def model_provider_pipelined():
# Determine model based on position of stage in pipeline. # Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
model = GPTModelFirstStage(num_tokentypes=0) model = GPTModelFirstStage(num_tokentypes=0)
...@@ -46,6 +46,17 @@ def model_provider(): ...@@ -46,6 +46,17 @@ def model_provider():
else: else:
model = GPTModelIntermediateStage( model = GPTModelIntermediateStage(
num_tokentypes=0) num_tokentypes=0)
return model
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None:
model = []
for i in range(args.virtual_pipeline_model_parallel_size):
mpu.set_virtual_pipeline_model_parallel_rank(i)
model.append(model_provider_pipelined())
else:
model = model_provider_pipelined()
else: else:
model = GPTModel(num_tokentypes=0, parallel_output=True) model = GPTModel(num_tokentypes=0, parallel_output=True)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
"""Pretrain BERT for Inverse Cloze Task""" """Pretrain BERT for Inverse Cloze Task"""
import math
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -23,17 +24,21 @@ from megatron import get_args ...@@ -23,17 +24,21 @@ from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron.data.biencoder_dataset_utils import get_ict_batch
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model.biencoder_model import biencoder_model_provider
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
from megatron.model.realm_model import general_ict_model_provider
from megatron.data.realm_dataset_utils import get_ict_batch
def pretrain_ict_model_provider(): def pretrain_ict_model_provider():
args = get_args() args = get_args()
return general_ict_model_provider(False, False) model = biencoder_model_provider(
only_context_model=False,
only_query_model=False,
biencoder_shared_query_context_model=\
args.biencoder_shared_query_context_model)
return model
def get_group_world_size_rank(): def get_group_world_size_rank():
...@@ -72,7 +77,6 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function): ...@@ -72,7 +77,6 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function):
output = output_list[rank].contiguous() output = output_list[rank].contiguous()
return output return output
def forward_step(data_iterator, model, input_tensor): def forward_step(data_iterator, model, input_tensor):
"""Forward step.""" """Forward step."""
args = get_args() args = get_args()
...@@ -80,37 +84,57 @@ def forward_step(data_iterator, model, input_tensor): ...@@ -80,37 +84,57 @@ def forward_step(data_iterator, model, input_tensor):
# Get the batch. # Get the batch.
timers('batch-generator').start() timers('batch-generator').start()
query_tokens, query_pad_mask, \ query_tokens, query_mask, \
block_tokens, block_pad_mask, block_indices = get_ict_batch(data_iterator) context_tokens, context_mask, context_indices = get_ict_batch(data_iterator)
timers('batch-generator').stop() timers('batch-generator').stop()
# Query and Context Types
query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
context_types = torch.cuda.LongTensor(*context_tokens.shape).fill_(0)
# Forward model. # Forward model.
query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask) query_logits, context_logits = model(query_tokens, query_mask,
query_types, context_tokens,
context_mask, context_types)
micro_batch_size = query_logits.shape[0] micro_batch_size = query_logits.shape[0]
global_batch_size = dist.get_world_size() * micro_batch_size # recall we assert that tensor_model_parallel_size == 1 # recall we assert that tensor_model_parallel_size == 1
assert mpu.get_tensor_model_parallel_world_size() == 1, \
"Model parallel size > 1 not supported for ICT"
global_batch_size = dist.get_world_size() * micro_batch_size
all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits) all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
all_block_logits = AllgatherFromDataParallelRegion.apply(block_logits) all_context_logits = AllgatherFromDataParallelRegion.apply(context_logits)
# scores are inner products between query and context embeddings
retrieval_scores = torch.matmul(all_query_logits,
torch.transpose(all_context_logits, 0, 1))
# scaling the retriever scores
if args.retriever_score_scaling:
retrieval_scores = retrieval_scores / math.sqrt(args.hidden_size)
# scores are inner products between query and block embeddings softmax_scores = F.log_softmax(retrieval_scores, dim=1)
retrieval_scores = all_query_logits.float().matmul(torch.transpose(all_block_logits, 0, 1).float()) sorted_vals, sorted_indices = torch.topk(softmax_scores,
softmaxed = F.softmax(retrieval_scores, dim=1) k=softmax_scores.shape[1], sorted=True)
sorted_vals, sorted_indices = torch.topk(softmaxed, k=softmaxed.shape[1], sorted=True)
def topk_accuracy(k): def topk_accuracy(k):
return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) for i in range(global_batch_size)]) / global_batch_size]) return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) \
for i in range(global_batch_size)]) / global_batch_size])
topk_accs = [topk_accuracy(int(k)) for k in args.report_topk_accuracies] topk_accs = [topk_accuracy(int(k)) for k in args.retriever_report_topk_accuracies]
retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda())
retrieval_loss = retrieval_loss.float()
averaged_losses = average_losses_across_data_parallel_group([retrieval_loss, *topk_accs])
# create stats_dict with retrieval loss and all specified top-k accuracies labels = torch.arange(global_batch_size).long().cuda()
topk_acc_dict = {'top{}_acc'.format(k): v for k, v in zip(args.report_topk_accuracies, averaged_losses[1:])} loss = F.nll_loss(softmax_scores, labels, reduction='mean')
stats_dict = dict(retrieval_loss=averaged_losses[0], **topk_acc_dict) reduced_losses = average_losses_across_data_parallel_group([loss, *topk_accs])
# Scale the retrieval loss
loss = loss * mpu.get_data_parallel_world_size()
return retrieval_loss, stats_dict # create stats_dict with retrieval loss and all specified top-k accuracies
topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \
zip(args.retriever_report_topk_accuracies, reduced_losses[1:])}
stats_dict = dict(loss=reduced_losses[0], **topk_acc_dict)
return loss, stats_dict
def train_valid_test_datasets_provider(train_val_test_num_samples): def train_valid_test_datasets_provider(train_val_test_num_samples):
...@@ -129,6 +153,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -129,6 +153,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
short_seq_prob=args.short_seq_prob, short_seq_prob=args.short_seq_prob,
seed=args.seed, seed=args.seed,
skip_warmup=(not args.mmap_warmup), skip_warmup=(not args.mmap_warmup),
binary_head=False,
dataset_type='ict') dataset_type='ict')
print_rank_0("> finished creating BERT ICT datasets ...") print_rank_0("> finished creating BERT ICT datasets ...")
...@@ -136,5 +161,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -136,5 +161,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if __name__ == "__main__": if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, pretrain_ict_model_provider, forward_step, pretrain(train_valid_test_datasets_provider,
pretrain_ict_model_provider,
forward_step,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
import os
import sys import sys
sys.path.append('../') sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
from megatron import print_rank_0
from megatron.indexer import IndexBuilder from megatron.indexer import IndexBuilder
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
...@@ -23,7 +26,7 @@ def main(): ...@@ -23,7 +26,7 @@ def main():
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
index_builder = IndexBuilder() index_builder = IndexBuilder()
index_builder.build_and_save_index() index_builder.build_and_save_index()
print_rank_0("Build and save indices: done!")
if __name__ == "__main__": if __name__ == "__main__":
main() main()
......
...@@ -26,9 +26,9 @@ python blacklist_urls.py <path to the dowloaded deduplicated URLs> <filename for ...@@ -26,9 +26,9 @@ python blacklist_urls.py <path to the dowloaded deduplicated URLs> <filename for
``` ```
python cleanup_dataset.py <input data file> <output cleaned data filename> python cleanup_dataset.py <input data file> <output cleaned data filename>
``` ```
2. Using LSH, find possible duplicates and store then in a file for later processing. This step can NOT be sharded and usually takes 12 to 24 hours for OpenWebText dataset. 2. Using LSH, find possible duplicates and store then in a file for later processing. This step can NOT be sharded and usually takes 12 to 24 hours for OpenWebText dataset. The code supports saving and loading fingerprints for recurrent deduplications.
``` ```
python find_duplicates.py <input cleaned data file> <output possible duplicate urls filename> python find_duplicates.py --inputs <pairlist list of input cleaned data files and keys, e.g. cc.json cc_id news.json news_id> --output <output possible duplicate urls filename>
``` ```
3. Based on similarity measure defind inside function `is_similar` (default: 0.9), group urls that are similar. Basically, for each group, only one url we should keep and remove the rest. 3. Based on similarity measure defind inside function `is_similar` (default: 0.9), group urls that are similar. Basically, for each group, only one url we should keep and remove the rest.
``` ```
...@@ -44,3 +44,12 @@ python remove_group_duplicates.py <file containing simialr documents> <cleaned d ...@@ -44,3 +44,12 @@ python remove_group_duplicates.py <file containing simialr documents> <cleaned d
shuf <cleaned deduped data file> -o train_data.json shuf <cleaned deduped data file> -o train_data.json
``` ```
# Deduplicating ngrams
To deduplicate the downstream tasks from the training dataset, we run the following command.
```
python filter_ngrams.py <down stream task dataset> <training dataset to deduplicate> <output training dataset>
```
We use 13-grams for the deduplication. When we find a 13-gram match in a training document, we split the document into two pieces and remove the 13-gram along with 200 characters from the both side of the 13-gram. We also remove any splitted document with less than 200 characters or if a document got splitted more than 10 times.
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Deduplicate downstream tasks from training dataset. 13-grams have been used.
All split documents with less than 200 characters got filtered. Any document
with more than 10 splits got filtered as well.
"""
from functools import partial
import json
import multiprocessing
import nltk
import re
import string
import sys
import time
def get_words(text):
# get all the lowercase words from text
words, positions = [], []
for match in re.finditer(r'\w+', text.lower()):
words.append(match.group(0))
positions.append(match.start())
return words, positions
def free_ngram(line, ngrams, ngram_size, filter_text_len,
splits_count, split_window_each_size):
# remove all the ngrams
try:
myjson = json.loads(line)
text_buf = [myjson['text']]
except Exception as e:
print("Error: {}".format(e), flush=True)
text_buf = []
text_buf_ngram_free = []
while len(text_buf) > 0:
# get the first one from the buffer
text = text_buf.pop(0)
words, positions = get_words(text)
not_ngram_free = True
punctuations = ".!?"
# find n-grams
for i in range(len(words) - ngram_size + 1):
seq = " ".join(words[i:i+ngram_size])
if seq in ngrams:
# splits the text
# first part of the text
pos = positions[i] - split_window_each_size
text_first = ""
while pos > 0 and not text[pos] in punctuations:
pos -= 1
if pos > 0:
text_first = text[0:pos+1]
pos = positions[i] + split_window_each_size
# last part of the text
text_second = ""
while pos < len(text) and not text[pos] in punctuations:
pos += 1
if pos + 1 < len(text):
text_second = text[pos+1:len(text)]
# first part of ngrams free
if len(text_first) > filter_text_len:
text_buf_ngram_free.append(text_first)
# add second part for further processing
if len(text_second) > filter_text_len:
text_buf.append(text_second)
not_ngram_free = False
break
# text are ngram free
if not_ngram_free:
text_buf_ngram_free.append(text)
return text_buf_ngram_free
if __name__ == '__main__':
print('finding possible duplicate content ...')
main_file = sys.argv[1] # lambada file
dedup_file = sys.argv[2] # Book corpus
output_file = sys.argv[3] #Filtered book corpus
ngrams = {}
id_prefix = "lambada"
# we use 13-grams, any text less than 200 characters got removed
# any text splitted more than 10 got removed as well
ngram_size = 13
filter_text_len = 200
splits_count = 10
split_window_each_size = 200
print('Reading file {} and computing ngrams'.format(main_file))
with open(main_file, 'r') as f:
for line in f:
try:
myjson = json.loads(line)
words, positions = get_words(myjson['text'])
for i in range(len(words) - ngram_size+1):
seq = " ".join(words[i:i+ngram_size])
if seq not in ngrams:
ngrams[seq] = positions[i]
except Exception as e:
print('Error:', e)
print("ngrams size {}".format(len(ngrams)))
print('Reading file {} and deduping n-grams'.format(dedup_file))
counter = 0
start_time = time.time()
out_f = open(output_file, 'wb')
splitted, ignored, split_mt_thld = 0, 0, 0
# Setup multi-processing.
num_workers = 40
fin = open(dedup_file, 'r', encoding='utf-8')
pool = multiprocessing.Pool(num_workers)
free_ngram_x=partial(free_ngram, ngrams=ngrams, ngram_size=ngram_size,
filter_text_len=filter_text_len, splits_count=splits_count,
split_window_each_size=split_window_each_size)
free_ngrams = pool.imap(free_ngram_x, fin, 25)
for text_buf_ngram_free in free_ngrams:
counter += 1
try:
if len(text_buf_ngram_free) > 1:
splitted += (len(text_buf_ngram_free) - 1)
if len(text_buf_ngram_free) == 0:
ignored += 1
# more than 10 splits ignored
if len(text_buf_ngram_free) > splits_count:
text_buf_ngram_free = []
split_mt_thld += 1
for i in range(len(text_buf_ngram_free)):
split_id_string = id_prefix + '-{:010d}'.format(int(counter)) \
+ '-{:010d}'.format(int(i))
outjson = json.dumps({"text":text_buf_ngram_free[i],
id_prefix+"_split_id":split_id_string},
ensure_ascii=False)
out_f.write(outjson.encode('utf-8'))
out_f.write('\n'.encode('utf-8'))
if counter % 1000 == 0:
print(' [search]> processed {} documents in {:.2f} seconds ...'.
format(counter, time.time() - start_time), flush=True)
except Exception as e:
print('Error:', e)
print("Deduped file written to: {}".format(output_file), flush=True)
print("Total docs {} splitted {} ignored {} docs with many splits {}".\
format(counter, splitted, ignored, split_mt_thld), flush=True)
print('done :-)')
...@@ -13,14 +13,15 @@ ...@@ -13,14 +13,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse
import itertools import itertools
import json import json
from lsh import cache, minhash from lsh import cache, minhash
import numpy as np
import time import time
import pickle
import sys import sys
# This function is adapted from: # This function is adapted from:
# https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb # https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
def shingles(text, char_ngram=5): def shingles(text, char_ngram=5):
...@@ -38,36 +39,98 @@ def jaccard(set_a, set_b): ...@@ -38,36 +39,98 @@ def jaccard(set_a, set_b):
if __name__ == '__main__': if __name__ == '__main__':
print('parsing the arguments ...')
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=1234,
help='Random seed used for python, numpy')
parser.add_argument('--inputs', nargs = '*', default=None, help = \
'Pairwise list of the input files and keys, '
'e.g. --inputs cc.json cc_id news.json news_id')
parser.add_argument('--load-fingerprints', nargs = '*', default=None,
help='Load fingerprints from a list of pickle files,'
' e.g. cc.pkl news.pkl')
parser.add_argument('--save-fingerprints', type=str, default=None,
help='Save the fingerprints of the inputs.')
parser.add_argument('--output', type=str, default=None,
help='Output file name that consists of all ids'
' with matching similarities')
args = parser.parse_args()
print('finding possible duplicate content ...') print('finding possible duplicate content ...')
input = sys.argv[1] # set seed and get an array of seeds of 100 integers
output = sys.argv[2] np.random.seed(args.seed)
seeds = np.random.randint(0, 1e6, size=100)
hasher = minhash.MinHasher(seeds=100, char_ngram=5, hashbytes=4) # initialize minhash and lsh cache
hasher = minhash.MinHasher(seeds=seeds, char_ngram=5, hashbytes=4)
lshcache = cache.Cache(bands=10, hasher=hasher) lshcache = cache.Cache(bands=10, hasher=hasher)
counter = 0
url_doc = {} url_doc = {}
# load fingerprints from pickle file if needed
if args.load_fingerprints is not None:
for count_fp, fp_file_name in enumerate(args.load_fingerprints):
print("Loading fingerprints from pickle file {}".format(
fp_file_name), flush=True)
fp = open(fp_file_name, "rb")
if count_fp == 0:
# assign directory for the first pkl
lshcache = pickle.load(fp)
url_doc = pickle.load(fp)
else:
# append these to lshcache and url_doc
local_lshcache = pickle.load(fp)
local_url_doc = pickle.load(fp)
for url in local_lshcache.fingerprints.keys():
url_doc[url] = local_url_doc[url]
lshcache.add_fingerprint(local_lshcache.fingerprints[url], url)
fp.close()
counter = 0
start_time = time.time() start_time = time.time()
with open(input, 'r') as f:
for line in f: print("Computing fingerprints", flush=True)
try:
myjson = json.loads(line) # compute finger prints of the inputs if any
url = myjson['url'] # input file and the key to use as id
text = myjson['text'] if args.inputs is not None:
counter += 1 assert len(args.inputs) % 2 == 0
url_doc[url] = text for input_file, key in zip(args.inputs[::2], args.inputs[1::2]):
lshcache.add_fingerprint(hasher.fingerprint(text), url) print(' document processing {} with key {}'.format(input_file, key),
except Exception as e: flush=True)
print('Error:', e) # traverse all the texts and add fingerprints
if counter % 10000 == 0: with open(input_file, 'r') as f_input:
print(' [read]> processed {} documents in {:.2f} seconds ...'. for line in f_input:
format(counter, time.time() - start_time), flush=True) try:
myjson = json.loads(line)
url = myjson[key]
text = myjson['text']
counter += 1
url_doc[url] = text
lshcache.add_fingerprint(hasher.fingerprint(text), url)
except Exception as e:
print('Error:', e)
if counter % 10000 == 0:
print(' [read]> processed {} documents in {:.2f} '
'seconds ...'.format(counter, time.time() - \
start_time), flush=True)
# Save the fingerprints if needed
if args.save_fingerprints is not None:
print("Saving fingerprints to pickle file {}".format(
args.save_fingerprints), flush=True)
with open(args.save_fingerprints, 'wb') as f_save:
pickle.dump(lshcache, f_save)
pickle.dump(url_doc, f_save)
counter = 0 counter = 0
start_time = time.time() start_time = time.time()
deduped = 0 deduped = 0
with open(output, 'wb') as f: # compute jaccard index of the input texts and write to file if needed
if args.output is not None:
f_out = open(args.output, 'wb')
for b in lshcache.bins: for b in lshcache.bins:
for bucket_id in b: for bucket_id in b:
if len(b[bucket_id]) > 1: if len(b[bucket_id]) > 1:
...@@ -94,7 +157,8 @@ if __name__ == '__main__': ...@@ -94,7 +157,8 @@ if __name__ == '__main__':
if len(remove_urls) > 0: if len(remove_urls) > 0:
myjson = json.dumps({main_url: remove_urls}, myjson = json.dumps({main_url: remove_urls},
ensure_ascii=False) ensure_ascii=False)
f.write(myjson.encode('utf-8')) f_out.write(myjson.encode('utf-8'))
f.write('\n'.encode('utf-8')) f_out.write('\n'.encode('utf-8'))
f_out.close()
print('done :-)') print('done :-)')
...@@ -19,7 +19,7 @@ import sys ...@@ -19,7 +19,7 @@ import sys
def is_similar(jaccard_similarity): def is_similar(jaccard_similarity):
return (js >= 0.9) return (jaccard_similarity >= 0.7)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment