"docs/img/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "755ac5f07d7e9bf802b92855f69a29bce31a2c88"
Commit 45709d75 authored by thomwolf's avatar thomwolf
Browse files

model running with simple inputs

parent b407972e
...@@ -17,6 +17,9 @@ from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHe ...@@ -17,6 +17,9 @@ from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHe
from .modeling_gpt2 import (GPT2Config, GPT2Model, from .modeling_gpt2 import (GPT2Config, GPT2Model,
GPT2LMHeadModel, GPT2DoubleHeadsModel, GPT2MultipleChoiceHead, GPT2LMHeadModel, GPT2DoubleHeadsModel, GPT2MultipleChoiceHead,
load_tf_weights_in_gpt2) load_tf_weights_in_gpt2)
from .modeling_xlnet import (XLNetBaseConfig, XLNetConfig, XLNetRunConfig,
XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel,
load_tf_weights_in_xlnet)
from .optimization import BertAdam from .optimization import BertAdam
from .optimization_openai import OpenAIAdam from .optimization_openai import OpenAIAdam
......
...@@ -21,13 +21,13 @@ from __future__ import print_function ...@@ -21,13 +21,13 @@ from __future__ import print_function
import argparse import argparse
import torch import torch
from pytorch_pretrained_bert.modeling_xlnet import XLNetConfig, XLNetRunConfig, XLNetModel, load_tf_weights_in_xlnet from pytorch_pretrained_bert.modeling_xlnet import XLNetConfig, XLNetRunConfig, XLNetLMHeadModel, load_tf_weights_in_xlnet
def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
# Initialise PyTorch model # Initialise PyTorch model
config = XLNetConfig.from_json_file(bert_config_file) config = XLNetConfig.from_json_file(bert_config_file)
print("Building PyTorch model from configuration: {}".format(str(config))) print("Building PyTorch model from configuration: {}".format(str(config)))
model = XLNetModel(config) model = XLNetLMHeadModel(config)
# Load weights from tf checkpoint # Load weights from tf checkpoint
load_tf_weights_in_xlnet(model, tf_checkpoint_path) load_tf_weights_in_xlnet(model, tf_checkpoint_path)
......
...@@ -867,7 +867,7 @@ class BertModel(BertPreTrainedModel): ...@@ -867,7 +867,7 @@ class BertModel(BertPreTrainedModel):
if head_mask is not None: if head_mask is not None:
if head_mask.dim() == 1: if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand_as(self.config.num_hidden_layers, -1, -1, -1, -1) head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
elif head_mask.dim() == 2: elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
......
...@@ -722,7 +722,7 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -722,7 +722,7 @@ class GPT2Model(GPT2PreTrainedModel):
if head_mask is not None: if head_mask is not None:
if head_mask.dim() == 1: if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand_as(self.config.n_layer, -1, -1, -1, -1) head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1)
elif head_mask.dim() == 2: elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
......
...@@ -718,7 +718,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -718,7 +718,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
if head_mask is not None: if head_mask is not None:
if head_mask.dim() == 1: if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand_as(self.config.n_layer, -1, -1, -1, -1) head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1)
elif head_mask.dim() == 2: elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
......
...@@ -29,6 +29,7 @@ from io import open ...@@ -29,6 +29,7 @@ from io import open
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from .file_utils import cached_path, WEIGHTS_NAME, CONFIG_NAME from .file_utils import cached_path, WEIGHTS_NAME, CONFIG_NAME
...@@ -126,32 +127,27 @@ def swish(x): ...@@ -126,32 +127,27 @@ def swish(x):
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
def positional_embedding(pos_seq, inv_freq, bsz=None):
sinusoid_inp = torch.einsum('i,d->id', pos_seq, inv_freq)
pos_emb = torch.cat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1)
pos_emb = pos_emb[:, None, :]
if bsz is not None:
pos_emb = pos_emb.expand(1, bsz, 1)
return pos_emb
class XLNetBaseConfig(object): class XLNetBaseConfig(object):
@classmethod @classmethod
def from_dict(cls, json_object): def from_dict(cls, json_object):
"""Constructs a `XLNetConfig` from a Python dictionary of parameters.""" """Constructs a `XLNetBaseConfig` from a Python dictionary of parameters."""
config = XLNetConfig(vocab_size_or_config_json_file=-1) config = cls(vocab_size_or_config_json_file=-1)
for key, value in json_object.items(): for key, value in json_object.items():
config.__dict__[key] = value config.__dict__[key] = value
return config return config
@classmethod @classmethod
def from_json_file(cls, json_file): def from_json_file(cls, json_file):
"""Constructs a `XLNetConfig` from a json file of parameters.""" """Constructs a `XLNetBaseConfig` from a json file of parameters."""
with open(json_file, "r", encoding='utf-8') as reader: with open(json_file, "r", encoding='utf-8') as reader:
text = reader.read() text = reader.read()
return cls.from_dict(json.loads(text)) return cls.from_dict(json.loads(text))
def update(self, other):
dict_b = other.to_dict()
for key, value in dict_b.items():
self.__dict__[key] = value
def __repr__(self): def __repr__(self):
return str(self.to_json_string()) return str(self.to_json_string())
...@@ -181,6 +177,7 @@ class XLNetConfig(XLNetBaseConfig): ...@@ -181,6 +177,7 @@ class XLNetConfig(XLNetBaseConfig):
d_inner=4096, d_inner=4096,
ff_activation="gelu", ff_activation="gelu",
untie_r=True, untie_r=True,
attn_type="bi",
max_position_embeddings=512, max_position_embeddings=512,
initializer_range=0.02, initializer_range=0.02,
...@@ -198,6 +195,7 @@ class XLNetConfig(XLNetBaseConfig): ...@@ -198,6 +195,7 @@ class XLNetConfig(XLNetBaseConfig):
ff_activation: The non-linear activation function (function or string) in the ff_activation: The non-linear activation function (function or string) in the
encoder and pooler. If string, "gelu", "relu" and "swish" are supported. encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
untie_r: untie relative position biases untie_r: untie relative position biases
attn_type: 'bi' for XLNet, 'uni' for Transformer-XL
dropout: The dropout probabilitiy for all fully connected dropout: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler. layers in the embeddings, encoder, and pooler.
...@@ -226,6 +224,7 @@ class XLNetConfig(XLNetBaseConfig): ...@@ -226,6 +224,7 @@ class XLNetConfig(XLNetBaseConfig):
self.ff_activation = ff_activation self.ff_activation = ff_activation
self.d_inner = d_inner self.d_inner = d_inner
self.untie_r = untie_r self.untie_r = untie_r
self.attn_type = attn_type
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
...@@ -304,15 +303,15 @@ class XLNetRelativeAttention(nn.Module): ...@@ -304,15 +303,15 @@ class XLNetRelativeAttention(nn.Module):
def __init__(self, config, output_attentions=False, keep_multihead_output=False): def __init__(self, config, output_attentions=False, keep_multihead_output=False):
super(XLNetRelativeAttention, self).__init__() super(XLNetRelativeAttention, self).__init__()
self.output_attentions = output_attentions self.output_attentions = output_attentions
if config.d_model % config.num_attention_heads != 0: if config.d_model % config.n_head != 0:
raise ValueError( raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention " "The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.d_model, config.num_attention_heads)) "heads (%d)" % (config.d_model, config.n_head))
self.output_attentions = output_attentions self.output_attentions = output_attentions
self.keep_multihead_output = keep_multihead_output self.keep_multihead_output = keep_multihead_output
self.multihead_output = None self.multihead_output = None
self.n_head = config.num_attention_heads self.n_head = config.n_head
self.d_head = config.d_head self.d_head = config.d_head
self.d_model = config.d_model self.d_model = config.d_model
self.scale = 1 / (config.d_head ** 0.5) self.scale = 1 / (config.d_head ** 0.5)
...@@ -326,7 +325,7 @@ class XLNetRelativeAttention(nn.Module): ...@@ -326,7 +325,7 @@ class XLNetRelativeAttention(nn.Module):
self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head)) self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
self.r_s_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head)) self.r_s_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head)) self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
self.seg_embed = nn.Parameter(torch.Tensor(self.n_head, 2, self.d_head)) self.seg_embed = nn.Parameter(torch.Tensor(2, self.n_head, self.d_head))
self.LayerNorm = XLNetLayerNorm(config.d_model, eps=config.layer_norm_eps) self.LayerNorm = XLNetLayerNorm(config.d_model, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.dropout) self.dropout = nn.Dropout(config.dropout)
...@@ -334,6 +333,18 @@ class XLNetRelativeAttention(nn.Module): ...@@ -334,6 +333,18 @@ class XLNetRelativeAttention(nn.Module):
def prune_heads(self, heads): def prune_heads(self, heads):
raise NotImplementedError raise NotImplementedError
@staticmethod
def rel_shift(x, klen=-1):
"""perform relative shift to form the relative attention score."""
x_size = x.shape
x = x.reshape(x_size[1], x_size[0], x_size[2], x_size[3])
x = x[1:, ...]
x = x.reshape(x_size[0], x_size[1] - 1, x_size[2], x_size[3])
x = x[:, 0:klen, :, :]
return x
def rel_attn_core(self, q_head, k_head_h, v_head_h, k_head_r, seg_mat=None, attn_mask=None): def rel_attn_core(self, q_head, k_head_h, v_head_h, k_head_r, seg_mat=None, attn_mask=None):
"""Core relative positional attention operations.""" """Core relative positional attention operations."""
...@@ -342,7 +353,7 @@ class XLNetRelativeAttention(nn.Module): ...@@ -342,7 +353,7 @@ class XLNetRelativeAttention(nn.Module):
# position based attention score # position based attention score
bd = torch.einsum('ibnd,jbnd->ijbn', q_head + self.r_r_bias, k_head_r) bd = torch.einsum('ibnd,jbnd->ijbn', q_head + self.r_r_bias, k_head_r)
bd = rel_shift(bd, klen=torch.shape(ac)[1]) bd = self.rel_shift(bd, klen=ac.shape[1])
# segment based attention score # segment based attention score
if seg_mat is None: if seg_mat is None:
...@@ -426,7 +437,6 @@ class XLNetRelativeAttention(nn.Module): ...@@ -426,7 +437,6 @@ class XLNetRelativeAttention(nn.Module):
# post processing # post processing
output_g = self.post_attention(g, attn_vec_g) output_g = self.post_attention(g, attn_vec_g)
attention_output = output_h, output_g
else: else:
###### Multi-head attention with relative positional encoding ###### Multi-head attention with relative positional encoding
if mems is not None and mems.dim() > 1: if mems is not None and mems.dim() > 1:
...@@ -447,7 +457,8 @@ class XLNetRelativeAttention(nn.Module): ...@@ -447,7 +457,8 @@ class XLNetRelativeAttention(nn.Module):
q_head_h, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_h) q_head_h, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_h)
# post processing # post processing
attention_output = self.post_attention(h, attn_vec) output_h = self.post_attention(h, attn_vec)
output_g = None
# Mask heads if we want to # Mask heads if we want to
...@@ -467,7 +478,7 @@ class XLNetRelativeAttention(nn.Module): ...@@ -467,7 +478,7 @@ class XLNetRelativeAttention(nn.Module):
# attentions, self_output = self_output # attentions, self_output = self_output
# if self.output_attentions: # if self.output_attentions:
# return attentions, attention_output # return attentions, attention_output
return attention_output return output_h, output_g
class XLNetFeedForward(nn.Module): class XLNetFeedForward(nn.Module):
def __init__(self, config): def __init__(self, config):
...@@ -481,13 +492,15 @@ class XLNetFeedForward(nn.Module): ...@@ -481,13 +492,15 @@ class XLNetFeedForward(nn.Module):
else: else:
self.activation_function = config.ff_activation self.activation_function = config.ff_activation
def forward(self, hidden_states, input_tensor): def forward(self, inp):
hidden_states = self.layer_1(hidden_states) output = inp
hidden_states = self.activation_function(hidden_states) output = self.layer_1(output)
hidden_states = self.layer_2(hidden_states) output = self.activation_function(output)
hidden_states = self.dropout(hidden_states) output = self.dropout(output)
hidden_states = self.LayerNorm(hidden_states + input_tensor) output = self.layer_2(output)
return hidden_states output = self.dropout(output)
output = self.LayerNorm(output + inp)
return output
class XLNetLayer(nn.Module): class XLNetLayer(nn.Module):
def __init__(self, config, output_attentions=False, keep_multihead_output=False): def __init__(self, config, output_attentions=False, keep_multihead_output=False):
...@@ -500,13 +513,13 @@ class XLNetLayer(nn.Module): ...@@ -500,13 +513,13 @@ class XLNetLayer(nn.Module):
def forward(self, output_h, output_g, def forward(self, output_h, output_g,
attn_mask_h, attn_mask_g, attn_mask_h, attn_mask_g,
r, seg_mat, r, seg_mat, r, seg_mat,
two_streams=False, mems=None, target_mapping=None, head_mask=None): mems=None, target_mapping=None, head_mask=None):
output_h, output_g = self.rel_attn(output_h, output_g, output_h, output_g = self.rel_attn(output_h, output_g,
attn_mask_h, attn_mask_g, attn_mask_h, attn_mask_g,
r, seg_mat, r, seg_mat,
mems=mems, target_mapping=target_mapping, head_mask=head_mask) mems=mems, target_mapping=target_mapping, head_mask=head_mask)
if two_streams: if output_g is not None:
output_g = self.ff(output_g) output_g = self.ff(output_g)
output_h = self.ff(output_h) output_h = self.ff(output_h)
...@@ -520,9 +533,9 @@ class XLNetPreTrainedModel(nn.Module): ...@@ -520,9 +533,9 @@ class XLNetPreTrainedModel(nn.Module):
""" """
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super(XLNetPreTrainedModel, self).__init__() super(XLNetPreTrainedModel, self).__init__()
if not isinstance(config, XLNetConfig): if not isinstance(config, XLNetBaseConfig):
raise ValueError( raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `XLNetConfig`. " "Parameter config in `{}(config)` should be an instance of class `XLNetBaseConfig`. "
"To create a model from a Google pretrained model use " "To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__ self.__class__.__name__, self.__class__.__name__
...@@ -668,26 +681,41 @@ class XLNetPreTrainedModel(nn.Module): ...@@ -668,26 +681,41 @@ class XLNetPreTrainedModel(nn.Module):
class XLNetModel(XLNetPreTrainedModel): class XLNetModel(XLNetPreTrainedModel):
def __init__(self, config, output_attentions=False, keep_multihead_output=False): def __init__(self, config, output_attentions=False, keep_multihead_output=False):
super(XLNetModel, self).__init__() super(XLNetModel, self).__init__(config)
self.output_attentions = output_attentions self.output_attentions = output_attentions
self.mem_len = config.mem_len self.mem_len = config.mem_len
self.reuse_len = config.reuse_len self.reuse_len = config.reuse_len
layer = XLNetLayer(config, output_attentions=output_attentions, self.d_model = config.d_model
keep_multihead_output=keep_multihead_output) self.same_length = config.same_length
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) self.attn_type = config.attn_type
self.bi_data = config.bi_data
self.clamp_len = config.clamp_len
@classmethod layer = XLNetLayer(config, output_attentions=output_attentions,
def _create_mask(qlen, mlen, dtype=torch.float, same_length=False): keep_multihead_output=keep_multihead_output)
"""create causal attention mask.""" self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.n_layer)])
attn_mask = torch.ones([qlen, qlen], dtype=dtype) self.dropout = nn.Dropout(config.dropout)
mask_u = tf.matrix_band_part(attn_mask, 0, -1)
mask_dia = tf.matrix_band_part(attn_mask, 0, 0)
attn_mask_pad = tf.zeros([qlen, mlen], dtype=dtype)
ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1)
if same_length:
mask_l = tf.matrix_band_part(attn_mask, -1, 0)
ret = tf.concat([ret[:, :qlen] + mask_l - mask_dia, ret[:, qlen:]], 1)
def create_mask(self, qlen, mlen):
""" create causal attention mask.
float mask where 1.0 indicate masked, 0.0 indicated not-masked.
same_length=False: same_length=True:
<mlen > < qlen > <mlen > < qlen >
^ [0 0 0 0 0 1 1 1 1] [0 0 0 0 0 1 1 1 1]
[0 0 0 0 0 0 1 1 1] [1 0 0 0 0 0 1 1 1]
qlen [0 0 0 0 0 0 0 1 1] [1 1 0 0 0 0 0 1 1]
[0 0 0 0 0 0 0 0 1] [1 1 1 0 0 0 0 0 1]
v [0 0 0 0 0 0 0 0 0] [1 1 1 1 0 0 0 0 0]
"""
attn_mask = torch.ones([qlen, qlen])
mask_up = torch.triu(attn_mask, diagonal=1)
attn_mask_pad = torch.zeros([qlen, mlen])
ret = torch.cat([attn_mask_pad, mask_up], dim=1)
if self.same_length:
mask_lo = torch.tril(attn_mask, diagonal=-1)
ret = torch.cat([ret[:, :qlen] + mask_lo, ret[:, qlen:]], dim=1)
ret = ret.to(next(self.parameters()))
return ret return ret
def cache_mem(self, curr_out, prev_mem): def cache_mem(self, curr_out, prev_mem):
...@@ -705,10 +733,21 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -705,10 +733,21 @@ class XLNetModel(XLNetPreTrainedModel):
return new_mem.detach() return new_mem.detach()
def relative_positional_encoding(self, qlen, klen, bsz=None, dtype=torch.float): @staticmethod
def positional_embedding(pos_seq, inv_freq, bsz=None):
sinusoid_inp = torch.einsum('i,d->id', pos_seq, inv_freq)
pos_emb = torch.cat([torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)], dim=-1)
pos_emb = pos_emb[:, None, :]
if bsz is not None:
pos_emb = pos_emb.expand(-1, bsz, -1)
return pos_emb
def relative_positional_encoding(self, qlen, klen, bsz=None):
"""create relative positional encoding.""" """create relative positional encoding."""
freq_seq = torch.zrange(0, d_model, 2.0, dtype=dtype) freq_seq = torch.arange(0, self.d_model, 2.0, dtype=torch.float)
inv_freq = 1 / (10000 ** (freq_seq / self.config.d_model)) inv_freq = 1 / (10000 ** (freq_seq / self.d_model))
if self.attn_type == 'bi': if self.attn_type == 'bi':
# beg, end = klen - 1, -qlen # beg, end = klen - 1, -qlen
...@@ -720,51 +759,52 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -720,51 +759,52 @@ class XLNetModel(XLNetPreTrainedModel):
raise ValueError('Unknown `attn_type` {}.'.format(self.attn_type)) raise ValueError('Unknown `attn_type` {}.'.format(self.attn_type))
if self.bi_data: if self.bi_data:
fwd_pos_seq = torch.arange(beg, end, -1.0, dtype=dtype) fwd_pos_seq = torch.arange(beg, end, -1.0, dtype=torch.float)
bwd_pos_seq = torch.arange(-beg, -end, 1.0, dtype=dtype) bwd_pos_seq = torch.arange(-beg, -end, 1.0, dtype=torch.float)
if self.clamp_len > 0: if self.clamp_len > 0:
fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len) fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
bwd_pos_seq = bwd_pos_seq.clamp(-self.clamp_len, self.clamp_len) bwd_pos_seq = bwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
if bsz is not None: if bsz is not None:
fwd_pos_emb = positional_embedding(fwd_pos_seq, inv_freq, bsz//2) fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz//2)
bwd_pos_emb = positional_embedding(bwd_pos_seq, inv_freq, bsz//2) bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq, bsz//2)
else: else:
fwd_pos_emb = positional_embedding(fwd_pos_seq, inv_freq) fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq)
bwd_pos_emb = positional_embedding(bwd_pos_seq, inv_freq) bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq)
pos_emb = torch.cat([fwd_pos_emb, bwd_pos_emb], dim=1) pos_emb = torch.cat([fwd_pos_emb, bwd_pos_emb], dim=1)
else: else:
fwd_pos_seq = torch.arange(beg, end, -1.0, dtype=dtype) fwd_pos_seq = torch.arange(beg, end, -1.0)
if self.clamp_len > 0: if self.clamp_len > 0:
fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len) fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
pos_emb = positional_embedding(fwd_pos_seq, inv_freq, bsz) pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)
pos_emb = pos_emb.to(next(self.parameters()))
return pos_emb return pos_emb
def forward(self, inp_k, seg_id=None, input_mask=None, def forward(self, word_emb_k, seg_id=None, input_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None, mems=None, perm_mask=None, target_mapping=None, inp_q=None,
output_all_encoded_layers=True, head_mask=None): output_all_encoded_layers=True, head_mask=None):
""" """
Args: Args:
inp_k: int32 Tensor in shape [len, bsz], the input token IDs. word_emb_k: float32 Tensor in shape [len, bsz, d_model], the input token embeddings.
seg_id: int32 Tensor in shape [len, bsz], the input segment IDs. seg_id: int32 Tensor in shape [len, bsz], the input segment IDs.
input_mask: float32 Tensor in shape [len, bsz], the input mask. input_mask: [optional] float32 Tensor in shape [len, bsz], the input mask.
0 for real tokens and 1 for padding. 0 for real tokens and 1 for padding.
mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory mems: [optional] a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
from previous batches. The length of the list equals n_layer. from previous batches. The length of the list equals n_layer.
If None, no memory is used. If None, no memory is used.
perm_mask: float32 Tensor in shape [len, len, bsz]. perm_mask: [optional] float32 Tensor in shape [len, len, bsz].
If perm_mask[i, j, k] = 0, i attend to j in batch k; If perm_mask[i, j, k] = 0, i attend to j in batch k;
if perm_mask[i, j, k] = 1, i does not attend to j in batch k. if perm_mask[i, j, k] = 1, i does not attend to j in batch k.
If None, each position attends to all the others. If None, each position attends to all the others.
target_mapping: float32 Tensor in shape [num_predict, len, bsz]. target_mapping: [optional] float32 Tensor in shape [num_predict, len, bsz].
If target_mapping[i, j, k] = 1, the i-th predict in batch k is If target_mapping[i, j, k] = 1, the i-th predict in batch k is
on the j-th token. on the j-th token.
Only used during pretraining for partial prediction. Only used during pretraining for partial prediction.
Set to None during finetuning. Set to None during finetuning.
inp_q: float32 Tensor in shape [len, bsz]. inp_q: [optional] float32 Tensor in shape [len, bsz].
1 for tokens with losses and 0 for tokens without losses. 1 for tokens with losses and 0 for tokens without losses.
Only used during pretraining for two-stream attention. Only used during pretraining for two-stream attention.
Set to None during finetuning. Set to None during finetuning.
...@@ -780,14 +820,16 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -780,14 +820,16 @@ class XLNetModel(XLNetPreTrainedModel):
summary_type: str, "last", "first", "mean", or "attn". The method summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation. to pool the input to get a vector representation.
""" """
qlen, bsz = inp_k.shape qlen, bsz = word_emb_k.shape[0], word_emb_k.shape[1]
mlen = mems[0].shape[0] if mems is not None else 0 mlen = mems[0].shape[0] if mems is not None else 0
klen = mlen + qlen klen = mlen + qlen
dtype_float = word_emb_k.dtype
device = word_emb_k.device
##### Attention mask ##### Attention mask
# causal attention mask # causal attention mask
if self.attn_type == 'uni': if self.attn_type == 'uni':
attn_mask = _create_mask(qlen, mlen, inp_k.dtype, self.same_length) attn_mask = self.create_mask(qlen, mlen)
attn_mask = attn_mask[:, :, None, None] attn_mask = attn_mask[:, :, None, None]
elif self.attn_type == 'bi': elif self.attn_type == 'bi':
attn_mask = None attn_mask = None
...@@ -806,7 +848,7 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -806,7 +848,7 @@ class XLNetModel(XLNetPreTrainedModel):
if data_mask is not None: if data_mask is not None:
# all mems can be attended to # all mems can be attended to
mems_mask = torch.zeros([data_mask.shape[0], mlen, bsz], dtype=data_mask.dtype, device=data_mask.device) mems_mask = torch.zeros([data_mask.shape[0], mlen, bsz]).to(data_mask)
data_mask = torch.cat([mems_mask, data_mask], dim=1) data_mask = torch.cat([mems_mask, data_mask], dim=1)
if attn_mask is None: if attn_mask is None:
attn_mask = data_mask[:, :, :, None] attn_mask = data_mask[:, :, :, None]
...@@ -814,23 +856,20 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -814,23 +856,20 @@ class XLNetModel(XLNetPreTrainedModel):
attn_mask += data_mask[:, :, :, None] attn_mask += data_mask[:, :, :, None]
if attn_mask is not None: if attn_mask is not None:
attn_mask = (attn_mask > 0).float() attn_mask = (attn_mask > 0).to(dtype_float)
if attn_mask is not None: if attn_mask is not None:
non_tgt_mask = -tf.eye(qlen, dtype=tf_float) non_tgt_mask = -torch.eye(qlen).to(attn_mask)
non_tgt_mask = tf.concat([tf.zeros([qlen, mlen], dtype=tf_float), non_tgt_mask = torch.cat([torch.zeros([qlen, mlen]).to(attn_mask), non_tgt_mask], dim=-1)
non_tgt_mask], axis=-1) non_tgt_mask = ((attn_mask + non_tgt_mask[:, :, None, None]) > 0).to(attn_mask)
non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0,
dtype=tf_float)
else: else:
non_tgt_mask = None non_tgt_mask = None
##### Word embedding ##### Process Word embeddings and prepare h & g hidden states
word_emb_k = self.word_embedding(inp_k)
output_h = self.dropout(word_emb_k) output_h = self.dropout(word_emb_k)
if inp_q is not None: if inp_q is not None:
if target_mapping is not None: if target_mapping is not None:
word_emb_q = mask_emb.expand(target_mapping.shape[0], bsz, 1) word_emb_q = mask_emb.expand(target_mapping.shape[0], bsz, -1)
else: else:
inp_q_ext = inp_q[:, :, None] inp_q_ext = inp_q[:, :, None]
word_emb_q = inp_q_ext * mask_emb + (1 - inp_q_ext) * word_emb_k word_emb_q = inp_q_ext * mask_emb + (1 - inp_q_ext) * word_emb_k
...@@ -841,33 +880,33 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -841,33 +880,33 @@ class XLNetModel(XLNetPreTrainedModel):
##### Segment embedding ##### Segment embedding
if seg_id is not None: if seg_id is not None:
# Convert `seg_id` to one-hot `seg_mat` # Convert `seg_id` to one-hot `seg_mat`
mem_pad = torch.zeros([mlen, bsz], dtype=torch.long) mem_pad = torch.zeros([mlen, bsz], dtype=torch.long, device=device)
cat_ids = torch.cat([mem_pad, seg_id], dim=0) cat_ids = torch.cat([mem_pad, seg_id], dim=0)
# `1` indicates not in the same segment [qlen x klen x bsz] # `1` indicates not in the same segment [qlen x klen x bsz]
seg_mat = (seg_id[:, None] != cat_ids[None, :]).long() seg_mat = (seg_id[:, None] != cat_ids[None, :]).long()
# seg_mat = tf.one_hot(seg_mat, 2, dtype=tf_float) seg_mat = F.one_hot(seg_mat, num_classes=2).to(dtype_float)
else: else:
seg_mat = None seg_mat = None
##### Positional encoding ##### Positional encoding
pos_emb = relative_positional_encoding(qlen, klen, bsz=bsz, dtype=inp_k.dtype) pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz)
pos_emb = self.dropout(pos_emb) pos_emb = self.dropout(pos_emb)
##### Head mask if needed (for bertology/pruning) ##### Head mask if needed (for bertology/pruning)
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # input head_mask has shape [num_heads] or [n_layer x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [n_layer x batch x num_heads x seq_length x seq_length]
if head_mask is not None: if head_mask is not None:
if head_mask.dim() == 1: if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand_as(self.config.num_hidden_layers, -1, -1, -1, -1) head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1)
elif head_mask.dim() == 2: elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
else: else:
head_mask = [None] * self.config.num_hidden_layers head_mask = [None] * self.config.n_layer
new_mems = [] new_mems = []
if mems is None: if mems is None:
...@@ -878,14 +917,14 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -878,14 +917,14 @@ class XLNetModel(XLNetPreTrainedModel):
new_mems.append(self.cache_mem(output_h, mems[i])) new_mems.append(self.cache_mem(output_h, mems[i]))
output_h, output_g = layer_module(output_h, output_g, output_h, output_g = layer_module(output_h, output_g,
attn_mask_h, attn_mask_g, attn_mask_h=non_tgt_mask, attn_mask_g=attn_mask,
r, seg_mat, r=pos_emb, seg_mat=seg_mat,
mems=mems[i], target_mapping=target_mapping, mems=mems[i], target_mapping=target_mapping,
head_mask=head_mask) head_mask=head_mask)
output = self.dropout(output_g if output_g is not None else output_h) output = self.dropout(output_g if output_g is not None else output_h)
return output return output, new_mems
class XLNetLMHeadModel(XLNetPreTrainedModel): class XLNetLMHeadModel(XLNetPreTrainedModel):
...@@ -932,28 +971,27 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -932,28 +971,27 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = modeling.XLNetConfig(vocab_size_or_config_json_file=32000, d_model=768, config = modeling.XLNetConfig(vocab_size_or_config_json_file=32000, d_model=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) n_layer=12, num_attention_heads=12, intermediate_size=3072)
model = modeling.XLNetModel(config=config) model = modeling.XLNetModel(config=config)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config, run_config, output_attentions=False, keep_multihead_output=False): def __init__(self, config, output_attentions=False, keep_multihead_output=False):
super(XLNetLMHeadModel, self).__init__(config) super(XLNetLMHeadModel, self).__init__(config)
self.output_attentions = output_attentions self.output_attentions = output_attentions
self.attn_type = run_config.attn_type self.attn_type = config.attn_type
self.same_length = run_config.same_length self.same_length = config.same_length
self.word_embedding = nn.Embedding(config.vocab_size, config.d_model) self.word_embedding = nn.Embedding(config.vocab_size, config.d_model)
self.mask_emb = nn.Parameter(torch.Tensor(1, 1, self.d_model)) self.mask_emb = nn.Parameter(torch.Tensor(1, 1, config.d_model))
self.transformer = XLNetModel(config, self.transformer = XLNetModel(config, output_attentions=output_attentions,
output_attentions=output_attentions, keep_multihead_output=keep_multihead_output)
keep_multihead_output=keep_multihead_output)
self.lm_loss = nn.Linear(config.d_model, config.vocab_size, bias=True) self.lm_loss = nn.Linear(config.d_model, config.vocab_size, bias=True)
self.dropout = nn.Dropout(config.dropout) self.dropout = nn.Dropout(config.dropout)
# Tie weights # Tie weights
if config.tie_weight: self.lm_loss.weight = self.word_embedding.weight
self.lm_loss.weight = self.word_embedding.weight
self.apply(self.init_xlnet_weights) self.apply(self.init_xlnet_weights)
...@@ -972,7 +1010,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -972,7 +1010,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
def forward(self, inp_k, seg_id=None, input_mask=None, def forward(self, inp_k, seg_id=None, input_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None, mems=None, perm_mask=None, target_mapping=None, inp_q=None,
output_all_encoded_layers=True, head_mask=None): target=None, output_all_encoded_layers=True, head_mask=None):
""" """
Args: Args:
inp_k: int32 Tensor in shape [len, bsz], the input token IDs. inp_k: int32 Tensor in shape [len, bsz], the input token IDs.
...@@ -1007,13 +1045,21 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -1007,13 +1045,21 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
summary_type: str, "last", "first", "mean", or "attn". The method summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation. to pool the input to get a vector representation.
""" """
output, new_mems = self.transformer(output_h, non_tgt_mask, r, seg_mat, word_emb_k = self.word_embedding(inp_k)
output_g=output_g, attn_mask_g=attn_mask,
mems=mems, target_mapping=target_mapping, output, new_mems = self.transformer(word_emb_k, seg_id, input_mask,
head_mask=head_mask) mems, perm_mask, target_mapping, inp_q,
output_all_encoded_layers, head_mask)
logits = self.lm_loss(output) logits = self.lm_loss(output)
if target is not None:
# Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(logits.view(-1, logits.size(-1)),
target.view(-1))
return loss, new_mems
# if self.output_attentions: # if self.output_attentions:
# all_attentions, encoded_layers = encoded_layers # all_attentions, encoded_layers = encoded_layers
# sequence_output = encoded_layers[-1] # sequence_output = encoded_layers[-1]
......
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import unittest
import json
import random
import shutil
import pytest
import torch
from pytorch_pretrained_bert import (XLNetConfig, XLNetRunConfig, XLNetModel, XLNetLMHeadModel)
from pytorch_pretrained_bert.modeling_xlnet import PRETRAINED_MODEL_ARCHIVE_MAP
class XLNetModelTest(unittest.TestCase):
class XLNetModelTester(object):
def __init__(self,
parent,
batch_size=13,
seq_length=7,
mem_len=30,
clamp_len=15,
reuse_len=15,
is_training=True,
use_labels=True,
vocab_size=99,
cutoffs=[10, 50, 80],
d_model=32,
n_head=4,
d_inner=128,
n_layer=5,
max_position_embeddings=10,
untie_r=True,
bi_data=False,
same_length=False,
seed=1,
type_vocab_size=2):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.mem_len = mem_len
self.clamp_len = clamp_len
self.reuse_len = reuse_len
self.is_training = is_training
self.use_labels = use_labels
self.vocab_size = vocab_size
self.cutoffs = cutoffs
self.d_model = d_model
self.n_head = n_head
self.d_inner = d_inner
self.n_layer = n_layer
self.max_position_embeddings = max_position_embeddings
self.bi_data = bi_data
self.untie_r = untie_r
self.same_length = same_length
self.seed = seed
self.type_vocab_size = type_vocab_size
def prepare_config_and_inputs(self):
input_ids_1 = XLNetModelTest.ids_tensor([self.seq_length, self.batch_size], self.vocab_size)
input_ids_2 = XLNetModelTest.ids_tensor([self.seq_length, self.batch_size], self.vocab_size)
segment_ids = XLNetModelTest.ids_tensor([self.seq_length, self.batch_size], self.type_vocab_size)
lm_labels = None
if self.use_labels:
lm_labels = XLNetModelTest.ids_tensor([self.seq_length, self.batch_size], self.vocab_size)
config = XLNetConfig(
vocab_size_or_config_json_file=self.vocab_size,
d_model=self.d_model,
n_head=self.n_head,
d_inner=self.d_inner,
n_layer=self.n_layer,
untie_r=self.untie_r,
max_position_embeddings=self.max_position_embeddings)
run_config = XLNetRunConfig(
mem_len=self.mem_len,
clamp_len=self.clamp_len,
same_length=self.same_length,
reuse_len=self.reuse_len,
bi_data=self.bi_data)
config.update(run_config)
return (config, input_ids_1, input_ids_2, segment_ids, lm_labels)
def set_seed(self):
random.seed(self.seed)
torch.manual_seed(self.seed)
def create_transfo_xl_model(self, config, input_ids_1, input_ids_2, segment_ids, lm_labels):
model = XLNetLMHeadModel(config)
model.eval()
hidden_states_1, mems_1 = model(input_ids_1, seg_id=segment_ids)
hidden_states_2, mems_2 = model(input_ids_2, seg_id=segment_ids, mems=mems_1)
outputs = {
"hidden_states_1": hidden_states_1,
"mems_1": mems_1,
"hidden_states_2": hidden_states_2,
"mems_2": mems_2,
}
return outputs
def check_transfo_xl_model_output(self, result):
self.parent.assertListEqual(
list(result["hidden_states_1"].size()),
[self.seq_length, self.batch_size, self.d_model])
self.parent.assertListEqual(
list(result["hidden_states_2"].size()),
[self.seq_length, self.batch_size, self.d_model])
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
def create_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, segment_ids, lm_labels):
model = XLNetLMHeadModel(config)
model.eval()
loss_1, mems_1a = model(input_ids_1, target=lm_labels)
lm_logits_1, mems_1b = model(input_ids_1)
loss_2, mems_2a = model(input_ids_2, target=lm_labels, mems=mems_1a)
lm_logits_2, mems_2b = model(input_ids_2, mems=mems_1b)
outputs = {
"loss_1": loss_1,
"mems_1a": mems_1a,
"lm_logits_1": lm_logits_1,
"mems_1b": mems_1b,
"loss_2": loss_2,
"mems_2a": mems_2a,
"lm_logits_2": lm_logits_2,
"mems_2b": mems_2b,
}
return outputs
def check_transfo_xl_lm_head_output(self, result):
self.parent.assertListEqual(
list(result["loss_1"].size()),
[self.seq_length, self.batch_size])
self.parent.assertListEqual(
list(result["lm_logits_1"].size()),
[self.seq_length, self.batch_size, self.vocab_size])
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1a"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1b"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
self.parent.assertListEqual(
list(mem[~torch.isnan(mem)].sum() for mem in result["mems_1a"]),
list(mem[~torch.isnan(mem)].sum() for mem in result["mems_1b"]))
self.parent.assertListEqual(
list(result["loss_2"].size()),
[self.seq_length, self.batch_size])
self.parent.assertListEqual(
list(result["lm_logits_2"].size()),
[self.seq_length, self.batch_size, self.vocab_size])
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2a"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2b"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
self.parent.assertListEqual(
list(mem[~torch.isnan(mem)].sum() for mem in result["mems_2a"]),
list(mem[~torch.isnan(mem)].sum() for mem in result["mems_2b"]))
def test_default(self):
self.run_tester(XLNetModelTest.XLNetModelTester(self))
def test_config_to_json_string(self):
config = XLNetConfig(vocab_size_or_config_json_file=96, d_model=37)
obj = json.loads(config.to_json_string())
self.assertEqual(obj["n_token"], 96)
self.assertEqual(obj["d_model"], 37)
def test_config_to_json_file(self):
config_first = XLNetConfig(vocab_size_or_config_json_file=96, d_model=37)
json_file_path = "/tmp/config.json"
config_first.to_json_file(json_file_path)
config_second = XLNetConfig.from_json_file(json_file_path)
os.remove(json_file_path)
self.assertEqual(config_second.to_dict(), config_first.to_dict())
@pytest.mark.slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = XLNetModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model)
def run_tester(self, tester):
config_and_inputs = tester.prepare_config_and_inputs()
tester.set_seed()
output_result = tester.create_transfo_xl_model(*config_and_inputs)
tester.check_transfo_xl_model_output(output_result)
tester.set_seed()
output_result = tester.create_transfo_xl_lm_head(*config_and_inputs)
tester.check_transfo_xl_lm_head_output(output_result)
@classmethod
def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
"""Creates a random int32 tensor of the shape within the vocab size."""
if rng is None:
rng = random.Random()
total_dims = 1
for dim in shape:
total_dims *= dim
values = []
for _ in range(total_dims):
values.append(rng.randint(0, vocab_size - 1))
return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment