Commit 1b95daa0 authored by thomwolf's avatar thomwolf
Browse files

model conversion WIP

parent da017ac9
......@@ -105,7 +105,132 @@ class BertConfig(object):
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
class BERTLayerNorm(nn.Module):
def __init__(self):
tf.contrib.layers.layer_norm(
inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)
class BERTEmbeddings(nn.Module):
def __init__(self, embedding_size, vocab_size,
token_type_vocab_size, max_position_embeddings,
config):
self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
self.token_type_embeddings = nn.Embedding(config.token_type_vocab_size, config.embedding_size)
self.LayerNorm = BERTLayerNorm() # Not snake-cased to fit with TF model variable name
self.dropout = nn.dropout(config.hidden_dropout_prob)
self.initialize_weights(self, config.initializer_range)
def initialize_weights(self, initializer_range):
torch.truncated_normal_initializer(stddev=initializer_range)
def forward(self, input_ids, token_type_ids=None):
batch_size = input_ids.size(0)
seq_length = input_ids.size(1)
position_ids = torch.range().view(batch_size, seq_length)
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = words_embeddings + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class BERTIntermediate(nn.Module):
def __init__(self, config):
super(BERTOutput, self).__init__()
self.dense = nn.Linear()
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
return hidden_states
class BERTOutput(nn.Module):
def __init__(self, config):
super(BERTOutput, self).__init__()
self.dense = nn.Linear()
self.LayerNorm = BERTLayerNorm(config)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class BERTSelfAttention(nn.Module):
def __init__(self, config):
super(BERTSelfAttention, self).__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
attention_head_size = int(config.hidden_size / config.num_attention_heads)
all_head_size = num_attention_heads * attention_head_size
self.query = nn.Linear(config.hidden_size, all_head_size)
self.key = nn.Linear(config.hidden_size, all_head_size)
self.value = nn.Linear(config.hidden_size, all_head_size)
def transpose_for_scores(self, x, k=False):
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
if k:
return x.permute(0, 2, 3, 1)
else:
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states):
return hidden_states
class BERTAttention(nn.Module):
def __init__(self, config):
super(BERTAttention, self).__init__()
self.self = BERTSelfAttention(config)
self.output = BERTOutput(config)
def forward(self, hidden_states):
hidden_states = self.self(hidden_states)
hidden_states = self.output(hidden_states)
return hidden_states
class BERTLayer(nn.Module):
def __init__(self, config):
super(BERTLayer, self).__init__()
self.attention = BERTAttention(config)
self.intermediate = BERTIntermediate(config)
self.output = BERTOutput(config)
def forward(self, hidden_states):
hidden_states = self.attention(hidden_states)
hidden_states = self.intermediate(hidden_states)
hidden_states = self.output(hidden_states)
return hidden_states
class BERTEncoder(nn.Module):
def __init__(self, config):
super(BERTEncoder, self).__init__()
layer = BERTLayer(n_ctx, cfg, scale=True)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states):
"""
Args:
hidden_states: float Tensor of shape [batch_size, seq_length, hidden_size]
Return:
float Tensor of shape [batch_size, seq_length, hidden_size]
"""
for layer_module in self.layer:
hidden_states = layer_module(hidden_states)
return hidden_states
class BertModel(nn.Module):
......@@ -132,28 +257,11 @@ class BertModel(nn.Module):
```
"""
def __init__(self,
config,
is_training,
input_ids,
input_mask=None,
token_type_ids=None,
use_one_hot_embeddings=True,
scope=None):
def __init__(self, config: BertConfig):
"""Constructor for BertModel.
Args:
config: `BertConfig` instance.
is_training: bool. rue for training model, false for eval model. Controls
whether dropout will be applied.
input_ids: int32 Tensor of shape [batch_size, seq_length].
input_mask: (optional) int32 Tensor of shape [batch_size, seq_length].
token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
use_one_hot_embeddings: (optional) bool. Whether to use one-hot word
embeddings or tf.embedding_lookup() for the word embeddings. On the TPU,
it is must faster if this is True, on the CPU or GPU, it is faster if
this is False.
scope: (optional) variable scope. Defaults to "bert".
Raises:
ValueError: The config is invalid or one of the input tensor shapes
......@@ -168,15 +276,20 @@ class BertModel(nn.Module):
batch_size = input_ids.size(0)
seq_length = input_ids.size(1)
if input_mask is None:
input_mask = torch.ones(batch_size, seq_length), dtype=torch.long)
self.embeddings = BERTEmbeddings(config)
self.encoder = BERTEncoder(config)
if token_type_ids is None:
token_type_ids = torch.zeros((batch_size, seq_length), dtype=torch.long)
self.embeddings = BERTEmbeddings(config.vocab_size, config.hidden_size)
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
def forward(self, input_ids, token_type_ids=None, input_mask=None):
if input_mask is None:
input_mask = torch.ones(batch_size, seq_length), dtype=torch.long)
if token_type_ids is None:
token_type_ids = torch.zeros((batch_size, seq_length), dtype=torch.long)
hidden_states = self.embeddings(input_ids, token_type_ids, input_mask)
hidden_states = self.encoder(hidden_states)
# Perform embedding lookup on the word ids.
(self.embedding_output, self.embedding_table) = embedding_lookup(
input_ids=input_ids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment