"examples/industrial_data_pretraining/campplus_sv/demo.py" did not exist on "8c2527766639f8cd7a86400a154e0a2fe44c90e6"
Commit 9b0e3a30 authored by cmx's avatar cmx
Browse files

first commit

parent fe5cd1fc
Pipeline #3450 failed with stages
in 0 seconds
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import torch
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
def lce_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
skip_logits: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, LigerCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, Olmo2ForCausalLM
>>> model = Olmo2ForCausalLM.from_pretrained("allenai/Olmo2-1B-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("allenai/Olmo2-1B-hf")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m'
```
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
kept_hidden_states = hidden_states[:, slice_indices, :]
shift_labels = kwargs.pop("shift_labels", None)
logits = None
loss = None
token_accuracy = None
predicted_tokens = None
if skip_logits and labels is None and shift_labels is None:
raise ValueError("skip_logits is True, but labels and shift_labels are None")
if skip_logits is None:
# By default, if in training mode, don't materialize logits
skip_logits = self.training and (labels is not None or shift_labels is not None)
# Compute loss
if skip_logits:
result = LigerForCausalLMLoss(
hidden_states=kept_hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
shift_labels=shift_labels,
hidden_size=self.config.hidden_size,
**kwargs,
)
loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result)
else:
logits = self.lm_head(kept_hidden_states)
if labels is not None or shift_labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
shift_labels=shift_labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
if not return_dict:
output = (logits,) + outputs[1:]
output = ((loss,) + output) if loss is not None else output
output = output + (token_accuracy,) if token_accuracy is not None else output
output = output + (predicted_tokens,) if predicted_tokens is not None else output
return output
# Return custom output class with token_accuracy field
return LigerCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
token_accuracy=token_accuracy,
predicted_tokens=predicted_tokens,
)
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import torch
from transformers.modeling_outputs import BaseModelOutputWithPast
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
def lce_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
skip_logits: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, LigerCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, Olmo3ForCausalLM
>>> model = Olmo3ForCausalLM.from_pretrained("allenai/Olmo-3-7B-Instruct")
>>> tokenizer = AutoTokenizer.from_pretrained("allenai/Olmo-3-7B-Instruct")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m'
```
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
kept_hidden_states = hidden_states[:, slice_indices, :]
shift_labels = kwargs.pop("shift_labels", None)
logits = None
loss = None
token_accuracy = None
predicted_tokens = None
if skip_logits and labels is None and shift_labels is None:
raise ValueError("skip_logits is True, but labels and shift_labels are None")
if skip_logits is None:
# By default, if in training mode, don't materialize logits
skip_logits = self.training and (labels is not None or shift_labels is not None)
# Compute loss
if skip_logits:
result = LigerForCausalLMLoss(
hidden_states=kept_hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
shift_labels=shift_labels,
hidden_size=self.config.hidden_size,
**kwargs,
)
loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result)
else:
logits = self.lm_head(kept_hidden_states)
if labels is not None or shift_labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
shift_labels=shift_labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
if not return_dict:
output = (logits,) + outputs[1:]
output = ((loss,) + output) if loss is not None else output
output = output + (token_accuracy,) if token_accuracy is not None else output
output = output + (predicted_tokens,) if predicted_tokens is not None else output
return output
# Return custom output class with token_accuracy field
return LigerCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
token_accuracy=token_accuracy,
predicted_tokens=predicted_tokens,
)
"""
Custom output classes for Liger-Kernel that extend transformers' ModelOutput classes
with optional token accuracy field.
"""
from dataclasses import dataclass
from typing import Optional
import torch
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
# The following model-specific outputs are optional and depend on the installed
# transformers version. Guard their imports so our module remains importable
# even when those models are not available in the environment.
try:
from transformers.models.gemma3.modeling_gemma3 import Gemma3CausalLMOutputWithPast as _Gemma3CausalLMOutputWithPast
except Exception:
_Gemma3CausalLMOutputWithPast = None
try:
from transformers.models.glm4v_moe.modeling_glm4v_moe import (
Glm4vMoeCausalLMOutputWithPast as _Glm4vMoeCausalLMOutputWithPast,
)
except Exception:
_Glm4vMoeCausalLMOutputWithPast = None
try:
from transformers.models.internvl.modeling_internvl import (
InternVLCausalLMOutputWithPast as _InternVLCausalLMOutputWithPast,
)
except Exception:
_InternVLCausalLMOutputWithPast = None
try:
from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast as _LlavaCausalLMOutputWithPast
except Exception:
_LlavaCausalLMOutputWithPast = None
try:
from transformers.models.paligemma.modeling_paligemma import (
PaliGemmaCausalLMOutputWithPast as _PaliGemmaCausalLMOutputWithPast,
)
except Exception:
_PaliGemmaCausalLMOutputWithPast = None
try:
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLCausalLMOutputWithPast as _Qwen2_5_VLCausalLMOutputWithPast,
)
except Exception:
_Qwen2_5_VLCausalLMOutputWithPast = None
try:
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
Qwen2VLCausalLMOutputWithPast as _Qwen2VLCausalLMOutputWithPast,
)
except Exception:
_Qwen2VLCausalLMOutputWithPast = None
try:
from transformers.models.qwen3_vl.modeling_qwen3_vl import (
Qwen3VLCausalLMOutputWithPast as _Qwen3VLCausalLMOutputWithPast,
)
except Exception:
_Qwen3VLCausalLMOutputWithPast = None
try:
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
Qwen3VLMoeCausalLMOutputWithPast as _Qwen3VLMoeCausalLMOutputWithPast,
)
except Exception:
_Qwen3VLMoeCausalLMOutputWithPast = None
try:
from transformers.models.qwen3_5.modeling_qwen3_5 import (
Qwen3_5CausalLMOutputWithPast as _Qwen3_5CausalLMOutputWithPast,
)
except Exception:
_Qwen3_5CausalLMOutputWithPast = None
@dataclass
class LigerCausalLMOutputWithPast(CausalLMOutputWithPast):
token_accuracy: Optional[torch.FloatTensor] = None
predicted_tokens: Optional[torch.LongTensor] = None
@dataclass
class LigerMoeCausalLMOutputWithPast(MoeCausalLMOutputWithPast):
token_accuracy: Optional[torch.FloatTensor] = None
predicted_tokens: Optional[torch.LongTensor] = None
if _Gemma3CausalLMOutputWithPast is not None:
@dataclass
class LigerGemma3CausalLMOutputWithPast(_Gemma3CausalLMOutputWithPast):
token_accuracy: Optional[torch.FloatTensor] = None
predicted_tokens: Optional[torch.LongTensor] = None
if _Glm4vMoeCausalLMOutputWithPast is not None:
@dataclass
class LigerGlm4vMoeCausalLMOutputWithPast(_Glm4vMoeCausalLMOutputWithPast):
token_accuracy: Optional[torch.FloatTensor] = None
predicted_tokens: Optional[torch.LongTensor] = None
if _LlavaCausalLMOutputWithPast is not None:
@dataclass
class LigerLlavaCausalLMOutputWithPast(_LlavaCausalLMOutputWithPast):
token_accuracy: Optional[torch.FloatTensor] = None
predicted_tokens: Optional[torch.LongTensor] = None
if _InternVLCausalLMOutputWithPast is not None:
@dataclass
class LigerInternVLCausalLMOutputWithPast(_InternVLCausalLMOutputWithPast):
token_accuracy: Optional[torch.FloatTensor] = None
predicted_tokens: Optional[torch.LongTensor] = None
if _PaliGemmaCausalLMOutputWithPast is not None:
@dataclass
class LigerPaliGemmaCausalLMOutputWithPast(_PaliGemmaCausalLMOutputWithPast):
token_accuracy: Optional[torch.FloatTensor] = None
predicted_tokens: Optional[torch.LongTensor] = None
if _Qwen2_5_VLCausalLMOutputWithPast is not None:
@dataclass
class LigerQwen2_5_VLCausalLMOutputWithPast(_Qwen2_5_VLCausalLMOutputWithPast):
token_accuracy: Optional[torch.FloatTensor] = None
predicted_tokens: Optional[torch.LongTensor] = None
if _Qwen2VLCausalLMOutputWithPast is not None:
@dataclass
class LigerQwen2VLCausalLMOutputWithPast(_Qwen2VLCausalLMOutputWithPast):
token_accuracy: Optional[torch.FloatTensor] = None
predicted_tokens: Optional[torch.LongTensor] = None
if _Qwen3VLCausalLMOutputWithPast is not None:
@dataclass
class LigerQwen3VLCausalLMOutputWithPast(_Qwen3VLCausalLMOutputWithPast):
token_accuracy: Optional[torch.FloatTensor] = None
predicted_tokens: Optional[torch.LongTensor] = None
if _Qwen3VLMoeCausalLMOutputWithPast is not None:
@dataclass
class LigerQwen3VLMoeCausalLMOutputWithPast(_Qwen3VLMoeCausalLMOutputWithPast):
token_accuracy: Optional[torch.FloatTensor] = None
predicted_tokens: Optional[torch.LongTensor] = None
if _Qwen3_5CausalLMOutputWithPast is not None:
@dataclass
class LigerQwen3_5CausalLMOutputWithPast(_Qwen3_5CausalLMOutputWithPast):
token_accuracy: Optional[torch.FloatTensor] = None
predicted_tokens: Optional[torch.LongTensor] = None
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import torch
from torch.nn import CrossEntropyLoss
from transformers.cache_utils import Cache
from transformers.utils import is_torchdynamo_compiling
from transformers.utils import logging
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
from liger_kernel.transformers.model.output_classes import LigerPaliGemmaCausalLMOutputWithPast
logger = logging.get_logger(__name__)
def lce_forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
token_type_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
skip_logits: Optional[bool] = None,
**lm_kwargs,
) -> Union[Tuple, LigerPaliGemmaCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
>>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/PaliGemma-test-224px-hf")
>>> processor = AutoProcessor.from_pretrained("google/PaliGemma-test-224px-hf")
>>> prompt = "answer en Where is the cow standing?"
>>> url = "https://huggingface.co/gv-hf/PaliGemma-test-224px-hf/resolve/main/cow_beach_1.png"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs, max_length=30)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"answer en Where is the cow standing?\nbeach"
```"""
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
is_training = token_type_ids is not None and labels is not None
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed
# Merge text and images
if pixel_values is not None:
image_features = self.get_image_features(pixel_values)
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
image_tokens_in_text = torch.sum(input_ids == self.config.image_token_index)
raise ValueError(
f"Number of images does not match number of special image tokens in the input text. "
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
"tokens from image embeddings."
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
# mask out pad-token-ids in labels for BC
if labels is not None and self.pad_token_id in labels:
logger.warning_once(
"`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
"You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
)
labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
causal_mask = self._update_causal_mask(
attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
)
outputs = self.language_model.model(
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**lm_kwargs,
)
shift_labels = lm_kwargs.pop("shift_labels", None)
hidden_states = outputs[0]
loss = None
logits = None
token_accuracy = None
predicted_tokens = None
if skip_logits and labels is None:
raise ValueError("skip_logits is True, but labels is None")
if skip_logits is None:
skip_logits = self.training and (labels is not None)
if skip_logits:
shift_hidden_states = hidden_states[..., :-1, :]
shift_labels = labels[..., 1:]
hidden_device = shift_hidden_states.device
if attention_mask is not None:
# we use the input attention mask to shift the hidden_states and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -shift_hidden_states.shape[1] :].to(hidden_device)
shift_hidden_states = shift_hidden_states[shift_attention_mask.to(hidden_device) != 0].contiguous()
shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
else:
shift_hidden_states = shift_hidden_states.contiguous()
shift_labels = shift_labels.contiguous()
# Flatten hidden state
shift_hidden_states = shift_hidden_states.view(-1, self.config.text_config.hidden_size)
shift_labels = shift_labels.view(-1).to(hidden_device)
# Use LigerForCausalLMLoss with accuracy support and pass already shifted labels
result = LigerForCausalLMLoss(
hidden_states=shift_hidden_states,
lm_head_weight=self.language_model.lm_head.weight,
labels=None,
shift_labels=shift_labels,
hidden_size=self.config.text_config.hidden_size,
**lm_kwargs,
)
loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result)
else:
logits = self.language_model.lm_head(hidden_states)
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
shift_logits = logits[..., :-1, :]
shift_labels = labels[..., 1:]
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
else:
shift_logits = shift_logits.contiguous()
shift_labels = shift_labels.contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
flat_labels = shift_labels.view(-1).to(shift_logits.device)
loss = loss_fct(flat_logits, flat_labels)
elif shift_labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
shift_logits = logits[..., :-1, :]
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
else:
shift_logits = shift_logits.contiguous()
shift_labels = shift_labels.contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
flat_labels = shift_labels.view(-1).to(shift_logits.device)
loss = loss_fct(flat_logits, flat_labels)
if not return_dict:
output = (logits,) + outputs[1:]
output = (loss,) + output if loss is not None else output
output = output + (token_accuracy,) if token_accuracy is not None else output
output = output + (predicted_tokens,) if predicted_tokens is not None else output
return output
# Return PaliGemma output with token_accuracy field
return LigerPaliGemmaCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
token_accuracy=token_accuracy,
predicted_tokens=predicted_tokens,
)
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import torch
from transformers.modeling_outputs import BaseModelOutputWithPast
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
def lce_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
skip_logits: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, LigerCausalLMOutputWithPast]:
r"""
Example:
```python
>>> from transformers import AutoTokenizer, Phi3ForCausalLM
>>> model = Phi3ForCausalLM.from_pretrained("meta-phi3/Phi3-2-7b-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-phi3/Phi3-2-7b-hf")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
kept_hidden_states = hidden_states[:, slice_indices, :]
shift_labels = kwargs.pop("shift_labels", None)
logits = None
loss = None
token_accuracy = None
predicted_tokens = None
if skip_logits and labels is None and shift_labels is None:
raise ValueError("skip_logits is True, but labels and shift_labels are None")
if skip_logits is None:
# By default, if in training mode, don't materialize logits
skip_logits = self.training and (labels is not None or shift_labels is not None)
# Compute loss
if skip_logits:
result = LigerForCausalLMLoss(
hidden_states=kept_hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
shift_labels=shift_labels,
hidden_size=self.config.hidden_size,
**kwargs,
)
loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result)
else:
logits = self.lm_head(kept_hidden_states)
if labels is not None or shift_labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
shift_labels=shift_labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
if not return_dict:
output_tuple = (logits,) + outputs[1:]
output = (loss,) + output_tuple if loss is not None else output_tuple
output = output + (token_accuracy,) if token_accuracy is not None else output
output = output + (predicted_tokens,) if predicted_tokens is not None else output
return output
# Return custom output class with token_accuracy field
return LigerCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
token_accuracy=token_accuracy,
predicted_tokens=predicted_tokens,
)
# Pixtral vision encoder does not require a custom forward function.
# The Liger kernel optimizations for Pixtral (RMSNorm, SwiGLU, RoPE) are applied
# via class/function-level monkey patching in monkey_patch.py, which is sufficient
# since the vision encoder has no cross-entropy loss to fuse.
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import torch
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import CausalLMOutputWithPast
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
def lce_forward_deprecated(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
skip_logits: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Copy paste Qwen2's forward but replace torch cross entropy with liger fused linear cross entropy
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> model = Qwen2ForCausalLM.from_pretrained("Qwen/Qwen2-1.5B")
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
loss = None
logits = None
if skip_logits and labels is None:
raise ValueError("skip_logits is True, but labels is None")
if skip_logits is None:
# By default, if in training mode, don't materialize logits
skip_logits = self.training and labels is not None
if self.training and (labels is not None):
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# flatten tokens
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
shift_labels = shift_labels.view(-1)
lce = LigerFusedLinearCrossEntropyLoss()
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
else:
logits = self.lm_head(hidden_states)
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def lce_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
skip_logits: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, LigerCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, Qwen2ForCausalLM
>>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
kept_hidden_states = hidden_states[:, slice_indices, :]
shift_labels = kwargs.pop("shift_labels", None)
logits = None
loss = None
token_accuracy = None
predicted_tokens = None
if skip_logits and labels is None and shift_labels is None:
raise ValueError("skip_logits is True, but labels and shift_labels are None")
if skip_logits is None:
# By default, if in training mode, don't materialize logits
skip_logits = self.training and (labels is not None or shift_labels is not None)
# Compute loss
if skip_logits:
result = LigerForCausalLMLoss(
hidden_states=kept_hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
shift_labels=shift_labels,
hidden_size=self.config.hidden_size,
**kwargs,
)
loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result)
else:
logits = self.lm_head(kept_hidden_states)
if labels is not None or shift_labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
shift_labels=shift_labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
if not return_dict:
output_tuple = (logits,) + outputs[1:]
output = (loss,) + output_tuple if loss is not None else output_tuple
output = output + (token_accuracy,) if token_accuracy is not None else output
output = output + (predicted_tokens,) if predicted_tokens is not None else output
return output
# Return custom output class with token accuracy field
return LigerCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
token_accuracy=token_accuracy,
predicted_tokens=predicted_tokens,
)
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import torch
from packaging import version
from transformers import __version__ as transformers_version
from transformers.utils import can_return_tuple
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
from liger_kernel.transformers.model.output_classes import LigerQwen2_5_VLCausalLMOutputWithPast
_TRANSFORMERS_V5_OR_LATER = version.parse(transformers_version) >= version.parse("5.0.0")
def _get_hidden_size(config) -> int:
"""Get hidden_size from Qwen2.5VLConfig in a version-aware manner."""
if _TRANSFORMERS_V5_OR_LATER:
return config.text_config.hidden_size
return config.hidden_size
def _get_vocab_size(config) -> int:
"""Get vocab_size from Qwen2.5VLConfig in a version-aware manner."""
if _TRANSFORMERS_V5_OR_LATER:
return config.text_config.vocab_size
return config.vocab_size
@can_return_tuple
def lce_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
mm_token_type_ids: Optional[torch.IntTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
skip_logits: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, LigerQwen2_5_VLCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)):
The tensors corresponding to the input videos. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`Qwen2_5_VLImageProcessor.__call__`] for details. [`Qwen2_5_VLProcessor`] uses
[`Qwen2_5_VLImageProcessor`] for processing videos.
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
The rope index difference between sequence length and multimodal rope.
second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
>>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
>>> messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
mm_token_type_ids=mm_token_type_ids,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
shift_labels = kwargs.pop("shift_labels", None)
loss = None
logits = None
token_accuracy = None
predicted_tokens = None
if skip_logits and labels is None and shift_labels is None:
raise ValueError("skip_logits is True, but labels and shift_labels are None")
if skip_logits is None:
skip_logits = self.training and (labels is not None or shift_labels is not None)
# Compute loss
if skip_logits:
result = LigerForCausalLMLoss(
hidden_states=hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
shift_labels=shift_labels,
hidden_size=_get_hidden_size(self.config),
**kwargs,
)
loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result)
else:
logits = self.lm_head(hidden_states)
loss = None
if labels is not None or shift_labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
shift_labels=shift_labels,
vocab_size=_get_vocab_size(self.config),
)
if not return_dict:
output_tuple = (logits,) + outputs[1:]
output = (loss,) + output_tuple if loss is not None else output_tuple
output = output + (token_accuracy,) if token_accuracy is not None else output
output = output + (predicted_tokens,) if predicted_tokens is not None else output
return output
# Return Qwen2.5-VL output with token accuracy
return LigerQwen2_5_VLCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=outputs.rope_deltas,
token_accuracy=token_accuracy,
predicted_tokens=predicted_tokens,
)
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import torch
from packaging import version
from transformers import __version__ as transformers_version
from transformers.utils import can_return_tuple
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
from liger_kernel.transformers.model.output_classes import LigerQwen2VLCausalLMOutputWithPast
_TRANSFORMERS_V5_OR_LATER = version.parse(transformers_version) >= version.parse("5.0.0")
def _get_hidden_size(config) -> int:
"""Get hidden_size from Qwen2VLConfig in a version-aware manner."""
if _TRANSFORMERS_V5_OR_LATER:
return config.text_config.hidden_size
return config.hidden_size
def _get_vocab_size(config) -> int:
"""Get vocab_size from Qwen2VLConfig in a version-aware manner."""
if _TRANSFORMERS_V5_OR_LATER:
return config.text_config.vocab_size
return config.vocab_size
@can_return_tuple
def lce_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
mm_token_type_ids: Optional[torch.IntTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
skip_logits: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, LigerQwen2VLCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)):
The tensors corresponding to the input videos. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`Qwen2VLImageProcessor.__call__`] for details. [`Qwen2VLProcessor`] uses
[`Qwen2VLImageProcessor`] for processing videos.
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
The rope index difference between sequence length and multimodal rope.
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
>>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
>>> messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
mm_token_type_ids=mm_token_type_ids,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
shift_labels = kwargs.pop("shift_labels", None)
loss = None
logits = None
token_accuracy = None
predicted_tokens = None
if skip_logits and labels is None and shift_labels is None:
raise ValueError("skip_logits is True, but labels and shift_labels are None")
if skip_logits is None:
skip_logits = self.training and (labels is not None or shift_labels is not None)
# Compute loss
if skip_logits:
result = LigerForCausalLMLoss(
hidden_states=hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
shift_labels=shift_labels,
hidden_size=_get_hidden_size(self.config),
**kwargs,
)
loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result)
else:
logits = self.lm_head(hidden_states)
loss = None
if labels is not None or shift_labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
shift_labels=shift_labels,
vocab_size=_get_vocab_size(self.config),
)
if not return_dict:
output_tuple = (logits,) + outputs[1:]
output = (loss,) + output_tuple if loss is not None else output_tuple
output = output + (token_accuracy,) if token_accuracy is not None else output
output = output + (predicted_tokens,) if predicted_tokens is not None else output
return output
# Return Qwen2VL output with token accuracy
return LigerQwen2VLCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=outputs.rope_deltas,
token_accuracy=token_accuracy,
predicted_tokens=predicted_tokens,
)
from typing import List
from typing import Optional
from typing import Union
import torch
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
def lce_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
skip_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> LigerCausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, Qwen3ForCausalLM
>>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B")
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
kept_hidden_states = hidden_states[:, slice_indices, :]
shift_labels = kwargs.pop("shift_labels", None)
# Remove output-control parameters that shouldn't be passed to loss functions
kwargs.pop("return_dict", None)
logits = None
loss = None
token_accuracy = None
predicted_tokens = None
if skip_logits and labels is None and shift_labels is None:
raise ValueError("skip_logits is True, but labels and shift_labels are None")
if skip_logits is None:
# By default, if in training mode, don't materialize logits
skip_logits = self.training and (labels is not None or shift_labels is not None)
# Compute loss
if skip_logits:
result = LigerForCausalLMLoss(
hidden_states=kept_hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
shift_labels=shift_labels,
hidden_size=self.config.hidden_size,
**kwargs,
)
loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result)
else:
logits = self.lm_head(kept_hidden_states)
if labels is not None or shift_labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
shift_labels=shift_labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
if not return_dict:
output = (logits,) + outputs[1:]
output = ((loss,) + output) if loss is not None else output
output = output + (token_accuracy,) if token_accuracy is not None else output
output = output + (predicted_tokens,) if predicted_tokens is not None else output
return output
# Return custom output class with accuracy field
return LigerCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
token_accuracy=token_accuracy,
predicted_tokens=predicted_tokens,
)
from typing import List
from typing import Optional
from typing import Union
import torch
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
from liger_kernel.transformers.model.output_classes import LigerQwen3_5CausalLMOutputWithPast
def lce_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
skip_logits: Optional[bool] = None,
**kwargs,
) -> LigerCausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, Qwen3_5ForCausalLM
>>> model = Qwen3_5ForCausalLM.from_pretrained("Qwen/Qwen3.5-9B")
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-9B")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
return_dict = kwargs.pop("return_dict", None)
if return_dict is None:
return_dict = self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
kept_hidden_states = hidden_states[:, slice_indices, :]
shift_labels = kwargs.pop("shift_labels", None)
logits = None
loss = None
token_accuracy = None
predicted_tokens = None
if skip_logits is None:
skip_logits = self.training and (labels is not None or shift_labels is not None)
if skip_logits:
result = LigerForCausalLMLoss(
hidden_states=kept_hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
shift_labels=shift_labels,
hidden_size=self.config.hidden_size,
**kwargs,
)
loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result)
else:
logits = self.lm_head(kept_hidden_states)
if labels is not None or shift_labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
shift_labels=shift_labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
if not return_dict:
output = (logits,) + outputs[1:]
output = ((loss,) + output) if loss is not None else output
output = output + (token_accuracy,) if token_accuracy is not None else output
output = output + (predicted_tokens,) if predicted_tokens is not None else output
return output
return LigerCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
token_accuracy=token_accuracy,
predicted_tokens=predicted_tokens,
)
def lce_forward_for_multimodal(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
mm_token_type_ids: Optional[torch.IntTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
skip_logits: Optional[bool] = None,
**kwargs,
) -> Union[tuple, LigerQwen3_5CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
Example:
```python
>>> from transformers import AutoProcessor, Qwen3_5ForConditionalGeneration
>>> model = Qwen3_5ForConditionalGeneration.from_pretrained("Qwen/Qwen3-VL-8B-Instruct")
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-8B-Instruct")
>>> messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg",
},
{"type": "text", "text": "Describe the image."},
],
}
]
>>> inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt"
)
>>> # Generate
>>> generated_ids = model.generate(**inputs, max_new_tokens=1024)
>>> generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
>>> output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
>>> print(output_text)
```
"""
return_dict = kwargs.pop("return_dict", None)
if return_dict is None:
return_dict = self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
mm_token_type_ids=mm_token_type_ids,
**kwargs,
)
hidden_states = outputs[0]
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
kept_hidden_states = hidden_states[:, slice_indices, :]
shift_labels = kwargs.pop("shift_labels", None)
logits = None
loss = None
token_accuracy = None
predicted_tokens = None
if skip_logits is None:
skip_logits = self.training and (labels is not None or shift_labels is not None)
if skip_logits:
result = LigerForCausalLMLoss(
hidden_states=kept_hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
shift_labels=shift_labels,
hidden_size=self.config.text_config.hidden_size,
**kwargs,
)
loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result)
else:
logits = self.lm_head(kept_hidden_states)
if labels is not None or shift_labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
shift_labels=shift_labels,
vocab_size=self.config.text_config.vocab_size,
**kwargs,
)
if not return_dict:
output = (logits,) + outputs[1:]
output = ((loss,) + output) if loss is not None else output
output = output + (token_accuracy,) if token_accuracy is not None else output
output = output + (predicted_tokens,) if predicted_tokens is not None else output
return output
return LigerQwen3_5CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=outputs.rope_deltas,
token_accuracy=token_accuracy,
predicted_tokens=predicted_tokens,
)
from typing import TYPE_CHECKING
from typing import List
from typing import Optional
from typing import Union
import torch
from transformers.modeling_outputs import MoeModelOutputWithPast
if TYPE_CHECKING:
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import load_balancing_loss_func
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
from liger_kernel.transformers.model.output_classes import LigerMoeCausalLMOutputWithPast
def lce_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
mm_token_type_ids: Optional[torch.IntTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
skip_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> LigerMoeCausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3.5-35B-A3B-Instruct")
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-35B-A3B-Instruct")
>>> prompt = "Give me a short introduction to large language model."
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_router_logits = (
output_router_logits if output_router_logits is not None else self.config.output_router_logits
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs: MoeModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits,
mm_token_type_ids=mm_token_type_ids,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
kept_hidden_states = hidden_states[:, slice_indices, :]
shift_labels = kwargs.pop("shift_labels", None)
logits = None
loss = None
token_accuracy = None
predicted_tokens = None
if skip_logits is None:
skip_logits = self.training and (labels is not None or shift_labels is not None)
if skip_logits:
result = LigerForCausalLMLoss(
hidden_states=kept_hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
shift_labels=shift_labels,
hidden_size=self.config.hidden_size,
**kwargs,
)
loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result)
else: # if in inference model materialize logits
logits = self.lm_head(kept_hidden_states)
if labels is not None or shift_labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
shift_labels=shift_labels,
vocab_size=self.vocab_size,
**kwargs,
)
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(
outputs.router_logits,
self.num_experts,
self.num_experts_per_tok,
attention_mask,
)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
if not return_dict:
output = (logits,) + outputs[1:]
output = ((aux_loss,) + output) if aux_loss is not None else output
output = ((loss,) + output) if loss is not None else output
output = output + (token_accuracy,) if token_accuracy is not None else output
output = output + (predicted_tokens,) if predicted_tokens is not None else output
return output
return LigerMoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
router_logits=outputs.router_logits,
token_accuracy=token_accuracy,
predicted_tokens=predicted_tokens,
)
from typing import List
from typing import Optional
from typing import Union
import torch
from transformers.modeling_outputs import MoeModelOutputWithPast
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
from liger_kernel.transformers.model.output_classes import LigerMoeCausalLMOutputWithPast
def lce_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
skip_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> LigerMoeCausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM
>>> model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_router_logits = (
output_router_logits if output_router_logits is not None else self.config.output_router_logits
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs: MoeModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
kept_hidden_states = hidden_states[:, slice_indices, :]
shift_labels = kwargs.pop("shift_labels", None)
logits = None
loss = None
token_accuracy = None
predicted_tokens = None
if skip_logits is None:
skip_logits = self.training and (labels is not None or shift_labels is not None)
# Compute loss
if skip_logits:
result = LigerForCausalLMLoss(
hidden_states=kept_hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
shift_labels=shift_labels,
hidden_size=self.config.hidden_size,
**kwargs,
)
loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result)
else: # if in inference model materialize logits
logits = self.lm_head(kept_hidden_states)
if labels is not None or shift_labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
shift_labels=shift_labels,
vocab_size=self.vocab_size,
**kwargs,
)
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(
outputs.router_logits,
self.num_experts,
self.num_experts_per_tok,
attention_mask,
)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
if not return_dict:
output = (logits,) + outputs[1:]
output = ((aux_loss,) + output) if aux_loss is not None else output
output = ((loss,) + output) if loss is not None else output
output = output + (token_accuracy,) if token_accuracy is not None else output
output = output + (predicted_tokens,) if predicted_tokens is not None else output
return output
# Return custom output class with accuracy field
return LigerMoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
router_logits=outputs.router_logits,
token_accuracy=token_accuracy,
predicted_tokens=predicted_tokens,
)
from typing import TYPE_CHECKING
from typing import List
from typing import Optional
from typing import Union
import torch
from transformers.modeling_outputs import MoeModelOutputWithPast
if TYPE_CHECKING:
from transformers.models.qwen3_next.modeling_qwen3_next import load_balancing_loss_func
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
from liger_kernel.transformers.model.output_classes import LigerMoeCausalLMOutputWithPast
def lce_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
skip_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> LigerMoeCausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Next-80B-A3B-Instruct")
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Next-80B-A3B-Instruct")
>>> prompt = "Give me a short introduction to large language model."
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_router_logits = (
output_router_logits if output_router_logits is not None else self.config.output_router_logits
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs: MoeModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
kept_hidden_states = hidden_states[:, slice_indices, :]
shift_labels = kwargs.pop("shift_labels", None)
logits = None
loss = None
token_accuracy = None
predicted_tokens = None
if skip_logits is None:
skip_logits = self.training and (labels is not None or shift_labels is not None)
if skip_logits:
result = LigerForCausalLMLoss(
hidden_states=kept_hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
shift_labels=shift_labels,
hidden_size=self.config.hidden_size,
**kwargs,
)
loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result)
else: # if in inference model materialize logits
logits = self.lm_head(kept_hidden_states)
if labels is not None or shift_labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
shift_labels=shift_labels,
vocab_size=self.vocab_size,
**kwargs,
)
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(
outputs.router_logits,
self.num_experts,
self.num_experts_per_tok,
attention_mask,
)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
if not return_dict:
output = (logits,) + outputs[1:]
output = ((aux_loss,) + output) if aux_loss is not None else output
output = ((loss,) + output) if loss is not None else output
output = output + (token_accuracy,) if token_accuracy is not None else output
output = output + (predicted_tokens,) if predicted_tokens is not None else output
return output
return LigerMoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
router_logits=outputs.router_logits,
token_accuracy=token_accuracy,
predicted_tokens=predicted_tokens,
)
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import torch
from transformers.utils import can_return_tuple
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
from liger_kernel.transformers.model.output_classes import LigerQwen3VLCausalLMOutputWithPast
@can_return_tuple
def lce_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
mm_token_type_ids: Optional[torch.IntTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
skip_logits: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, LigerQwen3VLCausalLMOutputWithPast]:
"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)):
The tensors corresponding to the input videos. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`Qwen2_5_VLImageProcessor.__call__`] for details. [`Qwen2_5_VLProcessor`] uses
[`Qwen2_5_VLImageProcessor`] for processing videos.
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
The rope index difference between sequence length and multimodal rope.
second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
>>> model = Qwen3VLForConditionalGeneration.from_pretrained("Qwen/Qwen3-VL")
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL")
>>> messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
mm_token_type_ids=mm_token_type_ids,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
shift_labels = kwargs.pop("shift_labels", None)
loss = None
logits = None
token_accuracy = None
predicted_tokens = None
if skip_logits and labels is None and shift_labels is None:
raise ValueError("skip_logits is True, but labels and shift_labels are None")
if skip_logits is None:
skip_logits = self.training and (labels is not None or shift_labels is not None)
if skip_logits:
result = LigerForCausalLMLoss(
hidden_states=hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
shift_labels=shift_labels,
hidden_size=self.config.text_config.hidden_size,
**kwargs,
)
loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result)
else:
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
if not return_dict:
output = (logits,) + outputs[1:]
output = (loss,) + output if loss is not None else output
output = output + (token_accuracy,) if token_accuracy is not None else output
output = output + (predicted_tokens,) if predicted_tokens is not None else output
return output
return LigerQwen3VLCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=outputs.rope_deltas,
token_accuracy=token_accuracy,
predicted_tokens=predicted_tokens,
)
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import torch
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import load_balancing_loss_func
from transformers.utils import can_return_tuple
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
from liger_kernel.transformers.model.output_classes import LigerQwen3VLMoeCausalLMOutputWithPast
@can_return_tuple
def lce_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
mm_token_type_ids: Optional[torch.IntTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
skip_logits: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, LigerQwen3VLMoeCausalLMOutputWithPast]:
"""
Qwen3-VL-MoE forward with fused linear cross entropy support mirroring Qwen3-VL behaviour.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
mm_token_type_ids=mm_token_type_ids,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
shift_labels = kwargs.pop("shift_labels", None)
loss = None
logits = None
token_accuracy = None
predicted_tokens = None
if skip_logits and labels is None and shift_labels is None:
raise ValueError("skip_logits is True, but labels and shift_labels are None")
if skip_logits is None:
skip_logits = self.training and (labels is not None or shift_labels is not None)
if skip_logits:
result = LigerForCausalLMLoss(
hidden_states=hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
shift_labels=shift_labels,
hidden_size=self.config.text_config.hidden_size,
**kwargs,
)
loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result)
else:
logits = self.lm_head(hidden_states)
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
# Compute auxiliary load-balancing loss for MoE when requested
aux_loss = None
if kwargs.get("output_router_logits", False):
aux_loss = load_balancing_loss_func(
outputs.router_logits,
self.config.text_config.num_experts,
self.config.text_config.num_experts_per_tok,
attention_mask,
)
# If we computed training loss, add the scaled aux loss to it
if loss is not None and aux_loss is not None:
loss = loss + self.config.text_config.router_aux_loss_coef * aux_loss.to(loss.device)
if not return_dict:
output = (logits,) + outputs[1:]
output = (loss,) + output if loss is not None else output
output = output + (aux_loss,) if aux_loss is not None else output
output = output + (token_accuracy,) if token_accuracy is not None else output
output = output + (predicted_tokens,) if predicted_tokens is not None else output
return output
return LigerQwen3VLMoeCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=outputs.rope_deltas,
aux_loss=aux_loss,
token_accuracy=token_accuracy,
predicted_tokens=predicted_tokens,
)
from typing import TYPE_CHECKING
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import torch
from torch.distributed.fsdp import FullyShardedDataParallel
from liger_kernel.transformers.fsdp import _FSDPForwardRedirection
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
from liger_kernel.utils import PEFT_AVAILABLE
if TYPE_CHECKING:
from transformers.cache_utils import Cache
if PEFT_AVAILABLE:
from peft.utils.other import ModulesToSaveWrapper
def lce_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union["Cache", List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
skip_logits: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, LigerCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, Smollm3ForCausalLM
>>> model = Smollm3ForCausalLM.from_pretrained("HuggingFaceTB/SmolLM3-3B")
>>> tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM3-3B")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
kept_hidden_states = hidden_states[:, slice_indices, :]
shift_labels = kwargs.pop("shift_labels", None)
logits = None
loss = None
token_accuracy = None
predicted_tokens = None
# if in training mode, don't materialize logits
if skip_logits and labels is None and shift_labels is None:
raise ValueError("skip_logits is True, but labels and shift_labels are None")
if skip_logits is None:
# By default, if in training mode, don't materialize logits
skip_logits = self.training and (labels is not None or shift_labels is not None)
# Compute loss
if skip_logits:
result = lce_maybe_trainable_lm_head(
self,
hidden_states=kept_hidden_states,
hidden_size=self.config.hidden_size,
labels=labels,
shift_labels=shift_labels,
**kwargs,
)
loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result)
else:
logits = self.lm_head(kept_hidden_states)
if labels is not None or shift_labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
shift_labels=shift_labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
if not return_dict:
output_tuple = (logits,) + outputs[1:]
output = (loss,) + output_tuple if loss is not None else output_tuple
output = output + (token_accuracy,) if token_accuracy is not None else output
output = output + (predicted_tokens,) if predicted_tokens is not None else output
return output
# Return custom output class with token_accuracy field
return LigerCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
token_accuracy=token_accuracy,
predicted_tokens=predicted_tokens,
)
def lce_maybe_trainable_lm_head(self, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs):
lm_head = self.lm_head
# Unwrap the module if lm_head has been added as trainable module in PEFT LoRA configuration,
# i.e. listed in the modules_to_save field of LoraConfig, so the lm_head weights are read
# from the unwrapped module.
# See https://huggingface.co/docs/peft/package_reference/lora for reference.
if PEFT_AVAILABLE and isinstance(lm_head, ModulesToSaveWrapper):
lm_head = lm_head.modules_to_save.default
# If FSDP is used and lm_head is trainable, e.g., during full fine-tuning or with LoRA,
# reading the lm_head module weights and calling the kernel must be done within FSDP forward pass
# so the module entire parameters are summoned and kept in memory during the kernel execution.
if isinstance(lm_head, FullyShardedDataParallel):
return _FSDPForwardRedirection()(
lm_head,
_liger_for_causal_lm_loss,
lm_head.module,
hidden_states,
hidden_size,
labels,
shift_labels,
**loss_kwargs,
)
# FSDP is not used so we can read the lm_head weights and call the kernel directly
return _liger_for_causal_lm_loss(
lm_head=self.lm_head,
hidden_states=hidden_states,
hidden_size=hidden_size,
labels=labels,
shift_labels=shift_labels,
**loss_kwargs,
)
def _liger_for_causal_lm_loss(lm_head, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs):
return LigerForCausalLMLoss(
hidden_states=hidden_states,
lm_head_weight=lm_head.weight,
labels=labels,
hidden_size=hidden_size,
shift_labels=shift_labels,
**loss_kwargs,
)
from typing import TYPE_CHECKING
from typing import Optional
from typing import Union
import torch
from transformers.models.smolvlm.modeling_smolvlm import SmolVLMCausalLMOutputWithPast
from transformers.processing_utils import Unpack
from transformers.utils.generic import can_return_tuple
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
if TYPE_CHECKING:
from transformers.cache_utils import Cache
from transformers.utils.generic import TransformersKwargs
# Forward adapted to enable fused Linear + CE without materializing logits.
# Mirrors the pattern used for other multimodal models (e.g., InternVL, LLaVA).
@can_return_tuple
def lce_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional["Cache"] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_attention_mask: Optional[torch.BoolTensor] = None,
image_hidden_states: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
skip_logits: Optional[bool] = None, # Added argument for liger-kernel
**lm_kwargs: Unpack["TransformersKwargs"], # renamed from kwargs
) -> Union[tuple, SmolVLMCausalLMOutputWithPast]:
r"""
pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
Mask to avoid performing attention on padding pixel indices.
image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The hidden states of the image encoder after modality projection.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or `model.image_token_id`. Tokens with indices set to `model.image_token_id` are
ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Example:
```python
>>> import requests
>>> import torch
>>> from PIL import Image
>>> from io import BytesIO
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
>>> from transformers.image_utils import load_image
>>> # Note that passing the image urls (instead of the actual pil images) to the processor is also possible
>>> image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")
>>> image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg")
>>> image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg")
>>> processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct")
>>> model = AutoModelForImageTextToText.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct", dtype=torch.bfloat16, device_map="auto")
>>> # Create inputs
>>> messages = [
... {
... "role": "user",
... "content": [
... {"type": "video", "path": path/to/video},
... {"type": "text", "text": "What is happening in this video?"},
... ]
... }
... ]
>>> inputs = processor.apply_chat_template([messages], add_generation_prompt=True)
>>> # Generate
>>> generated_ids = model.generate(**inputs, max_new_tokens=256)
>>> generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
>>> print(generated_texts)
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
pixel_values=pixel_values,
pixel_attention_mask=pixel_attention_mask,
image_hidden_states=image_hidden_states,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
return_dict=True,
**lm_kwargs,
)
# Copied from llava.py
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
kept_hidden_states = hidden_states[:, slice_indices, :]
shift_labels = lm_kwargs.pop("shift_labels", None)
logits = None
loss = None
if skip_logits and labels is None and shift_labels is None:
raise ValueError("skip_logits is True, but labels and shift_labels are None")
if skip_logits is None:
# By default, if in training mode, don't materialize logits
skip_logits = self.training and (labels is not None or shift_labels is not None)
if skip_logits:
loss = LigerForCausalLMLoss(
hidden_states=kept_hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
shift_labels=shift_labels,
hidden_size=self.config.text_config.hidden_size,
**lm_kwargs,
)
else:
logits = self.lm_head(kept_hidden_states)
if labels is not None or shift_labels is not None:
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **lm_kwargs
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return SmolVLMCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=outputs.image_hidden_states,
)
import inspect
import logging
from functools import partial
from types import MethodType
from typing import Callable
from typing import Optional
import transformers
from packaging import version
from transformers import PreTrainedModel
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.model.falcon_h1 import lce_forward as falcon_h1_lce_forward
from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
from liger_kernel.transformers.model.gpt_oss import lce_forward as gpt_oss_lce_forward
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
from liger_kernel.transformers.model.llava import lce_forward as llava_lce_forward
from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
from liger_kernel.transformers.model.smollm3 import lce_forward as smollm3_lce_forward
from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.rope import liger_rotary_pos_emb_vision
from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
from liger_kernel.transformers.swiglu import LigerExperts
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
try:
import peft
PEFT_AVAILABLE = True
except ImportError:
PEFT_AVAILABLE = False
transformer_version = version.parse(transformers.__version__)
logger = logging.getLogger(__name__)
MIN_SUPPORTED_TRANSFORMERS_VERSION = version.parse("4.52.0")
if transformer_version < MIN_SUPPORTED_TRANSFORMERS_VERSION:
raise ImportError(
f"liger-kernel requires transformers >= {MIN_SUPPORTED_TRANSFORMERS_VERSION}, got {transformers.__version__}. "
"Please install an older version of liger-kernel that is compatible with your transformers version."
)
IS_TRANSFORMERS_V5_OR_LATER = version.parse(transformers.__version__) >= version.parse("5.0.0")
def _bind_method_to_module(module, method_name: str, new_method: Callable):
# Binds a new method to a module instance so that self is passed as the first argument
module.__dict__[method_name] = new_method.__get__(module, module.__class__)
def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True, row_mode=None):
# Check if the module is a PEFT ModulesToSaveWrapper
# If it is, we need to patch the modules_to_save.default and original_modules
if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
module.modules_to_save.default.offset = offset
module.modules_to_save.default.casting_mode = casting_mode
module.modules_to_save.default.variance_epsilon = (
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
)
module.modules_to_save.default.in_place = in_place
module.modules_to_save.default.row_mode = row_mode
module.original_module.offset = offset
module.original_module.casting_mode = casting_mode
module.original_module.variance_epsilon = (
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
)
module.original_module.in_place = in_place
module.original_module.row_mode = row_mode
_bind_method_to_module(module.modules_to_save.default, "forward", LigerRMSNorm.forward)
_bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
_bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
_bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
_bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerRMSNorm.__name__)
_bind_method_to_module(module.original_module, "_get_name", lambda self: LigerRMSNorm.__name__)
else:
module.offset = offset
module.casting_mode = casting_mode
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
module.in_place = in_place
module.row_mode = row_mode
_bind_method_to_module(module, "forward", LigerRMSNorm.forward)
_bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
_bind_method_to_module(module, "_get_name", lambda self: LigerRMSNorm.__name__)
def _patch_layer_norm_module(module, eps=1e-6):
# Check if the module is a PEFT ModulesToSaveWrapper
# If it is, we need to patch the modules_to_save.default and original_modules
if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
module.hidden_size = module.normalized_shape
_bind_method_to_module(module, "forward", LigerLayerNorm.forward)
_bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
module.modules_to_save.default.variance_epsilon = (
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
)
module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
module, "normalized_shape", None
)
module.original_module.variance_epsilon = (
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
)
module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
module, "normalized_shape", None
)
_bind_method_to_module(module.modules_to_save.default, "forward", LigerLayerNorm.forward)
_bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerLayerNorm.extra_repr)
_bind_method_to_module(module.original_module, "forward", LigerLayerNorm.forward)
_bind_method_to_module(module.original_module, "extra_repr", LigerLayerNorm.extra_repr)
_bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerLayerNorm.__name__)
_bind_method_to_module(module.original_module, "_get_name", lambda self: LigerLayerNorm.__name__)
else:
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
_bind_method_to_module(module, "forward", LigerLayerNorm.forward)
_bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
_bind_method_to_module(module, "_get_name", lambda self: LigerLayerNorm.__name__)
def _patch_swiglu_module(module, liger_module):
_bind_method_to_module(module, "forward", liger_module.forward)
_bind_method_to_module(module, "_get_name", lambda self: liger_module.__name__)
def _patch_geglu_module(module):
_bind_method_to_module(module, "forward", LigerGEGLUMLP.forward)
_bind_method_to_module(module, "_get_name", lambda self: LigerGEGLUMLP.__name__)
def apply_liger_kernel_to_granite(
rope: bool = True,
cross_entropy: bool = True,
fused_linear_cross_entropy: bool = False,
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Granite 3 models
Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is False.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
Debugging notes:
If LigerSwiGLUMLP is OK for Llama, it should be fine for Granite, but it's not.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
from transformers.models.granite import modeling_granite
from transformers.models.granite.modeling_granite import GraniteModel
if swiglu:
modeling_granite.GraniteMLP = LigerSwiGLUMLP
if rms_norm:
modeling_granite.GraniteRMSNorm = LigerRMSNorm
if rope:
modeling_granite.apply_rotary_pos_emb = liger_rotary_pos_emb
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
raise NotImplementedError("LigerFusedLinearCrossEntropy is not available for Granite models.")
# NOTE: Granite model `GraniteForCausalLM.forward` scales logits each
# call, so we can't sidestep logit materialization. A bit more work
# would be needed to add a scaling term to the `LigerFusedLinearCrossEntropyFunction`
# for the logit output.
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules (e.g. GraniteRMSNorm or GraniteMLP)
# get the base model from the model instance
base_model: GraniteModel = getattr(model, model.base_model_prefix, model)
if rms_norm:
_patch_rms_norm_module(base_model.norm)
for decoder_layer in base_model.layers:
if swiglu:
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
def apply_liger_kernel_to_llama(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
from transformers.models.llama import modeling_llama
from transformers.models.llama.modeling_llama import LlamaModel
if rope:
modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb
if rms_norm:
modeling_llama.LlamaRMSNorm = LigerRMSNorm
if swiglu:
modeling_llama.LlamaMLP = LigerSwiGLUMLP
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(llama_lce_forward, model)
else:
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
# get the base model from the model instance
base_model: LlamaModel = getattr(model, model.base_model_prefix, model)
if rms_norm:
_patch_rms_norm_module(base_model.norm)
for decoder_layer in base_model.layers:
if swiglu:
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
def apply_liger_kernel_to_smollm3(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace SmolLM3 model
Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
from transformers.models.smollm3 import modeling_smollm3
from transformers.models.smollm3.modeling_smollm3 import SmolLM3Model
if rope:
modeling_smollm3.apply_rotary_pos_emb = liger_rotary_pos_emb
if rms_norm:
modeling_smollm3.SmolLM3RMSNorm = LigerRMSNorm
if swiglu:
modeling_smollm3.SmolLM3MLP = LigerSwiGLUMLP
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(smollm3_lce_forward, model)
else:
modeling_smollm3.SmolLM3ForCausalLM.forward = smollm3_lce_forward
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules (e.g. SmolLM3RMSNorm or SmolLM3MLP)
# get the base model from the model instance
base_model: SmolLM3Model = getattr(model, model.base_model_prefix, model)
if rms_norm:
_patch_rms_norm_module(base_model.norm)
for decoder_layer in base_model.layers:
if swiglu:
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
def apply_liger_kernel_to_llava(
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
model: PreTrainedModel = None,
**kwargs,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Llava models.
Due to the characteristics of LlaVa, the model must be passed to apply Liger-Kernel's patch to other models connected to LLaVa.
However, if an LM not supported by Liger-Kernel is connected to LLaVa, unexpected side effects may occur.
NOTE: Llava is not available in transformers<4.36.0
Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
from transformers.models.llava import modeling_llava
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(llava_lce_forward, model)
else:
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
if model is not None:
text_model_name, vision_model_name = model.config.text_config.model_type, model.config.vision_config.model_type
text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
vision_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(vision_model_name, None)
kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs}
if text_liger_fn:
accept_params = inspect.signature(text_liger_fn).parameters
remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
if remain_params:
logger.warning(
f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
)
text_kwargs["model"] = model.model.language_model
text_liger_fn(**text_kwargs)
elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
logger.warning(f"{text_model_name} is not supported by Liger kernel.")
if vision_liger_fn:
accept_params = inspect.signature(vision_liger_fn).parameters
remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
vision_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
if remain_params:
logger.warning(
f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n"
f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}"
)
vision_kwargs["model"] = model.model.vision_tower
vision_liger_fn(**vision_kwargs)
elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
def apply_liger_kernel_to_llama4(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
layer_norm: bool = True,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Llama4 models.
Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
from transformers.models.llama4 import modeling_llama4
from transformers.models.llama4.modeling_llama4 import Llama4ForCausalLM
from transformers.models.llama4.modeling_llama4 import Llama4ForConditionalGeneration
from transformers.models.llama4.modeling_llama4 import Llama4TextModel
from transformers.models.llama4.modeling_llama4 import Llama4VisionModel
from liger_kernel.transformers.model.llama4 import lce_forward as llama4_lce_forward
if rope:
from liger_kernel.transformers.llama4_rope import apply_liger_llama4_rope_full
apply_liger_llama4_rope_full(modeling_llama4)
if rms_norm:
modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm
if swiglu:
modeling_llama4.Llama4TextMLP = LigerSwiGLUMLP
if cross_entropy:
modeling_llama4.CrossEntropyLoss = LigerCrossEntropyLoss
if fused_linear_cross_entropy:
modeling_llama4.Llama4ForCausalLM.forward = llama4_lce_forward
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules
if isinstance(model, Llama4ForConditionalGeneration):
language_model: Llama4ForCausalLM = model.language_model
vision_model: Llama4VisionModel = model.vision_model
text_model: Llama4TextModel = language_model.model
elif isinstance(model, Llama4ForCausalLM):
text_model = model.model
vision_model = None
elif isinstance(model, Llama4TextModel):
text_model = model
vision_model = None
else:
raise ValueError(f"Unsupported Llama4 model type: {type(model)}")
if text_model:
if rms_norm:
_patch_rms_norm_module(text_model.norm)
for decoder_layer in text_model.layers:
if swiglu:
if decoder_layer.is_moe_layer:
_patch_swiglu_module(decoder_layer.feed_forward.shared_expert, LigerSwiGLUMLP)
else:
_patch_swiglu_module(decoder_layer.feed_forward, LigerSwiGLUMLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
if vision_model:
_patch_layer_norm_module(vision_model.layernorm_pre)
_patch_layer_norm_module(vision_model.layernorm_post)
for layer in vision_model.model.layers:
if layer_norm:
_patch_layer_norm_module(layer.input_layernorm)
_patch_layer_norm_module(layer.post_attention_layernorm)
def apply_liger_kernel_to_mllama(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
layer_norm: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace MLlama models.
NOTE: MLlama is not available in transformers<4.45.0
Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
from transformers.models.mllama import modeling_mllama
from transformers.models.mllama.modeling_mllama import MllamaForCausalLM
from transformers.models.mllama.modeling_mllama import MllamaForConditionalGeneration
from transformers.models.mllama.modeling_mllama import MllamaTextModel
from transformers.models.mllama.modeling_mllama import MllamaVisionModel
from liger_kernel.transformers.model.mllama import lce_forward as mllama_lce_forward
if rope:
modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
if layer_norm and model is None:
modeling_mllama.nn.LayerNorm = LigerLayerNorm
if rms_norm:
modeling_mllama.MllamaTextRMSNorm = LigerRMSNorm
if swiglu:
modeling_mllama.MllamaTextMLP = LigerSwiGLUMLP
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(mllama_lce_forward, model)
else:
modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules
if isinstance(model, MllamaForConditionalGeneration):
language_model: MllamaForCausalLM = model.model.language_model
vision_model: MllamaVisionModel = model.model.vision_model
if isinstance(language_model, MllamaForCausalLM):
text_model: MllamaTextModel = language_model.model
else:
text_model = language_model
elif isinstance(model, MllamaForCausalLM):
text_model = model.model
vision_model = None
elif isinstance(model, MllamaTextModel):
text_model = model
vision_model = None
else:
raise ValueError(f"Unsupported Mllama model type: {type(model)}")
if text_model:
if rms_norm:
_patch_rms_norm_module(text_model.norm)
for decoder_layer in text_model.layers:
if swiglu:
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
if vision_model:
_patch_layer_norm_module(vision_model.layernorm_pre)
_patch_layer_norm_module(vision_model.layernorm_post)
for layer in vision_model.transformer.layers:
if layer_norm:
_patch_layer_norm_module(layer.input_layernorm)
_patch_layer_norm_module(layer.post_attention_layernorm)
for layer in vision_model.global_transformer.layers:
if layer_norm:
_patch_layer_norm_module(layer.input_layernorm)
_patch_layer_norm_module(layer.post_attention_layernorm)
def apply_liger_kernel_to_mistral(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Mistral models
Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
from transformers.models.mistral import modeling_mistral
from transformers.models.mistral.modeling_mistral import MistralModel
if rope:
modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb
if rms_norm:
modeling_mistral.MistralRMSNorm = LigerRMSNorm
if cross_entropy:
modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(mistral_lce_forward, model)
else:
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
if swiglu:
modeling_mistral.MistralMLP = LigerSwiGLUMLP
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules
# get the base model from the model instance
base_model: MistralModel = getattr(model, model.base_model_prefix, model)
if rms_norm:
_patch_rms_norm_module(base_model.norm)
for decoder_layer in base_model.layers:
if swiglu:
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
def apply_liger_kernel_to_mixtral(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Mixtral models
Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
from transformers.models.mixtral import modeling_mixtral
from transformers.models.mixtral.modeling_mixtral import MixtralModel
if rope:
modeling_mixtral.apply_rotary_pos_emb = liger_rotary_pos_emb
if rms_norm:
modeling_mixtral.MixtralRMSNorm = LigerRMSNorm
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(mixtral_lce_forward, model)
else:
modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
if swiglu:
if IS_TRANSFORMERS_V5_OR_LATER:
modeling_mixtral.MixtralExperts = LigerExperts
else:
modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules
# get the base model from the model instance
base_model: MixtralModel = getattr(model, model.base_model_prefix, model)
if rms_norm:
_patch_rms_norm_module(base_model.norm)
for decoder_layer in base_model.layers:
if swiglu:
if IS_TRANSFORMERS_V5_OR_LATER:
_patch_swiglu_module(decoder_layer.mlp.experts, LigerExperts)
else:
for expert in decoder_layer.block_sparse_moe.experts:
_patch_swiglu_module(expert, LigerBlockSparseTop2MLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
def apply_liger_kernel_to_pixtral(
rope: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Pixtral vision models.
Note: Pixtral's vision encoder does not have a cross-entropy loss, so there is no
`fused_linear_cross_entropy` or `cross_entropy` option. The language model side of
Pixtral uses Mistral, which can be patched separately via `apply_liger_kernel_to_mistral`.
Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model
has already been loaded. Default is None.
"""
from transformers.models.pixtral import modeling_pixtral
from transformers.models.pixtral.modeling_pixtral import PixtralVisionModel
if rope:
modeling_pixtral.apply_rotary_pos_emb = liger_rotary_pos_emb
if rms_norm:
modeling_pixtral.PixtralRMSNorm = LigerRMSNorm
if swiglu:
modeling_pixtral.PixtralMLP = LigerSwiGLUMLP
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules.
if isinstance(model, PixtralVisionModel):
transformer = model.transformer
else:
raise ValueError(f"Unsupported Pixtral model type: {type(model)}")
if rms_norm:
_patch_rms_norm_module(model.ln_pre, eps=1e-5)
for layer in transformer.layers:
if swiglu:
_patch_swiglu_module(layer.feed_forward, LigerSwiGLUMLP)
if rms_norm:
_patch_rms_norm_module(layer.attention_norm, eps=1e-5)
_patch_rms_norm_module(layer.ffn_norm, eps=1e-5)
def apply_liger_kernel_to_gemma(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
geglu: bool = True,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Gemma
(Gemma 1 and 1.1 supported, for Gemma2 please use `apply_liger_kernel_to_gemma2` ) to make GPU go burrr.
Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
from transformers.models.gemma import modeling_gemma
from transformers.models.gemma.modeling_gemma import GemmaModel
from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma
_patch_rms_norm_module_for_gemma = partial(_patch_rms_norm_module, casting_mode="gemma", offset=1.0)
if rope:
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
if rms_norm:
modeling_gemma.GemmaRMSNorm = LigerRMSNormForGemma
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if geglu:
modeling_gemma.GemmaMLP = LigerGEGLUMLP
if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(gemma_lce_forward, model)
else:
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules
# get the base model from the model instance
base_model: GemmaModel = getattr(model, model.base_model_prefix, model)
if rms_norm:
_patch_rms_norm_module_for_gemma(base_model.norm)
for decoder_layer in base_model.layers:
if geglu:
_patch_geglu_module(decoder_layer.mlp)
if rms_norm:
_patch_rms_norm_module_for_gemma(decoder_layer.input_layernorm)
_patch_rms_norm_module_for_gemma(decoder_layer.post_attention_layernorm)
def apply_liger_kernel_to_gemma2(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
geglu: bool = True,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Gemma2
(for Gemma1 please use `apply_liger_kernel_to_gemma`) to make GPU go burrr.
Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
from transformers.models.gemma2 import modeling_gemma2
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma2
_patch_rms_norm_module_for_gemma2 = partial(
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
)
if rope:
modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
if rms_norm:
# https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
modeling_gemma2.Gemma2RMSNorm = LigerRMSNormForGemma2
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(gemma2_lce_forward, model)
else:
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
if geglu:
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules
# get the base model from the model instance
base_model: Gemma2Model = getattr(model, model.base_model_prefix, model)
if rms_norm:
_patch_rms_norm_module_for_gemma2(base_model.norm)
for decoder_layer in base_model.layers:
if geglu:
_patch_geglu_module(decoder_layer.mlp)
if rms_norm:
_patch_rms_norm_module_for_gemma2(decoder_layer.input_layernorm)
_patch_rms_norm_module_for_gemma2(decoder_layer.post_attention_layernorm)
_patch_rms_norm_module_for_gemma2(decoder_layer.pre_feedforward_layernorm)
_patch_rms_norm_module_for_gemma2(decoder_layer.post_feedforward_layernorm)
def apply_liger_kernel_to_gemma3_text(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
geglu: bool = True,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Gemma3
Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
from transformers.models.gemma3 import modeling_gemma3
from transformers.models.gemma3.modeling_gemma3 import Gemma3DecoderLayer
from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM
from transformers.models.gemma3.modeling_gemma3 import Gemma3TextModel
from liger_kernel.transformers.model.gemma3 import causal_forward
from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma3
_patch_rms_norm_module_for_gemma3 = partial(
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
)
if rope:
modeling_gemma3.apply_rotary_pos_emb = liger_rotary_pos_emb
if rms_norm:
modeling_gemma3.Gemma3RMSNorm = LigerRMSNormForGemma3
if geglu:
modeling_gemma3.Gemma3MLP = LigerGEGLUMLP
# Handle loss function
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(causal_forward, model)
else:
modeling_gemma3.Gemma3ForCausalLM.forward = causal_forward
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules
if isinstance(model, Gemma3ForCausalLM) or isinstance(model, Gemma3TextModel):
# get the base model from the model instance
base_model = model.model if isinstance(model, Gemma3ForCausalLM) else model
if rms_norm:
_patch_rms_norm_module_for_gemma3(base_model.norm)
for decoder_layer in base_model.layers:
decoder_layer: Gemma3DecoderLayer
if geglu:
_bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward)
if rms_norm:
_patch_rms_norm_module_for_gemma3(decoder_layer.input_layernorm)
_patch_rms_norm_module_for_gemma3(decoder_layer.post_attention_layernorm)
_patch_rms_norm_module_for_gemma3(decoder_layer.pre_feedforward_layernorm)
_patch_rms_norm_module_for_gemma3(decoder_layer.post_feedforward_layernorm)
_patch_rms_norm_module_for_gemma3(decoder_layer.self_attn.q_norm)
_patch_rms_norm_module_for_gemma3(decoder_layer.self_attn.k_norm)
else:
raise TypeError("The model must be Gemma3ForCausalLM.")
def apply_liger_kernel_to_gemma3(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
layer_norm: bool = True,
rms_norm: bool = True,
geglu: bool = True,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Gemma3
Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
from transformers.models.gemma3 import modeling_gemma3
from transformers.models.gemma3.modeling_gemma3 import Gemma3ForConditionalGeneration
from transformers.models.siglip import modeling_siglip
from transformers.models.siglip.modeling_siglip import SiglipEncoderLayer
from transformers.models.siglip.modeling_siglip import SiglipVisionModel
from liger_kernel.transformers.model.gemma3 import multimodal_forward
_patch_rms_norm_module_for_gemma3 = partial(
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
)
if layer_norm and model is None:
modeling_siglip.nn.LayerNorm = LigerLayerNorm
apply_liger_kernel_to_gemma3_text(
rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
)
if cross_entropy:
modeling_gemma3.nn.CrossEntropyLoss = LigerCrossEntropyLoss
if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(multimodal_forward, model)
else:
modeling_gemma3.Gemma3ForConditionalGeneration.forward = multimodal_forward
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules
if isinstance(model, Gemma3ForConditionalGeneration):
if isinstance(model.model.vision_tower, SiglipVisionModel):
vision_tower = model.model.vision_tower
_patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
for layer in vision_tower.vision_model.encoder.layers:
layer: SiglipEncoderLayer
if layer_norm:
_patch_layer_norm_module(layer.layer_norm1)
_patch_layer_norm_module(layer.layer_norm2)
else:
raise TypeError("The vision tower must be SiglipVisionModel")
if rms_norm:
_patch_rms_norm_module_for_gemma3(model.model.multi_modal_projector.mm_soft_emb_norm)
apply_liger_kernel_to_gemma3_text(
rope=rope,
cross_entropy=False,
fused_linear_cross_entropy=False,
rms_norm=rms_norm,
geglu=geglu,
model=model.model.language_model,
)
else:
raise TypeError("The model must be Gemma3ForConditionalGeneration.")
def apply_liger_kernel_to_paligemma(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
layer_norm: bool = True,
rms_norm: bool = True,
geglu: bool = True,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace PaliGemma
Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
# PaliGemma submodules are ['vision_tower', 'multi_modal_projector', 'language_model']
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
from transformers.models.gemma.modeling_gemma import GemmaModel
from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
from transformers.models.paligemma import modeling_paligemma
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
from transformers.models.siglip import modeling_siglip
from transformers.models.siglip.modeling_siglip import SiglipEncoderLayer
from transformers.models.siglip.modeling_siglip import SiglipVisionModel
from liger_kernel.transformers.model.paligemma import lce_forward
# The vision_tower is a SiglipVisionModel
if layer_norm and model is None:
modeling_siglip.nn.LayerNorm = LigerLayerNorm
# SiglipMLP is standard FFN so LigerGEGLUMLP is not compatible
# The multi_modal_projector is Linear, nothing to do
# The language_model is GemmaForCausalLM or Gemma2ForCausalLM
apply_liger_kernel_to_gemma(
rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
)
apply_liger_kernel_to_gemma2(
rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
)
# Handle loss function
if cross_entropy:
modeling_paligemma.nn.CrossEntropyLoss = LigerCrossEntropyLoss
if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(lce_forward, model)
else:
modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules
if not isinstance(model, PaliGemmaForConditionalGeneration):
raise TypeError("model have to be of type PaliGemmaForConditionalGeneration")
vision_tower: SiglipVisionModel = model.model.vision_tower
_patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
for layer in vision_tower.vision_model.encoder.layers:
layer: SiglipEncoderLayer
if layer_norm:
_patch_layer_norm_module(layer.layer_norm1)
_patch_layer_norm_module(layer.layer_norm2)
language_model = model.model.language_model
if isinstance(language_model, (GemmaForCausalLM, GemmaModel)):
apply_liger_kernel_to_gemma(
rope=rope,
cross_entropy=False,
fused_linear_cross_entropy=False,
rms_norm=rms_norm,
geglu=geglu,
model=language_model,
)
elif isinstance(language_model, (Gemma2ForCausalLM, Gemma2Model)):
apply_liger_kernel_to_gemma2(
rope=rope,
cross_entropy=False,
fused_linear_cross_entropy=False,
rms_norm=rms_norm,
geglu=geglu,
model=language_model,
)
else:
raise TypeError(
"The language_model of a PaliGemma model must be either GemmaForCausalLM or Gemma2ForCausalLM."
)
def apply_liger_kernel_to_qwen2(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Qwen2 models
Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
from transformers.models.qwen2 import modeling_qwen2
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model
if rope:
modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
if rms_norm:
modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(qwen2_lce_forward, model)
else:
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
if swiglu:
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules
# get the base model from the model instance
base_model: Qwen2Model = getattr(model, model.base_model_prefix, model)
if rms_norm:
_patch_rms_norm_module(base_model.norm)
for decoder_layer in base_model.layers:
if swiglu:
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
def apply_liger_kernel_to_qwen3(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
from transformers.models.qwen3 import modeling_qwen3
from transformers.models.qwen3.modeling_qwen3 import Qwen3Model
from liger_kernel.transformers.model.qwen3 import lce_forward as qwen3_lce_forward
if rope:
modeling_qwen3.apply_rotary_pos_emb = liger_rotary_pos_emb
if rms_norm:
modeling_qwen3.Qwen3RMSNorm = LigerRMSNorm
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(qwen3_lce_forward, model)
else:
modeling_qwen3.Qwen3ForCausalLM.forward = qwen3_lce_forward
if swiglu:
modeling_qwen3.Qwen3MLP = LigerSwiGLUMLP
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules
# get the base model from the model instance
base_model: Qwen3Model = getattr(model, model.base_model_prefix, model)
if rms_norm:
_patch_rms_norm_module(base_model.norm)
for decoder_layer in base_model.layers:
if swiglu:
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
def apply_liger_kernel_to_qwen3_moe(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
from transformers.models.qwen3_moe import modeling_qwen3_moe
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeModel
from liger_kernel.transformers.model.qwen3_moe import lce_forward as qwen3_lce_forward
from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
if rope:
modeling_qwen3_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
if rms_norm:
modeling_qwen3_moe.Qwen3MoeRMSNorm = LigerRMSNorm
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(qwen3_lce_forward, model)
else:
modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = qwen3_lce_forward
if swiglu:
if IS_TRANSFORMERS_V5_OR_LATER:
modeling_qwen3_moe.Qwen3MoeExperts = LigerExperts
else:
modeling_qwen3_moe.Qwen3MoeMLP = LigerQwen3MoeSwiGLUMLP
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules
# get the base model from the model instance
base_model: Qwen3MoeModel = getattr(model, model.base_model_prefix, model)
if rms_norm:
_patch_rms_norm_module(base_model.norm)
for decoder_layer in base_model.layers:
if swiglu:
if IS_TRANSFORMERS_V5_OR_LATER:
_patch_swiglu_module(decoder_layer.mlp.experts, LigerExperts)
else:
for mlp_expert in decoder_layer.mlp.experts:
_patch_swiglu_module(mlp_expert, LigerQwen3MoeSwiGLUMLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
def apply_liger_kernel_to_gpt_oss(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = False, # Set to False by default since GPT-OSS has custom expert implementation
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace GPT-OSS models.
NOTE: GPT-OSS is supported in transformers >= 4.55.0
NOTE: SwiGLU patching is disabled by default for GPT-OSS as it uses a custom expert
implementation with clamping and MXFP4 quantization.
Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
Note: GPT-OSS uses a custom expert implementation, so SwiGLU patching is disabled by default.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
if version.parse(transformers.__version__) < version.parse("4.55.0"):
logger.warning("GPT-OSS support requires transformers >= 4.55.0")
return
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
from transformers.models.gpt_oss import modeling_gpt_oss
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssModel
if rope:
modeling_gpt_oss.apply_rotary_pos_emb = liger_rotary_pos_emb
if rms_norm:
modeling_gpt_oss.GptOssRMSNorm = LigerRMSNorm
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(gpt_oss_lce_forward, model)
else:
modeling_gpt_oss.GptOssForCausalLM.forward = gpt_oss_lce_forward
# Note: SwiGLU patching is not implemented for GPT-OSS due to custom expert implementation
# with clamping (swiglu_limit=7.0) and MXFP4 quantization
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules
# get the base model from the model instance
base_model: GptOssModel = getattr(model, model.base_model_prefix, model)
if rms_norm:
_patch_rms_norm_module(base_model.norm)
for decoder_layer in base_model.layers:
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
def apply_liger_kernel_to_qwen2_vl(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
layer_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models.
NOTE: Qwen2-VL is not supported in transformers<4.52.4
Args:
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
if transformer_version < version.parse("4.52.4"):
logger.warning("Qwen2-VL support is only compatible with transformers >= 4.52.4")
return
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
from transformers.models.qwen2_vl import modeling_qwen2_vl
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLTextModel
from liger_kernel.transformers.model.qwen2_vl import lce_forward as qwen2_vl_lce_forward
if rope:
modeling_qwen2_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb
if rms_norm:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
if layer_norm and model is None:
modeling_qwen2_vl.LayerNorm = LigerLayerNorm
if cross_entropy:
modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss
if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(qwen2_vl_lce_forward, model)
else:
modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
if swiglu:
modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules
if isinstance(model, Qwen2VLForConditionalGeneration):
text_model: Qwen2VLTextModel = model.model.language_model
vision_model: Qwen2VisionTransformerPretrainedModel = model.model.visual
elif isinstance(model, Qwen2VLModel):
text_model: Qwen2VLTextModel = model.language_model
vision_model: Qwen2VisionTransformerPretrainedModel = model.visual
elif isinstance(model, Qwen2VLTextModel):
text_model: Qwen2VLTextModel = model
vision_model = None
else:
# Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
raise TypeError(
f"Unsupported Qwen2VL model type. `model` must be `Qwen2VLForConditionalGeneration`, `Qwen2VLModel` or `Qwen2VLTextModel`. Got: {type(model)}"
)
# Patch Qwen2VisionTransformerPretrainedModel
if vision_model is not None:
for vision_block in vision_model.blocks:
if layer_norm:
_patch_layer_norm_module(vision_block.norm1)
_patch_layer_norm_module(vision_block.norm2)
# Patch Qwen2VisionTextModel
if text_model is not None:
if rms_norm:
_patch_rms_norm_module(text_model.norm)
for decoder_layer in text_model.layers:
if swiglu:
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
def apply_liger_kernel_to_qwen2_5_vl(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Qwen2.5-VL models.
NOTE: Qwen2.5-VL is not available in transformers<4.48.2
Args:
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
if transformer_version < version.parse("4.52.4"):
logger.warning("Qwen2.5-VL support is only compatible with transformers >= 4.52.4")
return
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLTextModel
from liger_kernel.transformers.model.qwen2_5_vl import lce_forward as qwen2_5_vl_lce_forward
if rope:
modeling_qwen2_5_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb
if rms_norm:
modeling_qwen2_5_vl.Qwen2RMSNorm = LigerRMSNorm
if cross_entropy:
modeling_qwen2_5_vl.CrossEntropyLoss = LigerCrossEntropyLoss
if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(qwen2_5_vl_lce_forward, model)
else:
modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_vl_lce_forward
if swiglu:
modeling_qwen2_5_vl.Qwen2MLP = LigerSwiGLUMLP
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules
if isinstance(model, Qwen2_5_VLForConditionalGeneration):
text_model: Qwen2_5_VLTextModel = model.model.language_model
vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.model.visual
elif isinstance(model, Qwen2_5_VLModel):
text_model: Qwen2_5_VLTextModel = model.language_model
vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.visual
elif isinstance(model, Qwen2_5_VLTextModel):
text_model: Qwen2_5_VLTextModel = model
vision_model = None
else:
# Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
raise TypeError(
f"Unsupported Qwen2VL model type. `model` must be `Qwen2VLForConditionalGeneration`, `Qwen2VLModel` or `Qwen2VLTextModel`. Got: {type(model)}"
)
if vision_model is not None:
# Patch Qwen2_5_VisionTransformerPretrainedModel
for vision_block in vision_model.blocks:
if rms_norm:
_patch_rms_norm_module(vision_block.norm1)
_patch_rms_norm_module(vision_block.norm2)
if text_model is not None:
if rms_norm:
_patch_rms_norm_module(text_model.norm)
for decoder_layer in text_model.layers:
if swiglu:
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
def apply_liger_kernel_to_qwen3_vl(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = False,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL models.
Args:
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
from transformers.models.qwen3_vl import modeling_qwen3_vl
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLModel
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextModel
from liger_kernel.transformers.model.qwen3_vl import lce_forward as qwen3_vl_lce_forward
if rope:
modeling_qwen3_vl.apply_rotary_pos_emb = liger_rotary_pos_emb
modeling_qwen3_vl.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_vision
if rms_norm:
modeling_qwen3_vl.Qwen3VLTextRMSNorm = LigerRMSNorm
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(qwen3_vl_lce_forward, model)
else:
modeling_qwen3_vl.Qwen3VLForConditionalGeneration.forward = qwen3_vl_lce_forward
if model is not None and rms_norm:
if isinstance(model, Qwen3VLForConditionalGeneration):
text_model: Qwen3VLTextModel = model.model.language_model
elif isinstance(model, Qwen3VLModel):
text_model: Qwen3VLTextModel = model.language_model
elif isinstance(model, Qwen3VLTextModel):
text_model = model
else:
raise TypeError(
f"Unsupported Qwen3VL model type. `model` must be `Qwen3VLForConditionalGeneration`, `Qwen3VLModel` or `Qwen3VLTextModel`. Got: {type(model)}"
)
_patch_qwen3_vl_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
if text_model is not None:
_patch_qwen3_vl_rms_norm(text_model.norm)
for decoder_layer in text_model.layers:
_patch_qwen3_vl_rms_norm(decoder_layer.input_layernorm)
_patch_qwen3_vl_rms_norm(decoder_layer.post_attention_layernorm)
self_attn = getattr(decoder_layer, "self_attn", None)
if self_attn is not None:
if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
_patch_qwen3_vl_rms_norm(self_attn.q_norm)
if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
_patch_qwen3_vl_rms_norm(self_attn.k_norm)
def apply_liger_kernel_to_qwen3_vl_moe(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = False,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL MoE models.
Args:
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is False.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
from transformers.models.qwen3_vl_moe import modeling_qwen3_vl_moe
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeModel
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextModel
from liger_kernel.transformers.model.qwen3_vl_moe import lce_forward as qwen3_vl_moe_lce_forward
if rope:
modeling_qwen3_vl_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
modeling_qwen3_vl_moe.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_vision
if rms_norm:
modeling_qwen3_vl_moe.Qwen3VLMoeTextRMSNorm = LigerRMSNorm
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(qwen3_vl_moe_lce_forward, model)
else:
modeling_qwen3_vl_moe.Qwen3VLMoeForConditionalGeneration.forward = qwen3_vl_moe_lce_forward
if model is not None and rms_norm:
if isinstance(model, Qwen3VLMoeForConditionalGeneration):
text_model: Qwen3VLMoeTextModel = model.model.language_model
elif isinstance(model, Qwen3VLMoeModel):
text_model: Qwen3VLMoeTextModel = model.language_model
elif isinstance(model, Qwen3VLMoeTextModel):
text_model = model
else:
raise TypeError(
f"Unsupported Qwen3VLMoe model type. `model` must be `Qwen3VLMoeForConditionalGeneration`, `Qwen3VLMoeModel` or `Qwen3VLMoeTextModel`. Got: {type(model)}"
)
_patch_qwen3_vl_moe_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
if text_model is not None:
_patch_qwen3_vl_moe_rms_norm(text_model.norm)
for decoder_layer in text_model.layers:
_patch_qwen3_vl_moe_rms_norm(decoder_layer.input_layernorm)
_patch_qwen3_vl_moe_rms_norm(decoder_layer.post_attention_layernorm)
self_attn = getattr(decoder_layer, "self_attn", None)
if self_attn is not None:
if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
_patch_qwen3_vl_moe_rms_norm(self_attn.q_norm)
if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
_patch_qwen3_vl_moe_rms_norm(self_attn.k_norm)
def apply_liger_kernel_to_phi3(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Phi3 models.
Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU Phi3MLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
from transformers.models.phi3 import modeling_phi3
from transformers.models.phi3.modeling_phi3 import Phi3Model
if rope:
modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb # Same as Gemma
if rms_norm:
modeling_phi3.Phi3RMSNorm = LigerRMSNorm # Same as Llama
if swiglu:
modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(phi3_lce_forward, model)
else:
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules
# get the base model from the model instance
base_model: Phi3Model = getattr(model, model.base_model_prefix, model)
if rms_norm:
_patch_rms_norm_module(base_model.norm)
for decoder_layer in base_model.layers:
if swiglu:
_patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
def apply_liger_kernel_to_olmo2(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace OLMO2 models.
Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU Olmo2MLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
from transformers.models.olmo2 import modeling_olmo2
from transformers.models.olmo2.modeling_olmo2 import Olmo2Model
from liger_kernel.transformers.model.olmo2 import lce_forward as olmo2_lce_forward
from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
if rope:
modeling_olmo2.apply_rotary_pos_emb = liger_rotary_pos_emb
if rms_norm:
modeling_olmo2.Olmo2RMSNorm = LigerRMSNormForOlmo2
if swiglu:
modeling_olmo2.Olmo2MLP = LigerSwiGLUMLP
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(olmo2_lce_forward, model)
else:
modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules
# get the base model from the model instance
base_model: Olmo2Model = getattr(model, model.base_model_prefix, model)
if rms_norm:
_patch_rms_norm_module(base_model.norm)
for decoder_layer in base_model.layers:
if swiglu:
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
_patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
def apply_liger_kernel_to_olmo3(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Olmo3 models.
Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU to Olmo3MLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
from transformers.models.olmo3 import modeling_olmo3
from transformers.models.olmo3.modeling_olmo3 import Olmo3Model
from liger_kernel.transformers.model.olmo3 import lce_forward as olmo3_lce_forward
from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
# Olmo3 arch is very similar to Olmo2, so we can reuse all these components in the same way.
if rope:
modeling_olmo3.apply_rotary_pos_emb = liger_rotary_pos_emb
if rms_norm:
modeling_olmo3.Olmo3RMSNorm = LigerRMSNormForOlmo2 # same as olmo2
if swiglu:
modeling_olmo3.Olmo3MLP = LigerSwiGLUMLP
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(olmo3_lce_forward, model)
else:
modeling_olmo3.Olmo3ForCausalLM.forward = olmo3_lce_forward
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules
# get the base model from the model instance
base_model: Olmo3Model = getattr(model, model.base_model_prefix, model)
if rms_norm:
_patch_rms_norm_module(base_model.norm)
for decoder_layer in base_model.layers:
if swiglu:
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
_patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
def apply_liger_kernel_to_glm4(
rope: bool = False,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace GLM-4 models.
Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU Glm4MLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
from transformers.models.glm4 import modeling_glm4
from transformers.models.glm4.modeling_glm4 import Glm4Model
from liger_kernel.transformers.model.glm4 import lce_forward as glm4_lce_forward
from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
if rope:
raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
if rms_norm:
modeling_glm4.Glm4RMSNorm = LigerRMSNormForGlm4
if swiglu:
modeling_glm4.Glm4MLP = LigerPhi3SwiGLUMLP
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(glm4_lce_forward, model)
else:
modeling_glm4.Glm4ForCausalLM.forward = glm4_lce_forward
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules
# get the base model from the model instance
base_model: Glm4Model = getattr(model, model.base_model_prefix, model)
if rms_norm:
_patch_rms_norm_module(base_model.norm, in_place=False)
for decoder_layer in base_model.layers:
if swiglu:
_patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm, in_place=False)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
_patch_rms_norm_module(decoder_layer.post_self_attn_layernorm, in_place=False)
_patch_rms_norm_module(decoder_layer.post_mlp_layernorm, in_place=False)
def apply_liger_kernel_to_glm4v(
rope: bool = False,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace GLM-4v models.
Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU Glm4MLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
from transformers.models.glm4v import modeling_glm4v
from transformers.models.glm4v.modeling_glm4v import Glm4vForConditionalGeneration
from transformers.models.glm4v.modeling_glm4v import Glm4vModel
from transformers.models.glm4v.modeling_glm4v import Glm4vTextModel
from transformers.models.glm4v.modeling_glm4v import Glm4vVisionModel
from liger_kernel.transformers.model.glm4v import lce_forward as glm4v_lce_forward
from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
if rope:
raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
if rms_norm:
modeling_glm4v.Glm4vRMSNorm = LigerRMSNormForGlm4
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(glm4v_lce_forward, model)
else:
modeling_glm4v.Glm4vForConditionalGeneration.forward = glm4v_lce_forward
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules
if isinstance(model, Glm4vForConditionalGeneration):
text_model: Glm4vTextModel = model.model.language_model
vision_model: Glm4vVisionModel = model.model.visual
elif isinstance(model, Glm4vModel):
text_model: Glm4vTextModel = model.language_model
vision_model: Glm4vVisionModel = model.visual
elif isinstance(model, Glm4vTextModel):
text_model: Glm4vTextModel = model
vision_model = None
else:
# Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
raise TypeError(
f"Unsupported glm4.1v model type. `model` must be `Glm4VLForConditionalGeneration`, `Glm4vVisionModel` or `Glm4vTextModel`. Got: {type(model)}"
)
if vision_model is not None:
for vision_block in vision_model.blocks:
if rms_norm:
_patch_rms_norm_module(vision_block.norm1)
_patch_rms_norm_module(vision_block.norm2)
if swiglu:
_patch_swiglu_module(vision_block.mlp, LigerSwiGLUMLP)
if text_model is not None:
if rms_norm:
_patch_rms_norm_module(text_model.norm)
for decoder_layer in text_model.layers:
if swiglu:
_patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
_patch_rms_norm_module(decoder_layer.post_self_attn_layernorm)
_patch_rms_norm_module(decoder_layer.post_mlp_layernorm)
def apply_liger_kernel_to_glm4v_moe(
rope: bool = False,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace GLM4v_moe models.
Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
from transformers.models.glm4v_moe import modeling_glm4v_moe
from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeForConditionalGeneration
from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeModel
from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeTextModel
from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeVisionModel
from liger_kernel.transformers.model.glm4v_moe import lce_forward as glm4v_moe_lce_forward
from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
if rope:
raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
if rms_norm:
modeling_glm4v_moe.Glm4vMoeRMSNorm = LigerRMSNormForGlm4
modeling_glm4v_moe.Glm4vMoeTextRMSNorm = LigerRMSNormForGlm4
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(glm4v_moe_lce_forward, model)
else:
modeling_glm4v_moe.Glm4vMoeForConditionalGeneration.forward = glm4v_moe_lce_forward
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules
if isinstance(model, Glm4vMoeForConditionalGeneration):
text_model: Glm4vMoeTextModel = model.model.language_model
vision_model: Glm4vMoeVisionModel = model.model.visual
Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE
elif isinstance(model, Glm4vMoeModel):
text_model: Glm4vMoeTextModel = model.language_model
vision_model: Glm4vMoeVisionModel = model.visual
Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE
elif isinstance(model, Glm4vMoeTextModel):
text_model: Glm4vMoeTextModel = model
vision_model = None
else:
# Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
raise TypeError(
f"Unsupported glm4v_moe model type. `model` must be `Glm4vMoeForConditionalGeneration`, `Glm4vMoeVisionModel` or `Glm4vMoeTextModel`. Got: {type(model)}"
)
if vision_model is not None:
_patch_rms_norm_module(vision_model.post_conv_layernorm)
_patch_rms_norm_module(vision_model.post_layernorm)
for vision_block in vision_model.blocks:
if rms_norm:
_patch_rms_norm_module(vision_block.norm1)
_patch_rms_norm_module(vision_block.norm2)
if swiglu:
_patch_swiglu_module(vision_block.mlp, LigerSwiGLUMLP)
if text_model is not None:
if rms_norm:
_patch_rms_norm_module(text_model.norm)
for decoder_layer in text_model.layers:
if swiglu:
decoder_layer.mlp = _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
if isinstance(Glm4vMoeTextMoE, type) and isinstance(decoder_layer.mlp, Glm4vMoeTextMoE):
experts = getattr(decoder_layer.mlp, "experts", None)
if experts is not None:
for expert in experts:
_patch_swiglu_module(expert, LigerSwiGLUMLP)
if decoder_layer.mlp.shared_experts is not None:
_patch_swiglu_module(decoder_layer.mlp.shared_experts, LigerSwiGLUMLP)
for decoder_layer in text_model.layers:
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
def apply_liger_kernel_to_internvl(
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
layer_norm: bool = True,
model: Optional[PreTrainedModel] = None,
**kwargs,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace InternVL models.
Due to the characteristics of InternVL, the model must be passed to apply Liger-Kernel's patch to other models connected to InternVL.
However, if an LM not supported by Liger-Kernel is connected to InternVL, unexpected side effects may occur.
NOTE: InternVL is not available in transformers<4.52.1
Args:
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
import torch.nn as torch_nn
from transformers.models.internvl import modeling_internvl
from transformers.models.internvl.modeling_internvl import InternVLForConditionalGeneration
from transformers.models.internvl.modeling_internvl import InternVLModel
from transformers.models.internvl.modeling_internvl import InternVLVisionLayer
from transformers.models.internvl.modeling_internvl import InternVLVisionModel
from transformers.models.internvl.modeling_internvl import InternVLVisionRMSNorm
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.model.internvl import lce_forward as internvl_lce_forward
from liger_kernel.transformers.rms_norm import LigerRMSNorm
if layer_norm and model is None:
modeling_internvl.nn.LayerNorm = LigerLayerNorm
if cross_entropy:
logger.info("Apply liger cross entropy")
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
modeling_internvl.InternVLForConditionalGeneration.forward = internvl_lce_forward
if rms_norm:
modeling_internvl.InternVLVisionRMSNorm = LigerRMSNorm
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules
if isinstance(model, InternVLForConditionalGeneration):
text_model = model.model.language_model
vision_model: InternVLVisionModel = model.model.vision_tower
elif isinstance(model, InternVLModel):
text_model = model.language_model
vision_model: InternVLVisionModel = model.vision_tower
else:
raise TypeError(
f"Unsupported internvl model type. `model` must be `InternVLForConditionalGeneration`, `InternVLModel`. Got: {type(model)}"
)
text_model_name = model.config.text_config.model_type
text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
if text_liger_fn:
accept_params = inspect.signature(text_liger_fn).parameters
remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
if remain_params:
logger.warning(
f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
)
text_kwargs["model"] = text_model
text_liger_fn(**text_kwargs)
elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
logger.warning(f"{text_model_name} is not supported by Liger kernel.")
# Patch vision model RMSNorm layers
if rms_norm:
for encoder_layer in vision_model.encoder.layer:
encoder_layer: InternVLVisionLayer
if isinstance(encoder_layer.attention.q_norm, InternVLVisionRMSNorm):
_patch_rms_norm_module(encoder_layer.attention.q_norm)
if isinstance(encoder_layer.attention.k_norm, InternVLVisionRMSNorm):
_patch_rms_norm_module(encoder_layer.attention.k_norm)
# Patch vision model LayerNorm layers
if layer_norm:
# Patch layernorm
if isinstance(vision_model.layernorm, torch_nn.LayerNorm):
_patch_layer_norm_module(vision_model.layernorm)
# Patch encoder layers
for encoder_layer in vision_model.encoder.layer:
encoder_layer: InternVLVisionLayer
if isinstance(encoder_layer.layernorm_before, torch_nn.LayerNorm):
_patch_layer_norm_module(encoder_layer.layernorm_before)
if isinstance(encoder_layer.layernorm_after, torch_nn.LayerNorm):
_patch_layer_norm_module(encoder_layer.layernorm_after)
def apply_liger_kernel_to_smolvlm(
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
layer_norm: bool = True,
model: Optional[PreTrainedModel] = None,
**kwargs,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace SmolVLM models.
Due to the characteristics of SmolVLM, the model must be passed to apply Liger-Kernel's patch to other models connected to SmolVLM.
However, if an LM not supported by Liger-Kernel is connected to SmolVLM, unexpected side effects may occur.
NOTE: SmolVLM is not available in transformers<4.50.0
Args:
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
from transformers.models.smolvlm import modeling_smolvlm
from transformers.models.smolvlm.modeling_smolvlm import SmolVLMEncoderLayer
from transformers.models.smolvlm.modeling_smolvlm import SmolVLMForConditionalGeneration
from transformers.models.smolvlm.modeling_smolvlm import SmolVLMModel
from transformers.models.smolvlm.modeling_smolvlm import SmolVLMVisionTransformer
from liger_kernel.transformers.model.smolvlm import lce_forward as smolvlm_lce_forward
# Patch LayerNorm for vision model if model is not provided (pre-initialization)
if layer_norm and model is None:
modeling_smolvlm.nn.LayerNorm = LigerLayerNorm
if cross_entropy:
logger.info("Apply liger cross entropy")
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(smolvlm_lce_forward, model)
else:
modeling_smolvlm.SmolVLMForConditionalGeneration.forward = smolvlm_lce_forward
if rms_norm:
modeling_smolvlm.SmolVLMRMSNorm = LigerRMSNorm
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules
if isinstance(model, SmolVLMForConditionalGeneration):
text_model = model.model.text_model
vision_model: SmolVLMVisionTransformer = model.model.vision_model
elif isinstance(model, SmolVLMModel):
text_model = model.text_model
vision_model: SmolVLMVisionTransformer = model.vision_model
else:
raise TypeError(
f"Unsupported smolvlm model type. `model` must be `SmolVLMForConditionalGeneration`, `SmolVLMModel`. Got: {type(model)}"
)
text_model_name = model.config.text_config.model_type
text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
if text_liger_fn:
accept_params = inspect.signature(text_liger_fn).parameters
remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
if remain_params:
logger.warning(
f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
)
text_kwargs["model"] = text_model
text_liger_fn(**text_kwargs)
elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
logger.warning(f"{text_model_name} is not supported by Liger kernel.")
# Patch vision model LayerNorm layers
if layer_norm:
# Patch post_layernorm
_patch_layer_norm_module(vision_model.post_layernorm)
# Patch encoder layers
for encoder_layer in vision_model.encoder.layers:
encoder_layer: SmolVLMEncoderLayer
_patch_layer_norm_module(encoder_layer.layer_norm1)
_patch_layer_norm_module(encoder_layer.layer_norm2)
def apply_liger_kernel_to_falcon_h1(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = False,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Falcon-H1 models
Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is False.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
from transformers.models.falcon_h1 import modeling_falcon_h1
from transformers.models.falcon_h1.modeling_falcon_h1 import FalconH1Model
if rope:
logger.info("Apply liger rotary pos emb.")
modeling_falcon_h1.apply_rotary_pos_emb = liger_rotary_pos_emb
if rms_norm:
logger.info("Apply liger RMSNorm")
modeling_falcon_h1.FalconH1RMSNorm = LigerRMSNorm
if swiglu:
logger.warning("LigerSwiGLUMLP is not available for Falcon-H1 models. There will be no effect.")
if cross_entropy:
logger.info("Apply liger cross entropy")
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(falcon_h1_lce_forward, model)
else:
modeling_falcon_h1.FalconH1ForCausalLM.forward = falcon_h1_lce_forward
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
# get the base model from the model instance
base_model: FalconH1Model = getattr(model, model.base_model_prefix, model)
if rms_norm:
_patch_rms_norm_module(base_model.final_layernorm)
for decoder_layer in base_model.layers:
if swiglu:
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.pre_ff_layernorm)
def apply_liger_kernel_to_qwen3_next(
rope: bool = False,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace GLM4v_moe models.
Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
from transformers.models.qwen3_next import modeling_qwen3_next
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextForCausalLM
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextMLP
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextModel
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock
from liger_kernel.transformers.model.qwen3_next import lce_forward as qwen3_next_lce_forward
from liger_kernel.transformers.rms_norm import LigerRMSNormForQwen3Next
from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
if rope:
# It might enocunter nan issue
# modeling_qwen3_next.apply_rotary_pos_emb = liger_rotary_pos_emb
raise NotImplementedError("liger_rotary_pos_emb is not available for Qwen3Next models.")
if rms_norm:
modeling_qwen3_next.Qwen3NextRMSNorm = LigerRMSNormForQwen3Next
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
if model is not None:
if isinstance(model, Qwen3NextForCausalLM):
model.forward = MethodType(qwen3_next_lce_forward, model)
else:
raise TypeError(
f" fused_linear_cross_entropy is only applicable on Qwen3NextForCausalLM. Got: {type(model)}"
)
else:
modeling_qwen3_next.Qwen3NextForCausalLM.forward = qwen3_next_lce_forward
if swiglu:
if IS_TRANSFORMERS_V5_OR_LATER:
modeling_qwen3_next.Qwen3NextExperts = LigerExperts
else:
# Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
modeling_qwen3_next.Qwen3NextMLP = LigerQwen3MoeSwiGLUMLP
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules
if isinstance(model, (Qwen3NextForCausalLM, Qwen3NextModel)):
base_model: Qwen3NextForCausalLM = getattr(model, model.base_model_prefix, model)
else:
raise TypeError(
f"Unsupported qwen3_next model type. `model` must be `Qwen3NextForCausalLM`, `Qwen3NextModel`. Got: {type(model)}"
)
_patch_rms_norm_module_for_qwen3_next = partial(
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
)
if rms_norm:
_patch_rms_norm_module_for_qwen3_next(base_model.norm)
for decoder_layer in base_model.layers:
if rms_norm:
_patch_rms_norm_module_for_qwen3_next(decoder_layer.input_layernorm)
_patch_rms_norm_module_for_qwen3_next(decoder_layer.post_attention_layernorm)
# Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
if swiglu:
if isinstance(decoder_layer.mlp, Qwen3NextMLP):
_patch_swiglu_module(decoder_layer.mlp, LigerQwen3MoeSwiGLUMLP)
if isinstance(decoder_layer.mlp, Qwen3NextSparseMoeBlock):
_patch_swiglu_module(decoder_layer.mlp.shared_expert, LigerQwen3MoeSwiGLUMLP)
experts = getattr(decoder_layer.mlp, "experts", None)
if experts is not None:
if IS_TRANSFORMERS_V5_OR_LATER:
_patch_swiglu_module(experts, LigerExperts)
else:
for expert in experts:
_patch_swiglu_module(expert, LigerQwen3MoeSwiGLUMLP)
def apply_liger_kernel_to_qwen3_5(
rope: bool = False,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Qwen3.5 dense models.
Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
Not yet supported for Qwen3.5 due to hybrid attention (Gated DeltaNet + Gated Attention).
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
from transformers.models.qwen3_5 import modeling_qwen3_5
from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForCausalLM
from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5TextModel
try:
from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForConditionalGeneration
except ImportError:
Qwen3_5ForConditionalGeneration = None
from liger_kernel.transformers.model.qwen3_5 import lce_forward as qwen3_5_lce_forward
from liger_kernel.transformers.model.qwen3_5 import lce_forward_for_multimodal as qwen3_5_lce_forward_for_multimodal
from liger_kernel.transformers.monkey_patch import _patch_rms_norm_module
from liger_kernel.transformers.monkey_patch import _patch_swiglu_module
from liger_kernel.transformers.rms_norm import LigerRMSNormForQwen3Next
from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
if rope:
raise NotImplementedError("liger_rotary_pos_emb is not available for Qwen3_5 models.")
if rms_norm:
modeling_qwen3_5.Qwen3_5RMSNorm = LigerRMSNormForQwen3Next
if cross_entropy:
from transformers.loss.loss_utils import nn
from liger_kernel.transformers.cross_entropy import liger_cross_entropy
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
if model is not None:
if isinstance(model, Qwen3_5ForCausalLM):
model.forward = MethodType(qwen3_5_lce_forward, model)
elif isinstance(model, Qwen3_5ForConditionalGeneration):
model.forward = MethodType(qwen3_5_lce_forward_for_multimodal, model)
else:
raise TypeError(
f"fused_linear_cross_entropy is only applicable on Qwen3_5ForCausalLM or Qwen3_5ForConditionalGeneration. Got: {type(model)}"
)
else:
modeling_qwen3_5.Qwen3_5ForCausalLM.forward = qwen3_5_lce_forward
if Qwen3_5ForConditionalGeneration is not None:
modeling_qwen3_5.Qwen3_5ForConditionalGeneration.forward = qwen3_5_lce_forward_for_multimodal
if swiglu:
modeling_qwen3_5.Qwen3_5MLP = LigerQwen3MoeSwiGLUMLP
if model is not None:
if isinstance(model, (Qwen3_5ForCausalLM, Qwen3_5TextModel)):
text_model: Qwen3_5TextModel = getattr(model, model.base_model_prefix, model)
elif Qwen3_5ForConditionalGeneration is not None and isinstance(model, Qwen3_5ForConditionalGeneration):
text_model = model.model.language_model
else:
raise TypeError(f"Unsupported qwen3_5 model type. Got: {type(model)}")
_patch_rms_norm_module_for_qwen3_5 = partial(
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
)
if rms_norm:
_patch_rms_norm_module_for_qwen3_5(text_model.norm)
for decoder_layer in text_model.layers:
if rms_norm:
_patch_rms_norm_module_for_qwen3_5(decoder_layer.input_layernorm)
_patch_rms_norm_module_for_qwen3_5(decoder_layer.post_attention_layernorm)
if swiglu:
_patch_swiglu_module(decoder_layer.mlp, LigerQwen3MoeSwiGLUMLP)
def apply_liger_kernel_to_qwen3_5_moe(
rope: bool = False,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Qwen3.5 MoE models.
Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
from transformers.models.qwen3_5_moe import modeling_qwen3_5_moe
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeForCausalLM
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeTextModel
from liger_kernel.transformers.model.qwen3_5_moe import lce_forward as qwen3_5_moe_lce_forward
from liger_kernel.transformers.rms_norm import LigerRMSNormForQwen3Next
from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
if rope:
raise NotImplementedError("liger_rotary_pos_emb is not available for Qwen3_5Moe models.")
if rms_norm:
modeling_qwen3_5_moe.Qwen3_5MoeRMSNorm = LigerRMSNormForQwen3Next
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
if model is not None:
if isinstance(model, Qwen3_5MoeForCausalLM):
model.forward = MethodType(qwen3_5_moe_lce_forward, model)
else:
raise TypeError(
f" fused_linear_cross_entropy is only applicable on Qwen3_5MoeForCausalLM. Got: {type(model)}"
)
else:
modeling_qwen3_5_moe.Qwen3_5MoeForCausalLM.forward = qwen3_5_moe_lce_forward
if swiglu:
modeling_qwen3_5_moe.Qwen3_5MoeExperts = LigerExperts
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules
if isinstance(model, (Qwen3_5MoeForCausalLM, Qwen3_5MoeTextModel)):
base_model: Qwen3_5MoeTextModel = getattr(model, model.base_model_prefix, model)
else:
raise TypeError(
f"Unsupported qwen3_5_moe model type. `model` must be `Qwen3_5MoeForCausalLM`, `Qwen3_5MoeTextModel`. Got: {type(model)}"
)
_patch_rms_norm_module_for_qwen3_5_moe = partial(
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
)
if rms_norm:
_patch_rms_norm_module_for_qwen3_5_moe(base_model.norm)
for decoder_layer in base_model.layers:
if rms_norm:
_patch_rms_norm_module_for_qwen3_5_moe(decoder_layer.input_layernorm)
_patch_rms_norm_module_for_qwen3_5_moe(decoder_layer.post_attention_layernorm)
if swiglu:
_patch_swiglu_module(decoder_layer.mlp.shared_expert, LigerQwen3MoeSwiGLUMLP)
experts = getattr(decoder_layer.mlp, "experts", None)
if experts is not None:
_patch_swiglu_module(experts, LigerExperts)
def apply_liger_kernel_to_hunyuan_v1_dense(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Hunyuan v1 dense models.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
from transformers.models.hunyuan_v1_dense import modeling_hunyuan_v1_dense
from transformers.models.hunyuan_v1_dense.modeling_hunyuan_v1_dense import HunYuanDenseV1Model
from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_lce_forward
from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP
if rope:
modeling_hunyuan_v1_dense.apply_rotary_pos_emb = liger_rotary_pos_emb
if rms_norm:
modeling_hunyuan_v1_dense.HunYuanDenseV1RMSNorm = LigerRMSNorm
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(hunyuan_v1_lce_forward, model)
else:
modeling_hunyuan_v1_dense.HunYuanDenseV1ForCausalLM.forward = hunyuan_v1_lce_forward
if swiglu:
modeling_hunyuan_v1_dense.HunYuanDenseV1MLP = LigerHunyuanV1SwiGLUMLP
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules
# get the base model from the model instance
base_model: HunYuanDenseV1Model = getattr(model, model.base_model_prefix, model)
if rms_norm:
_patch_rms_norm_module(base_model.norm)
for decoder_layer in base_model.layers:
if swiglu:
_patch_swiglu_module(decoder_layer.mlp, LigerHunyuanV1SwiGLUMLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
def apply_liger_kernel_to_hunyuan_v1_moe(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
from transformers.models.hunyuan_v1_moe import modeling_hunyuan_v1_moe
from transformers.models.hunyuan_v1_moe.modeling_hunyuan_v1_moe import HunYuanMoEV1Model
from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_moe_lce_forward
from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP
if rope:
modeling_hunyuan_v1_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
if rms_norm:
modeling_hunyuan_v1_moe.HunYuanMoEV1RMSNorm = LigerRMSNorm
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(hunyuan_v1_moe_lce_forward, model)
else:
modeling_hunyuan_v1_moe.HunYuanMoEV1ForCausalLM.forward = hunyuan_v1_moe_lce_forward
if swiglu:
if IS_TRANSFORMERS_V5_OR_LATER:
modeling_hunyuan_v1_moe.HunYuanMoEV1Experts = LigerExperts
else:
modeling_hunyuan_v1_moe.HunYuanMoEV1MLP = LigerHunyuanV1SwiGLUMLP
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules
# get the base model from the model instance
base_model: HunYuanMoEV1Model = getattr(model, model.base_model_prefix, model)
if rms_norm:
_patch_rms_norm_module(base_model.norm)
for decoder_layer in base_model.layers:
if swiglu:
if IS_TRANSFORMERS_V5_OR_LATER:
_patch_swiglu_module(decoder_layer.mlp.experts, LigerExperts)
else:
for mlp_expert in decoder_layer.mlp.experts:
_patch_swiglu_module(mlp_expert, LigerHunyuanV1SwiGLUMLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
def apply_liger_kernel_to_exaone4(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace EXAONE4 models.
Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
from transformers.models.exaone4 import modeling_exaone4
from transformers.models.exaone4.modeling_exaone4 import Exaone4Model
from liger_kernel.transformers.model.exaone4 import lce_forward as exaone4_lce_forward
if rope:
modeling_exaone4.apply_rotary_pos_emb = liger_rotary_pos_emb
if rms_norm:
# EXAONE4 requires in_place=False to avoid gradient issues
class Exaone4LigerRMSNorm(LigerRMSNorm):
def __init__(self, hidden_size, eps=1e-6, **kwargs):
super().__init__(hidden_size, eps, **kwargs)
self.in_place = False
modeling_exaone4.Exaone4RMSNorm = Exaone4LigerRMSNorm
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(exaone4_lce_forward, model)
else:
modeling_exaone4.Exaone4ForCausalLM.forward = exaone4_lce_forward
if swiglu:
modeling_exaone4.Exaone4MLP = LigerSwiGLUMLP
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules
# get the base model from the model instance
base_model: Exaone4Model = getattr(model, model.base_model_prefix, model)
if rms_norm:
_patch_rms_norm_module(base_model.norm, in_place=False)
for decoder_layer in base_model.layers:
if swiglu:
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
if rms_norm:
_patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
_patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
_patch_rms_norm_module(decoder_layer.self_attn.q_norm, in_place=False)
_patch_rms_norm_module(decoder_layer.self_attn.k_norm, in_place=False)
# Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
MODEL_TYPE_TO_APPLY_LIGER_FN = {
"gemma": apply_liger_kernel_to_gemma,
"gemma2": apply_liger_kernel_to_gemma2,
"gemma3_text": apply_liger_kernel_to_gemma3_text,
"gemma3": apply_liger_kernel_to_gemma3,
"glm4": apply_liger_kernel_to_glm4,
"glm4v": apply_liger_kernel_to_glm4v,
"glm4v_moe": apply_liger_kernel_to_glm4v_moe,
"gpt_oss": apply_liger_kernel_to_gpt_oss,
"internvl": apply_liger_kernel_to_internvl,
"llama": apply_liger_kernel_to_llama,
"llama4_text": apply_liger_kernel_to_llama4,
"llama4": apply_liger_kernel_to_llama4,
"llava": apply_liger_kernel_to_llava,
"granite": apply_liger_kernel_to_granite,
"mllama": apply_liger_kernel_to_mllama,
"mllama_text_model": apply_liger_kernel_to_mllama,
"mistral": apply_liger_kernel_to_mistral,
"mixtral": apply_liger_kernel_to_mixtral,
"olmo2": apply_liger_kernel_to_olmo2,
"pixtral": apply_liger_kernel_to_pixtral,
"olmo3": apply_liger_kernel_to_olmo3,
"qwen2": apply_liger_kernel_to_qwen2,
"qwen3": apply_liger_kernel_to_qwen3,
"qwen3_moe": apply_liger_kernel_to_qwen3_moe,
"qwen2_vl": apply_liger_kernel_to_qwen2_vl,
"qwen2_vl_text": apply_liger_kernel_to_qwen2_vl,
"qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
"qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl,
"qwen3_next": apply_liger_kernel_to_qwen3_next,
"qwen3_5": apply_liger_kernel_to_qwen3_5,
"qwen3_5_text": apply_liger_kernel_to_qwen3_5,
"qwen3_5_moe": apply_liger_kernel_to_qwen3_5_moe,
"qwen3_5_moe_text": apply_liger_kernel_to_qwen3_5_moe,
"qwen3_vl": apply_liger_kernel_to_qwen3_vl,
"qwen3_vl_text": apply_liger_kernel_to_qwen3_vl,
"qwen3_vl_moe": apply_liger_kernel_to_qwen3_vl_moe,
"qwen3_vl_moe_text": apply_liger_kernel_to_qwen3_vl_moe,
"smollm3": apply_liger_kernel_to_smollm3,
"phi3": apply_liger_kernel_to_phi3,
"paligemma": apply_liger_kernel_to_paligemma,
"falcon_h1": apply_liger_kernel_to_falcon_h1,
"smolvlm": apply_liger_kernel_to_smolvlm,
"hunyuan_v1_dense": apply_liger_kernel_to_hunyuan_v1_dense,
"hunyuan_v1_moe": apply_liger_kernel_to_hunyuan_v1_moe,
"exaone4": apply_liger_kernel_to_exaone4,
}
def _apply_liger_kernel(model_type: str, **kwargs) -> None:
"""
Applies Liger kernels based on the specified model type. The custom
kernels for the specified model type will be applied with the provided
keyword arguments, otherwise the default configuration will be used.
** Note: Calling _apply_liger_kernel() after model initialization
will not be able to fully patch models. This must be called before model initialization.
If the model has already been instantiated
Args:
- model_type: the model types as defined in transformers/models/auto/modeling_auto.py
and specified in the model's config.json
- kwargs: keyword arguments that are passed to the corresponding apply_liger_kernel_to_* function.
"""
if not model_type:
logger.info("Model type was not provided. No Liger kernels will be applied.")
return
if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
logger.info(f"There are currently no Liger kernels supported for model type: {model_type}.")
return
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
apply_fn_signature = inspect.signature(apply_fn)
# Filter out the keyword arguments that are not supported by the apply function
applicable_kwargs = {key: value for key, value in kwargs.items() if key in apply_fn_signature.parameters}
logger.info(f"Applying Liger kernels for model type: {model_type} with kwargs: {applicable_kwargs}")
# Assume this is invoked pre-model initialization, so we only need to patch transformers code
apply_fn(**applicable_kwargs)
def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
"""
Applies Liger kernels to the provided model instance.
Args:
- model: the model instance to apply Liger kernels to
- kwargs: keyword arguments that are passed to the corresponding apply_liger_kernel_to_* function.
"""
model_type = getattr(model, "config", None) and getattr(model.config, "model_type", None)
if not model_type:
logger.info("Model type could not be determined from model config. No Liger kernels will be applied.")
return
if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
logger.info(f"There are currently no Liger kernels supported for model type: {model_type}.")
return
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
apply_fn_signature = inspect.signature(apply_fn)
# Filter out the keyword arguments that are not supported by the apply function
applicable_kwargs = {key: value for key, value in kwargs.items() if key in apply_fn_signature.parameters}
logger.info(
f"Applying Liger kernels to model instance with model type: {model_type} with kwargs: {applicable_kwargs}"
)
apply_fn(model=model, **applicable_kwargs)
import math
import torch
import torch.nn as nn
from torch.nn.modules.utils import _pair
from liger_kernel.ops import LigerMultiTokenAttentionFunction
class LigerMultiTokenAttention(nn.Module):
r"""
Multi-Token Attention:
out = mask_{0}(conv2d(softmax(mask_{-\inf}(scores))))
Reference: https://arxiv.org/pdf/2504.00927
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
sparse: bool = False,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.groups = groups
self.sparse = sparse
self.weight = nn.Parameter(torch.empty(out_channels, in_channels // groups, *self.kernel_size))
if bias:
self.bias = nn.Parameter(torch.empty(out_channels))
else:
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(self, scores: torch.Tensor) -> torch.Tensor:
return LigerMultiTokenAttentionFunction.apply(
scores,
self.weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups,
self.sparse,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment