Commit abe36e2e authored by Raul Puri's avatar Raul Puri
Browse files

large update including model parallelism and gpt2


Co-authored-by: default avatarshoeybi <shoeybim@gmail.com>
Co-authored-by: default avatarraulpuric <raulpuric@berkeley.edu>
Co-authored-by: default avatarjaredcasper <jaredcasper@gmail.com>
Co-authored-by: default avatarmpatwary <mostofa.patwary@gmail.com>
Co-authored-by: default avatarplegresl <plegresl@gmail.com>
parent 0399d32c
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import sys
sys.path.append("../..")
import torch
import torch.nn.init as init
from torch.nn.parameter import Parameter
import mpu
from commons import initialize_distributed
from commons import print_separator
from commons import set_random_seed
from mpu import layers
def test_parallel_embedding(model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing parallel embedding with model parallel size {} ...'.
format(model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
batch_size = 17
seq_length = 23
vocab_size = 48
hidden_size = 16
seed = 1236
set_random_seed(123)
input_data = torch.LongTensor(
size=(batch_size,seq_length)).random_(0, vocab_size).cuda()
loss_weight = torch.randn([batch_size, seq_length, hidden_size]).cuda()
set_random_seed(seed)
embedding_original = torch.nn.Embedding(vocab_size, hidden_size).cuda()
output = embedding_original(input_data)
loss_original = torch.mul(output, loss_weight).sum()
loss_original.backward()
set_random_seed(seed)
embedding_parallel = layers.ParallelEmbedding(
vocab_size, hidden_size, init_method=init.normal_).cuda()
output = embedding_parallel(input_data)
loss_parallel = torch.mul(output, loss_weight).sum()
loss_parallel.backward()
set_random_seed(seed)
embedding_vocab_parallel = layers.VocabParallelEmbedding(
vocab_size, hidden_size, init_method=init.normal_).cuda()
output = embedding_vocab_parallel(input_data)
loss_vocab_parallel = torch.mul(output, loss_weight).sum()
loss_vocab_parallel.backward()
torch.distributed.barrier()
error = loss_parallel.sub(loss_original).abs()
print(' error in loss (parallel) on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-12, 'error: {}'.format(error)
torch.distributed.barrier()
error = loss_vocab_parallel.sub(loss_original).abs()
print(' error in loss (vocab parallel) on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-12, 'error: {}'.format(error)
weight_grad_orig = torch.split(embedding_original.weight.grad,
hidden_size // model_parallel_size,
1)[mpu.get_model_parallel_rank()]
error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max()
print(' error in grad (parallel) on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-12, 'error: {}'.format(error)
weight_grad_orig = torch.split(embedding_original.weight.grad,
vocab_size // model_parallel_size,
0)[mpu.get_model_parallel_rank()]
error = embedding_vocab_parallel.weight.grad.sub(
weight_grad_orig).abs().max()
print(' error in grad (vocab parallel) on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-12, 'error: {}'.format(error)
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print('>> passed the test :-)')
def test_initialize_affine_weight(model_parallel_size):
mpu.initialize_model_parallel(model_parallel_size)
if torch.distributed.get_rank() == 0:
print('> testing initialize_affine_weight with model parallel '
'size: {}'.format(model_parallel_size))
model_parallel_size = mpu.get_model_parallel_world_size()
seed = 12345
input_size_coeff = 13
input_size = input_size_coeff * model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * model_parallel_size
# ---------------
# Column parallel
# ---------------
weight = torch.empty(output_size_coeff, input_size)
set_random_seed(seed)
layers._initialize_affine_weight(weight, output_size, input_size,
output_size_coeff, 0,
torch.nn.init.normal_)
# Target.
set_random_seed(seed)
master_weight = torch.empty(output_size, input_size)
torch.nn.init.normal_(master_weight)
rank = mpu.get_model_parallel_rank()
my_weight = torch.split(master_weight, output_size_coeff,
dim=0)[rank].contiguous().clone()
# Compare.
error = weight.sub(my_weight).abs().max()
torch.distributed.barrier()
print(' column parallel max error (should be zero) on global rank '
'{}: {}'.format(torch.distributed.get_rank(), error))
assert error < 1.0e-6
# ------------
# Row parallel
# ------------
weight = torch.empty(output_size, input_size_coeff)
set_random_seed(seed)
mpu.layers._initialize_affine_weight(weight, output_size, input_size,
input_size_coeff, 1,
torch.nn.init.normal_)
# Target.
set_random_seed(seed)
master_weight = torch.empty(output_size, input_size)
torch.nn.init.normal_(master_weight)
rank = mpu.get_model_parallel_rank()
my_weight = torch.split(master_weight, input_size_coeff,
dim=1)[rank].contiguous().clone()
# Compare.
error = weight.sub(my_weight).abs().max()
torch.distributed.barrier()
print(' row parallel max error (should be zero) on global rank '
'{}: {}'.format(torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
class IdentityLayer2D(torch.nn.Module):
def __init__(self, m , n):
super(IdentityLayer2D, self).__init__()
self.weight = Parameter(torch.Tensor(m, n))
torch.nn.init.xavier_normal_(self.weight)
def forward(self):
return self.weight
def test_column_parallel_linear(model_parallel_size):
mpu.initialize_model_parallel(model_parallel_size)
if torch.distributed.get_rank() == 0:
print('> testing ColumnParallelLinear with model parallel '
'size: {}'.format(model_parallel_size))
model_parallel_size = mpu.get_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
input_size_coeff = 13
input_size = input_size_coeff * model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * model_parallel_size
batch_size = 7
# Network
identity_layer = IdentityLayer2D(batch_size, input_size).cuda()
linear_layer = mpu.ColumnParallelLinear(
input_size, output_size, keep_master_weight_for_test=True).cuda()
loss_weight = torch.randn([batch_size, output_size]).cuda()
# Forward
input_ = identity_layer()
output = linear_layer(input_)
loss = torch.mul(output, loss_weight).sum()
# Backward
loss.backward()
# Values.
dLdY = loss_weight
X = identity_layer.weight
A = linear_layer.master_weight.cuda()
dLdA = torch.matmul(dLdY.t(), X)
dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
dLdX = torch.matmul(dLdY, A)
rank = mpu.get_model_parallel_rank()
my_dLdA = torch.split(dLdA, output_size_coeff,
dim=0)[rank].contiguous().clone()
error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdA on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
my_dLdb = torch.split(dLdb, output_size_coeff,
dim=0)[rank].contiguous().clone()
error = my_dLdb.sub(linear_layer.bias.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdb on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
error = dLdX.sub(identity_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdX on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
def test_row_parallel_linear(model_parallel_size):
mpu.initialize_model_parallel(model_parallel_size)
if torch.distributed.get_rank() == 0:
print('> testing RowParallelLinear with model parallel '
'size: {}'.format(model_parallel_size))
model_parallel_size = mpu.get_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
input_size_coeff = 13
input_size = input_size_coeff * model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * model_parallel_size
batch_size = 7
# Network
identity_layer = IdentityLayer2D(batch_size, input_size).cuda()
linear_layer = mpu.RowParallelLinear(
input_size, output_size, keep_master_weight_for_test=True).cuda()
loss_weight = torch.randn([batch_size, output_size]).cuda()
# Forward
input_ = identity_layer()
output = linear_layer(input_)
loss = torch.mul(output, loss_weight).sum()
# Backward
loss.backward()
# Values.
dLdY = loss_weight
X = identity_layer.weight
A = linear_layer.master_weight.cuda()
dLdA = torch.matmul(dLdY.t(), X)
dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
dLdX = torch.matmul(dLdY, A)
rank = mpu.get_model_parallel_rank()
my_dLdA = torch.split(dLdA, input_size_coeff,
dim=1)[rank].contiguous().clone()
error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdA on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
error = dLdb.sub(linear_layer.bias.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdb on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
error = dLdX.sub(identity_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdX on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
class IdentityLayer3D(torch.nn.Module):
def __init__(self, m , n, k):
super(IdentityLayer3D, self).__init__()
self.weight = Parameter(torch.Tensor(m, n, k))
torch.nn.init.xavier_normal_(self.weight)
def forward(self):
return self.weight
def parallel_self_attention(model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size,
sequence_length):
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
num_att_heads = num_att_heads_per_partition * \
torch.distributed.get_world_size()
hidden_size = hidden_size_per_att_head * num_att_heads
# Network
identity_layer = IdentityLayer3D(batch_size, sequence_length,
hidden_size).cuda()
attention_layer = mpu.BertParallelSelfAttention(hidden_size, num_att_heads,
dropout_prob).cuda()
loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda()
attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda()
# Forward
input_ = identity_layer()
output = attention_layer(input_, attention_mask)
loss = torch.mul(output, loss_weight).sum()
# Backward
loss.backward()
rank = mpu.get_model_parallel_rank()
mpu.destroy_model_parallel()
return rank, hidden_size, model_parallel_size, loss, \
attention_layer, identity_layer
def test_parallel_self_attention(model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing ParallelSelfAttention with model parallel '
'size: {}'.format(model_parallel_size))
num_att_heads_per_partition = 3
hidden_size_per_att_head = 7
dropout_prob = 0.0 # has to be zero
batch_size = 5
sequence_length = 13
rank_1, hideen_size_1, model_parallel_size_1, loss_1, \
attention_layer_1, identity_layer_1 =parallel_self_attention(
1, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
rank, hidden_size, model_parallel_size, loss, \
attention_layer, identity_layer =parallel_self_attention(
model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
assert hideen_size_1 == hidden_size
error = loss_1.sub(loss).abs().max()
torch.distributed.barrier()
print(' loss error on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 5.0e-6
my_lin_grad_list = torch.split(
attention_layer_1.query_key_value.weight.grad,
hidden_size // model_parallel_size, 0)[rank::model_parallel_size]
my_lin_grad = torch.cat(my_lin_grad_list, dim=0)
error = my_lin_grad.sub(
attention_layer.query_key_value.weight.grad).abs().max()
torch.distributed.barrier()
print(' weight gradient error on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 5.0e-6
error = identity_layer_1.weight.grad.sub(
identity_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' input gradient error on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 5.0e-6
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
def parallel_transformer(model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length):
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
num_att_heads = num_att_heads_per_partition * \
torch.distributed.get_world_size()
hidden_size = hidden_size_per_att_head * num_att_heads
intermediate_size = 4 * hidden_size
# Network
identity_layer = IdentityLayer3D(batch_size, sequence_length,
hidden_size).cuda()
transformer_layer = mpu.BertParallelTransformerLayer(
hidden_size, intermediate_size, num_att_heads, 0.0, 0.0,
torch.nn.functional.relu, 1.0e-5).cuda()
loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda()
attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda()
# Forward
input_ = identity_layer()
output = transformer_layer(input_, attention_mask)
loss = torch.mul(output, loss_weight).sum()
# Backward
loss.backward()
rank = mpu.get_model_parallel_rank()
mpu.destroy_model_parallel()
return rank, hidden_size, model_parallel_size, loss, \
transformer_layer, identity_layer
def test_parallel_transformer_layer(model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing ParallelTransformerLayer with model parallel '
'size: {}'.format(model_parallel_size))
num_att_heads_per_partition = 3
hidden_size_per_att_head = 7
batch_size = 5
sequence_length = 13
rank_1, hidden_size_1, model_parallel_size_1, loss_1, \
transformer_layer_1, identity_layer_1 = parallel_transformer(
1, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length)
rank, hidden_size, model_parallel_size, loss, \
transformer_layer, identity_layer = parallel_transformer(
model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length)
error = loss_1.sub(loss).abs().max()
torch.distributed.barrier()
print(' loss error on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 5.0e-5, 'error: {}'.format(error)
error = identity_layer_1.weight.grad.sub(
identity_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' input gradient error on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 5.0e-5, 'error: {}'.format(error)
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
if __name__ == '__main__':
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
initialize_distributed()
world_size = torch.distributed.get_world_size()
print_separator('test initialize affine weight')
model_parallel_size = 1
while model_parallel_size <= world_size:
test_initialize_affine_weight(model_parallel_size)
model_parallel_size *= 2
model_parallel_size = 1
while model_parallel_size <= world_size:
print_separator('test parallel embedding')
test_parallel_embedding(model_parallel_size)
model_parallel_size *= 2
print_separator('test column-parallel linear')
model_parallel_size = 1
while model_parallel_size <= world_size:
test_column_parallel_linear(model_parallel_size)
model_parallel_size *= 2
print_separator('test row-parallel linear')
model_parallel_size = 1
while model_parallel_size <= world_size:
test_row_parallel_linear(model_parallel_size)
model_parallel_size *= 2
print_separator('test parallel self-attention')
model_parallel_size = 1
while model_parallel_size <= world_size:
test_parallel_self_attention(model_parallel_size)
model_parallel_size *= 2
print_separator('test parallel transformer')
model_parallel_size = 1
while model_parallel_size <= world_size:
test_parallel_transformer_layer(model_parallel_size)
model_parallel_size *= 2
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../..")
import torch
import mpu
from commons import initialize_distributed
from commons import print_separator
def test_set_cuda_rng_state(model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing set_rng_state with size {} ...'.
format(model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
size = 123
seed = 1234
torch.cuda.manual_seed(1234)
tensor = torch.cuda.FloatTensor(size)
# Get the state
rng_state = torch.cuda.get_rng_state()
rng_state_copy = rng_state.clone()
# Do some stuff.
for _ in range(5):
torch.randn(size, out=tensor)
result_1 = tensor.clone()
assert rng_state.sub(rng_state_copy).max() == 0
assert torch.cuda.get_rng_state().sub(rng_state_copy).max() > 0
# State should be different.
new_rng_state = torch.cuda.get_rng_state()
max_diff = new_rng_state.sub(rng_state).max()
print(' max diff in rng state (should be non-zero) on global rank {}: {}'.
format(torch.distributed.get_rank(), max_diff))
assert max_diff > 0
# Reset the rng state and do the same stuff.
mpu.random._set_cuda_rng_state(rng_state)
for _ in range(5):
torch.randn(size, out=tensor)
mpu.random._set_cuda_rng_state(rng_state)
for _ in range(5):
torch.randn(size, out=tensor)
result_2 = tensor.clone()
# Results should be the same
error = result_2.sub(result_1).abs().max()
print(' max error in generated tensors (should be zero) on '
'global rank {}: {}'.format(torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Input state should have remained intact.
error = rng_state.sub(rng_state_copy).max()
print(' max error in rng state (should be zero) on global rank {}: {}'.
format(torch.distributed.get_rank(), error))
assert error == 0
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print('>> passed the test :-)')
def test_cuda_rng_tracker(model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing cuda rng tracker with size {} ...'.
format(model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
seed_1 = 1234
seed_2 = 4321
size = [12, 21]
tensor = torch.cuda.FloatTensor(size)
# Set to seed_1 and generate two tensors.
torch.cuda.manual_seed(seed_1)
torch.randn(size, out=tensor)
target_11 = tensor.clone()
torch.randn(size, out=tensor)
target_12 = tensor.clone()
# Set to seed_2 and generate two tensors.
torch.cuda.manual_seed(seed_2)
torch.randn(size, out=tensor)
target_21 = tensor.clone()
torch.randn(size, out=tensor)
target_22 = tensor.clone()
# Now if we interleave seed_1 and seed_2,
# we should still get the same tensors
torch.cuda.manual_seed(seed_1)
mpu.get_cuda_rng_tracker().add('test', seed_2)
torch.randn(size, out=tensor)
result_11 = tensor.clone()
with mpu.get_cuda_rng_tracker().fork('test'):
torch.randn(size, out=tensor)
result_21 = tensor.clone()
torch.randn(size, out=tensor)
result_12 = tensor.clone()
with mpu.get_cuda_rng_tracker().fork('test'):
torch.randn(size, out=tensor)
result_22 = tensor.clone()
diff = result_11.sub(result_21).abs().max()
diff = min(diff, result_12.sub(result_22).abs().max())
print(' max diff in generated tensors (should be non-zero) on '
'global rank {}: {}'.format(torch.distributed.get_rank(), diff))
assert diff > 1.0e-6
error = max(result_11.sub(target_11).abs().max(),
result_12.sub(target_12).abs().max())
error = max(error, result_21.sub(target_21).abs().max())
error = max(error, result_22.sub(target_22).abs().max())
print(' max error in generated tensors (should be zero) on '
'global rank {}: {}'.format(torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Reset the tracker
mpu.get_cuda_rng_tracker().reset()
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print('>> passed the test :-)')
def test_model_parallel_cuda_manual_seed(model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing model parallel cuda manual seed with size {} ...'.
format(model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
mpu.model_parallel_cuda_manual_seed(12345)
assert torch.cuda.initial_seed() == 12345
with mpu.get_cuda_rng_tracker().fork():
assert torch.cuda.initial_seed() == (12345 + 2718 +
mpu.get_model_parallel_rank())
# Reset the tracker
mpu.get_cuda_rng_tracker().reset()
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print('>> passed the test :-)')
if __name__ == '__main__':
initialize_distributed()
world_size = torch.distributed.get_world_size()
model_parallel_size = 1
while model_parallel_size <= world_size:
print_separator('test set rng state')
test_set_cuda_rng_state(model_parallel_size)
model_parallel_size *= 2
model_parallel_size = 1
while model_parallel_size <= world_size:
print_separator('test cuda rng tracker')
test_cuda_rng_tracker(model_parallel_size)
model_parallel_size *= 2
model_parallel_size = 1
while model_parallel_size <= world_size:
print_separator('test model parallel cuda manual seed')
test_model_parallel_cuda_manual_seed(model_parallel_size)
model_parallel_size *= 2
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer."""
import math
import torch
import torch.nn.init as init
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
from .initialize import get_model_parallel_world_size
from .layers import ColumnParallelLinear
from .layers import RowParallelLinear
from .mappings import gather_from_model_parallel_region
from .random import checkpoint
from .random import get_cuda_rng_tracker
from .utils import divide
from .utils import split_tensor_along_last_dim
class GPT2ParallelSelfAttention(torch.nn.Module):
"""Parallel self-attention layer for GPT2.
Self-attention layer takes input with size [b, s, h] where b is
the batch size, s is the sequence lenght, and h is the hidden size
and creates output of the same size.
Arguments:
hidden_size: total hidden size of the layer (h).
num_attention_heads: number of attention heads (n). Note that we
require n to be divisible by number of GPUs
used to parallelize the model. Also, we
require hidden size to be divisible by n.
dropout_prob: dropout probability for the attention scores.
init_method: weight initialization.
output_layer_init_method: output layer initialization. If None, use
`init_method`.
We use the following notation:
h: hidden_size
n: num_attention_heads
p: number of partitions
np: n/p
hp: h/p
hn: h/n
b: batch size
s: sequence length
"""
def __init__(self, hidden_size, num_attention_heads,
attention_dropout_prob, output_dropout_prob,
init_method, output_layer_init_method=None):
super(GPT2ParallelSelfAttention, self).__init__()
# Set output layer initialization if not provided.
if output_layer_init_method is None:
output_layer_init_method = init_method
# Per attention head and per partition values.
world_size = get_model_parallel_world_size()
self.hidden_size_per_partition = divide(hidden_size, world_size)
self.hidden_size_per_attention_head = divide(hidden_size,
num_attention_heads)
self.num_attention_heads_per_partition = divide(num_attention_heads,
world_size)
# Strided linear layer.
self.query_key_value = ColumnParallelLinear(hidden_size, 3*hidden_size,
stride=3,
gather_output=False,
init_method=init_method)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(attention_dropout_prob)
# Output.
self.dense = RowParallelLinear(hidden_size,
hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method)
self.output_dropout = torch.nn.Dropout(output_dropout_prob)
def _transpose_for_scores(self, tensor):
"""Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
size [b, np, s, hn].
"""
new_tensor_shape = tensor.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
tensor = tensor.view(*new_tensor_shape)
return tensor.permute(0, 2, 1, 3)
def forward(self, hidden_states, ltor_mask):
# hidden_states: [b, s, h]
# ltor_mask: [1, 1, s, s]
# Attention heads. [b, s, hp]
mixed_x_layer = self.query_key_value(hidden_states)
(mixed_query_layer,
mixed_key_layer,
mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
# Reshape and transpose [b, np, s, hn]
query_layer = self._transpose_for_scores(mixed_query_layer)
key_layer = self._transpose_for_scores(mixed_key_layer)
value_layer = self._transpose_for_scores(mixed_value_layer)
# Raw attention scores. [b, np, s, s]
attention_scores = torch.matmul(query_layer,
key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(
self.hidden_size_per_attention_head)
# Apply the left to right attention mask.
attention_scores = torch.mul(attention_scores, ltor_mask) - \
10000.0 * (1.0 - ltor_mask)
# Attention probabilities. [b, np, s, s]
attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
# Context layer.
# [b, np, s, hn]
context_layer = torch.matmul(attention_probs, value_layer)
# [b, s, np, hn]
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
# [b, s, hp]
context_layer = context_layer.view(*new_context_layer_shape)
# Output. [b, s, h]
output = self.dense(context_layer)
output = self.output_dropout(output)
return output
@torch.jit.script
def gelu_impl(x):
"""OpenAI's gelu implementation."""
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
(1.0 + 0.044715 * x * x)))
def gelu(x):
return gelu_impl(x)
class GPT2ParallelMLP(torch.nn.Module):
"""MLP for GPT2.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform gelu transformation, and project the
state back into h hidden dimension. At the end, dropout is also
applied.
Arguments:
hidden_size: The hidden size of the self attention.
output_dropout_prob: dropout probability for the outputs
after self attention and final output.
init_method: initialization method used for the weights. Note
that all biases are initialized to zero and
layernorm weight are initialized to one.
output_layer_init_method: output layer initialization. If None,
use `init_method`.
"""
def __init__(self, hidden_size, output_dropout_prob, init_method,
output_layer_init_method=None):
super(GPT2ParallelMLP, self).__init__()
# Set output layer initialization if not provided.
if output_layer_init_method is None:
output_layer_init_method = init_method
# Project to 4h.
self.dense_h_to_4h = ColumnParallelLinear(hidden_size, 4*hidden_size,
gather_output=False,
init_method=init_method)
# Project back to h.
self.dense_4h_to_h = RowParallelLinear(
4*hidden_size,
hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method)
self.dropout = torch.nn.Dropout(output_dropout_prob)
def forward(self, hidden_states):
# [b, s, 4hp]
intermediate_parallel = self.dense_h_to_4h(hidden_states)
intermediate_parallel = gelu(intermediate_parallel)
# [b, s, h]
output = self.dense_4h_to_h(intermediate_parallel)
output = self.dropout(output)
return output
class GPT2ParallelTransformerLayer(torch.nn.Module):
"""A single layer transformer for GPT2.
We use the following notation:
h: hidden size
n: number of attention heads
b: batch size
s: sequence length
Transformore layer takes input with size [b, s, h] and returns an
output of the same size.
Arguments:
hidden_size: The hidden size of the self attention.
num_attention_heads: number of attention head in the self
attention.
attention_dropout_prob: dropout probability of the attention
score in self attention.
output_dropout_prob: dropout probability for the outputs
after self attention and final output.
layernorm_epsilon: epsilon used in layernorm to avoid
division by zero.
init_method: initialization method used for the weights. Note
that all biases are initialized to zero and
layernorm weight are initialized to one.
output_layer_init_method: output layers (attention output and
mlp output) initialization. If None,
use `init_method`.
"""
def __init__(self,
hidden_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
layernorm_epsilon,
init_method,
output_layer_init_method=None):
super(GPT2ParallelTransformerLayer, self).__init__()
# Set output layer initialization if not provided.
if output_layer_init_method is None:
output_layer_init_method = init_method
# Layernorm on the input data.
self.input_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
# Self attention.
self.attention = GPT2ParallelSelfAttention(
hidden_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
init_method,
output_layer_init_method=output_layer_init_method)
# Layernorm on the input data.
self.post_attention_layernorm = LayerNorm(hidden_size,
eps=layernorm_epsilon)
# MLP
self.mlp = GPT2ParallelMLP(
hidden_size,
output_dropout_prob,
init_method,
output_layer_init_method=output_layer_init_method)
def forward(self, hidden_states, ltor_mask):
# hidden_states: [b, s, h]
# ltor_mask: [1, 1, s, s]
# Layer norm at the begining of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output = self.attention(layernorm_output, ltor_mask)
# Residual connection.
layernorm_input = hidden_states + attention_output
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
# MLP.
mlp_output = self.mlp(layernorm_output)
# Second residual connection.
output = layernorm_input + mlp_output
return output
def unscaled_init_method(sigma):
"""Init method based on N(0, sigma)."""
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
return init_
def scaled_init_method(sigma, num_layers):
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = sigma / math.sqrt(2.0 * num_layers)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
class GPT2ParallelTransformer(torch.nn.Module):
"""GPT-2 transformer.
This module takes input from embedding layer and it's output can
be used directly by a logit layer. It consists of L (num-layers)
blocks of:
layer norm
self attention
residual connection
layer norm
mlp
residual connection
followed by a final layer norm.
Arguments:
num_layers: Number of transformer layers.
hidden_size: The hidden size of the self attention.
num_attention_heads: number of attention head in the self
attention.
attention_dropout_prob: dropout probability of the attention
score in self attention.
output_dropout_prob: dropout probability for the outputs
after self attention and final output.
checkpoint_activations: if True, checkpoint activations.
checkpoint_num_layers: number of layers to checkpoint. This
is basically the chunk size in checkpoitning.
layernorm_epsilon: epsilon used in layernorm to avoid
division by zero.
init_method_std: standard deviation of the init method which has
the form N(0, std).
use_scaled_init_for_output_weights: If Ture use 1/sqrt(2*num_layers)
scaling for the output weights (
output of self attention and mlp).
"""
def __init__(self,
num_layers,
hidden_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
checkpoint_activations,
checkpoint_num_layers=1,
layernorm_epsilon=1.0e-5,
init_method_std=0.02,
use_scaled_init_for_output_weights=True):
super(GPT2ParallelTransformer, self).__init__()
# Store activation checkpoiting flag.
self.checkpoint_activations = checkpoint_activations
self.checkpoint_num_layers = checkpoint_num_layers
output_layer_init_method = None
if use_scaled_init_for_output_weights:
output_layer_init_method = scaled_init_method(init_method_std,
num_layers)
def get_layer():
return GPT2ParallelTransformerLayer(
hidden_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
layernorm_epsilon,
unscaled_init_method(init_method_std),
output_layer_init_method=output_layer_init_method)
# Transformer layers.
self.layers = torch.nn.ModuleList(
[get_layer() for _ in range(num_layers)])
# Final layer norm before output.
self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
def forward(self, hidden_states, attention_mask):
def custom(start, end):
def custom_forward(*inputs):
layers_ = self.layers[start:end]
x_ = inputs[0]
for layer in layers_:
x_ = layer(x_, inputs[1])
return x_
return custom_forward
if self.checkpoint_activations:
l = 0
num_layers = len(self.layers)
chunk_length = self.checkpoint_num_layers
while l < num_layers:
hidden_states = checkpoint(custom(l, l+chunk_length),
hidden_states, attention_mask)
l += chunk_length
else:
for layer in self.layers:
hidden_states = layer(hidden_states, attention_mask)
# Final layer norm.
output = self.final_layernorm(hidden_states)
return output
class BertParallelSelfAttention(torch.nn.Module):
"""Parallel self-attention layer for BERT.
Self-attention layer takes input with size [b, s, h] where b is
the batch size, s is the sequence lenght, and h is the hidden size
and creates output of the same size.
Arguments:
hidden_size: total hidden size of the layer (h).
num_attention_heads: number of attention heads (n). Note that we
require n to be divisible by number of GPUs
used to parallelize the model. Also, we
require hidden size be divisible by n.
dropout_prob: dropout probability for the attention scores.
output_parallel: If true, no all-gather is done on the output and
the output values will be per partition.
We use the following notation:
h: hidden_size
n: num_attention_heads
p: number of partitions
np: n/p
hp: h/p
hn: h/n
b: batch size
s: sequence length
"""
def __init__(self, hidden_size, num_attention_heads,
dropout_prob, output_parallel=False,
init_method=init.xavier_normal_):
super(BertParallelSelfAttention, self).__init__()
# Input configuration.
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.dropout_prob = dropout_prob
self.output_parallel = output_parallel
# Per attention head and per partition values.
world_size = get_model_parallel_world_size()
self.hidden_size_per_partition = divide(hidden_size, world_size)
self.hidden_size_per_attention_head = divide(hidden_size,
num_attention_heads)
self.num_attention_heads_per_partition = divide(num_attention_heads,
world_size)
# Strided linear layer.
self.query_key_value = ColumnParallelLinear(hidden_size, 3*hidden_size,
stride=3,
gather_output=False,
init_method=init_method)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.dropout = torch.nn.Dropout(dropout_prob)
def _transpose_for_scores(self, tensor):
"""Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
size [b, np, s, hn].
"""
new_tensor_shape = tensor.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
tensor = tensor.view(*new_tensor_shape)
return tensor.permute(0, 2, 1, 3)
def forward(self, hidden_states, attention_mask):
# Attention heads. [b, s, hp]
mixed_x_layer = self.query_key_value(hidden_states)
(mixed_query_layer,
mixed_key_layer,
mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
# Reshape and transpose [b, np, s, hn]
query_layer = self._transpose_for_scores(mixed_query_layer)
key_layer = self._transpose_for_scores(mixed_key_layer)
value_layer = self._transpose_for_scores(mixed_value_layer)
# Raw attention scores. [b, np, s, s]
attention_scores = torch.matmul(query_layer,
key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(
self.hidden_size_per_attention_head)
# Apply the attention mask.
attention_scores += attention_mask
# Attention probabilities. [b, np, s, s]
attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with get_cuda_rng_tracker().fork():
attention_probs = self.dropout(attention_probs)
# Context layer.
# [b, np, s, hn]
context_layer = torch.matmul(attention_probs, value_layer)
# [b, s, np, hn]
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
# [b, s, hp]
context_layer = context_layer.view(*new_context_layer_shape)
# Output. [b, s, h]
if self.output_parallel:
output = context_layer
else:
output = gather_from_model_parallel_region(context_layer)
return output
class BertParallelTransformerOutput(torch.nn.Module):
"""The output layer used after self attention and intermediate
parts of transformer layer."""
def __init__(self, input_size, output_size, dropout_prob,
layernorm_epsilon=1.0e-12, input_is_parallel=False,
init_method=init.xavier_normal_):
super(BertParallelTransformerOutput, self).__init__()
# Components.
self.dense = RowParallelLinear(input_size,
output_size,
input_is_parallel=input_is_parallel,
init_method=init_method)
self.dropout = torch.nn.Dropout(dropout_prob)
self.layernorm = LayerNorm(output_size, eps=layernorm_epsilon)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
layernorm_input = hidden_states + input_tensor
hidden_states = self.layernorm(layernorm_input)
return hidden_states
class BertParallelTransformerLayer(torch.nn.Module):
"""A single layer transformer for Bert.
We use the following notation:
h: hidden size
n: number of attention heads
b: batch size
s: sequence length
Transformore layer takes input with size [b, s, h] and returns an
output of the same size.
Arguments:
hidden_size: The hidden size of the self attention.
intermediate_size: size of the intermediate state after
self attention. In both BERT and GPT
this is set to be 4 times the hidden
size.
num_attention_heads: number of attention head in the self
attention.
attention_dropout_prob: dropout probability of the attention
score in self attention.
output_dropout_prob: dropout probability for the outputs
after self attention and final output.
intermediate_activation_fn: activation function for output
of intermediate.
layernorm_epsilon: epsilon used in layernorm to avoid
division by zero.
init_method: initialization method used for the weights. Note
that all biases are initialized to zero and
layernorm weight are initialized to one.
"""
def __init__(self,
hidden_size,
intermediate_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
intermediate_activation_fn,
layernorm_epsilon,
init_method=init.xavier_normal_):
super(BertParallelTransformerLayer, self).__init__()
# Self attention.
self.attention = BertParallelSelfAttention(hidden_size,
num_attention_heads,
attention_dropout_prob,
output_parallel=True,
init_method=init_method)
# Self attention output.
self.self_output = BertParallelTransformerOutput(
hidden_size, hidden_size, output_dropout_prob,
layernorm_epsilon=layernorm_epsilon,
input_is_parallel=True,
init_method=init_method)
# Intermediate.
self.intermediate = ColumnParallelLinear(hidden_size, intermediate_size,
gather_output=False,
init_method=init_method)
self.intermediate_activation_fn = intermediate_activation_fn
# Output.
self.output = BertParallelTransformerOutput(
intermediate_size, hidden_size, output_dropout_prob,
layernorm_epsilon=layernorm_epsilon,
input_is_parallel=True,
init_method=init_method)
def forward(self, hidden_states, attention_mask):
# [b, s, hp]
attention_output_parallel = self.attention(hidden_states,
attention_mask)
# [b, s, h]
attention_self_output = self.self_output(attention_output_parallel,
hidden_states)
# [b, s, ip]
intermediate_output_parallel = self.intermediate(attention_self_output)
intermediate_output_parallel = self.intermediate_activation_fn(
intermediate_output_parallel)
# [b, s, h]
layer_output = self.output(intermediate_output_parallel,
attention_self_output)
return layer_output
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, '{} is not divisible by {}'.format(
numerator, denominator)
def divide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
def split_tensor_along_last_dim(tensor, num_partitions,
contiguous_split_chunks=False):
"""Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
# Note: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list
class VocabUtility:
"""Split the vocabulary into `world_size` chunks amd return the
first and last index of the vocabulary belonging to the `rank`
partition: Note that indecies in [fist, last)"""
@staticmethod
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
rank, world_size):
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f, index_l
@staticmethod
def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
per_partition_vocab_size = divide(global_vocab_size, world_size)
return VocabUtility.vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size, rank, world_size)
The following steps show how to prepare training dataset to train the mode.
# Libraries to install
```
pip install ftfy langdetect numpy torch pandas nltk sentencepiece boto3 tqdm regex bs4 newspaper3k htmlmin tldextract
git clone https://github.com/mattilyra/LSH
cd LSH
python setup.py install
```
# Download the dataset
1. Download the deduplicated URLs from [jcpeterson](https://mega.nz/#F!EZZD0YwJ!9_PlEQzdMVLaNdKv_ICNVQ!cc4RgQQZ)
2. Remove blacklisted URLs.
```
python blacklist_urls.py <path to the dowloaded deduplicated URLs> <filename for clean urls. e.g. clean_urls.txt>
```
3. Download the content from the clean urls with [openwebtext's utilities](https://github.com/eukaryote31/openwebtext/blob/master/download.py).
4. Merge the contents into one loose json file with 1 json per newline of the format `{'text': text, 'url': unique_url}`. It is important for the url to be unique.
# Prepare the data for GPT-2 training:
1. Perform ftfy, english detection and remove documents with less than 128 tokens. This step can be sharded and run on shards.
```
python cleanup_dataset.py <input data file> <output cleaned data filename>
```
2. Using LSH, find possible duplicates and store then in a file for later processing. This step can NOT be sharded and usually takes 12 to 24 hours for OpenWebText dataset.
```
python find_duplicates.py <input cleaned data file> <output possible duplicate urls filename>
```
3. Based on similarity measure defind inside function `is_similar` (default: 0.9), group urls that are similar. Basically, for each group, only one url we should keep and remove the rest.
```
python group_duplicate_urls.py <possible duplicate urls file> <output file containing similar urls>
```
4. Remove similar documents that were detected in the last step.
```
python remove_group_duplicates.py <file containing simialr documents> <cleaned data file> <outputfile containing deduplicate data>
```
5. Shuffle the dataset.
```
shuf <cleaned deduped data file> -o train_data.json
```
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import re
import time
import tldextract
import sys
# List of the domains to blacklist.
domain_blacklist = set([
'500px',
'aapks',
'akamaihd',
'amazon',
'apple',
'artifactfire',
'artstation',
'awwni',
'bandcamp',
'battleforthenet',
'coinscalendar',
'dailymotion',
'deviantart',
'discord',
'discordapp',
'dlapkandroid',
'dropbox',
'e621',
'ebay',
'edealinfo',
'erome',
'eroshare',
'explosm',
'facebook',
'fbcdn',
'flickr',
'furaffinity',
'futhead',
'gatopardo',
'gfycat',
'gifsound',
'gifsoup',
'giphy',
'github',
'google',
'gunprime',
'gyazo',
'hotdealstar',
'imagefap',
'imageshack',
'imgflip',
'imgur',
'instagram',
'karmadecay',
'kryptocal',
'kym-cdn',
'liveleak',
'livememe',
'lmgtfy',
'magaimg',
'memegenerator',
'minorplanetcenter',
'minus',
'mobafire',
'morejpeg',
'nocookie',
'pcpartpicker',
'photobucket',
'pinimg',
'pinterest',
'pixiv',
'pornhub',
'prntscr',
'puu',
'qkme',
'quickmeme',
'radd',
'redd',
'reddit',
'reddit-stream',
'redditlog',
'redditmedia',
'reddituploads',
'redtube',
'reupp',
'reverb',
'roanoke',
'rollingstone',
'sli',
'soundcloud',
'soundgasm',
'spankbang',
'spotify',
'strawpoll',
'streamable',
'timeanddate',
'tinypic',
'touhouradio',
'tumblr',
'twimg',
'twitch',
'twitter',
'vid',
'vimeo',
'vine',
'vkaao',
'vocaroo',
'voyagefusion',
'walmart',
'wciu',
'wikimedia',
'wikipedia',
'xhamster',
'xkcd',
'xvideos',
'youtu',
'youtube',
'youtubedoubler',
'ytimg',
'zillexplorer',
])
def domain_is_in_blacklist(url):
domain = tldextract.extract(url).domain
return domain in domain_blacklist
# List of extentions to blacklist.
extentions_blacklist = (
'.3gp',
'.7z'
'.ai',
'.aif',
'.apk',
'.app',
'.avi',
'.bin',
'.bmp',
'.bz2',
'.css',
'.csv',
'.dat',
'.deb',
'.dmg',
'.doc',
'.docx',
'.exe',
'.gif',
'.gifv',
'.gz',
'.iso',
'.jar',
'.jpeg',
'.jpg',
'.js',
'.log',
'.mid',
'.midi',
'.mkv',
'.mov',
'.mp3',
'.mp4',
'.mpeg',
'.mpg',
'.ogg',
'.ogv',
'.otf',
'.pdf',
'.pkg',
'.png',
'.pps',
'.ppt',
'.pptx',
'.psd',
'.py',
'.qt',
'.ram',
'.rar',
'.sql',
'.svg',
'.swf',
'.tar.gz',
'.tar',
'.tgz',
'.tiff',
'.ttf',
'.txt',
'.wav',
'.webm',
'.wma',
'.wmv',
'.xls',
'.xlsx',
'.xml',
'.xz',
'.zip',
)
def extention_is_in_blacklist(url):
if url.split('?')[0].lower().endswith(extentions_blacklist):
return True
return False
# Malformed urls.
# This function is adapted from:
# https://stackoverflow.com/questions/7160737/python-how-to-validate-a-url-in-python-malformed-or-not
url_regex = re.compile(
r'^(?:http)s?://' # http:// or https://
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' #domain...
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip
r'(?::\d+)?' # optional port
r'(?:/?|[/?]\S+)$', re.IGNORECASE)
def url_is_malformed(url):
return re.match(url_regex, url) is None
def print_progress(prefix, start_time, urls_counter,
domain_blacklist_counter,
extention_blacklist_counter,
short_url_counter, malformed_url_counter,
duplicate_url_counter):
string = prefix + ' | '
string += 'time elapsed (s): {:.2f} | '.format(time.time() - start_time)
string += 'number of urls: {} | '.format(urls_counter)
string += 'domain blacklisted: {} | '.format(domain_blacklist_counter)
string += 'extention blacklisted: {} | '.format(extention_blacklist_counter)
string += 'short urls (<=8): {} | '.format(short_url_counter)
string += 'malformed urls: {} | '.format(malformed_url_counter)
string += 'duplicate urls: {}'.format(duplicate_url_counter)
print(string, flush=True)
if __name__ == '__main__':
print('remove blacklisted urls ..')
# Path to the url files.
path = sys.argv[1]
# Output url file.
output = sys.argv[2]
# Get the list of url files.
files = glob.glob(path + '/*.txt')
print('> found {} files'.format(len(files)))
urls = set()
urls_counter = 0
domain_blacklist_counter = 0
extention_blacklist_counter = 0
short_url_counter = 0
malformed_url_counter = 0
duplicate_url_counter = 0
start_time = time.time()
for filename in files:
with open(filename, 'r') as f:
for line in f:
url = line.strip()
urls_counter += 1
if domain_is_in_blacklist(url):
print('[DOMAIN BLACKLIST]: {}'.format(url), flush=True)
domain_blacklist_counter += 1
elif extention_is_in_blacklist(url):
print('[EXTENTION BLACKLIST]: {}'.format(url), flush=True)
extention_blacklist_counter += 1
elif len(url) <= 8:
print('[SHORT URL]: {}'.format(url), flush=True)
short_url_counter += 1
elif url_is_malformed(url):
print('[MALFORMED URL]: {}'.format(url), flush=True)
malformed_url_counter += 1
elif url in urls:
print('[DUPLICATE URL]: {}'.format(url), flush=True)
duplicate_url_counter += 1
else:
urls.add(url)
if urls_counter % 100000 == 0:
print_progress('PROGRESS', start_time, urls_counter,
domain_blacklist_counter,
extention_blacklist_counter,
short_url_counter, malformed_url_counter,
duplicate_url_counter)
print_progress('FINAL', start_time, urls_counter,
domain_blacklist_counter,
extention_blacklist_counter,
short_url_counter, malformed_url_counter,
duplicate_url_counter)
# Write the final set of urls.
print('> writing cleaned up url list to {}'.format(output))
with open(output, 'w') as f:
for url in urls:
f.write(url + '\n')
print('done :-)')
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ftfy
import json
from langdetect import detect
import numpy as np
import time
import os
import sys
from tokenizer import Tokenizer
MIN_DOCUMENT_LENGHT = 128
def print_progress(prefix, start_time, num_docs, num_fixed_text,
num_non_english_docs, chars_non_english_docs,
num_small_docs, chars_small_docs):
string = prefix + ' | '
string += 'elapsed time: {:.2f} | '.format(time.time() - start_time)
string += 'documents: {} | '.format(num_docs)
string += 'fixed text: {} | '.format(num_fixed_text)
string += 'non-english: {} | '.format(num_non_english_docs)
string += 'non-english chars: {} | '.format(chars_non_english_docs)
string += 'small docs: {} | '.format(num_small_docs)
string += 'small docs chars: {}'.format(chars_small_docs)
print(string, flush=True)
def filter_corpus(filename, out_filename, print_interval=10000):
print(' > filtering {}'.format(filename))
tokenizer = Tokenizer(cache_dir='./cache')
num_docs = 0
num_written_docs = 0
num_small_docs = 0
num_fixed_text = 0
num_non_english_docs = 0
chars_non_english_docs = 0
chars_small_docs = 0
start_time = time.time()
with open(out_filename, 'wb') as f:
with open(filename, 'r') as fin:
for line in fin:
try:
num_docs += 1
myjson = json.loads(line)
# Fix text
text = ftfy.fix_text(myjson['text'])
if text != myjson['text']:
num_fixed_text += 1
myjson['text'] = text
# Detect language.
if detect(text) != 'en':
print('[non-english text]', myjson)
num_non_english_docs += 1
chars_non_english_docs += len(text)
continue
# On average each token is 5 characters so 8 is an
# upper bound.
if len(text) < (8 * MIN_DOCUMENT_LENGHT):
tokens = tokenizer.tokenize_document(text)
if len(tokens) < MIN_DOCUMENT_LENGHT:
print('[small document, skipping]:', myjson)
num_small_docs += 1
chars_small_docs += len(text)
continue
myjson = json.dumps(myjson, ensure_ascii=False)
f.write(myjson.encode('utf-8'))
f.write('\n'.encode('utf-8'))
num_written_docs += 1
if num_docs % print_interval == 0:
print_progress('[PROGRESS]', start_time, num_docs,
num_fixed_text, num_non_english_docs,
chars_non_english_docs,
num_small_docs, chars_small_docs)
except Exception as e:
print(' skipping ', line, e)
print_progress('[FINAL]', start_time, num_docs,
num_fixed_text, num_non_english_docs,
chars_non_english_docs,
num_small_docs, chars_small_docs)
if __name__ == '__main__':
print('building gpt2 dataset ...')
input_filename = sys.argv[1]
output_filename = sys.argv[2]
print('will be reading {}'.format(input_filename))
print('and will write the results to {}'.format(output_filename))
filter_corpus(input_filename, output_filename)
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import json
from lsh import cache, minhash
import time
import sys
# This function is adapted from:
# https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
def shingles(text, char_ngram=5):
return set(text[head:head + char_ngram]
for head in range(0, len(text) - char_ngram))
# This function is adapted from:
# https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
def jaccard(set_a, set_b):
intersection = set_a & set_b
union = set_a | set_b
return len(intersection) / len(union)
if __name__ == '__main__':
print('finding possible duplicate content ...')
input = sys.argv[1]
output = sys.argv[2]
hasher = minhash.MinHasher(seeds=100, char_ngram=5, hashbytes=4)
lshcache = cache.Cache(bands=10, hasher=hasher)
counter = 0
url_doc = {}
start_time = time.time()
with open(input, 'r') as f:
for line in f:
try:
myjson = json.loads(line)
url = myjson['url']
text = myjson['text']
counter += 1
url_doc[url] = text
lshcache.add_fingerprint(hasher.fingerprint(text), url)
except Exception as e:
print('Error:', e)
if counter % 10000 == 0:
print(' [read]> processed {} documents in {:.2f} seconds ...'.
format(counter, time.time() - start_time), flush=True)
counter = 0
start_time = time.time()
deduped = 0
with open(output, 'wb') as f:
for b in lshcache.bins:
for bucket_id in b:
if len(b[bucket_id]) > 1:
items = list(b[bucket_id])
main_url = items[0]
main_dhingles = shingles(url_doc[main_url])
remove_urls = []
for i in range(1, len(items)):
counter += 1
other_url= items[i]
other_shingles = shingles(url_doc[other_url])
try:
jaccard_sim = jaccard(main_dhingles, other_shingles)
except Exception as e:
print('Error:', e)
if jaccard_sim > 0.5:
remove_urls.append({other_url: jaccard_sim})
deduped += 1
if counter % 10000 == 0:
print(' [write]> processed {} documents in {:.2f} '
'seoncds and deduped {} documents ...'.
format(counter, time.time() - start_time,
deduped), flush=True)
if len(remove_urls) > 0:
myjson = json.dumps({main_url: remove_urls},
ensure_ascii=False)
f.write(myjson.encode('utf-8'))
f.write('\n'.encode('utf-8'))
print('done :-)')
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import time
import sys
def is_similar(jaccard_similarity):
return (js >= 0.9)
if __name__ == '__main__':
print('grouping duplicate urls ...')
input = sys.argv[1]
output = sys.argv[2]
url_to_index = {}
index_to_urls = []
counter = 0
start_time = time.time()
with open(input, 'r') as f:
for line in f:
counter += 1
myjson = json.loads(line)
urls = []
for main_url in myjson.keys():
urls.append(main_url)
for value in myjson[main_url]:
for other_url, js in value.items():
if is_similar(js):
urls.append(other_url)
current_index = -1
other_indices = set()
for url in urls:
if url in url_to_index:
if current_index == -1:
current_index = url_to_index[url]
elif current_index != url_to_index[url]:
other_indices.add(url_to_index[url])
if current_index == -1:
current_index = len(index_to_urls)
index_to_urls.append(set())
for url in urls:
url_to_index[url] = current_index
index_to_urls[current_index].add(url)
for index in other_indices:
for url in index_to_urls[index]:
index_to_urls[current_index].add(url)
url_to_index[url] = current_index
index_to_urls[index] = None
if counter % 100000 == 0:
print(' > processed {} lines in {} seconds ...'.format(
counter, time.time() - start_time))
total_remove = 0
total_remain = 0
for urls in index_to_urls:
if urls is not None:
if len(urls) > 1:
total_remove += (len(urls) - 1)
total_remain += 1
print('out of {} urls, only {} are unique and {} should be removed'.format(
total_remove+total_remain, total_remain, total_remove))
with open(output, 'wb') as f:
for i, urls in enumerate(index_to_urls):
if urls is not None:
if len(urls) > 1:
myjson = json.dumps({str(i): list(urls)},
ensure_ascii=False)
f.write(myjson.encode('utf-8'))
f.write('\n'.encode('utf-8'))
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import numpy as np
import time
import os
import sys
from tokenizer import Tokenizer
def tokenize_corpus(filename, np_filename, print_interval=10000):
print(' > tokenizing {}'.format(filename))
tokenizer = Tokenizer(cache_dir='./cache')
tokenized_docs = []
num_docs = 0
num_tokens = 0
start_time = time.time()
with open(filename, 'r') as f:
for line in f:
try:
myjson = json.loads(line)
url = myjson['url']
sample = myjson['text']
tokens = tokenizer.tokenize_document(sample)
tokenized_docs.append(np.array(tokens, dtype=np.uint16))
num_docs += 1
num_tokens += len(tokens)
if num_docs % print_interval == 0:
print(' processed {:9d} documents in {:.2f} (s) so far'.
format(num_docs, time.time() - start_time),
flush=True)
except Exception as e:
print(' skipping ', line, e)
print(' >> processed {} document with total of {} tokens ...'.format(
num_docs, num_tokens))
tokenized_docs = np.array(tokenized_docs, dtype=object)
np.save(np_filename, tokenized_docs, allow_pickle=True)
print(' >> saved the tokenzed document to {} ...'.format(np_filename))
if __name__ == '__main__':
print('building gpt2 dataset ...')
path = sys.argv[1]
shard = sys.argv[2]
input_filename = os.path.join(path,
'shards/shard_{:04d}'.format(int(shard)))
output_filename = os.path.join(path,
'npys/shard_{:04d}.npy'.format(int(shard)))
print('will be reading {}'.format(input_filename))
print('and will write the results to {}'.format(output_filename))
tokenize_corpus(input_filename, output_filename)
import glob
import json
import os
import time
import sys
import numpy as np
if __name__ == '__main__':
print('building the shard sizes ...')
path = sys.argv[1]
print('> reading numpy files from {}'.format(path))
npy_files = glob.glob(path + '/*.npy')
npy_files.sort()
print(' found {} numpy files'.format(len(npy_files)))
size_dict = {}
counter = 0
start_time = time.time()
for filename in npy_files:
data = np.load(filename, allow_pickle=True)
size = np.hstack(data).size
np_filename = os.path.basename(filename)
size_dict[np_filename] = size
counter += 1
if counter % 10 == 0:
print(' processed {} files in {:.2f} seconds'.format(
counter, time.time() - start_time))
output_filename = os.path.join(path, 'sizes.txt')
with open(output_filename, 'w') as f:
json.dump(size_dict, f)
print('> wrote sizes to {}'.format(output_filename))
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import sys
import json
import argparse
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--json_path", type=str, default=".",
help="path where all the json files are located")
parser.add_argument("--output_file", type=str, default="merged_output.json",
help="filename where the merged json should go")
args = parser.parse_args()
json_path = args.json_path
out_file = args.output_file
json_files = glob.glob(json_path + '/*.json')
counter = 0
with open(out_file, 'w') as outfile:
for fname in json_files:
counter += 1
if counter % 1024 == 0:
print("Merging at ", counter, flush=True)
with open(fname, 'r') as infile:
for row in infile:
each_row = json.loads(row)
outfile.write(row)
print("Merged file", out_file, flush=True)
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import time
import sys
if __name__ == '__main__':
url_filename = sys.argv[1]
data_filename = sys.argv[2]
output_filename = sys.argv[3]
urls = set()
with open(url_filename, 'r') as f:
for line in f:
myjson = json.loads(line)
for key in myjson:
this_urls = myjson[key]
for i in range(1, len(this_urls)):
urls.add(this_urls[i])
print('will be removing {} urls'.format(len(urls)), flush=True)
written_docs = 0
removed_docs = 0
removed_chars = 0
start_time = time.time()
with open(output_filename, 'wb') as fout:
with open(data_filename, 'r') as fin:
for line in fin:
try:
myjson = json.loads(line)
url = myjson['url']
if url in urls:
print('removing', myjson)
removed_docs += 1
removed_chars += len(myjson['text'])
continue
myjson = json.dumps(myjson, ensure_ascii=False)
fout.write(myjson.encode('utf-8'))
fout.write('\n'.encode('utf-8'))
written_docs += 1
if written_docs % 10000 == 0:
print(' [PROCESSED] time (s): {:.2f} | written: {} '
'| removed: {} (char: {})'.format(
time.time() - start_time,
written_docs, removed_docs, removed_chars))
except Exception as e:
print('[SKIPPING]', line, e)
print(' [PROCESSED] time (s): {:.2f} | written: {} '
'| removed: {} (char: {})'.format(
time.time() - start_time,
written_docs, removed_docs, removed_chars))
print('done :-)')
#!/bin/bash
echo "processing gpt2 data ..."
DIR="/raid/mpatwary/redownload_v0/0-21"
for thread in {0..3}; do
echo " launching thread "$thread && python make_gpt2_dataset.py $DIR $thread > $DIR/logs/shard_$thread.log 2>&1 &
done
......@@ -13,14 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torch.optim import SGD
from torch.optim import Adadelta
from torch.optim import Adagrad
from torch.optim import SparseAdam
from torch.optim import Adamax
from torch.optim import SGD
from torch.optim import Rprop
from torch.optim import RMSprop
from torch.optim import Optimizer
from torch.optim import LBFGS
from .adam import Adam
import sys
sys.path.append('..')
from data_utils.tokenization_gpt2 import GPT2Tokenizer
class Tokenizer:
def __init__(self, cache_dir=None):
self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2',
cache_dir=cache_dir)
self.tokenizer.max_len = int(1e12)
self.eod_token = self.tokenizer.encoder['<|endoftext|>']
assert self.eod_token < 65535, 'vocab size will not fit in uint16'
print('> GPT2 tokenizer with {} vocab size and eod token {} ...'.format(
len(self.tokenizer.encoder), self.eod_token))
def tokenize_document(self, document):
tokens = self.tokenizer.encode(document)
tokens.append(self.eod_token)
return tokens
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
from torch.optim import Optimizer
class Adam(Optimizer):
r"""Implements Adam algorithm.
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False)
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, amsgrad=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, amsgrad=amsgrad)
super(Adam, self).__init__(params, defaults)
def __setstate__(self, state):
super(Adam, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
amsgrad = group['amsgrad']
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
if amsgrad:
max_exp_avg_sq = state['max_exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
else:
denom = exp_avg_sq.sqrt().add_(group['eps'])
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
step_size = group['lr']# * math.sqrt(bias_correction2) / bias_correction1
if group['weight_decay'] != 0:
p.data.add_(-step_size * group['weight_decay'], p.data)
p.data.addcdiv_(-step_size, exp_avg, denom)
return loss
......@@ -15,10 +15,15 @@
"""Pretrain BERT"""
# Flag to use Pytorch ddp which uses overlapping communication and computation.
USE_TORCH_DDP = False
from datetime import datetime
import os
import random
import numpy as np
import torch
import torch.nn.functional as F
from arguments import get_args
from configure_data import configure_data
......@@ -27,20 +32,32 @@ from fp16 import FP16_Optimizer
from learning_rates import AnnealingLR
from model import BertModel
from model import get_params_for_weight_decay_optimization
from model import DistributedDataParallel as DDP
from optim import Adam
from model import gpt2_get_params_for_weight_decay_optimization
if USE_TORCH_DDP:
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
else:
from model import DistributedDataParallel as DDP
import mpu
from apex.optimizers import FusedAdam as Adam
from utils import Timers
from utils import save_checkpoint
from utils import load_checkpoint
from utils import report_memory
from utils import print_args
from utils import print_params_min_max_norm
from utils import print_rank_0
def get_model(tokenizer, args):
def get_model(args):
"""Build the model."""
print('building BERT model ...')
model = BertModel(tokenizer, args)
print(' > number of parameters: {}'.format(
sum([p.nelement() for p in model.parameters()])), flush=True)
print_rank_0('building BERT model ...')
model = BertModel(args)
if mpu.get_data_parallel_rank() == 0:
print(' > number of parameters on model parallel rank {}: {}'.format(
mpu.get_model_parallel_rank(),
sum([p.nelement() for p in model.parameters()])), flush=True)
# GPU allocation.
model.cuda(torch.cuda.current_device())
......@@ -60,7 +77,11 @@ def get_model(tokenizer, args):
_module.float()
# Wrap model for distributed training.
if args.world_size > 1:
if USE_TORCH_DDP:
i = torch.cuda.current_device()
model = DDP(model, device_ids=[i], output_device=i,
process_group=mpu.get_data_parallel_group())
else:
model = DDP(model)
return model
......@@ -86,6 +107,12 @@ def get_optimizer(model, args):
lmheads.transform))
param_groups[1]['params'].append(lmheads.bias)
# Add model parallel attribute if it is not set.
for param_group in param_groups:
for param in param_group['params']:
if not hasattr(param, 'model_parallel'):
param.model_parallel = False
# Use Adam.
optimizer = Adam(param_groups,
lr=args.lr, weight_decay=args.weight_decay)
......@@ -110,7 +137,7 @@ def get_learning_rate_scheduler(optimizer, args):
if args.lr_decay_iters is not None:
num_iters = args.lr_decay_iters
else:
num_iters = args.train_iters * args.epochs
num_iters = args.train_iters
init_step = -1
warmup_iter = args.warmup * num_iters
lr_scheduler = AnnealingLR(optimizer,
......@@ -123,26 +150,22 @@ def get_learning_rate_scheduler(optimizer, args):
return lr_scheduler
def setup_model_and_optimizer(args, tokenizer):
def setup_model_and_optimizer(args):
"""Setup model and optimizer."""
model = get_model(tokenizer, args)
model = get_model(args)
optimizer = get_optimizer(model, args)
lr_scheduler = get_learning_rate_scheduler(optimizer, args)
criterion = torch.nn.CrossEntropyLoss(reduce=False, ignore_index=-1)
if args.load is not None:
epoch, i, total_iters = load_checkpoint(model, optimizer,
lr_scheduler, args)
if args.resume_dataloader:
args.epoch = epoch
args.mid_epoch_iters = i
args.total_iters = total_iters
args.iteration = load_checkpoint(model, optimizer, lr_scheduler, args)
else:
args.iteration = 0
return model, optimizer, lr_scheduler, criterion
return model, optimizer, lr_scheduler
def get_batch(data):
def get_batch(data_iterator, timers):
''' get_batch subdivides the source data into chunks of
length args.seq_length. If source is equal to the example
output of the data loading example, with a seq_length limit
......@@ -155,40 +178,52 @@ def get_batch(data):
to the seq_len dimension in the LSTM. A Variable representing an appropriate
shard reset mask of the same dimensions is also returned.
'''
tokens = torch.autograd.Variable(data['text'].long())
types = torch.autograd.Variable(data['types'].long())
next_sentence = torch.autograd.Variable(data['is_random'].long())
loss_mask = torch.autograd.Variable(data['mask'].float())
lm_labels = torch.autograd.Variable(data['mask_labels'].long())
padding_mask = torch.autograd.Variable(data['pad_mask'].byte())
# Move to cuda
tokens = tokens.cuda()
types = types.cuda()
next_sentence = next_sentence.cuda()
loss_mask = loss_mask.cuda()
lm_labels = lm_labels.cuda()
padding_mask = padding_mask.cuda()
# Items and their type.
keys = ['text', 'types', 'is_random', 'mask', 'mask_labels', 'pad_mask']
datatype = torch.int64
# Broadcast data.
timers('data loader').start()
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
timers('data loader').stop()
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
tokens = data_b['text'].long()
types = data_b['types'].long()
next_sentence = data_b['is_random'].long()
loss_mask = data_b['mask'].float()
lm_labels = data_b['mask_labels'].long()
padding_mask = data_b['pad_mask'].byte()
return tokens, types, next_sentence, loss_mask, lm_labels, padding_mask
def forward_step(data, model, criterion, args):
def forward_step(data_iterator, model, args, timers):
"""Forward step."""
# Get the batch.
timers('batch generator').start()
tokens, types, next_sentence, loss_mask, lm_labels, \
padding_mask = get_batch(data)
padding_mask = get_batch(data_iterator, timers)
timers('batch generator').stop()
# Forward model.
output, nsp = model(tokens, types, 1-padding_mask,
checkpoint_activations=args.checkpoint_activations)
nsp_loss = criterion(nsp.view(-1, 2).contiguous().float(),
next_sentence.view(-1).contiguous()).mean()
losses = criterion(output.view(-1, args.data_size).contiguous().float(),
lm_labels.contiguous().view(-1).contiguous())
nsp_loss = F.cross_entropy(nsp.view(-1, 2).contiguous().float(),
next_sentence.view(-1).contiguous(),
ignore_index=-1)
losses = mpu.vocab_parallel_cross_entropy(
output.contiguous().float(), lm_labels.contiguous())
loss_mask = loss_mask.contiguous()
loss_mask = loss_mask.view(-1)
lm_loss = torch.sum(
losses * loss_mask.view(-1).float()) / loss_mask.sum()
losses.view(-1) * loss_mask.view(-1).float()) / loss_mask.sum()
return lm_loss, nsp_loss
......@@ -209,14 +244,15 @@ def backward_step(optimizer, model, lm_loss, nsp_loss, args):
# Reduce across processes.
lm_loss_reduced = lm_loss
nsp_loss_reduced = nsp_loss
if args.world_size > 1:
reduced_losses = torch.cat((lm_loss.view(1), nsp_loss.view(1)))
torch.distributed.all_reduce(reduced_losses.data)
reduced_losses.data = reduced_losses.data / args.world_size
reduced_losses = torch.cat((lm_loss.view(1), nsp_loss.view(1)))
torch.distributed.all_reduce(reduced_losses.data)
reduced_losses.data = reduced_losses.data / args.world_size
if not USE_TORCH_DDP:
model.allreduce_params(reduce_after=False,
fp32_allreduce=args.fp32_allreduce)
lm_loss_reduced = reduced_losses[0]
nsp_loss_reduced = reduced_losses[1]
lm_loss_reduced = reduced_losses[0]
nsp_loss_reduced = reduced_losses[1]
# Update master gradients.
if args.fp16:
......@@ -225,25 +261,33 @@ def backward_step(optimizer, model, lm_loss, nsp_loss, args):
# Clipping gradients helps prevent the exploding gradient.
if args.clip_grad > 0:
if not args.fp16:
torch.nn.utils.clip_grad_norm(model.parameters(), args.clip_grad)
mpu.clip_grad_norm(model.parameters(), args.clip_grad)
else:
optimizer.clip_master_grads(args.clip_grad)
return lm_loss_reduced, nsp_loss_reduced
def train_step(input_data, model, criterion, optimizer, lr_scheduler, args):
def train_step(data_iterator, model, optimizer, lr_scheduler,
args, timers):
"""Single training step."""
# Forward model for one step.
lm_loss, nsp_loss = forward_step(input_data, model, criterion, args)
timers('forward').start()
lm_loss, nsp_loss = forward_step(data_iterator, model,
args, timers)
timers('forward').stop()
# Calculate gradients, reduce across processes, and clip.
timers('backward').start()
lm_loss_reduced, nsp_loss_reduced = backward_step(optimizer, model, lm_loss,
nsp_loss, args)
timers('backward').stop()
# Update parameters.
timers('optimizer').start()
optimizer.step()
timers('optimizer').stop()
# Update learning rate.
skipped_iter = 0
......@@ -255,9 +299,9 @@ def train_step(input_data, model, criterion, optimizer, lr_scheduler, args):
return lm_loss_reduced, nsp_loss_reduced, skipped_iter
def train_epoch(epoch, model, optimizer, train_data,
lr_scheduler, criterion, timers, args):
"""Train one full epoch."""
def train(model, optimizer, lr_scheduler,
train_data_iterator, val_data_iterator, timers, args):
"""Train the model."""
# Turn on training mode which enables dropout.
model.train()
......@@ -267,25 +311,18 @@ def train_epoch(epoch, model, optimizer, train_data,
total_nsp_loss = 0.0
# Iterations.
max_iters = args.train_iters
iteration = 0
iteration = args.iteration
skipped_iters = 0
if args.resume_dataloader:
iteration = args.mid_epoch_iters
args.resume_dataloader = False
# Data iterator.
data_iterator = iter(train_data)
timers('interval time').start()
while iteration < max_iters:
report_memory_flag = True
while iteration < args.train_iters:
lm_loss, nsp_loss, skipped_iter = train_step(next(data_iterator),
lm_loss, nsp_loss, skipped_iter = train_step(train_data_iterator,
model,
criterion,
optimizer,
lr_scheduler,
args)
args, timers)
skipped_iters += skipped_iter
iteration += 1
......@@ -299,32 +336,47 @@ def train_epoch(epoch, model, optimizer, train_data,
avg_nsp_loss = total_nsp_loss.item() / args.log_interval
avg_lm_loss = total_lm_loss.item() / args.log_interval
elapsed_time = timers('interval time').elapsed()
log_string = ' epoch{:2d} |'.format(epoch)
log_string += ' iteration {:8d}/{:8d} |'.format(iteration,
max_iters)
log_string = ' iteration {:8d}/{:8d} |'.format(iteration,
args.train_iters)
log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
elapsed_time * 1000.0 / args.log_interval)
log_string += ' learning rate {:.3E} |'.format(learning_rate)
log_string += ' lm loss {:.3E} |'.format(avg_lm_loss)
log_string += ' nsp loss {:.3E} |'.format(avg_nsp_loss)
log_string += ' lm loss {:.6E} |'.format(avg_lm_loss)
log_string += ' nsp loss {:.6E} |'.format(avg_nsp_loss)
if args.fp16:
log_string += ' loss scale {:.1f} |'.format(
optimizer.loss_scale)
print(log_string, flush=True)
print_rank_0(log_string)
total_nsp_loss = 0.0
total_lm_loss = 0.0
if report_memory_flag:
report_memory('after {} iterations'.format(iteration))
report_memory_flag = False
timers.log(['forward', 'backward', 'optimizer', 'batch generator',
'data loader'],
normalizer=args.log_interval)
# Checkpointing
if args.save and args.save_iters and iteration % args.save_iters == 0:
total_iters = args.train_iters * (epoch-1) + iteration
model_suffix = 'model/%d.pt' % (total_iters)
save_checkpoint(model_suffix, epoch, iteration, model, optimizer,
lr_scheduler, args)
if args.save and args.save_interval and iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler, args)
# Evaluation
if args.eval_interval and iteration % args.eval_interval == 0 and args.do_valid:
prefix = 'iteration {}'.format(iteration)
evaluate_and_print_results(
prefix, val_data_iterator, model, args, timers, False)
if args.exit_interval and iteration % args.exit_interval == 0:
torch.distributed.barrier()
time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
rank = torch.distributed.get_rank()
print('rank: {} | time: {} | exiting the program at iteration {}'.
format(rank, time_str, iteration), flush=True)
exit()
return iteration, skipped_iters
def evaluate(data_source, model, criterion, args):
def evaluate(data_iterator, model, args, timers, verbose = False):
"""Evaluation."""
# Turn on evaluation mode which disables dropout.
......@@ -332,15 +384,16 @@ def evaluate(data_source, model, criterion, args):
total_lm_loss = 0
total_nsp_loss = 0
max_iters = args.eval_iters
with torch.no_grad():
data_iterator = iter(data_source)
iteration = 0
while iteration < max_iters:
while iteration < args.eval_iters:
iteration += 1
if verbose and iteration % args.log_interval == 0:
print_rank_0('Evaluating iter {}/{}'.format(iteration, args.eval_iters))
# Forward evaluation.
lm_loss, nsp_loss = forward_step(next(data_iterator), model,
criterion, args)
lm_loss, nsp_loss = forward_step(data_iterator, model,
args, timers)
# Reduce across processes.
if isinstance(model, DDP):
reduced_losses = torch.cat((lm_loss.view(1), nsp_loss.view(1)))
......@@ -351,16 +404,34 @@ def evaluate(data_source, model, criterion, args):
total_lm_loss += lm_loss.data.detach().float().item()
total_nsp_loss += nsp_loss.data.detach().float().item()
iteration += 1
# Move model back to the train mode.
model.train()
total_lm_loss /= max_iters
total_nsp_loss /= max_iters
total_lm_loss /= args.eval_iters
total_nsp_loss /= args.eval_iters
return total_lm_loss, total_nsp_loss
def evaluate_and_print_results(prefix, data_iterator, model,
args, timers, verbose=False):
"""Helper function to evaluate and dump results on screen."""
lm_loss, nsp_loss = evaluate(data_iterator, model,
args, timers, verbose)
val_loss = lm_loss + nsp_loss
print_rank_0('-' * 100)
string = ' validation loss at {} | '.format(prefix)
string += 'LM loss: {:.6E} | '.format(lm_loss)
string += 'NSP loss: {:.6E} | '.format(nsp_loss)
string += 'total loss: {:.6E}'.format(val_loss)
length = len(string) + 1
print_rank_0('-' * length)
print_rank_0(string)
print_rank_0('-' * length)
return val_loss
def initialize_distributed(args):
"""Initialize torch.distributed."""
......@@ -370,15 +441,17 @@ def initialize_distributed(args):
device = args.local_rank
torch.cuda.set_device(device)
# Call the init process
if args.world_size > 1:
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
master_port = os.getenv('MASTER_PORT', '6000')
init_method += master_ip + ':' + master_port
torch.distributed.init_process_group(
backend=args.distributed_backend,
world_size=args.world_size, rank=args.rank,
init_method=init_method)
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
master_port = os.getenv('MASTER_PORT', '6000')
init_method += master_ip + ':' + master_port
torch.distributed.init_process_group(
backend=args.distributed_backend,
world_size=args.world_size, rank=args.rank,
init_method=init_method)
# Set the model-parallel / data-parallel communicators.
mpu.initialize_model_parallel(args.model_parallel_size)
def set_random_seed(seed):
......@@ -388,14 +461,51 @@ def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
mpu.model_parallel_cuda_manual_seed(seed)
def get_train_val_test_data(args):
"""Load the data on rank zero and boradcast number of tokens to all GPUS."""
(train_data, val_data, test_data) = (None, None, None)
# Data loader only on rank 0 of each model parallel group.
if mpu.get_model_parallel_rank() == 0:
data_config = configure_data()
data_config.set_defaults(data_set_type='BERT', transpose=False)
(train_data, val_data, test_data), tokenizer = data_config.apply(args)
before = tokenizer.num_tokens
after = before
multiple = args.make_vocab_size_divisible_by * \
mpu.get_model_parallel_world_size()
while (after % multiple) != 0:
after += 1
print_rank_0('> padded vocab (size: {}) with {} dummy '
'tokens (new size: {})'.format(
before, after - before, after))
# Need to broadcast num_tokens and num_type_tokens.
token_counts = torch.cuda.LongTensor([after,
tokenizer.num_type_tokens,
int(args.do_train), int(args.do_valid), int(args.do_test)])
else:
token_counts = torch.cuda.LongTensor([0, 0, 0, 0, 0])
# Broadcast num tokens.
torch.distributed.broadcast(token_counts,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
num_tokens = token_counts[0].item()
num_type_tokens = token_counts[1].item()
args.do_train = token_counts[2].item()
args.do_valid = token_counts[3].item()
args.do_test = token_counts[4].item()
return train_data, val_data, test_data, num_tokens, num_type_tokens
def main():
"""Main training program."""
print('Pretrain BERT model')
# Disable CuDNN.
torch.backends.cudnn.enabled = False
......@@ -407,85 +517,65 @@ def main():
# Pytorch distributed.
initialize_distributed(args)
if torch.distributed.get_rank() == 0:
print('Pretrain BERT model')
print_args(args)
# Random seeds for reproducability.
set_random_seed(args.seed)
# Data stuff.
data_config = configure_data()
data_config.set_defaults(data_set_type='BERT', transpose=False)
(train_data, val_data, test_data), tokenizer = data_config.apply(args)
args.data_size = tokenizer.num_tokens
train_data, val_data, test_data, args.tokenizer_num_tokens, \
args.tokenizer_num_type_tokens = get_train_val_test_data(args)
# Model, optimizer, and learning rate.
model, optimizer, lr_scheduler, criterion = setup_model_and_optimizer(
args, tokenizer)
# At any point you can hit Ctrl + C to break out of training early.
try:
total_iters = 0
skipped_iters = 0
start_epoch = 1
best_val_loss = float('inf')
# Resume data loader if necessary.
if args.resume_dataloader:
start_epoch = args.epoch
total_iters = args.total_iters
train_data.batch_sampler.start_iter = total_iters % len(train_data)
# For all epochs.
for epoch in range(start_epoch, args.epochs+1):
if args.shuffle:
train_data.batch_sampler.sampler.set_epoch(epoch+args.seed)
timers('epoch time').start()
iteration, skipped = train_epoch(epoch, model, optimizer,
train_data, lr_scheduler,
criterion, timers, args)
elapsed_time = timers('epoch time').elapsed()
total_iters += iteration
skipped_iters += skipped
lm_loss, nsp_loss = evaluate(val_data, model, criterion, args)
val_loss = lm_loss + nsp_loss
print('-' * 100)
print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:.4E} | '
'valid LM Loss {:.4E} | valid NSP Loss {:.4E}'.format(
epoch, elapsed_time, val_loss, lm_loss, nsp_loss))
print('-' * 100)
if val_loss < best_val_loss:
best_val_loss = val_loss
if args.save:
best_path = 'best/model.pt'
print('saving best model to:',
os.path.join(args.save, best_path))
save_checkpoint(best_path, epoch+1, total_iters, model,
optimizer, lr_scheduler, args)
except KeyboardInterrupt:
print('-' * 100)
print('Exiting from training early')
if args.save:
cur_path = 'current/model.pt'
print('saving current model to:',
os.path.join(args.save, cur_path))
save_checkpoint(cur_path, epoch, total_iters, model, optimizer,
lr_scheduler, args)
exit()
if args.save:
final_path = 'final/model.pt'
print('saving final model to:', os.path.join(args.save, final_path))
save_checkpoint(final_path, args.epochs, total_iters, model, optimizer,
lr_scheduler, args)
model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
if args.resume_dataloader:
if train_data is not None:
train_data.batch_sampler.start_iter = args.iteration % \
len(train_data)
if val_data is not None:
start_iter_val = (args.train_iters // args.save_interval) * \
args.eval_interval
val_data.batch_sampler.start_iter = start_iter_val % \
len(val_data)
if train_data is not None:
train_data_iterator = iter(train_data)
else:
train_data_iterator = None
if val_data is not None:
val_data_iterator = iter(val_data)
else:
val_data_iterator = None
iteration = 0
if args.train_iters > 0:
if args.do_train:
iteration, skipped = train(model, optimizer,
lr_scheduler,
train_data_iterator,
val_data_iterator,
timers, args)
if args.do_valid:
prefix = 'the end of training for val data'
val_loss = evaluate_and_print_results(prefix, val_data_iterator,
model, args, timers, False)
if args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler, args)
if test_data is not None:
test_data_iterator = iter(test_data)
else:
test_data_iterator = None
if args.do_test:
# Run on test data.
print('entering test')
lm_loss, nsp_loss = evaluate(test_data, model, criterion, args)
test_loss = lm_loss + nsp_loss
print('=' * 100)
print('| End of training | test loss {:5.4f} | valid LM Loss {:.4E} |'
' valid NSP Loss {:.4E}'.format(test_loss, lm_loss, nsp_loss))
print('=' * 100)
prefix = 'the end of training for test data'
evaluate_and_print_results(prefix, test_data_iterator,
model, args, timers, True)
if __name__ == "__main__":
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain GPT2"""
# Flag to use Pytorch ddp which uses overlapping communication and computation.
USE_TORCH_DDP = False
from datetime import datetime
import os
import random
import math
import numpy as np
import torch
from arguments import get_args
from configure_data import configure_data
from fp16 import FP16_Module
from fp16 import FP16_Optimizer
from learning_rates import AnnealingLR
from model import GPT2Model
from model import gpt2_get_params_for_weight_decay_optimization
if USE_TORCH_DDP:
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
else:
from model import DistributedDataParallel as DDP
import mpu
from apex.optimizers import FusedAdam as Adam
from utils import Timers
from utils import save_checkpoint
from utils import load_checkpoint
from utils import report_memory
from utils import print_args
from utils import print_params_min_max_norm
from utils import print_rank_0
from gpt2_data_loader import make_gpt2_dataloaders
def get_model(args):
"""Build the model."""
print_rank_0('building GPT2 model ...')
model = GPT2Model(num_layers=args.num_layers,
vocab_size=args.vocab_size,
hidden_size=args.hidden_size,
num_attention_heads=args.num_attention_heads,
embedding_dropout_prob=args.hidden_dropout,
attention_dropout_prob=args.attention_dropout,
output_dropout_prob=args.hidden_dropout,
max_sequence_length=args.max_position_embeddings,
checkpoint_activations=args.checkpoint_activations,
checkpoint_num_layers=args.checkpoint_num_layers,
parallel_output=True)
if mpu.get_data_parallel_rank() == 0:
print(' > number of parameters on model parallel rank {}: {}'.format(
mpu.get_model_parallel_rank(),
sum([p.nelement() for p in model.parameters()])), flush=True)
# GPU allocation.
model.cuda(torch.cuda.current_device())
# Fp16 conversion.
if args.fp16:
model = FP16_Module(model)
# Wrap model for distributed training.
if USE_TORCH_DDP:
i = torch.cuda.current_device()
model = DDP(model, device_ids=[i], output_device=i,
process_group=mpu.get_data_parallel_group())
else:
model = DDP(model)
return model
def get_optimizer(model, args):
"""Set up the optimizer."""
# Build parameter groups (weight decay and non-decay).
while isinstance(model, (DDP, FP16_Module)):
model = model.module
param_groups = gpt2_get_params_for_weight_decay_optimization(model)
# Add model parallel attribute if it is not set.
for param_group in param_groups:
for param in param_group['params']:
if not hasattr(param, 'model_parallel'):
param.model_parallel = False
# Use Adam.
optimizer = Adam(param_groups,
lr=args.lr, weight_decay=args.weight_decay)
# Wrap into fp16 optimizer.
if args.fp16:
optimizer = FP16_Optimizer(optimizer,
static_loss_scale=args.loss_scale,
dynamic_loss_scale=args.dynamic_loss_scale,
dynamic_loss_args={
'scale_window': args.loss_scale_window,
'min_scale':args.min_scale,
'delayed_shift': args.hysteresis})
return optimizer
def get_learning_rate_scheduler(optimizer, args):
"""Build the learning rate scheduler."""
# Add linear learning rate scheduler.
if args.lr_decay_iters is not None:
num_iters = args.lr_decay_iters
else:
num_iters = args.train_iters
num_iters = max(1, num_iters)
init_step = -1
warmup_iter = args.warmup * num_iters
lr_scheduler = AnnealingLR(optimizer,
start_lr=args.lr,
warmup_iter=warmup_iter,
num_iters=num_iters,
decay_style=args.lr_decay_style,
last_iter=init_step)
return lr_scheduler
def setup_model_and_optimizer(args):
"""Setup model and optimizer."""
model = get_model(args)
optimizer = get_optimizer(model, args)
lr_scheduler = get_learning_rate_scheduler(optimizer, args)
if args.load is not None:
args.iteration = load_checkpoint(model, optimizer, lr_scheduler, args)
else:
args.iteration = 0
return model, optimizer, lr_scheduler
def get_masks_and_position_ids(data,
eod_token,
reset_position_ids,
reset_attention_mask):
# Extract batch size and sequence length.
batch_size, seq_length = data.size()
# Attention mask (lower triangular).
if reset_attention_mask:
att_mask_batch = batch_size
else:
att_mask_batch = 1
attention_mask = torch.tril(torch.ones(
(att_mask_batch, seq_length, seq_length), device=data.device)).view(
att_mask_batch, 1, seq_length, seq_length)
# Loss mask.
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
loss_mask[data == eod_token] = 0.0
# Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long,
device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
# We need to clone as the ids will be modifed based on batch index.
if reset_position_ids:
position_ids = position_ids.clone()
if reset_position_ids or reset_attention_mask:
# Loop through the batches:
for b in range(batch_size):
# Find indecies where EOD token is.
eod_index = position_ids[b, data[b] == eod_token]
# Detach indecies from positions if going to modify positions.
if reset_position_ids:
eod_index = eod_index.clone()
# Loop through EOD indecies:
prev_index = 0
for j in range(eod_index.size()[0]):
i = eod_index[j]
# Mask attention loss.
if reset_attention_mask:
attention_mask[b, 0, (i+1):, :(i+1)] = 0
# Reset positions.
if reset_position_ids:
position_ids[b, (i+1):] -= (i + 1 - prev_index)
prev_index = i + 1
return attention_mask, loss_mask, position_ids
def get_batch(data_iterator, args, timers):
''' get_batch subdivides the source data into chunks of
length args.seq_length. If source is equal to the example
output of the data loading example, with a seq_length limit
of 2, we'd get the following two Variables for i = 0:
┌ a g m s ┐ ┌ b h n t ┐
└ b h n t ┘ └ c i o u ┘
Note that despite the name of the function, the subdivison of data is not
done along the batch dimension (i.e. dimension 1), since that was handled
by the data loader. The chunks are along dimension 0, corresponding
to the seq_len dimension in the LSTM. A Variable representing an appropriate
shard reset mask of the same dimensions is also returned.
'''
# Items and their type.
keys = ['text']
datatype = torch.int64
# Broadcast data.
timers('data loader').start()
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
timers('data loader').stop()
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
tokens_ = data_b['text'].long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()
# Get the masks and postition ids.
attention_mask, loss_mask, position_ids = get_masks_and_position_ids(
tokens,
args.eod_token,
args.reset_position_ids,
args.reset_attention_mask)
# Convert
if args.fp16:
attention_mask = attention_mask.half()
return tokens, labels, loss_mask, attention_mask, position_ids
def forward_step(data_iterator, model, args, timers):
"""Forward step."""
# Get the batch.
timers('batch generator').start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator, args, timers)
timers('batch generator').stop()
# Forward model.
output = model(tokens, position_ids, attention_mask)
losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(),
labels)
loss_mask = loss_mask.view(-1)
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
return loss
def backward_step(optimizer, model, lm_loss, args, timers):
"""Backward step."""
# Total loss.
loss = lm_loss
# Backward pass.
optimizer.zero_grad()
if args.fp16:
optimizer.backward(loss, update_master_grads=False)
else:
loss.backward()
# Reduce across processes.
lm_loss_reduced = lm_loss
reduced_losses = lm_loss.view(1)
torch.distributed.all_reduce(reduced_losses.data)
reduced_losses.data = reduced_losses.data / args.world_size
if not USE_TORCH_DDP:
timers('allreduce').start()
model.allreduce_params(reduce_after=False,
fp32_allreduce=args.fp32_allreduce)
timers('allreduce').stop()
lm_loss_reduced = reduced_losses
# Update master gradients.
if args.fp16:
optimizer.update_master_grads()
# Clipping gradients helps prevent the exploding gradient.
if args.clip_grad > 0:
if not args.fp16:
mpu.clip_grad_norm(model.parameters(), args.clip_grad)
else:
optimizer.clip_master_grads(args.clip_grad)
return lm_loss_reduced
def train_step(data_iterator, model, optimizer, lr_scheduler,
args, timers):
"""Single training step."""
# Forward model for one step.
timers('forward').start()
lm_loss = forward_step(data_iterator, model, args, timers)
timers('forward').stop()
# Calculate gradients, reduce across processes, and clip.
timers('backward').start()
lm_loss_reduced = backward_step(optimizer, model, lm_loss, args, timers)
timers('backward').stop()
# Update parameters.
timers('optimizer').start()
optimizer.step()
timers('optimizer').stop()
# Update learning rate.
skipped_iter = 0
if not (args.fp16 and optimizer.overflow):
lr_scheduler.step()
else:
skipped_iter = 1
return lm_loss_reduced, skipped_iter
def train(model, optimizer, lr_scheduler,
train_data_iterator, val_data_iterator, timers, args):
"""Train the model."""
# Turn on training mode which enables dropout.
model.train()
# Tracking loss.
total_lm_loss = 0.0
# Iterations.
iteration = args.iteration
skipped_iters = 0
timers('interval time').start()
report_memory_flag = True
while iteration < args.train_iters:
lm_loss, skipped_iter = train_step(train_data_iterator,
model,
optimizer,
lr_scheduler,
args, timers)
skipped_iters += skipped_iter
iteration += 1
# Update losses.
total_lm_loss += lm_loss.data.detach().float()
# Logging.
if iteration % args.log_interval == 0:
learning_rate = optimizer.param_groups[0]['lr']
avg_lm_loss = total_lm_loss.item() / args.log_interval
elapsed_time = timers('interval time').elapsed()
log_string = ' iteration {:8d}/{:8d} |'.format(iteration,
args.train_iters)
log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
elapsed_time * 1000.0 / args.log_interval)
log_string += ' learning rate {:.3E} |'.format(learning_rate)
log_string += ' lm loss {:.6E} |'.format(avg_lm_loss)
if args.fp16:
log_string += ' loss scale {:.1f} |'.format(
optimizer.loss_scale)
print_rank_0(log_string)
total_lm_loss = 0.0
if report_memory_flag:
report_memory('after {} iterations'.format(iteration))
report_memory_flag = False
if USE_TORCH_DDP:
timers.log(['forward', 'backward', 'optimizer',
'batch generator', 'data loader'],
normalizer=args.log_interval)
else:
timers.log(['forward', 'backward', 'allreduce', 'optimizer',
'batch generator', 'data loader'],
normalizer=args.log_interval)
# Checkpointing
if args.save and args.save_interval and iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler, args)
# Evaluation
if args.eval_interval and iteration % args.eval_interval == 0 and args.do_valid:
prefix = 'iteration {}'.format(iteration)
evaluate_and_print_results(
prefix, val_data_iterator, model, args, timers, False)
if args.exit_interval and iteration % args.exit_interval == 0:
torch.distributed.barrier()
time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
rank = torch.distributed.get_rank()
print('rank: {} | time: {} | exiting the program at iteration {}'.
format(rank, time_str, iteration), flush=True)
exit()
return iteration, skipped_iters
def evaluate(data_iterator, model, args, timers, verbose=False):
"""Evaluation."""
# Turn on evaluation mode which disables dropout.
model.eval()
total_lm_loss = 0
with torch.no_grad():
iteration = 0
while iteration < args.eval_iters:
iteration += 1
if verbose and iteration % args.log_interval == 0:
print_rank_0('Evaluating iter {}/{}'.format(iteration, args.eval_iters))
# Forward evaluation.
lm_loss = forward_step(data_iterator, model, args, timers)
# Reduce across processes.
if isinstance(model, DDP):
torch.distributed.all_reduce(lm_loss.data)
lm_loss.data = lm_loss.data / args.world_size
total_lm_loss += lm_loss.data.detach().float().item()
# Move model back to the train mode.
model.train()
total_lm_loss /= args.eval_iters
return total_lm_loss
def evaluate_and_print_results(prefix, data_iterator, model,
args, timers, verbose=False):
"""Helper function to evaluate and dump results on screen."""
lm_loss = evaluate(data_iterator, model, args, timers, verbose)
lm_ppl = math.exp(min(20, lm_loss))
print_rank_0('-' * 100)
string = ' validation loss at {} | '.format(prefix)
string += 'LM loss: {:.6E} | '.format(lm_loss)
string += 'LM PPL: {:.6E}'.format(lm_ppl)
length = len(string) + 1
print_rank_0('-' * length)
print_rank_0(string)
print_rank_0('-' * length)
return lm_loss
def initialize_distributed(args):
"""Initialize torch.distributed."""
# Manually set the device ids.
device = args.rank % torch.cuda.device_count()
if args.local_rank is not None:
device = args.local_rank
torch.cuda.set_device(device)
# Call the init process
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
master_port = os.getenv('MASTER_PORT', '6000')
init_method += master_ip + ':' + master_port
torch.distributed.init_process_group(
backend=args.distributed_backend,
world_size=args.world_size, rank=args.rank,
init_method=init_method)
# Set the model-parallel / data-parallel communicators.
mpu.initialize_model_parallel(args.model_parallel_size)
def set_random_seed(seed):
"""Set random seed for reproducability."""
if seed is not None and seed > 0:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
mpu.model_parallel_cuda_manual_seed(seed)
def get_train_val_test_data(args):
"""Load the data on rank zero and boradcast number of tokens to all GPUS."""
(train_data, val_data, test_data) = (None, None, None)
# Data loader only on rank 0 of each model parallel group.
if mpu.get_model_parallel_rank() == 0:
if args.use_npy_data_loader:
(train_data, val_data, test_data), num_tokens, \
eod_token = make_gpt2_dataloaders(args)
else:
data_config = configure_data()
data_config.set_defaults(data_set_type='GPT2', transpose=False)
(train_data, val_data, test_data), tokenizer = data_config.apply(
args)
num_tokens = tokenizer.num_tokens
eod_token = tokenizer.get_command('eos').Id
assert eod_token == tokenizer.get_command('pad').Id
before = num_tokens
after = before
multiple = args.make_vocab_size_divisible_by * \
mpu.get_model_parallel_world_size()
while (after % multiple) != 0:
after += 1
print_rank_0('> padded vocab (size: {}) with {} dummy '
'tokens (new size: {})'.format(
before, after - before, after))
print_rank_0('> found end-of-document token: {}'.format(eod_token))
token_counts = torch.cuda.LongTensor([after, eod_token, int(args.do_train), int(args.do_valid), int(args.do_test)])
else:
token_counts = torch.cuda.LongTensor([0, 0, 0, 0, 0])
# Broadcast num tokens.
torch.distributed.broadcast(token_counts,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
num_tokens = token_counts[0].item()
eod_token = token_counts[1].item()
args.do_train = token_counts[2].item()
args.do_valid = token_counts[3].item()
args.do_test = token_counts[4].item()
return train_data, val_data, test_data, num_tokens, eod_token
def main():
"""Main training program."""
# Disable CuDNN.
torch.backends.cudnn.enabled = False
# Timer.
timers = Timers()
# Arguments.
args = get_args()
# Pytorch distributed.
initialize_distributed(args)
if torch.distributed.get_rank() == 0:
print('Pretrain GPT2 model')
print_args(args)
# Random seeds for reproducability.
set_random_seed(args.seed)
# Data stuff.
train_data, val_data, test_data, args.vocab_size, \
args.eod_token = get_train_val_test_data(args)
# Model, optimizer, and learning rate.
model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
# Resume data loader if necessary.
if args.resume_dataloader:
if train_data is not None:
train_data.batch_sampler.start_iter = args.iteration % \
len(train_data)
if val_data is not None:
start_iter_val = (args.train_iters // args.save_interval) * \
args.eval_interval
val_data.batch_sampler.start_iter = start_iter_val % \
len(val_data)
if train_data is not None:
train_data_iterator = iter(train_data)
else:
train_data_iterator = None
if val_data is not None:
val_data_iterator = iter(val_data)
else:
val_data_iterator = None
#TODO: figure out how to properly set this especially when resuming training
iteration = 0
if args.train_iters > 0:
if args.do_train:
iteration, skipped = train(model, optimizer,
lr_scheduler,
train_data_iterator,
val_data_iterator,
timers, args)
if args.do_valid:
prefix = 'the end of training for val data'
val_loss = evaluate_and_print_results(prefix, val_data_iterator,
model, args, timers, False)
if args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer,
lr_scheduler, args)
if test_data is not None:
test_data_iterator = iter(test_data)
else:
test_data_iterator = None
if args.do_test:
# Run on test data.
prefix = 'the end of training for test data'
evaluate_and_print_results(prefix, test_data_iterator,
model, args, timers, True)
if __name__ == "__main__":
main()
#!/bin/bash
CHECKPOINT_PATH=/path/to/checkpoint
MPSIZE=1
NLAYERS=24
NHIDDEN=1024
NATT=16
MAXSEQLEN=1024
#SAMPLING ARGS
TEMP=0.9
#If TOPK/TOPP are 0 it defaults to greedy sampling, top-k will also override top-p
TOPK=0
TOPP=0
python generate_samples.py \
--model-parallel-size $MPSIZE \
--num-layers $NLAYERS \
--hidden-size $NHIDDEN \
--load $CHECKPOINT_PATH \
--num-attention-heads $NATT \
--max-position-embeddings 1024 \
--tokenizer-type GPT2BPETokenizer \
--fp16 \
--cache-dir cache \
--out-seq-length $MAXSEQLEN \
--temperature $TEMP \
--top_k $TOPK \
--top_p $TOPP
......@@ -4,35 +4,31 @@ RANK=0
WORLD_SIZE=1
python pretrain_bert.py \
--batch-size 4 \
--tokenizer-type BertWordPieceTokenizer \
--cache-dir cache_dir \
--tokenizer-model-type bert-large-uncased \
--vocab-size 30522 \
--train-data wikipedia \
--presplit-sentences \
--loose-json \
--text-key text \
--split 1000,1,1 \
--lazy-loader \
--max-preds-per-seq 80 \
--seq-length 512 \
--max-position-embeddings 512 \
--num-layers 24 \
--hidden-size 1024 \
--intermediate-size 4096 \
--num-attention-heads 16 \
--hidden-dropout 0.1 \
--attention-dropout 0.1 \
--train-iters 1000000 \
--lr 0.0001 \
--lr-decay-style linear \
--lr-decay-iters 990000 \
--warmup .01 \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--fp16 \
--fp32-layernorm \
--fp32-embedding \
--hysteresis 2 \
--num-workers 2
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--batch-size 4 \
--seq-length 512 \
--max-preds-per-seq 80 \
--max-position-embeddings 512 \
--train-iters 1000000 \
--save checkpoints/bert_345m \
--load checkpoints/bert_345m \
--resume-dataloader \
--train-data wikipedia \
--lazy-loader \
--tokenizer-type BertWordPieceTokenizer \
--tokenizer-model-type bert-large-uncased \
--presplit-sentences \
--cache-dir cache \
--split 949,50,1 \
--distributed-backend nccl \
--lr 0.0001 \
--lr-decay-style linear \
--lr-decay-iters 990000 \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--warmup .01 \
--fp16 \
--fp32-layernorm \
--fp32-embedding
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment