Commit 5ff0c605 authored by thomwolf's avatar thomwolf
Browse files

language update

parent 210d4072
...@@ -237,17 +237,17 @@ class Attention(nn.Module): ...@@ -237,17 +237,17 @@ class Attention(nn.Module):
else: else:
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3)
def forward(self, x, past=None): def forward(self, x, layer_past=None):
x = self.c_attn(x) x = self.c_attn(x)
query, key, value = x.split(self.split_size, dim=2) query, key, value = x.split(self.split_size, dim=2)
query = self.split_heads(query) query = self.split_heads(query)
key = self.split_heads(key, k=True) key = self.split_heads(key, k=True)
value = self.split_heads(value) value = self.split_heads(value)
present = key, value if layer_past is not None:
if past is not None: past_key, past_value = layer_past[0], layer_past[1]
past_key, past_value = past
key = torch.cat((past_key, key), dim=-2) key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2) value = torch.cat((past_value, value), dim=-2)
present = torch.stack((key, value))
a = self._attn(query, key, value) a = self._attn(query, key, value)
a = self.merge_heads(a) a = self.merge_heads(a)
a = self.c_proj(a) a = self.c_proj(a)
...@@ -277,8 +277,8 @@ class Block(nn.Module): ...@@ -277,8 +277,8 @@ class Block(nn.Module):
self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon) self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
self.mlp = MLP(4 * nx, config) self.mlp = MLP(4 * nx, config)
def forward(self, x, past=None): def forward(self, x, layer_past=None):
a, present = self.attn(self.ln_1(x), past=past) a, present = self.attn(self.ln_1(x), layer_past=past)
x = x + a x = x + a
m = self.mlp(self.ln_2(x)) m = self.mlp(self.ln_2(x))
x = x + m x = x + m
...@@ -346,7 +346,7 @@ class GPT2PreTrainedModel(nn.Module): ...@@ -346,7 +346,7 @@ class GPT2PreTrainedModel(nn.Module):
) )
self.config = config self.config = config
def set_tied(): def set_tied(self):
pass pass
def init_weights(self, module): def init_weights(self, module):
...@@ -526,12 +526,12 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -526,12 +526,12 @@ class GPT2Model(GPT2PreTrainedModel):
self.apply(self.init_weights) self.apply(self.init_weights)
def forward(self, input_ids, position_ids=None, token_type_ids=None, pasts=None): def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None):
if pasts is None: if past is None:
past_length = 0 past_length = 0
pasts = [None] * len(self.h) past = [None] * len(self.h)
else: else:
pasts[0][0].size(-2) past[0][0].size(-2)
if position_ids is None: if position_ids is None:
position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device) position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
...@@ -549,8 +549,8 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -549,8 +549,8 @@ class GPT2Model(GPT2PreTrainedModel):
token_type_embeds = 0 token_type_embeds = 0
hidden_states = inputs_embeds + position_embeds + token_type_embeds hidden_states = inputs_embeds + position_embeds + token_type_embeds
presents = [] presents = []
for block, past in zip(self.h, pasts): for block, layer_past in zip(self.h, past):
hidden_states, present = block(hidden_states, past) hidden_states, present = block(hidden_states, layer_past)
presents.append(present) presents.append(present)
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),) output_shape = input_shape + (hidden_states.size(-1),)
...@@ -607,8 +607,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -607,8 +607,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
""" """
self.lm_head.set_embeddings_weights(self.transformer.wte.weight) self.lm_head.set_embeddings_weights(self.transformer.wte.weight)
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, pasts=None): def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None):
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, pasts) hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
if lm_labels is not None: if lm_labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
...@@ -673,8 +673,8 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -673,8 +673,8 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
""" """
self.lm_head.set_embeddings_weights(self.transformer.wte.weight) self.lm_head.set_embeddings_weights(self.transformer.wte.weight)
def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None, pasts=None): def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None, past=None):
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, pasts) hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids) mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
losses = [] losses = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment