Commit 22c0e300 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'staging' into 'master'

April 2020 Release

See merge request ADLR/megatron-lm!69
parents 70174ae3 acf8780b
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -72,7 +72,6 @@ class _VocabParallelCrossEntropy(torch.autograd.Function): ...@@ -72,7 +72,6 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
op=torch.distributed.ReduceOp.SUM, op=torch.distributed.ReduceOp.SUM,
group=get_model_parallel_group()) group=get_model_parallel_group())
# Loss = log(sum(exp(logits))) - predicted-logit. # Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits loss = torch.log(sum_exp_logits) - predicted_logits
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -21,10 +21,47 @@ ...@@ -21,10 +21,47 @@
import torch import torch
from torch._six import inf from torch._six import inf
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from .initialize import get_model_parallel_group from .initialize import get_model_parallel_group
from .initialize import get_model_parallel_rank from .initialize import get_model_parallel_rank
def l2_grad_clipper(parameters, max_norm):
"""Efficient L2 norm gradient clipping."""
overflow_buf = torch.zeros(1, dtype=torch.int, device='cuda')
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
mp_rank_is_zero = (get_model_parallel_rank() == 0)
parameters = list(filter(lambda p: (p.grad is not None) and
(p.model_parallel or mp_rank_is_zero),
parameters))
norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
overflow_buf,
[parameters],
False # no per-parameter norm
)
# Sum across all model parallel GPUs.
norm_2 = norm * norm
torch.distributed.all_reduce(norm_2,
op=torch.distributed.ReduceOp.SUM,
group=get_model_parallel_group())
total_norm = norm_2.item() ** 0.5
clip_coef = max_norm / (total_norm + 1e-6)
grads = [p.grad for p in parameters]
if clip_coef < 1:
multi_tensor_applier(
amp_C.multi_tensor_scale,
overflow_buf,
[grads, grads],
clip_coef)
return total_norm
def clip_grad_norm(parameters, max_norm, norm_type=2): def clip_grad_norm(parameters, max_norm, norm_type=2):
"""Clips gradient norm of an iterable of parameters. """Clips gradient norm of an iterable of parameters.
...@@ -55,6 +92,13 @@ def clip_grad_norm(parameters, max_norm, norm_type=2): ...@@ -55,6 +92,13 @@ def clip_grad_norm(parameters, max_norm, norm_type=2):
op=torch.distributed.ReduceOp.MAX, op=torch.distributed.ReduceOp.MAX,
group=get_model_parallel_group()) group=get_model_parallel_group())
total_norm = total_norm_cuda[0].item() total_norm = total_norm_cuda[0].item()
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1:
for p in parameters:
p.grad.data.mul_(clip_coef)
elif norm_type == 2:
total_norm = l2_grad_clipper(parameters, max_norm)
else: else:
total_norm = 0 total_norm = 0
for p in parameters: for p in parameters:
...@@ -67,8 +111,8 @@ def clip_grad_norm(parameters, max_norm, norm_type=2): ...@@ -67,8 +111,8 @@ def clip_grad_norm(parameters, max_norm, norm_type=2):
op=torch.distributed.ReduceOp.SUM, op=torch.distributed.ReduceOp.SUM,
group=get_model_parallel_group()) group=get_model_parallel_group())
total_norm = total_norm_cuda[0].item() ** (1. / norm_type) total_norm = total_norm_cuda[0].item() ** (1. / norm_type)
clip_coef = max_norm / (total_norm + 1e-6) clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1: if clip_coef < 1:
for p in parameters: for p in parameters:
p.grad.data.mul_(clip_coef) p.grad.data.mul_(clip_coef)
return total_norm return total_norm
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -89,6 +89,7 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -89,6 +89,7 @@ class VocabParallelEmbedding(torch.nn.Module):
embedding_dim: size of hidden state. embedding_dim: size of hidden state.
init_method: method to initialize weights. init_method: method to initialize weights.
""" """
def __init__(self, num_embeddings, embedding_dim, def __init__(self, num_embeddings, embedding_dim,
init_method=init.xavier_normal_): init_method=init.xavier_normal_):
super(VocabParallelEmbedding, self).__init__() super(VocabParallelEmbedding, self).__init__()
...@@ -108,7 +109,7 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -108,7 +109,7 @@ class VocabParallelEmbedding(torch.nn.Module):
self.num_embeddings, get_model_parallel_rank(), self.num_embeddings, get_model_parallel_rank(),
get_model_parallel_world_size()) get_model_parallel_world_size())
self.num_embeddings_per_partition = self.vocab_end_index - \ self.num_embeddings_per_partition = self.vocab_end_index - \
self.vocab_start_index self.vocab_start_index
# Allocate weights. # Allocate weights.
self.weight = Parameter(torch.Tensor(self.num_embeddings_per_partition, self.weight = Parameter(torch.Tensor(self.num_embeddings_per_partition,
...@@ -147,6 +148,7 @@ class ParallelEmbedding(torch.nn.Module): ...@@ -147,6 +148,7 @@ class ParallelEmbedding(torch.nn.Module):
embedding_dim: size of hidden state. embedding_dim: size of hidden state.
init_method: method to initialize weights. init_method: method to initialize weights.
""" """
def __init__(self, num_embeddings, embedding_dim, def __init__(self, num_embeddings, embedding_dim,
init_method=init.xavier_normal_, init_method=init.xavier_normal_,
keep_master_weight_for_test=False): keep_master_weight_for_test=False):
...@@ -205,6 +207,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -205,6 +207,7 @@ class ColumnParallelLinear(torch.nn.Module):
set to False. It returns the master weights set to False. It returns the master weights
used for initialization. used for initialization.
""" """
def __init__(self, input_size, output_size, bias=True, gather_output=True, def __init__(self, input_size, output_size, bias=True, gather_output=True,
init_method=init.xavier_normal_, stride=1, init_method=init.xavier_normal_, stride=1,
keep_master_weight_for_test=False): keep_master_weight_for_test=False):
...@@ -279,6 +282,7 @@ class RowParallelLinear(torch.nn.Module): ...@@ -279,6 +282,7 @@ class RowParallelLinear(torch.nn.Module):
set to False. It returns the master weights set to False. It returns the master weights
used for initialization. used for initialization.
""" """
def __init__(self, input_size, output_size, bias=True, def __init__(self, input_size, output_size, bias=True,
input_is_parallel=False, input_is_parallel=False,
init_method=init.xavier_normal_, stride=1, init_method=init.xavier_normal_, stride=1,
...@@ -327,4 +331,3 @@ class RowParallelLinear(torch.nn.Module): ...@@ -327,4 +331,3 @@ class RowParallelLinear(torch.nn.Module):
else: else:
output = output_ output = output_
return output return output
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -131,11 +131,14 @@ class _GatherFromModelParallelRegion(torch.autograd.Function): ...@@ -131,11 +131,14 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
def copy_to_model_parallel_region(input_): def copy_to_model_parallel_region(input_):
return _CopyToModelParallelRegion.apply(input_) return _CopyToModelParallelRegion.apply(input_)
def reduce_from_model_parallel_region(input_): def reduce_from_model_parallel_region(input_):
return _ReduceFromModelParallelRegion.apply(input_) return _ReduceFromModelParallelRegion.apply(input_)
def scatter_to_model_parallel_region(input_): def scatter_to_model_parallel_region(input_):
return _ScatterToModelParallelRegion.apply(input_) return _ScatterToModelParallelRegion.apply(input_)
def gather_from_model_parallel_region(input_): def gather_from_model_parallel_region(input_):
return _GatherFromModelParallelRegion.apply(input_) return _GatherFromModelParallelRegion.apply(input_)
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -73,6 +73,7 @@ class CudaRNGStatesTracker: ...@@ -73,6 +73,7 @@ class CudaRNGStatesTracker:
rng state, we can perform operations and return to our starting rng state, we can perform operations and return to our starting
cuda state. cuda state.
""" """
def __init__(self): def __init__(self):
# Map from a string name to the cuda rng state. # Map from a string name to the cuda rng state.
self.states_ = {} self.states_ = {}
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -26,6 +26,7 @@ class IdentityLayer(torch.nn.Module): ...@@ -26,6 +26,7 @@ class IdentityLayer(torch.nn.Module):
def __init__(self, size, scale=1.0): def __init__(self, size, scale=1.0):
super(IdentityLayer, self).__init__() super(IdentityLayer, self).__init__()
self.weight = torch.nn.Parameter(scale * torch.randn(size)) self.weight = torch.nn.Parameter(scale * torch.randn(size))
def forward(self): def forward(self):
return self.weight return self.weight
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -13,20 +13,18 @@ ...@@ -13,20 +13,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from commons import set_random_seed
from commons import IdentityLayer
from commons import print_separator
from commons import initialize_distributed
from mpu.cross_entropy import vocab_parallel_cross_entropy
import mpu
import torch.nn.functional as F
import torch
import random import random
import sys import sys
sys.path.append("../..") sys.path.append("../..")
import torch
import torch.nn.functional as F
import mpu
from mpu.cross_entropy import vocab_parallel_cross_entropy
from commons import initialize_distributed
from commons import print_separator
from commons import IdentityLayer
from commons import set_random_seed
def torch_cross_entropy(batch_size, seq_length, vocab_size, def torch_cross_entropy(batch_size, seq_length, vocab_size,
logits_scale, seed): logits_scale, seed):
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -13,18 +13,16 @@ ...@@ -13,18 +13,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from commons import print_separator
from commons import initialize_distributed
from mpu import data as data_utils
import mpu
import torch
import functools import functools
import operator import operator
import sys import sys
sys.path.append("../..") sys.path.append("../..")
import torch
import mpu
from mpu import data as data_utils
from commons import initialize_distributed
from commons import print_separator
def test_boradcast_data(model_parallel_size): def test_boradcast_data(model_parallel_size):
...@@ -88,5 +86,3 @@ if __name__ == '__main__': ...@@ -88,5 +86,3 @@ if __name__ == '__main__':
print_separator('test test boradcast data') print_separator('test test boradcast data')
test_boradcast_data(model_parallel_size) test_boradcast_data(model_parallel_size)
model_parallel_size *= 2 model_parallel_size *= 2
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -13,15 +13,13 @@ ...@@ -13,15 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from commons import print_separator
from commons import initialize_distributed
import mpu
import torch
import sys import sys
sys.path.append("../..") sys.path.append("../..")
import torch
import mpu
from commons import initialize_distributed
from commons import print_separator
def test_initialize_model_parallel(model_parallel_size): def test_initialize_model_parallel(model_parallel_size):
...@@ -46,7 +44,6 @@ def test_initialize_model_parallel(model_parallel_size): ...@@ -46,7 +44,6 @@ def test_initialize_model_parallel(model_parallel_size):
assert rank == mpu.get_model_parallel_rank() assert rank == mpu.get_model_parallel_rank()
check(mpu.get_model_parallel_group(), world_size, rank) check(mpu.get_model_parallel_group(), world_size, rank)
# Data parallel. # Data parallel.
world_size = torch.distributed.get_world_size() // model_parallel_size_ world_size = torch.distributed.get_world_size() // model_parallel_size_
rank = torch.distributed.get_rank() // model_parallel_size rank = torch.distributed.get_rank() // model_parallel_size
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -13,20 +13,18 @@ ...@@ -13,20 +13,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from mpu import layers
from commons import set_random_seed
from commons import print_separator
from commons import initialize_distributed
import mpu
from torch.nn.parameter import Parameter
import torch.nn.init as init
import torch
import random import random
import sys import sys
sys.path.append("../..") sys.path.append("../..")
import torch
import torch.nn.init as init
from torch.nn.parameter import Parameter
import mpu
from commons import initialize_distributed
from commons import print_separator
from commons import set_random_seed
from mpu import layers
def test_parallel_embedding(model_parallel_size): def test_parallel_embedding(model_parallel_size):
...@@ -45,7 +43,7 @@ def test_parallel_embedding(model_parallel_size): ...@@ -45,7 +43,7 @@ def test_parallel_embedding(model_parallel_size):
set_random_seed(123) set_random_seed(123)
input_data = torch.LongTensor( input_data = torch.LongTensor(
size=(batch_size,seq_length)).random_(0, vocab_size).cuda() size=(batch_size, seq_length)).random_(0, vocab_size).cuda()
loss_weight = torch.randn([batch_size, seq_length, hidden_size]).cuda() loss_weight = torch.randn([batch_size, seq_length, hidden_size]).cuda()
set_random_seed(seed) set_random_seed(seed)
...@@ -57,7 +55,7 @@ def test_parallel_embedding(model_parallel_size): ...@@ -57,7 +55,7 @@ def test_parallel_embedding(model_parallel_size):
set_random_seed(seed) set_random_seed(seed)
embedding_parallel = layers.ParallelEmbedding( embedding_parallel = layers.ParallelEmbedding(
vocab_size, hidden_size, init_method=init.normal_).cuda() vocab_size, hidden_size, init_method=init.normal_).cuda()
output = embedding_parallel(input_data) output = embedding_parallel(input_data)
loss_parallel = torch.mul(output, loss_weight).sum() loss_parallel = torch.mul(output, loss_weight).sum()
loss_parallel.backward() loss_parallel.backward()
...@@ -176,10 +174,11 @@ def test_initialize_affine_weight(model_parallel_size): ...@@ -176,10 +174,11 @@ def test_initialize_affine_weight(model_parallel_size):
class IdentityLayer2D(torch.nn.Module): class IdentityLayer2D(torch.nn.Module):
def __init__(self, m , n): def __init__(self, m, n):
super(IdentityLayer2D, self).__init__() super(IdentityLayer2D, self).__init__()
self.weight = Parameter(torch.Tensor(m, n)) self.weight = Parameter(torch.Tensor(m, n))
torch.nn.init.xavier_normal_(self.weight) torch.nn.init.xavier_normal_(self.weight)
def forward(self): def forward(self):
return self.weight return self.weight
...@@ -317,10 +316,11 @@ def test_row_parallel_linear(model_parallel_size): ...@@ -317,10 +316,11 @@ def test_row_parallel_linear(model_parallel_size):
class IdentityLayer3D(torch.nn.Module): class IdentityLayer3D(torch.nn.Module):
def __init__(self, m , n, k): def __init__(self, m, n, k):
super(IdentityLayer3D, self).__init__() super(IdentityLayer3D, self).__init__()
self.weight = Parameter(torch.Tensor(m, n, k)) self.weight = Parameter(torch.Tensor(m, n, k))
torch.nn.init.xavier_normal_(self.weight) torch.nn.init.xavier_normal_(self.weight)
def forward(self): def forward(self):
return self.weight return self.weight
...@@ -335,14 +335,14 @@ def parallel_self_attention(model_parallel_size, num_att_heads_per_partition, ...@@ -335,14 +335,14 @@ def parallel_self_attention(model_parallel_size, num_att_heads_per_partition,
set_random_seed(seed) set_random_seed(seed)
num_att_heads = num_att_heads_per_partition * \ num_att_heads = num_att_heads_per_partition * \
torch.distributed.get_world_size() torch.distributed.get_world_size()
hidden_size = hidden_size_per_att_head * num_att_heads hidden_size = hidden_size_per_att_head * num_att_heads
# Network # Network
identity_layer = IdentityLayer3D(batch_size, sequence_length, identity_layer = IdentityLayer3D(batch_size, sequence_length,
hidden_size).cuda() hidden_size).cuda()
attention_layer = mpu.BertParallelSelfAttention(hidden_size, num_att_heads, attention_layer = mpu.BertParallelSelfAttention(hidden_size, num_att_heads,
dropout_prob).cuda() dropout_prob).cuda()
loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda() loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda()
attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda() attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda()
# Forward # Forward
...@@ -366,17 +366,17 @@ def test_parallel_self_attention(model_parallel_size): ...@@ -366,17 +366,17 @@ def test_parallel_self_attention(model_parallel_size):
num_att_heads_per_partition = 3 num_att_heads_per_partition = 3
hidden_size_per_att_head = 7 hidden_size_per_att_head = 7
dropout_prob = 0.0 # has to be zero dropout_prob = 0.0 # has to be zero
batch_size = 5 batch_size = 5
sequence_length = 13 sequence_length = 13
rank_1, hideen_size_1, model_parallel_size_1, loss_1, \ rank_1, hideen_size_1, model_parallel_size_1, loss_1, \
attention_layer_1, identity_layer_1 =parallel_self_attention( attention_layer_1, identity_layer_1 = parallel_self_attention(
1, num_att_heads_per_partition, 1, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size, sequence_length) hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
rank, hidden_size, model_parallel_size, loss, \ rank, hidden_size, model_parallel_size, loss, \
attention_layer, identity_layer =parallel_self_attention( attention_layer, identity_layer = parallel_self_attention(
model_parallel_size, num_att_heads_per_partition, model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size, sequence_length) hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
assert hideen_size_1 == hidden_size assert hideen_size_1 == hidden_size
...@@ -409,6 +409,7 @@ def test_parallel_self_attention(model_parallel_size): ...@@ -409,6 +409,7 @@ def test_parallel_self_attention(model_parallel_size):
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)') print(' >> passed the test :-)')
def parallel_transformer(model_parallel_size, num_att_heads_per_partition, def parallel_transformer(model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length): hidden_size_per_att_head, batch_size, sequence_length):
...@@ -419,7 +420,7 @@ def parallel_transformer(model_parallel_size, num_att_heads_per_partition, ...@@ -419,7 +420,7 @@ def parallel_transformer(model_parallel_size, num_att_heads_per_partition,
set_random_seed(seed) set_random_seed(seed)
num_att_heads = num_att_heads_per_partition * \ num_att_heads = num_att_heads_per_partition * \
torch.distributed.get_world_size() torch.distributed.get_world_size()
hidden_size = hidden_size_per_att_head * num_att_heads hidden_size = hidden_size_per_att_head * num_att_heads
intermediate_size = 4 * hidden_size intermediate_size = 4 * hidden_size
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -13,15 +13,13 @@ ...@@ -13,15 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from commons import print_separator
from commons import initialize_distributed
import mpu
import torch
import sys import sys
sys.path.append("../..") sys.path.append("../..")
import torch
import mpu
from commons import initialize_distributed
from commons import print_separator
def test_set_cuda_rng_state(model_parallel_size): def test_set_cuda_rng_state(model_parallel_size):
...@@ -204,4 +202,3 @@ if __name__ == '__main__': ...@@ -204,4 +202,3 @@ if __name__ == '__main__':
print_separator('test model parallel cuda manual seed') print_separator('test model parallel cuda manual seed')
test_model_parallel_cuda_manual_seed(model_parallel_size) test_model_parallel_cuda_manual_seed(model_parallel_size)
model_parallel_size *= 2 model_parallel_size *= 2
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -13,150 +13,98 @@ ...@@ -13,150 +13,98 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Sample Generate GPT2""" """Utilities for generating text."""
import os
import random
import json
import copy import copy
import numpy as np import json
import torch import os
import torch.nn.functional as F
import argparse
import time import time
from arguments import get_args
from megatron.utils import Timers
from megatron.utils import initialize_distributed
from megatron.utils import set_random_seed
from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import load_checkpoint
from megatron.data_utils import make_tokenizer
from configure_data import configure_data
from megatron import mpu
from megatron.fp16 import FP16_Module
from megatron.model import GPT2Model
from megatron.model import DistributedDataParallel as DDP
from megatron.utils import print_rank_0
def get_model(args):
"""Build the model."""
print_rank_0('building GPT2 model ...')
model = GPT2Model(num_layers=args.num_layers,
vocab_size=args.vocab_size,
hidden_size=args.hidden_size,
num_attention_heads=args.num_attention_heads,
embedding_dropout_prob=args.hidden_dropout,
attention_dropout_prob=args.attention_dropout,
output_dropout_prob=args.hidden_dropout,
max_sequence_length=args.max_position_embeddings,
checkpoint_activations=args.checkpoint_activations,
checkpoint_num_layers=args.checkpoint_num_layers,
parallel_output=False)
if mpu.get_data_parallel_rank() == 0: import torch
print(' > number of parameters on model parallel rank {}: {}'.format( import torch.nn.functional as F
mpu.get_model_parallel_rank(),
sum([p.nelement() for p in model.parameters()])), flush=True)
# GPU allocation.
model.cuda(torch.cuda.current_device())
# Fp16 conversion.
if args.fp16:
model = FP16_Module(model)
# Wrap model for distributed training.
model = DDP(model)
return model
def setup_model(args):
"""Setup model and optimizer."""
model = get_model(args)
if args.load is not None:
_ = load_checkpoint(
model, None, None, args)
return model from megatron import get_args
from megatron import get_tokenizer
from megatron import mpu
from megatron.utils import get_ltor_masks_and_position_ids
def get_batch(context_tokens, args): def get_batch(context_tokens):
tokens = context_tokens """Generate batch from context tokens."""
tokens = tokens.view(args.batch_size, -1).contiguous() args = get_args()
device = args.device tokenizer = get_tokenizer()
tokens = tokens.to(device)
# Get the masks and postition ids. # Move to GPU.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( tokens = context_tokens.view(args.batch_size, -1).contiguous().cuda()
# Get the attention mask and postition ids.
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
tokens, tokens,
args.eod_token, tokenizer.eod,
args.reset_position_ids, args.reset_position_ids,
args.reset_attention_mask, args.reset_attention_mask,
False) args.eod_mask_loss)
# Fp16 conversion.
if args.fp16:
attention_mask = attention_mask.half()
return tokens, attention_mask, position_ids return tokens, attention_mask, position_ids
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
# This function has been mostly taken from huggingface conversational ai code at """ This function has been mostly taken from huggingface conversational
# https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313 ai code at
https://medium.com/huggingface/how-to-build-a-state-of-the-art-
conversational-ai-with-transfer-learning-2d818ac26313 """
if top_k > 0: if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k # Remove all tokens with a probability less than the
# last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value logits[indices_to_remove] = filter_value
if top_p > 0.0: if top_p > 0.0:
#convert to 1D # Cconvert to 1D
# logits=logits.view(logits.size()[1]).contiguous() sorted_logits, sorted_indices = torch.sort(
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1),
dim=-1)
# Remove tokens with cumulative probability above the threshold # Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold # Shift the indices to the right to keep also the first token
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() # above the threshold
sorted_indices_to_remove[..., 1:] \
= sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0 sorted_indices_to_remove[..., 0] = 0
for i in range(sorted_indices.size(0)): for i in range(sorted_indices.size(0)):
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]] indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
logits[i][indices_to_remove] = filter_value logits[i][indices_to_remove] = filter_value
#going back to 2D
# logits=logits.view(1, -1).contiguous()
return logits return logits
def generate_samples_input_from_file(model, tokenizer, args):
if args.sample_input_file == "": def generate_samples_input_from_file(model):
if mpu.get_model_parallel_rank() == 0:
print("args.sample_input_file CAN NOT BE empty!\n")
return
args = get_args()
tokenizer = get_tokenizer()
# Read the sample file and open the output file.
assert args.sample_input_file is not None, \
'sample input file is not provided.'
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
fname = open(args.sample_input_file, "r") fname = open(args.sample_input_file, "r")
all_raw_text = fname.readlines() all_raw_text = fname.readlines()
input_count = len(all_raw_text) input_count = len(all_raw_text)
input_pos = 0 input_pos = 0
if args.sample_output_file == "": if args.sample_output_file is None:
print("Argument: sample-output-file can't be empty, setting it to\n") sample_output_file = args.sample_input_file + ".out"
print("\t args.sample_input_file.out") print('could not find `sample-output-file`, setting '
args.sample_output_file = args.sample_input_file+".out" 'it to {}'.format(sample_output_file))
fname_out = open(args.sample_output_file, "w+") fname_out = open(sample_output_file, "w+")
context_count=0 context_count = 0
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
while True: while True:
torch.distributed.barrier(group=mpu.get_model_parallel_group()) torch.distributed.barrier(group=mpu.get_model_parallel_group())
terminate_runs=0 terminate_runs = 0
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
raw_text = all_raw_text[input_pos] raw_text = all_raw_text[input_pos]
...@@ -167,63 +115,62 @@ def generate_samples_input_from_file(model, tokenizer, args): ...@@ -167,63 +115,62 @@ def generate_samples_input_from_file(model, tokenizer, args):
if "stop" in raw_text: if "stop" in raw_text:
terminate_runs = 1 terminate_runs = 1
else: else:
context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization context_tokens = tokenizer.tokenize(raw_text)
context_length = len(context_tokens) context_length = len(context_tokens)
if context_length >=args.seq_length//2: if context_length >= (args.seq_length // 2):
print("\nContext length", context_length, \ print("\nContext length", context_length,
"\nPlease give smaller context (half of the sequence length)!") "\nPlease give smaller context (half of the "
"sequence length)!", flush=True)
continue continue
else: else:
context_tokens = tokenizer.EncodeAsIds("EMPTY TEXT").tokenization context_tokens = tokenizer.tokenize("EMPTY TEXT")
context_length = len(context_tokens) context_length = len(context_tokens)
terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) torch.distributed.broadcast(terminate_runs_tensor,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
terminate_runs = terminate_runs_tensor[0].item() terminate_runs = terminate_runs_tensor[0].item()
if terminate_runs == 1: if terminate_runs == 1:
return return
start_time = time.time() token_stream = get_token_stream(model, [context_tokens])
token_stream = get_token_stream(model, [context_tokens], tokenizer, args) for _, decode_tokens in enumerate(token_stream):
for counter, decode_tokens in enumerate(token_stream):
# token_end = decode_tokens.find("<|endoftext|>")
# if token_end > 0:
# break
decode_tokens, _ = decode_tokens decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist() decode_tokens = decode_tokens[0].cpu().numpy().tolist()
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
os.system('clear') os.system('clear')
#print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
print("\nContext:", raw_text, flush=True) print("\nContext:", raw_text, flush=True)
trim_decode_tokens = tokenizer.DecodeIds(decode_tokens)[len(raw_text):] trim_decode_tokens = tokenizer.detokenize(
#print("\nMegatron-LM:", trim_decode_tokens.replace("\n", "\n\n"), flush=True) decode_tokens)[len(raw_text):]
print("\nMegatron-LM:", trim_decode_tokens, flush=True) print("\nMegatron-LM:", trim_decode_tokens, flush=True)
fname_out.write("\nContext:") fname_out.write("\nContext:")
fname_out.write(raw_text) fname_out.write(raw_text)
fname_out.write("\n\nMegatron-LM:") fname_out.write("\n\nMegatron-LM:")
fname_out.write(trim_decode_tokens) fname_out.write(trim_decode_tokens)
#fname_out.write(trim_decode_tokens.replace("\n", "\n\n"))
fname_out.write("\n") fname_out.write("\n")
raw_text = None raw_text = None
torch.distributed.barrier(group=mpu.get_model_parallel_group()) torch.distributed.barrier(group=mpu.get_model_parallel_group())
context_count += 1 context_count += 1
def generate_samples_interactive(model, tokenizer, args):
print_frequency = 24
context_count=0 def generate_samples_interactive(model, print_frequency=24):
args = get_args()
tokenizer = get_tokenizer()
context_count = 0
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
while True: while True:
torch.distributed.barrier(group=mpu.get_model_parallel_group()) torch.distributed.barrier(group=mpu.get_model_parallel_group())
terminate_runs=0 terminate_runs = 0
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
os.system('clear') os.system('clear')
...@@ -231,85 +178,85 @@ def generate_samples_interactive(model, tokenizer, args): ...@@ -231,85 +178,85 @@ def generate_samples_interactive(model, tokenizer, args):
while not raw_text: while not raw_text:
print('Prompt should not be empty!') print('Prompt should not be empty!')
raw_text = input("\nContext prompt (stop to exit) >>> ") raw_text = input("\nContext prompt (stop to exit) >>> ")
if "stop" in raw_text: if "stop" in raw_text:
terminate_runs = 1 terminate_runs = 1
else: else:
context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization context_tokens = tokenizer.tokenize(raw_text)
context_length = len(context_tokens) context_length = len(context_tokens)
if context_length >=args.seq_length//2: if context_length >= (args.seq_length // 2):
print("\nContext length", context_length, \ print("\nContext length", context_length,
"\nPlease give smaller context (half of the sequence length)!") "\nPlease give smaller context (half of the "
"sequence length)!", flush=True)
continue continue
else: else:
context_tokens = tokenizer.EncodeAsIds("EMPTY TEXT").tokenization context_tokens = tokenizer.tokenize("EMPTY TEXT")
context_length = len(context_tokens) context_length = len(context_tokens)
terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) torch.distributed.broadcast(terminate_runs_tensor,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
terminate_runs = terminate_runs_tensor[0].item() terminate_runs = terminate_runs_tensor[0].item()
if terminate_runs == 1: if terminate_runs == 1:
return return
start_time = time.time() token_stream = get_token_stream(model, [context_tokens])
token_stream = get_token_stream(model, [context_tokens], tokenizer, args)
for counter, decode_tokens in enumerate(token_stream): for counter, decode_tokens in enumerate(token_stream):
# token_end = decode_tokens.find("<|endoftext|>")
# if token_end > 0:
# break
decode_tokens, _ = decode_tokens decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist() decode_tokens = decode_tokens[0].cpu().numpy().tolist()
if mpu.get_model_parallel_rank() == 0 and counter % print_frequency == 0: if mpu.get_model_parallel_rank() == 0 and \
counter % print_frequency == 0:
os.system('clear') os.system('clear')
#print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
print("\nContext:", raw_text, flush=True) print("\nContext:", raw_text, flush=True)
trim_decode_tokens = tokenizer.DecodeIds(decode_tokens)[len(raw_text):] trim_decode_tokens = tokenizer.detokenize(
#print("\nGPT2:", trim_decode_tokens, flush=True) decode_tokens)[len(raw_text):]
#print("\nMegatron-LM:", trim_decode_tokens.replace("\n", "\n\n"), flush=True)
print("\nMegatron-LM:", trim_decode_tokens, flush=True) print("\nMegatron-LM:", trim_decode_tokens, flush=True)
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
os.system('clear') os.system('clear')
#print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
print("\nContext:", raw_text, flush=True) print("\nContext:", raw_text, flush=True)
trim_decode_tokens = tokenizer.DecodeIds(decode_tokens)[len(raw_text):] trim_decode_tokens = tokenizer.detokenize(
#print("\nGPT2:", trim_decode_tokens, flush=True) decode_tokens)[len(raw_text):]
#print("\nMegatron-LM:", trim_decode_tokens.replace("\n", "\n\n"), flush=True)
print("\nMegatron-LM:", trim_decode_tokens, flush=True) print("\nMegatron-LM:", trim_decode_tokens, flush=True)
raw_text = None raw_text = None
torch.distributed.barrier(group=mpu.get_model_parallel_group()) torch.distributed.barrier(group=mpu.get_model_parallel_group())
context_count += 1 context_count += 1
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
input("\nPress any key to continue >>>") input("\nPress any key to continue >>>")
def generate_samples_unconditional(model, tokenizer, args):
def generate_samples_unconditional(model):
args = get_args()
tokenizer = get_tokenizer()
num_samples = args.num_samples num_samples = args.num_samples
context_tokens = [[tokenizer.get_command('pad').Id] for _ in range(args.batch_size)] context_tokens = [[tokenizer.eod]
samples = [] for _ in range(args.batch_size)]
# with open(args.genfile, 'w') as f:
ctr = 0 ctr = 0
while True: while True:
start_time = time.time() start_time = time.time()
for token_stream in get_token_stream(model, copy.deepcopy(context_tokens), tokenizer, args): for token_stream in get_token_stream(model,
copy.deepcopy(context_tokens)):
pass pass
# token_stream = list(get_token_stream(model, copy.deepcopy(context_tokens), tokenizer, args)) if ctr % args.log_interval == 0:
if ctr%args.log_interval == 0: print('Avg s/batch:',
print('Avg s/batch:', (time.time()-start_time)/min(args.log_interval, ctr+1)) (time.time() - start_time) / min(args.log_interval, ctr + 1))
start_time = time.time() start_time = time.time()
length = len(token_stream) length = len(token_stream)
token_batch = token_stream[0].cpu().numpy().tolist() token_batch = token_stream[0].cpu().numpy().tolist()
length_batch = token_stream[1].cpu().numpy().tolist() length_batch = token_stream[1].cpu().numpy().tolist()
for tokens, length in zip(token_batch, length_batch): for tokens, length in zip(token_batch, length_batch):
tokens = tokens[1:length-1] tokens = tokens[1:length - 1]
text = tokenizer.DecodeIds(tokens) text = tokenizer.detokenize(tokens)
is_finished = length < args.seq_length - 1 is_finished = length < args.seq_length - 1
datum = {'text': text, 'length': length-1, 'finished': is_finished} datum = {'text': text, 'length': length - 1, 'finished': is_finished}
yield datum yield datum
ctr += 1 ctr += 1
if ctr >= num_samples: if ctr >= num_samples:
...@@ -317,65 +264,73 @@ def generate_samples_unconditional(model, tokenizer, args): ...@@ -317,65 +264,73 @@ def generate_samples_unconditional(model, tokenizer, args):
if ctr >= num_samples: if ctr >= num_samples:
break break
def write_and_generate_samples_unconditional(model, tokenizer, args):
def generate_and_write_samples_unconditional(model):
args = get_args()
assert args.genfile is not None assert args.genfile is not None
with open(args.genfile, 'w') as f: with open(args.genfile, 'w') as f:
for datum in generate_samples_unconditional(model, tokenizer, args): for datum in generate_samples_unconditional(model):
f.write(json.dumps(datum)+'\n') f.write(json.dumps(datum) + '\n')
def pad_batch(batch, pad_id, args):
def pad_batch(batch, tokenizer, args):
pad_id = tokenizer.get_command('pad').Id
context_lengths = [] context_lengths = []
for tokens in batch: for tokens in batch:
context_length = len(tokens) context_length = len(tokens)
if context_length < args.seq_length: if context_length < args.seq_length:
tokens.extend([pad_id]*(args.seq_length-context_length)) tokens.extend([pad_id] * (args.seq_length - context_length))
context_lengths.append(context_length) context_lengths.append(context_length)
return batch, context_lengths return batch, context_lengths
def get_token_stream(model, context_tokens, tokenizer, args):
pad_id = tokenizer.get_command('pad').Id def get_token_stream(model, context_tokens):
# context_length = len(context_tokens)
# if context_length < args.seq_length: args = get_args()
# context_tokens = context_tokens + [pad_id] * (args.seq_length - context_length) tokenizer = get_tokenizer()
context_tokens, context_lengths = pad_batch(context_tokens, tokenizer, args)
context_tokens, context_lengths = pad_batch(context_tokens,
tokenizer.eod, args)
context_tokens_tensor = torch.cuda.LongTensor(context_tokens) context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
context_length_tensor = torch.cuda.LongTensor(context_lengths) context_length_tensor = torch.cuda.LongTensor(context_lengths)
# context_length_tensor = torch.cuda.LongTensor([context_length])
torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) torch.distributed.broadcast(context_length_tensor,
torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
torch.distributed.broadcast(context_tokens_tensor,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
context_length = context_length_tensor.min().item() context_length = context_length_tensor.min().item()
tokens, attention_mask, position_ids=get_batch(context_tokens_tensor, args) tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
counter = 0 batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
org_context_length = context_length context_length_tensor,
attention_mask, position_ids)
layer_past = None
batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, context_length_tensor, attention_mask, position_ids, tokenizer, args)
for tokens, lengths in batch_token_iterator: for tokens, lengths in batch_token_iterator:
context_length += 1 context_length += 1
yield tokens[:, :context_length], lengths yield tokens[:, :context_length], lengths
def switch(val1, val2, boolean): def switch(val1, val2, boolean):
boolean = boolean.type_as(val1) boolean = boolean.type_as(val1)
return (1-boolean)*val1 + boolean*val2 return (1 - boolean) * val1 + boolean * val2
def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, tokenizer, args, maxlen=None, type_ids=None):
if isinstance(model, DDP): def sample_sequence_batch(model, context_tokens, context_lengths,
model = model.module attention_mask, position_ids,
if isinstance(model, FP16_Module): maxlen=None, type_ids=None):
model = model.module
original_output_parallel = model.parallel_output args = get_args()
model.parallel_output = False tokenizer = get_tokenizer()
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
context_length = context_lengths.min().item() context_length = context_lengths.min().item()
eos_id = tokenizer.get_command('eos').Id eos_id = tokenizer.eod
counter = 0 counter = 0
org_context_length = context_length org_context_length = context_length
...@@ -389,12 +344,16 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask ...@@ -389,12 +344,16 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask
if maxlen > (org_context_length + args.out_seq_length): if maxlen > (org_context_length + args.out_seq_length):
maxlen = org_context_length + args.out_seq_length maxlen = org_context_length + args.out_seq_length
lengths = torch.ones([batch_size]).long().cuda()*maxlen lengths = torch.ones([batch_size]).long().cuda() * maxlen
while context_length <= (maxlen): while context_length <= (maxlen):
if args.recompute: if args.recompute:
logits = model(tokens, position_ids, attention_mask, tokentype_ids=type_ids) logits = model(tokens,
position_ids,
attention_mask,
tokentype_ids=type_ids,
forward_method_parallel_output=False)
logits = logits[:, context_length - 1, :] logits = logits[:, context_length - 1, :]
else: else:
types2use = None types2use = None
...@@ -404,113 +363,48 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask ...@@ -404,113 +363,48 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask
if type_ids is not None: if type_ids is not None:
types2use = type_ids[:, :context_length] types2use = type_ids[:, :context_length]
else: else:
tokens2use = tokens[:, context_length - 1].view(batch_size, -1) tokens2use = tokens[:, context_length - 1].view(
positions2use = position_ids[:, context_length - 1].view(batch_size, -1) batch_size, -1)
positions2use = position_ids[:, context_length - 1].view(
batch_size, -1)
if type_ids is not None: if type_ids is not None:
types2use = type_ids[:, context_length - 1].view(batch_size, -1) types2use = type_ids[:, context_length - 1].view(
logits, layer_past = model(tokens2use, positions2use, attention_mask, layer_past=layer_past, get_key_value=True, tokentype_ids=types2use) batch_size, -1)
logits = logits[:, -1].view(batch_size,-1).contiguous() logits, layer_past = model(tokens2use,
positions2use,
attention_mask,
layer_past=layer_past,
get_key_value=True,
tokentype_ids=types2use,
forward_method_parallel_output=False)
logits = logits[:, -1].view(batch_size, -1).contiguous()
if args.greedy: if args.greedy:
prev = torch.argmax(logits, dim=-1).view(-1) prev = torch.argmax(logits, dim=-1).view(-1)
else: else:
logits = logits.float() logits = logits.float()
logits /= args.temperature logits /= args.temperature
logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p) logits = top_k_logits(logits, top_k=args.top_k,
top_p=args.top_p)
log_probs = F.softmax(logits, dim=-1) log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1).view(-1) prev = torch.multinomial(log_probs, num_samples=1).view(-1)
print_logits = [] print_logits = []
for p in prev: for p in prev:
print_logits.append([logits[i, p].item() for i in range(batch_size)]) print_logits.append([logits[i, p].item()
for i in range(batch_size)])
started = context_lengths <= context_length started = context_lengths <= context_length
tokens[:, context_length] = switch(tokens[:, context_length].view(-1), prev, started) tokens[:, context_length] = switch(
tokens[:, context_length].view(-1), prev, started)
context_length += 1 context_length += 1
counter += 1 counter += 1
done_token = (prev == eos_id).byte() & started.byte() done_token = (prev == eos_id).byte() & started.byte()
just_finished = (done_token & ~is_done).bool() just_finished = (done_token & ~is_done).bool()
lengths[just_finished.view(-1)] = context_length lengths[just_finished.view(-1)] = context_length
was_done = is_done
is_done = is_done | done_token is_done = is_done | done_token
done = torch.all(is_done) done = torch.all(is_done)
yield tokens, lengths yield tokens, lengths
if done: if done:
break break
model.parallel_output = original_output_parallel
def prepare_tokenizer(args):
tokenizer_args = {
'tokenizer_type': args.tokenizer_type,
'corpus': None,
'model_path': args.tokenizer_path,
'vocab_size': args.vocab_size,
'model_type': args.tokenizer_model_type,
'cache_dir': args.cache_dir}
tokenizer = make_tokenizer(**tokenizer_args)
args.tokenizer_num_tokens = tokenizer.num_tokens
args.tokenizer_num_type_tokens = tokenizer.num_type_tokens
args.eod_token = tokenizer.get_command('eos').Id
after = tokenizer.num_tokens
multiple = args.make_vocab_size_divisible_by * \
mpu.get_model_parallel_world_size()
if multiple != 0:
while (after % multiple) != 0:
after += 1
args.vocab_size = after
print("prepare tokenizer done", flush=True)
return tokenizer
def main():
"""Main training program."""
print('Generate Samples')
# Disable CuDNN.
torch.backends.cudnn.enabled = False
# Timer.
timers = Timers()
# Arguments.
args = get_args()
# Pytorch distributed.
initialize_distributed(args)
# Random seeds for reproducability.
set_random_seed(args.seed)
#get the tokenizer
tokenizer = prepare_tokenizer(args)
# Model, optimizer, and learning rate.
model = setup_model(args)
#setting default batch size to 1
# args.batch_size = 1
args.device = torch.cuda.current_device()
#generate samples
if args.num_samples == 0:
args.batch_size = 1
if args.sample_input_file != "":
generate_samples_input_from_file(model, tokenizer, args)
else:
generate_samples_interactive(model, tokenizer, args)
else:
write_and_generate_samples_unconditional(model, tokenizer, args)
if __name__ == "__main__":
main()
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .tokenizer import build_tokenizer
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
import unicodedata
import six
def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
"""Checks whether the casing config is consistent with the checkpoint name."""
# The casing has to be passed in by the user and there is no explicit check
# as to whether it matches the checkpoint. The casing information probably
# should have been stored in the bert_config.json file, but it's not, so
# we have to heuristically detect it to validate.
if not init_checkpoint:
return
m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
if m is None:
return
model_name = m.group(1)
lower_models = [
"uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
"multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
]
cased_models = [
"cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
"multi_cased_L-12_H-768_A-12"
]
is_bad_config = False
if model_name in lower_models and not do_lower_case:
is_bad_config = True
actual_flag = "False"
case_name = "lowercased"
opposite_flag = "True"
if model_name in cased_models and do_lower_case:
is_bad_config = True
actual_flag = "True"
case_name = "cased"
opposite_flag = "False"
if is_bad_config:
raise ValueError(
"You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
"However, `%s` seems to be a %s model, so you "
"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
"how the model was pre-training. If this error is wrong, please "
"just comment out this check." % (actual_flag, init_checkpoint,
model_name, case_name, opposite_flag))
def convert_to_unicode(text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text.decode("utf-8", "ignore")
elif isinstance(text, unicode):
return text
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def printable_text(text):
"""Returns text encoded in a way suitable for print or `tf.logging`."""
# These functions want `str` for both Python2 and Python3, but in one case
# it's a Unicode string and in the other it's a byte string.
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text
elif isinstance(text, unicode):
return text.encode("utf-8")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
index = 0
with open(vocab_file, "r") as reader:
while True:
token = convert_to_unicode(reader.readline())
if not token:
break
token = token.strip()
vocab[token] = index
index += 1
return vocab
def convert_by_vocab(vocab, items):
"""Converts a sequence of [tokens|ids] using the vocab."""
output = []
for item in items:
output.append(vocab[item])
return output
def convert_tokens_to_ids(vocab, tokens):
return convert_by_vocab(vocab, tokens)
def convert_ids_to_tokens(inv_vocab, ids):
return convert_by_vocab(inv_vocab, ids)
def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
class FullTokenizer(object):
"""Runs end-to-end tokenziation."""
def __init__(self, vocab_file, do_lower_case=True):
self.vocab = load_vocab(vocab_file)
self.inv_vocab = {v: k for k, v in self.vocab.items()}
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
def tokenize(self, text):
split_tokens = []
for token in self.basic_tokenizer.tokenize(text):
for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token)
return split_tokens
def convert_tokens_to_ids(self, tokens):
return convert_by_vocab(self.vocab, tokens)
def convert_ids_to_tokens(self, ids):
return convert_by_vocab(self.inv_vocab, ids)
def vocab_size(self):
return len(self.vocab)
class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def __init__(self, do_lower_case=True):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: Whether to lower case the input.
"""
self.do_lower_case = do_lower_case
def tokenize(self, text):
"""Tokenizes a piece of text."""
text = convert_to_unicode(text)
text = self._clean_text(text)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
text = self._tokenize_chinese_chars(text)
orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if self.do_lower_case:
token = token.lower()
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token))
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _run_split_on_punc(self, text):
"""Splits punctuation on a piece of text."""
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
(cp >= 0x3400 and cp <= 0x4DBF) or #
(cp >= 0x20000 and cp <= 0x2A6DF) or #
(cp >= 0x2A700 and cp <= 0x2B73F) or #
(cp >= 0x2B740 and cp <= 0x2B81F) or #
(cp >= 0x2B820 and cp <= 0x2CEAF) or
(cp >= 0xF900 and cp <= 0xFAFF) or #
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
return True
return False
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
class WordpieceTokenizer(object):
"""Runs WordPiece tokenziation."""
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word
def tokenize(self, text):
"""Tokenizes a piece of text into its word pieces.
This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary.
For example:
input = "unaffable"
output = ["un", "##aff", "##able"]
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer.
Returns:
A list of wordpiece tokens.
"""
text = convert_to_unicode(text)
output_tokens = []
for token in whitespace_tokenize(text):
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token)
continue
is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0:
substr = "##" + substr
if substr in self.vocab:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end
if is_bad:
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat in ("Cc", "Cf"):
return True
return False
def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import sys
import json
import logging
import os
import regex as re
from io import open
try:
from functools import lru_cache
except ImportError:
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE
# tokenizer on python 2 right now.
def lru_cache():
return lambda func: func
logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
}
PRETRAINED_MERGES_ARCHIVE_MAP = {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'gpt2': 1024,
}
VOCAB_NAME = 'vocab.json'
MERGES_NAME = 'merges.txt'
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
_chr = unichr if sys.version_info[0] == 2 else chr
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + \
list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [_chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class GPT2Tokenizer(object):
"""
GPT-2 BPE tokenizer. Peculiarities:
- Byte-level BPE
"""
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
special_tokens_file = None
else:
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
if not os.path.exists(special_tokens_file):
special_tokens_file = None
else:
logger.info("loading special tokens file {}".format(special_tokens_file))
# redirect to the cache, if necessary
try:
from .file_utils import cached_path
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
except EnvironmentError:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path,
vocab_file, merges_file))
return None
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
logger.info("loading vocabulary file {}".format(vocab_file))
logger.info("loading merges file {}".format(merges_file))
else:
logger.info("loading vocabulary file {} from cache at {}".format(
vocab_file, resolved_vocab_file))
logger.info("loading merges file {} from cache at {}".format(
merges_file, resolved_merges_file))
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer.
if special_tokens_file and 'special_tokens' not in kwargs:
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
else:
special_tokens = kwargs.pop('special_tokens', [])
tokenizer = cls(
resolved_vocab_file,
resolved_merges_file,
special_tokens=special_tokens,
*inputs,
**kwargs)
return tokenizer
def __init__(self, vocab_file, merges_file, errors='replace',
special_tokens=None, max_len=None):
self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file))
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_data]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
# Should haved added re.IGNORECASE so BPE merges can happen for
# capitalized versions of contractions
self.pat = re.compile(
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
self.special_tokens = {}
self.special_tokens_decoder = {}
self.set_special_tokens(special_tokens)
def __len__(self):
return len(self.encoder) + len(self.special_tokens)
def set_special_tokens(self, special_tokens):
""" Add a list of additional tokens to the encoder.
The additional tokens are indexed starting from the last index of the
current vocabulary in the order of the `special_tokens` list.
"""
if not special_tokens:
self.special_tokens = {}
self.special_tokens_decoder = {}
return
self.special_tokens = dict((tok, len(self.encoder) + i)
for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()}
logger.info("Special tokens {}".format(self.special_tokens))
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except BaseException:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def tokenize(self, text):
""" Tokenize a string. """
bpe_tokens = []
for token in re.findall(self.pat, text):
if sys.version_info[0] == 2:
token = ''.join(self.byte_encoder[ord(b)] for b in token)
else:
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def convert_tokens_to_ids(self, tokens):
""" Converts a sequence of tokens into ids using the vocab. """
ids = []
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
if tokens in self.special_tokens:
return self.special_tokens[tokens]
else:
return self.encoder.get(tokens, 0)
for token in tokens:
if token in self.special_tokens:
ids.append(self.special_tokens[token])
else:
ids.append(self.encoder.get(token, 0))
if len(ids) > self.max_len:
logger.warning(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors".format(
len(ids), self.max_len)
)
return ids
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
"""Converts a sequence of ids in BPE tokens using the vocab."""
tokens = []
for i in ids:
if i in self.special_tokens_decoder:
if not skip_special_tokens:
tokens.append(self.special_tokens_decoder[i])
else:
tokens.append(self.decoder[i])
return tokens
def encode(self, text):
return self.convert_tokens_to_ids(self.tokenize(text))
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
return text
def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary and merge files to a directory."""
if not os.path.isdir(vocab_path):
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
return
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
merge_file = os.path.join(vocab_path, MERGES_NAME)
special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
with open(vocab_file, 'w', encoding='utf-8') as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write(u'#version: 0.2\n')
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file))
index = token_index
writer.write(' '.join(bpe_tokens) + u'\n')
index += 1
index = len(self.encoder)
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(special_tokens_file))
index = token_index
writer.write(token + u'\n')
index += 1
return vocab_file, merge_file, special_tokens_file
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron tokenizers."""
from abc import ABC
from abc import abstractmethod
from .bert_tokenization import FullTokenizer as FullBertTokenizer
from .gpt2_tokenization import GPT2Tokenizer
def build_tokenizer(args):
"""Initialize tokenizer."""
if args.rank == 0:
print('> building {} tokenizer ...'.format(args.tokenizer_type),
flush=True)
# Select and instantiate the tokenizer.
assert args.vocab_file is not None
if args.tokenizer_type == 'BertWordPieceLowerCase':
tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file,
lower_case=True)
elif args.tokenizer_type == 'GPT2BPETokenizer':
assert args.merge_file is not None
tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file)
else:
raise NotImplementedError('{} tokenizer is not '
'implemented.'.format(args.tokenizer_type))
# Add vocab size.
args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size,
args)
return tokenizer
def _vocab_size_with_padding(orig_vocab_size, args):
"""Pad vocab size so it is divisible by model parallel size and
still having GPU friendly size."""
after = orig_vocab_size
multiple = args.make_vocab_size_divisible_by * \
args.model_parallel_size
while (after % multiple) != 0:
after += 1
if args.rank == 0:
print(' > padded vocab (size: {}) with {} dummy tokens '
'(new size: {})'.format(
orig_vocab_size, after - orig_vocab_size, after), flush=True)
return after
class AbstractTokenizer(ABC):
"""Abstract class for tokenizer."""
def __init__(self, name):
self.name = name
super().__init__()
@property
@abstractmethod
def vocab_size(self):
pass
@property
@abstractmethod
def vocab(self):
"""Dictionary from vocab text token to id token."""
pass
@property
@abstractmethod
def inv_vocab(self):
"""Dictionary from vocab id token to text token."""
pass
@abstractmethod
def tokenize(self, text):
pass
def detokenize(self, token_ids):
raise NotImplementedError('detokenizer is not implemented for {} '
'tokenizer'.format(self.name))
@property
def cls(self):
raise NotImplementedError('CLS is not provided for {} '
'tokenizer'.format(self.name))
@property
def sep(self):
raise NotImplementedError('SEP is not provided for {} '
'tokenizer'.format(self.name))
@property
def pad(self):
raise NotImplementedError('PAD is not provided for {} '
'tokenizer'.format(self.name))
@property
def eod(self):
raise NotImplementedError('EOD is not provided for {} '
'tokenizer'.format(self.name))
@property
def mask(self):
raise NotImplementedError('MASK is not provided for {} '
'tokenizer'.format(self.name))
class _BertWordPieceTokenizer(AbstractTokenizer):
"""Original BERT wordpiece tokenizer."""
def __init__(self, vocab_file, lower_case=True):
if lower_case:
name = 'BERT Lower Case'
else:
name = 'BERT Upper Case'
super().__init__(name)
self.tokenizer = FullBertTokenizer(vocab_file, do_lower_case=lower_case)
self.cls_id = self.tokenizer.vocab['[CLS]']
self.sep_id = self.tokenizer.vocab['[SEP]']
self.pad_id = self.tokenizer.vocab['[PAD]']
self.mask_id = self.tokenizer.vocab['[MASK]']
@property
def vocab_size(self):
return self.tokenizer.vocab_size()
@property
def vocab(self):
return self.tokenizer.vocab
@property
def inv_vocab(self):
return self.tokenizer.inv_vocab
def tokenize(self, text):
text_tokens = self.tokenizer.tokenize(text)
return self.tokenizer.convert_tokens_to_ids(text_tokens)
@property
def cls(self):
return self.cls_id
@property
def sep(self):
return self.sep_id
@property
def pad(self):
return self.pad_id
@property
def mask(self):
return self.mask_id
class _GPT2BPETokenizer(AbstractTokenizer):
"""Original GPT2 BPE tokenizer."""
def __init__(self, vocab_file, merge_file):
name = 'GPT2 BPE'
super().__init__(name)
self.tokenizer = GPT2Tokenizer(vocab_file, merge_file, errors='replace',
special_tokens=[], max_len=None)
self.eod_id = self.tokenizer.encoder['<|endoftext|>']
@property
def vocab_size(self):
return len(self.tokenizer.encoder)
@property
def vocab(self):
return self.tokenizer.encoder
@property
def inv_vocab(self):
return self.tokenizer.decoder
def tokenize(self, text):
return self.tokenizer.encode(text)
def detokenize(self, token_ids):
return self.tokenizer.decode(token_ids)
@property
def eod(self):
return self.eod_id
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment