"...git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "7daad468ec945aabfbf3f502c6c059bfc818014d"
Commit 965f172d authored by thomwolf's avatar thomwolf
Browse files

output all hidden layers states in GPT/GPT-2

parent f12007e4
...@@ -720,9 +720,13 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -720,9 +720,13 @@ class GPT2Model(GPT2PreTrainedModel):
hidden_states = inputs_embeds + position_embeds + token_type_embeds hidden_states = inputs_embeds + position_embeds + token_type_embeds
hidden_states = self.drop(hidden_states) hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
presents = [] presents = []
all_attentions = [] all_attentions = []
all_hidden_states = []
for block, layer_past in zip(self.h, past): for block, layer_past in zip(self.h, past):
all_hidden_states.append(hidden_states.view(*output_shape))
outputs = block(hidden_states, layer_past, head_mask) outputs = block(hidden_states, layer_past, head_mask)
if self.output_attentions: if self.output_attentions:
attentions, hidden_states, present = outputs attentions, hidden_states, present = outputs
...@@ -731,10 +735,11 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -731,10 +735,11 @@ class GPT2Model(GPT2PreTrainedModel):
hidden_states, present = outputs hidden_states, present = outputs
presents.append(present) presents.append(present)
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),) all_hidden_states.append(hidden_states.view(*output_shape))
if self.output_attentions: if self.output_attentions:
return all_attentions, hidden_states.view(*output_shape), presents return all_attentions, all_hidden_states, presents
return hidden_states.view(*output_shape), presents return all_hidden_states, presents
class GPT2LMHeadModel(GPT2PreTrainedModel): class GPT2LMHeadModel(GPT2PreTrainedModel):
...@@ -802,6 +807,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -802,6 +807,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
all_attentions, hidden_states, presents = transformer_output all_attentions, hidden_states, presents = transformer_output
else: else:
hidden_states, presents = transformer_output hidden_states, presents = transformer_output
hidden_states = hidden_states[-1]
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
if lm_labels is not None: if lm_labels is not None:
# Shift so that tokens < n predict n # Shift so that tokens < n predict n
...@@ -889,6 +896,8 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -889,6 +896,8 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
all_attentions, hidden_states, presents = transformer_output all_attentions, hidden_states, presents = transformer_output
else: else:
hidden_states, presents = transformer_output hidden_states, presents = transformer_output
hidden_states = hidden_states[-1]
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids) mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
losses = [] losses = []
......
...@@ -716,7 +716,10 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -716,7 +716,10 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
hidden_states = inputs_embeds + position_embeds + token_type_embeds hidden_states = inputs_embeds + position_embeds + token_type_embeds
hidden_states = self.drop(hidden_states) hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
all_attentions = [] all_attentions = []
all_hidden_states = [hidden_states.view(*output_shape)]
for block in self.h: for block in self.h:
outputs = block(hidden_states, head_mask) outputs = block(hidden_states, head_mask)
if self.output_attentions: if self.output_attentions:
...@@ -724,10 +727,11 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -724,10 +727,11 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
all_attentions.append(attentions) all_attentions.append(attentions)
else: else:
hidden_states = outputs hidden_states = outputs
output_shape = input_shape + (hidden_states.size(-1),) all_hidden_states.append(hidden_states.view(*output_shape))
if self.output_attentions: if self.output_attentions:
return all_attentions, hidden_states.view(*output_shape) return all_attentions, all_hidden_states
return hidden_states.view(*output_shape) return all_hidden_states
class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
...@@ -805,6 +809,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): ...@@ -805,6 +809,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
hidden_states = self.transformer(input_ids, position_ids, token_type_ids, head_mask) hidden_states = self.transformer(input_ids, position_ids, token_type_ids, head_mask)
if self.transformer.output_attentions: if self.transformer.output_attentions:
all_attentions, hidden_states = hidden_states all_attentions, hidden_states = hidden_states
hidden_states = hidden_states[-1]
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
if lm_labels is not None: if lm_labels is not None:
# Shift so that tokens < n predict n # Shift so that tokens < n predict n
...@@ -902,6 +908,8 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ...@@ -902,6 +908,8 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
hidden_states = self.transformer(input_ids, position_ids, token_type_ids, head_mask) hidden_states = self.transformer(input_ids, position_ids, token_type_ids, head_mask)
if self.transformer.output_attentions: if self.transformer.output_attentions:
all_attentions, hidden_states = hidden_states all_attentions, hidden_states = hidden_states
hidden_states = hidden_states[-1]
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids) mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
losses = [] losses = []
......
...@@ -115,8 +115,9 @@ class GPT2ModelTest(unittest.TestCase): ...@@ -115,8 +115,9 @@ class GPT2ModelTest(unittest.TestCase):
return outputs return outputs
def check_gpt2_model_output(self, result): def check_gpt2_model_output(self, result):
self.parent.assertEqual(len(result["hidden_states"]), self.n_layer + 1)
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["hidden_states"].size()), list(result["hidden_states"][0].size()),
[self.batch_size, self.n_choices, self.seq_length, self.n_embd]) [self.batch_size, self.n_choices, self.seq_length, self.n_embd])
...@@ -222,7 +223,10 @@ class GPT2ModelTest(unittest.TestCase): ...@@ -222,7 +223,10 @@ class GPT2ModelTest(unittest.TestCase):
else: else:
output = model(input_ids, head_mask=head_mask) output = model(input_ids, head_mask=head_mask)
output = sum(t.sum() for t in output[:-1]) if isinstance(model, GPT2Model):
output = sum(t.sum() for t in output[0])
elif isinstance(output, (list, tuple)):
output = sum(t.sum() for t in output[:-1])
output = output.sum() output = output.sum()
output.backward() output.backward()
multihead_outputs = (model if isinstance(model, GPT2Model) else model.transformer).get_multihead_outputs() multihead_outputs = (model if isinstance(model, GPT2Model) else model.transformer).get_multihead_outputs()
...@@ -256,7 +260,10 @@ class GPT2ModelTest(unittest.TestCase): ...@@ -256,7 +260,10 @@ class GPT2ModelTest(unittest.TestCase):
else: else:
output = model(input_ids) output = model(input_ids)
output = sum(t.sum() for t in output[:-1]) if isinstance(model, GPT2Model):
output = sum(t.sum() for t in output[0])
elif isinstance(output, (list, tuple)):
output = sum(t.sum() for t in output[:-1])
output = output.sum() output = output.sum()
output.backward() output.backward()
multihead_outputs = transformer.get_multihead_outputs() multihead_outputs = transformer.get_multihead_outputs()
......
...@@ -125,8 +125,9 @@ class OpenAIGPTModelTest(unittest.TestCase): ...@@ -125,8 +125,9 @@ class OpenAIGPTModelTest(unittest.TestCase):
return outputs return outputs
def check_openai_model_output(self, result): def check_openai_model_output(self, result):
self.parent.assertEqual(len(result["hidden_states"]), self.n_layer + 1)
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["hidden_states"].size()), list(result["hidden_states"][0].size()),
[self.batch_size, self.n_choices, self.seq_length, self.n_embd]) [self.batch_size, self.n_choices, self.seq_length, self.n_embd])
...@@ -195,7 +196,10 @@ class OpenAIGPTModelTest(unittest.TestCase): ...@@ -195,7 +196,10 @@ class OpenAIGPTModelTest(unittest.TestCase):
else: else:
output = model(input_ids, head_mask=head_mask) output = model(input_ids, head_mask=head_mask)
output = sum(t.sum() for t in output[:-1]) if isinstance(model, OpenAIGPTModel):
output = sum(t.sum() for t in output[0])
elif isinstance(output, (list, tuple)):
output = sum(t.sum() for t in output)
output = output.sum() output = output.sum()
output.backward() output.backward()
multihead_outputs = (model if isinstance(model, OpenAIGPTModel) else model.transformer).get_multihead_outputs() multihead_outputs = (model if isinstance(model, OpenAIGPTModel) else model.transformer).get_multihead_outputs()
...@@ -229,7 +233,10 @@ class OpenAIGPTModelTest(unittest.TestCase): ...@@ -229,7 +233,10 @@ class OpenAIGPTModelTest(unittest.TestCase):
else: else:
output = model(input_ids) output = model(input_ids)
output = sum(t.sum() for t in output[:-1]) if isinstance(model, OpenAIGPTModel):
output = sum(t.sum() for t in output[0])
elif isinstance(output, (list, tuple)):
output = sum(t.sum() for t in output)
output = output.sum() output = output.sum()
output.backward() output.backward()
multihead_outputs = transformer.get_multihead_outputs() multihead_outputs = transformer.get_multihead_outputs()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment