"...text-generation-inference.git" did not exist on "f5d43414c20810dfd797d64f186e35580487883d"
Unverified Commit 521de6ca authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

fix(server): fix OPT implementation (#2061)

parent 376a0b7a
...@@ -792,7 +792,7 @@ class OPTForCausalLM(OPTPreTrainedModel): ...@@ -792,7 +792,7 @@ class OPTForCausalLM(OPTPreTrainedModel):
return_dict=return_dict, return_dict=return_dict,
) )
logits, speculative_logits = self.lm_head(outputs) logits, speculative_logits = self.lm_head(outputs.last_hidden_state)
loss = None loss = None
......
...@@ -85,5 +85,4 @@ class GPTNeoxSharded(CausalLM): ...@@ -85,5 +85,4 @@ class GPTNeoxSharded(CausalLM):
use_cache=True, use_cache=True,
) )
logits = outputs.logits return outputs.logits, speculative_logits, outputs.past_key_values
return logits, speculative_logits, outputs.past_key_values
...@@ -75,11 +75,11 @@ class OPTSharded(CausalLM): ...@@ -75,11 +75,11 @@ class OPTSharded(CausalLM):
def forward( def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
): ):
outputs = self.model.forward( outputs, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=True, use_cache=True,
) )
return outputs.logits, outputs.past_key_values return outputs.logits, speculative_logits, outputs.past_key_values
...@@ -71,11 +71,13 @@ class RW(CausalLM): ...@@ -71,11 +71,13 @@ class RW(CausalLM):
def forward( def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: ):
# Model Forward # Model Forward
outputs = self.model.forward( outputs, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=True,
) )
return outputs.logits, outputs.past_key_values
return outputs.logits, speculative_logits, outputs.past_key_values
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment