Unverified Commit 901e9b8e authored by Funtowicz Morgan's avatar Funtowicz Morgan Committed by GitHub
Browse files

Remove the else branch adding 0 to the hidden state if token_type_embeds is None. (#7977)


Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>
parent f34372a9
...@@ -607,11 +607,12 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -607,11 +607,12 @@ class GPT2Model(GPT2PreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.wte(input_ids) inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids) position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
if token_type_ids is not None: if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids) token_type_embeds = self.wte(token_type_ids)
else: hidden_states = hidden_states + token_type_embeds
token_type_embeds = 0
hidden_states = inputs_embeds + position_embeds + token_type_embeds
hidden_states = self.drop(hidden_states) hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),) output_shape = input_shape + (hidden_states.size(-1),)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment