from typing import List from typing import Optional from typing import Tuple from typing import Union import torch from transformers.modeling_outputs import BaseModelOutputWithPast from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast def lce_forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, skip_logits: Optional[bool] = None, **kwargs, ) -> Union[Tuple, LigerCausalLMOutputWithPast]: r""" Example: ```python >>> from transformers import AutoTokenizer, Phi3ForCausalLM >>> model = Phi3ForCausalLM.from_pretrained("meta-phi3/Phi3-2-7b-hf") >>> tokenizer = AutoTokenizer.from_pretrained("meta-phi3/Phi3-2-7b-hf") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, **kwargs, ) hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep kept_hidden_states = hidden_states[:, slice_indices, :] shift_labels = kwargs.pop("shift_labels", None) logits = None loss = None token_accuracy = None predicted_tokens = None if skip_logits and labels is None and shift_labels is None: raise ValueError("skip_logits is True, but labels and shift_labels are None") if skip_logits is None: # By default, if in training mode, don't materialize logits skip_logits = self.training and (labels is not None or shift_labels is not None) # Compute loss if skip_logits: result = LigerForCausalLMLoss( hidden_states=kept_hidden_states, lm_head_weight=self.lm_head.weight, labels=labels, shift_labels=shift_labels, hidden_size=self.config.hidden_size, **kwargs, ) loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) else: logits = self.lm_head(kept_hidden_states) if labels is not None or shift_labels is not None: loss = self.loss_function( logits=logits, labels=labels, shift_labels=shift_labels, vocab_size=self.config.vocab_size, **kwargs, ) if not return_dict: output_tuple = (logits,) + outputs[1:] output = (loss,) + output_tuple if loss is not None else output_tuple output = output + (token_accuracy,) if token_accuracy is not None else output output = output + (predicted_tokens,) if predicted_tokens is not None else output return output # Return custom output class with token_accuracy field return LigerCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, token_accuracy=token_accuracy, predicted_tokens=predicted_tokens, )