Commit 9b0e3a30 authored by cmx's avatar cmx
Browse files

first commit

parent fe5cd1fc
Pipeline #3450 failed with stages
in 0 seconds
import torch
import torch.nn as nn
from liger_kernel.ops import LigerPolyNormFunction
class LigerPolyNorm(nn.Module):
"""
PolyNorm layer wrapper for Liger kernel.
PolyNorm formula:
y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
where norm(u) = u / sqrt(mean(u²) + ε)
Reference:
https://github.com/BryceZhuo/PolyCom/
Args:
eps: epsilon for numerical stability (default: 1e-6)
in_place: whether to in-place modify grad_output in backward to save memory (default: False).
Set to True to save memory if grad_output is not needed elsewhere.
"""
def __init__(self, eps=1e-6, in_place=True):
super().__init__()
# Align with PolyCom reference: initialize weights to (1/3, 1/3, 1/3) and bias to 1.0
self.weight = nn.Parameter(torch.full((3,), 1.0 / 3.0))
self.bias = nn.Parameter(torch.tensor(1.0))
self.variance_epsilon = eps
self.in_place = in_place
def forward(self, hidden_states):
return LigerPolyNormFunction.apply(
hidden_states,
self.weight,
self.bias,
self.variance_epsilon,
self.in_place,
)
def extra_repr(self):
return f"weight_shape={tuple(self.weight.shape)}, eps={self.variance_epsilon}, in_place={self.in_place}"
from liger_kernel.ops import LigerQwen2VLMRopeFunction
def liger_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
"""
Applies Multimodal Rotary Positional Embedding (M-RoPE) operation to query and key states.
Args:
q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim).
k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim).
cos (torch.Tensor): The cosine tensor of shape (3, bsz, seq_len, head_dim).
sin (torch.Tensor): The sine tensor of shape (3, bsz, seq_len, head_dim).
mrope_section (List[int]): The multimodal rope section for channel dimension of temporal, height and width in rope calculation.
unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The query and key tensors after applying the M-RoPE operation.
"""
return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim)
import torch
import torch.nn as nn
from liger_kernel.ops import LigerRMSNormFunction
class LigerRMSNorm(nn.Module):
def __init__(
self,
hidden_size,
eps=1e-6,
offset=0.0,
casting_mode="llama",
init_fn="ones",
in_place=True,
row_mode=None,
elementwise_affine=True,
):
super().__init__()
assert init_fn in [
"ones",
"zeros",
], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size))
else:
self.register_parameter("weight", None)
self.variance_epsilon, self.offset, self.casting_mode, self.in_place, self.row_mode = (
eps,
offset,
casting_mode,
in_place,
row_mode,
)
def forward(self, hidden_states):
return LigerRMSNormFunction.apply(
hidden_states,
self.weight,
self.variance_epsilon,
self.offset,
self.casting_mode,
self.in_place,
self.row_mode,
)
def extra_repr(self):
return f"weight_shape={tuple(self.weight.shape) if self.weight is not None else None}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}, row_mode={self.row_mode}"
class LigerRMSNormForGemma(LigerRMSNorm):
def __init__(
self, hidden_size, eps=1e-6, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=True, row_mode=None
):
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
class LigerRMSNormForGemma2(LigerRMSNorm):
def __init__(
self, hidden_size, eps=1e-6, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False, row_mode=None
):
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
class LigerRMSNormForGemma3(LigerRMSNorm):
"""Gemma3RMSNorm has a dim argument not hidden_size used in q_norm and k_norm."""
def __init__(self, dim, eps=0.000001, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False):
super().__init__(dim, eps, offset, casting_mode, init_fn, in_place)
class LigerRMSNormForOlmo2(LigerRMSNorm):
def __init__(
self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", init_fn="ones", in_place=False, row_mode=None
):
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
class LigerRMSNormForGlm4(LigerRMSNorm):
def __init__(
self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", init_fn="ones", in_place=False, row_mode=None
):
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
class LigerRMSNormForQwen3Next(LigerRMSNorm):
def __init__(
self, hidden_size, eps=1e-6, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False, row_mode=None
):
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
from typing import Tuple
import torch
from liger_kernel.ops import LigerRopeFunction
def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""
Applies Rotary Positional Embedding (RoPE) operation to query and key states.
Args:
q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim).
k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim).
cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim) or (bsz, seq_len, head_dim).
sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim) or (bsz, seq_len, head_dim).
position_ids (torch.Tensor, optional): The position ids tensor. Defaults to None.
unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The query and key tensors after applying the RoPE operation.
"""
return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim)
def liger_rotary_pos_emb_vision(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Modified version of liger_rotary_pos_emb for qwen3_vl's apply_rotary_pos_emb_vision function.
Manually tranposed the input and output to match the expected shape for liger_rotary_pos_emb.
Reference: https://https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py#L116
Args:
q (torch.Tensor): The query tensor of shape (seq_length, num_heads, head_dim),
with stride (num_heads * head_dim, head_dim, 1).
k (torch.Tensor): The query tensor of shape (seq_length, num_heads, head_dim),
with stride (num_heads * head_dim, head_dim, 1). Same as q.
cos (torch.Tensor): The cosine tensor of shape (seq_length, head_dim).
sin (torch.Tensor): The sine tensor of shape (seq_length, head_dim).
Returns:
Tuple[torch.Tensor, torch.Tensor]: The query and key tensors with the same shape and stride as inputs.
"""
orig_q_dtype, orig_k_dtype = q.dtype, k.dtype
# tranpose to (1, num_heads, seq_length, head_dim) and cast to float32 to match liger_rotary_pos_emb input shape
# also unsqueeze for batch dim
q32 = q.to(torch.float32).unsqueeze(0).transpose(1, 2)
k32 = k.to(torch.float32).unsqueeze(0).transpose(1, 2)
cos32 = cos.to(torch.float32)
sin32 = sin.to(torch.float32)
q_out, k_out = liger_rotary_pos_emb(q32, k32, cos32, sin32)
# transpose back to (seq_length, num_heads, head_dim) and cast back to original dtype
# also squeeze out batch dim
q_out = q_out.transpose(1, 2).squeeze(0).to(orig_q_dtype)
k_out = k_out.transpose(1, 2).squeeze(0).to(orig_k_dtype)
return q_out, k_out
import torch
import torch.nn as nn
from liger_kernel.ops import LigerSoftmaxFunction
class LigerSoftmax(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor):
return LigerSoftmaxFunction.apply(x)
import torch
import torch.nn as nn
from liger_kernel.ops import LigerSparsemaxFunction
class LigerSparsemax(nn.Module):
def __init__(self, dim: int = -1):
super().__init__()
self.dim = dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
return LigerSparsemaxFunction.apply(x, self.dim)
def extra_repr(self) -> str:
return f"dim={self.dim}"
import torch
import torch.nn as nn
from liger_kernel.ops import LigerSiLUMulFunction
class LigerSwiGLUMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
if config.hidden_act not in ["silu", "swish"]:
raise ValueError(f"Activation function {config.hidden_act} not supported.")
def forward(self, x):
return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
class LigerBlockSparseTop2MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.ffn_dim = config.intermediate_size
self.hidden_dim = config.hidden_size
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
if config.hidden_act not in ["silu", "swish"]:
raise ValueError(f"Activation function {config.hidden_act} not supported.")
def forward(self, x):
return self.w2(LigerSiLUMulFunction.apply(self.w1(x), self.w3(x)))
class LigerExperts(nn.Module):
"""
Patch MixtralExperts for transformers v5 or later to use LigerSiLUMulFunction
https://github.com/huggingface/transformers/blob/393b4b3d28e29b4b05b19b4b7f3242a7fc893637/src/transformers/models/mixtral/modeling_mixtral.py#L63
"""
def __init__(self, config):
super().__init__()
if "num_experts" in config:
# qwen3_moe, qwen3_next uses num_experts
self.num_experts = config.num_experts
else:
self.num_experts = config.num_local_experts
if "moe_intermediate_size" in config:
# qwen3_moe, qwen3_next uses moe_intermediate_size
self.intermediate_dim = config.moe_intermediate_size
else:
self.intermediate_dim = config.intermediate_size
self.hidden_dim = config.hidden_size
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
if config.hidden_act not in ["silu", "swish"]:
raise ValueError(f"Activation function {config.hidden_act} not supported.")
def forward(self, hidden_states, top_k_index, top_k_weights):
final_hidden_states = torch.zeros_like(hidden_states)
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == self.num_experts:
continue
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = LigerSiLUMulFunction.apply(gate, up)
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
return final_hidden_states
class LigerPhi3SwiGLUMLP(nn.Module):
"""
Patch Phi3MLP to use LigerSiLUMulFunction
https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/models/phi3/modeling_phi3.py#L241
"""
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
if config.hidden_act not in ["silu", "swish"]:
raise ValueError(f"Activation function {config.hidden_act} not supported.")
def forward(self, x):
up_states = self.gate_up_proj(x)
gate, up_states = up_states.chunk(2, dim=-1)
return self.down_proj(LigerSiLUMulFunction.apply(gate, up_states))
class LigerQwen3MoeSwiGLUMLP(nn.Module):
"""
Patch Qwen3MoeMLP to use LigerSiLUMulFunction.
https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/qwen3_moe/modular_qwen3_moe.py#L57
"""
def __init__(self, config, intermediate_size=None):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
if config.hidden_act not in ["silu", "swish"]:
raise ValueError(f"Activation function {config.hidden_act} not supported.")
def forward(self, x):
return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
class LigerHunyuanV1SwiGLUMLP(nn.Module):
def __init__(self, config, layer_idx=None, is_shared_mlp=False):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.layer_idx = layer_idx
if config.hidden_act not in ["silu", "swish"]:
raise ValueError(f"Activation function {config.hidden_act} not supported.")
def forward(self, x):
return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
from typing import Optional
import torch.nn as nn
from liger_kernel.ops import LigerGELUMulFunction
from liger_kernel.ops import LigerSiLUMulFunction
from liger_kernel.ops import apply_tiled_mlp
class LigerTiledGEGLUMLP(nn.Module):
"""
Memory-efficient GEGLU MLP using tiled computation.
This module combines GEGLU activation with tiled processing to handle
very long sequences efficiently. The forward pass is recomputed during
backward to save memory.
Args:
config: Model configuration with hidden_size and intermediate_size attributes
num_shards: Number of shards to split the sequence. If None, automatically
calculated as ceil(seqlen / hidden_size)
"""
def __init__(self, config, num_shards: Optional[int] = None):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.num_shards = num_shards
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
# Validate activation function
if hasattr(config, "hidden_act") and config.hidden_act not in [
"gelu",
"gelu_new",
"gelu_pytorch_tanh",
]:
raise ValueError(f"LigerTiledGEGLUMLP requires GELU activation, got {config.hidden_act}")
def _mlp_forward(self, module, x):
"""Internal MLP forward function for tiled computation."""
gate = module.gate_proj(x)
up = module.up_proj(x)
return module.down_proj(LigerGELUMulFunction.apply(gate, up))
def forward(self, x):
"""
Forward pass with tiled computation.
Args:
x: Input tensor of shape [batch_size, seq_len, hidden_size]
or [seq_len, hidden_size]
Returns:
Output tensor of the same shape as input
"""
compute_params = [p for p in self.parameters() if p.requires_grad]
return apply_tiled_mlp(
fn=self._mlp_forward,
mlp_module=self,
x=x,
num_shards=self.num_shards,
compute_params=compute_params,
)
class LigerTiledSwiGLUMLP(nn.Module):
"""
Memory-efficient SwiGLU MLP using tiled computation.
This module combines SwiGLU activation with tiled processing to handle
very long sequences efficiently. The forward pass is recomputed during
backward to save memory.
Args:
config: Model configuration with hidden_size and intermediate_size attributes
num_shards: Number of shards to split the sequence. If None, automatically
calculated as ceil(seqlen / hidden_size)
"""
def __init__(self, config, num_shards: Optional[int] = None):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.num_shards = num_shards
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
# Validate activation function
if hasattr(config, "hidden_act") and config.hidden_act not in ["silu", "swish"]:
raise ValueError(f"LigerTiledSwiGLUMLP requires SiLU/Swish activation, got {config.hidden_act}")
def _mlp_forward(self, module, x):
"""Internal MLP forward function for tiled computation."""
gate = module.gate_proj(x)
up = module.up_proj(x)
return module.down_proj(LigerSiLUMulFunction.apply(gate, up))
def forward(self, x):
"""
Forward pass with tiled computation.
Args:
x: Input tensor of shape [batch_size, seq_len, hidden_size]
or [seq_len, hidden_size]
Returns:
Output tensor of the same shape as input
"""
compute_params = [p for p in self.parameters() if p.requires_grad]
return apply_tiled_mlp(
fn=self._mlp_forward,
mlp_module=self,
x=x,
num_shards=self.num_shards,
compute_params=compute_params,
)
try:
from liger_kernel.transformers.trainer.orpo_trainer import LigerORPOTrainer # noqa: F401
except ImportError:
raise ImportError("Please `pip install trl` to use LigerORPOTrainer")
from typing import Dict
from typing import List
from typing import Literal
from typing import Tuple
from typing import Union
import torch
import torch.nn as nn
from torch.distributed.fsdp import FullyShardedDataParallel
from trl.trainer import ORPOTrainer
from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss
from liger_kernel.transformers.fsdp import _FSDPForwardRedirection
class LigerORPOTrainer(ORPOTrainer):
def concatenated_forward(
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""
Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
We do this to avoid doing two forward passes, because it's faster for FSDP.
"""
concatenated_batch = self.concatenated_inputs(
batch,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
padding_value=self.padding_value,
device=self.accelerator.device,
)
model_kwargs = (
{
"decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
}
if self.is_encoder_decoder
else {}
)
if self.aux_loss_enabled:
model_kwargs["output_router_logits"] = True
if self.is_encoder_decoder:
labels = concatenated_batch["concatenated_labels"].clone()
else:
labels = concatenated_batch["concatenated_input_ids"].clone()
attention_mask = concatenated_batch["concatenated_attention_mask"]
labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)
if isinstance(model, FullyShardedDataParallel):
outputs = _FSDPForwardRedirection()(
model,
model._fsdp_wrapped_module.model,
concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
use_cache=False,
**model_kwargs,
)
else:
if isinstance(model, torch.nn.DataParallel):
model = model.module
outputs = model.model(
concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
use_cache=False,
**model_kwargs,
)
orpo_loss_fn = LigerFusedLinearORPOLoss(ignore_index=self.label_pad_token_id, beta=self.beta)
def orpo_partial(lm_head, last_hidden_state, concatenated_labels, nll_target):
return orpo_loss_fn(
lm_head.weight, last_hidden_state, concatenated_labels, lm_head.bias, nll_target=nll_target
)
orpo_loss, aux_outputs = _FSDPForwardRedirection()(
model,
orpo_partial,
model.lm_head,
outputs.last_hidden_state[:, :-1] if not self.is_encoder_decoder else outputs.last_hidden_state,
concatenated_batch["concatenated_labels"][:, 1:]
if not self.is_encoder_decoder
else concatenated_batch["concatenated_labels"],
labels[:, 1:] if not self.is_encoder_decoder else labels,
)
# if aux_loss_enabled, add the aux_loss to the orpo_loss
if self.aux_loss_enabled:
orpo_loss += self.aux_loss_coef * outputs.aux_loss
return orpo_loss, aux_outputs
def get_batch_loss_metrics(
self,
model,
batch: Dict[str, Union[List, torch.LongTensor]],
train_eval: Literal["train", "eval"] = "train",
):
"""Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
metrics = {}
loss, aux_outputs = self.concatenated_forward(model, batch)
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
) = aux_outputs[:5]
# return loss, metrics
chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = aux_outputs[5:]
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean()
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean()
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean()
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean()
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean()
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean()
metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean()
metrics[f"{prefix}log_odds_ratio"] = log_odds_ratio
metrics[f"{prefix}log_odds_chosen"] = log_odds_chosen
for k, v in metrics.items():
metrics[k] = v.item()
return loss, metrics
# To not break HF Trainer integration
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa: F401
import torch.nn as nn
from liger_kernel.ops import LigerTVDLossFunction
class LigerTVDLoss(nn.Module):
def __init__(self, reduction="batchmean", ignore_index: int = -100):
super(LigerTVDLoss, self).__init__()
self.reduction = reduction
self.ignore_index = ignore_index
def forward(self, p, q, shift_labels=None):
return LigerTVDLossFunction.apply(p, q, shift_labels, self.reduction, self.ignore_index)
from liger_kernel.triton.monkey_patch import apply_liger_triton_cache_manager # noqa: F401
import os
import random
from triton.runtime.cache import FileCacheManager
class LigerTritonFileCacheManager(FileCacheManager):
def put(self, data, filename, binary=True) -> str:
if not self.cache_dir:
raise RuntimeError("Could not create or locate cache dir")
binary = isinstance(data, bytes)
if not binary:
data = str(data)
assert self.lock_path is not None
filepath = self._make_path(filename)
# Random ID to avoid any collisions
rnd_id = random.randint(0, 1000000)
# we use the PID incase a bunch of these around so we can see what PID made it
pid = os.getpid()
# use temp dir to be robust against program interruptions
temp_dir = os.path.join(self.cache_dir, f"tmp.pid_{pid}_{rnd_id}")
os.makedirs(temp_dir, exist_ok=True)
temp_path = os.path.join(temp_dir, filename)
mode = "wb" if binary else "w"
with open(temp_path, mode) as f:
f.write(data)
# Replace is guaranteed to be atomic on POSIX systems if it succeeds
# so filepath cannot see a partial write
os.replace(temp_path, filepath)
os.removedirs(temp_dir)
return filepath
def apply_liger_triton_cache_manager():
"""
Experimental feature to get around transient FileNotFoundError in triton compilation.
For more details please see https://github.com/triton-lang/triton/pull/4295
"""
os.environ["TRITON_CACHE_MANAGER"] = "liger_kernel.triton.monkey_patch:LigerTritonFileCacheManager"
try:
import peft # noqa: F401
PEFT_AVAILABLE = True
except ImportError:
PEFT_AVAILABLE = False
import torch
def is_peft_available():
return PEFT_AVAILABLE
def infer_comm_backend():
"""
Get communication backend name based on the environment.
"""
if torch.distributed.is_nccl_available():
# Works for Nvidia
# TODO: nccl may not work for AMD decices that may require use of rccl.
return "nccl"
elif is_npu_available():
# Use Ascend NPU if available (torch.npu)
# Ascend is not standard torch backend and requires extension.
# Assume that it is installed if NPUs are being used in
# multi device environment.
return "ascend"
# XPU (Intel) if available
elif torch.distributed.distributed_c10d.is_xccl_available():
return "xccl"
elif torch.distributed.is_mpi_available():
# CPU backend, first option
return "mpi"
elif torch.distributed.is_gloo_available():
# CPU backend, backup option
return "gloo"
else:
raise RuntimeError("There is no distributed backend available.")
def infer_device():
"""
Get current device name based on available devices
"""
if torch.cuda.is_available(): # Works for both Nvidia and AMD
return "cuda"
# Use Ascend NPU if available (torch.npu)
elif is_npu_available():
return "npu"
# XPU (Intel) if available
elif torch.xpu.is_available():
return "xpu"
else:
return "cpu"
def is_npu_available() -> bool:
"""Detect Ascend NPU availability."""
try:
from transformers.utils import is_torch_npu_available
return is_torch_npu_available()
except Exception:
return False
def transformers_version_dispatch(
required_version: str,
before_fn,
after_fn,
before_args: tuple = (),
after_args: tuple = (),
before_kwargs: dict = None,
after_kwargs: dict = None,
):
"""
Dispatches to different functions based on package version comparison.
Args:
required_version: Version to compare against (e.g. "4.48.0")
before_fn: Function to call if package_version < required_version
after_fn: Function to call if package_version >= required_version
before_args: Positional arguments for before_fn
after_args: Positional arguments for after_fn
before_kwargs: Keyword arguments for before_fn
after_kwargs: Keyword arguments for after_fn
Returns:
Result from either before_fn or after_fn
Example:
>>> rotary_emb = transformers_version_dispatch(
... "4.48.0",
... LlamaRotaryEmbedding,
... LlamaRotaryEmbedding,
... before_args=(head_dim,),
... after_args=(LlamaConfig(head_dim=head_dim),),
... before_kwargs={'device': device},
... after_kwargs={'device': device}
... )
"""
from packaging import version
from transformers import __version__ as transformers_version
before_kwargs = before_kwargs or {}
after_kwargs = after_kwargs or {}
if version.parse(transformers_version) < version.parse(required_version):
return before_fn(*before_args, **before_kwargs)
else:
return after_fn(*after_args, **after_kwargs)
def get_total_gpu_memory() -> int:
"""Returns total GPU memory in GBs."""
device = infer_device()
if device == "cuda":
return torch.cuda.get_device_properties(0).total_memory // (1024**3)
elif device == "xpu":
return torch.xpu.get_device_properties(0).total_memory // (1024**3)
elif device == "npu":
return torch.npu.get_device_properties(0).total_memory // (1024**3)
else:
raise RuntimeError(f"Unsupported device: {device}")
import pytest
import torch
import torch.nn.functional as F
from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityFunction
from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityLoss
from liger_kernel.chunked_loss.functional import liger_fused_linear_cosine
from liger_kernel.utils import infer_device
from test.utils import HFDistillationLoss
from test.utils import assert_verbose_allclose
from test.utils import set_seed
device = infer_device()
set_seed()
class HFCosineLoss(HFDistillationLoss):
"""
implementation of a distilltion loss using cosine similarity
"""
def __init__(
self,
weight_hard_loss: float = 0.5,
weight_soft_loss: float = 0.5,
ignore_index: int = -100,
temperature: float = 1.0,
):
super().__init__(
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
ignore_index=ignore_index,
temperature=temperature,
)
def distillation_loss(self, student_logits, teacher_logits, target=None, ignore_index=None, beta=1.0, **kwargs):
# Compute normalized logits
print(f"student_logits.shape: {student_logits.shape}")
student_norm = F.normalize(student_logits, p=2, dim=-1)
teacher_norm = F.normalize(teacher_logits, p=2, dim=-1)
# cosine_sim = (student_norm * teacher_norm).sum(dim=1).mean()
# loss = beta * (1 - cosine_sim)
cosine_sim = F.cosine_similarity(student_norm, teacher_norm, dim=-1)
loss = beta * (1 - cosine_sim)
return loss.mean()
class TorchCosineLoss(torch.nn.Module):
"""
Reference implementation for Cosine Similarity Loss using standard torch operations.
Computes the loss as 1 - cosine_similarity averaged over all tokens.
"""
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
bias: bool,
device: torch.device,
weight_hard_loss: float = 0.5,
weight_soft_loss: float = 0.5,
beta: float = 1.0,
ignore_index: int = -100,
temperature: float = 1.0,
):
super().__init__()
# Note: student inputs are expected to have hidden size H//2 while teacher inputs have H.
self.student_lin = torch.nn.Linear(in_features=H // 2, out_features=V, bias=bias, dtype=dtype, device=device)
self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype, device=device)
self.beta = beta
self.cosine = HFCosineLoss(
ignore_index=ignore_index,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
temperature=temperature,
).get_batch_loss_metrics
def forward(self, student_input, teacher_input, target):
loss = self.cosine(
student_input,
self.student_lin.weight,
teacher_input,
self.teacher_lin.weight,
target,
self.student_lin.bias,
self.teacher_lin.bias,
beta=self.beta,
)
return loss
class LigerCosineLoss(torch.nn.Module):
"""
Liger implementation that uses fused cosine similarity loss.
"""
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
bias: bool,
device: torch.device,
weight_hard_loss: float = 0.5,
weight_soft_loss: float = 0.5,
beta: float = 0.5,
ignore_index: int = -100,
temperature: float = 1.0,
compiled: bool = True,
chunk_size: int = 1024,
):
super().__init__()
self.chunked_cosine = LigerFusedLinearCosineSimilarityLoss(
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
beta=beta,
ignore_index=ignore_index,
temperature=temperature,
compiled=compiled,
chunk_size=chunk_size,
)
self.student_lin = torch.nn.Linear(in_features=H // 2, out_features=V, bias=bias, dtype=dtype, device=device)
self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype, device=device)
def forward(self, student_input, teacher_input, target):
return self.chunked_cosine(
student_input,
self.student_lin.weight,
teacher_input,
self.teacher_lin.weight,
target,
self.student_lin.bias,
self.teacher_lin.bias,
)
###############################################################################
# Test correctness of the module implementations
###############################################################################
@pytest.mark.parametrize(
"B, T, H, V",
[
(8, 128, 1024, 4096),
(3, 47, 32, 128), # H must be even
],
)
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
(1.0, torch.bfloat16, 5e-2, 5e-1),
(1.0, torch.float32, 1e-5, 5e-4),
],
)
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize(
"temperature, weight_hard_loss, weight_soft_loss, beta",
[
(1.0, 0.5, 0.5, 0.5),
(2.0, 0.0, 1.0, 0.8),
(0.5, 1.0, 0.0, 0.2),
],
)
def test_correctness(
B, T, H, V, scalar, dtype, atol, rtol, bias, temperature, weight_hard_loss, weight_soft_loss, beta
):
torch_cosine = TorchCosineLoss(
H=H,
V=V,
dtype=dtype,
bias=bias,
device=device,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
temperature=temperature,
beta=beta,
)
liger_cosine = LigerCosineLoss(
H=H,
V=V,
dtype=dtype,
bias=bias,
device=device,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
temperature=temperature,
beta=beta,
)
# Ensure both implementations start with the same weights and biases.
torch_cosine.student_lin.weight.data = liger_cosine.student_lin.weight.data = torch.rand(
V, H // 2, device=device, dtype=dtype
)
torch_cosine.teacher_lin.weight.data = liger_cosine.teacher_lin.weight.data = torch.rand(
V, H, device=device, dtype=dtype
)
if bias:
torch_cosine.student_lin.bias.data = liger_cosine.student_lin.bias.data = torch.rand(
V, device=device, dtype=dtype
)
torch_cosine.teacher_lin.bias.data = liger_cosine.teacher_lin.bias.data = torch.rand(
V, device=device, dtype=dtype
)
_tensor = torch.rand(B * T, H // 2, device=device, dtype=dtype) * scalar
student_input1 = _tensor.clone().detach().requires_grad_(True)
student_input2 = _tensor.clone().detach().requires_grad_(True)
teacher_input = torch.rand(B * T, H, device=device, dtype=dtype) * scalar
# Dummy target (not used in cosine computation)
target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)
loss1 = torch_cosine(student_input1, teacher_input, target)
loss2 = liger_cosine(student_input2, teacher_input, target)
assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol)
loss1.backward()
print("loss1 shape : {loss1.shape}")
loss2.backward()
assert_verbose_allclose(student_input1.grad, student_input2.grad, atol=atol, rtol=rtol)
assert_verbose_allclose(
torch_cosine.student_lin.weight.grad, liger_cosine.student_lin.weight.grad, atol=atol, rtol=rtol
)
if bias:
assert_verbose_allclose(
torch_cosine.student_lin.bias.grad, liger_cosine.student_lin.bias.grad, atol=atol, rtol=rtol
)
###############################################################################
# Test correctness of the functional interface
###############################################################################
@pytest.mark.parametrize(
"B, T, H, V",
[
(2, 2, 8, 8),
(9, 7, 40, 40), # H must be even
],
)
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
(1.0, torch.bfloat16, 5e-2, 5e-2),
(1.0, torch.float32, 1e-4, 5e-3),
],
)
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize(
"temperature, weight_hard_loss, weight_soft_loss, beta, ignore_index",
[
(1.0, 0.5, 0.5, 0.5, -100),
(2.0, 0.1, 0.9, 0.5, 42),
],
)
def test_correctness_functional(
B, T, H, V, scalar, dtype, bias, weight_hard_loss, weight_soft_loss, beta, ignore_index, temperature, atol, rtol
):
# Prepare weights and biases for functional testing.
student_weight1 = torch.rand(V, H // 2, device=device, dtype=dtype).detach().clone().requires_grad_(True)
student_weight2 = student_weight1.clone().detach().requires_grad_(True)
teacher_weight = torch.rand(V, H, device=device, dtype=dtype)
if bias:
student_bias1 = torch.rand(V, device=device, dtype=dtype).detach().clone().requires_grad_(True)
student_bias2 = student_bias1.clone().detach().requires_grad_(True)
teacher_bias = torch.rand(V, device=device, dtype=dtype)
else:
student_bias1 = student_bias2 = teacher_bias = None
_tensor = torch.rand(B * T, H // 2, device=device, dtype=dtype) * scalar
student_input1 = _tensor.clone().detach().requires_grad_(True)
student_input2 = _tensor.clone().detach().requires_grad_(True)
teacher_input = torch.rand(B * T, H, device=device, dtype=dtype) * scalar
target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)
# Functional call using the fused cosine similarity function
output1 = liger_fused_linear_cosine(
student_input1,
student_weight1,
teacher_input,
teacher_weight,
target,
student_bias1,
teacher_bias,
weight_hard_loss,
weight_soft_loss,
beta,
ignore_index,
temperature,
True,
1024,
)
output2 = LigerFusedLinearCosineSimilarityFunction.apply(
student_input2,
student_weight2,
teacher_input,
teacher_weight,
target,
student_bias2,
teacher_bias,
weight_hard_loss,
weight_soft_loss,
beta,
ignore_index,
temperature,
True,
1024,
)
assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol)
output1.backward()
output2.backward()
assert_verbose_allclose(student_input1.grad, student_input2.grad, atol=atol, rtol=rtol)
assert_verbose_allclose(student_weight1.grad, student_weight2.grad, atol=atol, rtol=rtol)
if bias:
assert_verbose_allclose(student_bias1.grad, student_bias2.grad, atol=atol, rtol=rtol)
from typing import Tuple
import pytest
import torch
import torch.nn.functional as F
from liger_kernel.chunked_loss import LigerFusedLinearCPOLoss
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
from liger_kernel.chunked_loss.functional import liger_fused_linear_cpo
from liger_kernel.utils import infer_device
from test.utils import HFAlignmentLoss
from test.utils import assert_verbose_allclose
from test.utils import set_seed
device = infer_device()
# set random seed globally
set_seed()
class HFCPOLoss(HFAlignmentLoss):
"""
HF's implementation of CPO loss in TRL. https://github.com/huggingface/trl/blob/main/trl/trainer/cpo_trainer.py
"""
def __init__(
self,
alpha: float = 1.0,
beta: float = 0.1,
ignore_index: int = -100,
label_smoothing: float = 0.0,
simpo_gamma: float = 0.5,
loss_type: str = "sigmoid",
):
super().__init__(alpha=alpha, beta=beta, ignore_index=ignore_index)
# Sigmoid defaults to the CPO loss defined in the paper listed above.
self.loss_type = loss_type
self.label_smoothing = label_smoothing
self.simpo_gamma = simpo_gamma
def alignment_loss(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Compute the CPO loss for a batch of policy and reference model log probabilities.
Args:
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
Returns:
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
The losses tensor contains the CPO loss for each example in the batch.
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
"""
logits = policy_chosen_logps - policy_rejected_logps
# The beta is a temperature parameter for the CPO loss, typically something in the range of 0.1 to 0.5.
# We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
# calculates a conservative CPO loss.
if self.loss_type == "sigmoid":
# This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
losses = (
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
)
elif self.loss_type == "simpo":
logits = logits - (self.simpo_gamma / self.beta)
losses = (
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
)
else:
raise ValueError(f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid']")
chosen_rewards = self.beta * policy_chosen_logps
rejected_rewards = self.beta * policy_rejected_logps
return losses, chosen_rewards, rejected_rewards
class TorchLMHeadCPO(torch.nn.Module):
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
bias: bool = False,
ignore_index: int = -100,
beta: float = 0.1,
alpha: float = 1.0,
label_smoothing: float = 0.0,
loss_type: str = "sigmoid",
simpo_gamma: float = 0.5,
):
super().__init__()
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
self.cpo_loss = HFCPOLoss(
ignore_index=ignore_index,
beta=beta,
loss_type=loss_type,
label_smoothing=label_smoothing,
simpo_gamma=simpo_gamma,
).get_batch_loss_metrics
self.average_log_prob = loss_type == "simpo"
def forward(self, x, y):
return self.cpo_loss(self.lin.weight, x, y, self.lin.bias, average_log_prob=self.average_log_prob)
class LigerLMHeadCPO(torch.nn.Module):
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
bias: bool = False,
ignore_index: int = -100,
beta: float = 0.1,
alpha: float = 1.0,
label_smoothing: float = 0.0,
):
super().__init__()
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
self.cpo_loss = LigerFusedLinearCPOLoss(
ignore_index=ignore_index,
beta=beta,
alpha=alpha,
label_smoothing=label_smoothing,
)
def forward(self, x, y):
return self.cpo_loss(self.lin.weight, x, y, self.lin.bias)
@pytest.mark.parametrize(
"B, T, H, V",
[
(8, 128, 1024, 4096),
(3, 47, 31, 123), # random shape
],
)
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
(1.0, torch.bfloat16, 5e-2, 5e-2),
(1.0, torch.float32, 1e-5, 5e-4),
],
)
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)])
@pytest.mark.parametrize("label_smoothing", [0.0, 0.1])
def test_correctness(
B,
T,
H,
V,
scalar,
dtype,
atol,
rtol,
bias,
ignore_index,
beta,
alpha,
label_smoothing,
):
B = 2 * B # cpo loss requires B to be even
torch_lm_head_cpo = TorchLMHeadCPO(
H=H,
V=V,
dtype=dtype,
bias=bias,
ignore_index=ignore_index,
beta=beta,
label_smoothing=label_smoothing,
)
liger_lm_head_cpo = LigerLMHeadCPO(
H=H,
V=V,
dtype=dtype,
bias=bias,
ignore_index=ignore_index,
beta=beta,
label_smoothing=label_smoothing,
)
torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn(
V, H, device=device, dtype=dtype
)
if bias:
torch_lm_head_cpo.lin.bias.data = liger_lm_head_cpo.lin.bias.data = torch.randn(V, device=device, dtype=dtype)
_input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar
input1 = _input.detach().clone().requires_grad_(True)
input2 = _input.detach().clone().requires_grad_(True)
target = torch.randint(
0,
V,
(
B,
T,
),
device=device,
dtype=torch.long,
)
# Assign some random number of elements as ignore_index
num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item()
indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign]
target.view(-1)[indices_to_assign] = ignore_index
loss1, aggregated_aux_outputs1 = torch_lm_head_cpo(input1, target)
loss2, aggregated_aux_outputs2 = liger_lm_head_cpo(input2, target)
assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol)
assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2)
for i in range(len(aggregated_aux_outputs1)):
assert_verbose_allclose(
aggregated_aux_outputs1[i],
aggregated_aux_outputs2[i],
atol=atol,
rtol=rtol,
)
loss1.backward()
loss2.backward()
assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol)
assert_verbose_allclose(
torch_lm_head_cpo.lin.weight.grad,
liger_lm_head_cpo.lin.weight.grad,
atol=atol,
rtol=rtol,
)
if bias:
assert_verbose_allclose(
torch_lm_head_cpo.lin.bias.grad,
liger_lm_head_cpo.lin.bias.grad,
atol=atol,
rtol=rtol,
)
@pytest.mark.parametrize(
"B, T, H, V",
[
(2, 2, 8, 8),
(3, 47, 31, 123), # random shape
],
)
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
(1.0, torch.bfloat16, 5e-2, 5e-1),
(1.0, torch.float32, 1e-5, 5e-4),
],
)
@pytest.mark.parametrize("bias", [True, False])
def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias):
B = 2 * B
_input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar
input1 = _input.detach().clone().requires_grad_(True)
input2 = _input.detach().clone().requires_grad_(True)
target = torch.randint(
0,
V,
(
B,
T,
),
device=device,
dtype=torch.long,
)
_weight = torch.randn(V, H, device=device, dtype=dtype)
weight1 = _weight.detach().clone().requires_grad_(True)
weight2 = _weight.detach().clone().requires_grad_(True)
_bias = torch.randn(V, device=device, dtype=dtype) if bias else None
bias1 = _bias.detach().clone().requires_grad_(True) if bias else None
bias2 = _bias.detach().clone().requires_grad_(True) if bias else None
loss1, aggregated_aux_outputs1 = LigerFusedLinearCPOFunction.apply(input1, weight1, target, bias1)
loss2, aggregated_aux_outputs2 = liger_fused_linear_cpo(input2, weight2, target, bias2)
assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol)
loss1.backward()
loss2.backward()
assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol)
assert_verbose_allclose(weight1.grad, weight2.grad, atol=atol, rtol=rtol)
if bias:
assert_verbose_allclose(bias1.grad, bias2.grad, atol=atol, rtol=rtol)
import pytest
import torch
import torch.nn.functional as F
from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
from liger_kernel.chunked_loss.functional import liger_fused_linear_dpo
from liger_kernel.utils import infer_device
from test.utils import HFAlignmentLoss
from test.utils import assert_verbose_allclose
from test.utils import set_seed
device = infer_device()
# set random seed globally
set_seed()
class HFDPOLoss(HFAlignmentLoss):
"""
Implementation of the Direct Preference Optimization (DPO) loss,
adapted from Hugging Face's implementation.
Reference: https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py
"""
def __init__(
self,
ignore_index: int = -100,
beta: float = 0.1,
use_ref_model: bool = True,
compute_nll_loss: bool = False,
):
super().__init__(
beta=beta,
ignore_index=ignore_index,
use_ref_model=use_ref_model,
compute_nll_loss=compute_nll_loss,
)
def alignment_loss(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
ref_chosen_logps: torch.FloatTensor,
ref_rejected_logps: torch.FloatTensor,
):
"""Compute DPO loss for a batch of policy log probabilities.
Args:
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
Returns:
The losses tensor contains the DPO loss for each example in the batch.
"""
# Derived from https://huggingface.co/papers/2305.18290
chosen_logratios = policy_chosen_logps - ref_chosen_logps
rejected_logratios = policy_rejected_logps - ref_rejected_logps
chosen_rewards = self.beta * chosen_logratios
rejected_rewards = self.beta * rejected_logratios
logits_diff = self.beta * (chosen_logratios - rejected_logratios)
losses = -F.logsigmoid(logits_diff)
return losses, chosen_rewards, rejected_rewards
class HFAPOZeroLoss(HFAlignmentLoss):
"""
Implementation of the APO-zero loss.
Reference: https://huggingface.co/papers/2408.06266
"""
def __init__(
self,
ignore_index: int = -100,
beta: float = 0.1,
use_ref_model: bool = True,
compute_nll_loss: bool = False,
):
super().__init__(
beta=beta,
ignore_index=ignore_index,
use_ref_model=use_ref_model,
compute_nll_loss=compute_nll_loss,
)
def alignment_loss(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
ref_chosen_logps: torch.FloatTensor,
ref_rejected_logps: torch.FloatTensor,
):
"""Compute APO-zero loss for a batch of policy log probabilities.
Args:
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
Returns:
The losses tensor contains the APO-zero loss for each example in the batch.
"""
# Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266)
chosen_logratios = policy_chosen_logps - ref_chosen_logps
rejected_logratios = policy_rejected_logps - ref_rejected_logps
chosen_rewards = self.beta * chosen_logratios
rejected_rewards = self.beta * rejected_logratios
# Use this loss when you believe the chosen outputs are better than your model's default output
losses_chosen = 1 - F.sigmoid(self.beta * chosen_logratios) # Increase chosen likelihood
losses_rejected = F.sigmoid(self.beta * rejected_logratios) # Decrease rejected likelihood
losses = losses_chosen + losses_rejected
return losses, chosen_rewards, rejected_rewards
class HFAPODownLoss(HFAlignmentLoss):
"""
Implementation of the APO-down loss.
Reference: https://huggingface.co/papers/2408.06266
"""
def __init__(
self,
ignore_index: int = -100,
beta: float = 0.1,
use_ref_model: bool = True,
compute_nll_loss: bool = False,
):
super().__init__(
beta=beta,
ignore_index=ignore_index,
use_ref_model=use_ref_model,
compute_nll_loss=compute_nll_loss,
)
def alignment_loss(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
ref_chosen_logps: torch.FloatTensor,
ref_rejected_logps: torch.FloatTensor,
):
"""Compute APO-down loss for a batch of policy log probabilities.
Args:
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
Returns:
The losses tensor contains the APO-down loss for each example in the batch.
"""
# Eqn (8) of the APO paper (https://huggingface.co/papers/2408.06266)
chosen_logratios = policy_chosen_logps - ref_chosen_logps
rejected_logratios = policy_rejected_logps - ref_rejected_logps
chosen_rewards = self.beta * chosen_logratios
rejected_rewards = self.beta * rejected_logratios
# Use this loss when you believe the chosen outputs are worse than your model's default output.
# Decrease chosen likelihood and decrease rejected likelihood more
losses_chosen = F.sigmoid(self.beta * chosen_logratios)
losses_rejected = 1 - F.sigmoid(self.beta * (chosen_logratios - rejected_logratios))
losses = losses_chosen + losses_rejected
return losses, chosen_rewards, rejected_rewards
class HFSPPPOHARDLoss(HFAlignmentLoss):
def __init__(
self,
ignore_index: int = -100,
beta: float = 0.1,
use_ref_model: bool = True,
compute_nll_loss: bool = False,
):
super().__init__(
beta=beta,
ignore_index=ignore_index,
use_ref_model=use_ref_model,
compute_nll_loss=compute_nll_loss,
)
def alignment_loss(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
ref_chosen_logps: torch.FloatTensor,
ref_rejected_logps: torch.FloatTensor,
):
chosen_logratios = policy_chosen_logps - ref_chosen_logps
rejected_logratios = policy_rejected_logps - ref_rejected_logps
chosen_rewards = self.beta * chosen_logratios
rejected_rewards = self.beta * rejected_logratios
a = policy_chosen_logps - ref_chosen_logps
b = policy_rejected_logps - ref_rejected_logps
losses = (a - 0.5 / self.beta) ** 2 + (b + 0.5 / self.beta) ** 2
return losses, chosen_rewards, rejected_rewards
class HFNCAPAIRLoss(HFAlignmentLoss):
def __init__(
self,
ignore_index: int = -100,
beta: float = 0.1,
use_ref_model: bool = True,
compute_nll_loss: bool = False,
):
super().__init__(
beta=beta,
ignore_index=ignore_index,
use_ref_model=use_ref_model,
compute_nll_loss=compute_nll_loss,
)
def alignment_loss(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
ref_chosen_logps: torch.FloatTensor,
ref_rejected_logps: torch.FloatTensor,
):
chosen_logratios = policy_chosen_logps - ref_chosen_logps
rejected_logratios = policy_rejected_logps - ref_rejected_logps
chosen_rewards = self.beta * chosen_logratios
rejected_rewards = self.beta * rejected_logratios
losses = (
-F.logsigmoid(chosen_rewards) - 0.5 * F.logsigmoid(-chosen_rewards) - 0.5 * F.logsigmoid(-rejected_rewards)
)
return losses, chosen_rewards, rejected_rewards
class TorchLMHeadDPO(torch.nn.Module):
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
bias: bool = False,
ref_bias: bool = False,
compute_nll_loss: bool = False,
ignore_index: int = -100,
beta: float = 0.1,
):
super().__init__()
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
self.ref_lin = torch.nn.Linear(in_features=H, out_features=V, bias=ref_bias, dtype=dtype)
self.dpo_loss = HFDPOLoss(
ignore_index=ignore_index,
beta=beta,
use_ref_model=True,
compute_nll_loss=compute_nll_loss,
).get_batch_loss_metrics
def forward(self, x, ref_x, y):
return self.dpo_loss(
self.lin.weight,
x,
y,
self.lin.bias,
ref_x,
self.ref_lin.weight,
self.ref_lin.bias,
average_log_prob=True,
)
class TorchLMHeadAPOZero(torch.nn.Module):
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
bias: bool = False,
ref_bias: bool = False,
compute_nll_loss: bool = False,
ignore_index: int = -100,
beta: float = 0.1,
):
super().__init__()
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
self.ref_lin = torch.nn.Linear(in_features=H, out_features=V, bias=ref_bias, dtype=dtype)
self.apo_loss = HFAPOZeroLoss(
ignore_index=ignore_index,
beta=beta,
use_ref_model=True,
compute_nll_loss=compute_nll_loss,
).get_batch_loss_metrics
def forward(self, x, ref_x, y):
return self.apo_loss(
self.lin.weight,
x,
y,
self.lin.bias,
ref_x,
self.ref_lin.weight,
self.ref_lin.bias,
average_log_prob=True,
)
class TorchLMHeadAPODown(torch.nn.Module):
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
bias: bool = False,
ref_bias: bool = False,
compute_nll_loss: bool = False,
ignore_index: int = -100,
beta: float = 0.1,
):
super().__init__()
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
self.ref_lin = torch.nn.Linear(in_features=H, out_features=V, bias=ref_bias, dtype=dtype)
self.apo_loss = HFAPODownLoss(
ignore_index=ignore_index,
beta=beta,
use_ref_model=True,
compute_nll_loss=compute_nll_loss,
).get_batch_loss_metrics
def forward(self, x, ref_x, y):
return self.apo_loss(
self.lin.weight,
x,
y,
self.lin.bias,
ref_x,
self.ref_lin.weight,
self.ref_lin.bias,
average_log_prob=True,
)
class TorchLMHeadSPPOHARD(torch.nn.Module):
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
bias: bool = False,
ref_bias: bool = False,
compute_nll_loss: bool = False,
ignore_index: int = -100,
beta: float = 0.1,
):
super().__init__()
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
self.ref_lin = torch.nn.Linear(in_features=H, out_features=V, bias=ref_bias, dtype=dtype)
self.sppo_hard = HFSPPPOHARDLoss(
ignore_index=ignore_index,
beta=beta,
use_ref_model=True,
compute_nll_loss=compute_nll_loss,
).get_batch_loss_metrics
def forward(self, x, ref_x, y):
return self.sppo_hard(
self.lin.weight,
x,
y,
self.lin.bias,
ref_x,
self.ref_lin.weight,
self.ref_lin.bias,
average_log_prob=True,
)
class TorchLMHeadNCAPAIR(torch.nn.Module):
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
bias: bool = False,
ref_bias: bool = False,
compute_nll_loss: bool = False,
ignore_index: int = -100,
beta: float = 0.1,
):
super().__init__()
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
self.ref_lin = torch.nn.Linear(in_features=H, out_features=V, bias=ref_bias, dtype=dtype)
self.nca_pair = HFNCAPAIRLoss(
ignore_index=ignore_index,
beta=beta,
use_ref_model=True,
compute_nll_loss=compute_nll_loss,
).get_batch_loss_metrics
def forward(self, x, ref_x, y):
return self.nca_pair(
self.lin.weight,
x,
y,
self.lin.bias,
ref_x,
self.ref_lin.weight,
self.ref_lin.bias,
average_log_prob=True,
)
class LigerLMHeadDPO(torch.nn.Module):
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
bias: bool = False,
ref_bias: bool = False,
compute_nll_loss: bool = False,
ignore_index: int = -100,
beta: float = 0.1,
loss_type: str = "sigmoid",
):
super().__init__()
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
self.ref_lin = torch.nn.Linear(in_features=H, out_features=V, bias=ref_bias, dtype=dtype)
self.dpo_loss = LigerFusedLinearDPOLoss(
ignore_index=ignore_index,
beta=beta,
use_ref_model=True,
compute_nll_loss=compute_nll_loss,
average_log_prob=True,
loss_type=loss_type,
)
def forward(self, x, ref_x, y):
return self.dpo_loss(
self.lin.weight,
x,
y,
self.lin.bias,
ref_x,
self.ref_lin.weight,
self.ref_lin.bias,
)
@pytest.mark.parametrize(
"B, T, H, V",
[
(8, 128, 1024, 4096),
(3, 47, 31, 123), # random shape
],
)
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
(1.0, torch.bfloat16, 5e-2, 5e-1),
(1.0, torch.float32, 1e-5, 5e-4),
],
)
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("ref_bias", [True, False])
@pytest.mark.parametrize("compute_nll_loss", [True, False])
@pytest.mark.parametrize("ignore_index, beta", [(-100, 0.1), (42, 0.2)])
def test_correctness(
B,
T,
H,
V,
scalar,
dtype,
atol,
rtol,
bias,
ref_bias,
compute_nll_loss,
ignore_index,
beta,
):
B = 2 * B # dpo loss requires B to be even
torch_lm_head_dpo = TorchLMHeadDPO(
H=H,
V=V,
dtype=dtype,
bias=bias,
ref_bias=ref_bias,
compute_nll_loss=compute_nll_loss,
ignore_index=ignore_index,
beta=beta,
)
liger_lm_head_dpo = LigerLMHeadDPO(
H=H,
V=V,
dtype=dtype,
bias=bias,
ref_bias=ref_bias,
compute_nll_loss=compute_nll_loss,
ignore_index=ignore_index,
beta=beta,
)
torch_lm_head_dpo.lin.weight.data = liger_lm_head_dpo.lin.weight.data = torch.randn(
V, H, device=device, dtype=dtype
)
torch_lm_head_dpo.ref_lin.weight.data = liger_lm_head_dpo.ref_lin.weight.data = torch.randn(
V, H, device=device, dtype=dtype
)
if bias:
torch_lm_head_dpo.lin.bias.data = liger_lm_head_dpo.lin.bias.data = torch.randn(V, device=device, dtype=dtype)
if ref_bias:
torch_lm_head_dpo.ref_lin.bias.data = liger_lm_head_dpo.ref_lin.bias.data = torch.randn(
V, device=device, dtype=dtype
)
_input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar
input1 = _input.detach().clone().requires_grad_(True)
input2 = _input.detach().clone().requires_grad_(True)
ref_input = torch.randn(B, T, H, device=device, dtype=dtype, requires_grad=False) * scalar
target = torch.randint(
0,
V,
(
B,
T,
),
device=device,
dtype=torch.long,
)
# Assign some random number of elements as ignore_index
num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item()
indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign]
target.view(-1)[indices_to_assign] = ignore_index
loss1, aggregated_aux_outputs1 = torch_lm_head_dpo(input1, ref_input, target)
loss2, aggregated_aux_outputs2 = liger_lm_head_dpo(input2, ref_input, target)
assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol)
assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2)
for i in range(len(aggregated_aux_outputs1)):
if i > 4 and dtype == torch.bfloat16:
# numerical instability in bf16 for chosen_rewards and rejected_rewards
# temporary fix. TODO: investigate how to reduce numercial instabiltiy issue
assert_verbose_allclose(
aggregated_aux_outputs1[i],
aggregated_aux_outputs2[i],
atol=5e-1,
rtol=rtol,
)
continue
assert_verbose_allclose(
aggregated_aux_outputs1[i],
aggregated_aux_outputs2[i],
atol=atol,
rtol=rtol,
)
loss1.backward()
loss2.backward()
assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol)
assert_verbose_allclose(
torch_lm_head_dpo.lin.weight.grad,
liger_lm_head_dpo.lin.weight.grad,
atol=atol,
rtol=rtol,
)
if bias:
assert_verbose_allclose(
torch_lm_head_dpo.lin.bias.grad,
liger_lm_head_dpo.lin.bias.grad,
atol=atol,
rtol=rtol,
)
@pytest.mark.parametrize(
"B, T, H, V",
[
(2, 2, 8, 8),
(3, 47, 31, 123), # random shape
],
)
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
(1.0, torch.bfloat16, 5e-2, 5e-1),
(1.0, torch.float32, 1e-5, 5e-4),
],
)
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("ref_bias", [True, False])
@pytest.mark.parametrize("compute_nll_loss", [True, False])
def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref_bias, compute_nll_loss):
B = 2 * B
_input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar
input1 = _input.detach().clone().requires_grad_(True)
input2 = _input.detach().clone().requires_grad_(True)
ref_input = torch.randn(B, T, H, device=device, dtype=dtype, requires_grad=False) * scalar
target = torch.randint(
0,
V,
(
B,
T,
),
device=device,
dtype=torch.long,
)
_weight = torch.randn(V, H, device=device, dtype=dtype)
weight1 = _weight.detach().clone().requires_grad_(True)
weight2 = _weight.detach().clone().requires_grad_(True)
_ref_weight = torch.randn(V, H, device=device, dtype=dtype)
ref_weight1 = _ref_weight.detach().clone().requires_grad_(True)
ref_weight2 = _ref_weight.detach().clone().requires_grad_(True)
_bias = torch.randn(V, device=device, dtype=dtype) if bias else None
bias1 = _bias.detach().clone().requires_grad_(True) if bias else None
bias2 = _bias.detach().clone().requires_grad_(True) if bias else None
_ref_bias = torch.randn(V, device=device, dtype=dtype) if ref_bias else None
ref_bias1 = _ref_bias.detach().clone().requires_grad_(True) if ref_bias else None
ref_bias2 = _ref_bias.detach().clone().requires_grad_(True) if ref_bias else None
loss1, aggregated_aux_outputs1 = LigerFusedLinearDPOFunction.apply(
input1,
weight1,
target,
bias1,
ref_input,
ref_weight1,
ref_bias1,
-100,
0.1,
compute_nll_loss,
)
loss2, aggregated_aux_outputs2 = liger_fused_linear_dpo(
input2,
weight2,
target,
bias2,
ref_input,
ref_weight2,
ref_bias2,
-100,
0.1,
compute_nll_loss,
)
assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol)
loss1.backward()
loss2.backward()
assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol)
assert_verbose_allclose(weight1.grad, weight2.grad, atol=atol, rtol=rtol)
if bias:
assert_verbose_allclose(bias1.grad, bias2.grad, atol=atol, rtol=rtol)
@pytest.mark.parametrize(
"B, T, H, V",
[
(8, 128, 1024, 4096),
(3, 47, 31, 123), # random shape
],
)
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
(1.0, torch.bfloat16, 5e-2, 5e-1),
(1.0, torch.float32, 1e-5, 5e-4),
],
)
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("ref_bias", [True, False])
@pytest.mark.parametrize("compute_nll_loss", [True, False])
@pytest.mark.parametrize("ignore_index, beta", [(-100, 0.1), (42, 0.2)])
@pytest.mark.parametrize("loss_type", ["apo_zero", "apo_down", "sppo_hard", "nca_pair"])
def test_correctness_apo_loss_types(
B,
T,
H,
V,
scalar,
dtype,
atol,
rtol,
bias,
ref_bias,
compute_nll_loss,
ignore_index,
beta,
loss_type,
):
B = 2 * B # dpo loss requires B to be even
# Select the appropriate HF reference implementation
if loss_type == "apo_zero":
torch_lm_head = TorchLMHeadAPOZero
elif loss_type == "apo_down":
torch_lm_head = TorchLMHeadAPODown
elif loss_type == "sppo_hard":
torch_lm_head = TorchLMHeadSPPOHARD
elif loss_type == "nca_pair":
torch_lm_head = TorchLMHeadNCAPAIR
else:
raise ValueError(f"Unsupported loss_type: {loss_type}")
torch_lm_head_apo = torch_lm_head(
H=H,
V=V,
dtype=dtype,
bias=bias,
ref_bias=ref_bias,
compute_nll_loss=compute_nll_loss,
ignore_index=ignore_index,
beta=beta,
)
liger_lm_head_apo = LigerLMHeadDPO(
H=H,
V=V,
dtype=dtype,
bias=bias,
ref_bias=ref_bias,
compute_nll_loss=compute_nll_loss,
ignore_index=ignore_index,
beta=beta,
loss_type=loss_type,
)
torch_lm_head_apo.lin.weight.data = liger_lm_head_apo.lin.weight.data = torch.randn(
V, H, device=device, dtype=dtype
)
torch_lm_head_apo.ref_lin.weight.data = liger_lm_head_apo.ref_lin.weight.data = torch.randn(
V, H, device=device, dtype=dtype
)
if bias:
torch_lm_head_apo.lin.bias.data = liger_lm_head_apo.lin.bias.data = torch.randn(V, device=device, dtype=dtype)
if ref_bias:
torch_lm_head_apo.ref_lin.bias.data = liger_lm_head_apo.ref_lin.bias.data = torch.randn(
V, device=device, dtype=dtype
)
_input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar
input1 = _input.detach().clone().requires_grad_(True)
input2 = _input.detach().clone().requires_grad_(True)
ref_input = torch.randn(B, T, H, device=device, dtype=dtype, requires_grad=False) * scalar
target = torch.randint(
0,
V,
(
B,
T,
),
device=device,
dtype=torch.long,
)
# Assign some random number of elements as ignore_index
num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item()
indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign]
target.view(-1)[indices_to_assign] = ignore_index
loss1, aggregated_aux_outputs1 = torch_lm_head_apo(input1, ref_input, target)
loss2, aggregated_aux_outputs2 = liger_lm_head_apo(input2, ref_input, target)
assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol)
assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2)
for i in range(len(aggregated_aux_outputs1)):
if i > 4 and dtype == torch.bfloat16:
# numerical instability in bf16 for chosen_rewards and rejected_rewards
# temporary fix. TODO: investigate how to reduce numerical instability issue
assert_verbose_allclose(
aggregated_aux_outputs1[i],
aggregated_aux_outputs2[i],
atol=5e-1,
rtol=rtol,
)
continue
assert_verbose_allclose(
aggregated_aux_outputs1[i],
aggregated_aux_outputs2[i],
atol=atol,
rtol=rtol,
)
loss1.backward()
loss2.backward()
assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol)
assert_verbose_allclose(
torch_lm_head_apo.lin.weight.grad,
liger_lm_head_apo.lin.weight.grad,
atol=atol,
rtol=rtol,
)
if bias:
assert_verbose_allclose(
torch_lm_head_apo.lin.bias.grad,
liger_lm_head_apo.lin.bias.grad,
atol=atol,
rtol=rtol,
)
@pytest.mark.parametrize(
"B, T, H, V",
[
(2, 2, 8, 8),
(3, 47, 31, 123), # random shape
],
)
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
(1.0, torch.bfloat16, 5e-2, 5e-1),
(1.0, torch.float32, 1e-5, 5e-4),
],
)
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("ref_bias", [True, False])
@pytest.mark.parametrize("compute_nll_loss", [True, False])
@pytest.mark.parametrize("loss_type", ["apo_zero", "apo_down", "sppo_hard", "nca_pair"])
def test_correctness_functional_apo_loss_types(
B, T, H, V, scalar, dtype, atol, rtol, bias, ref_bias, compute_nll_loss, loss_type
):
B = 2 * B
_input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar
input1 = _input.detach().clone().requires_grad_(True)
input2 = _input.detach().clone().requires_grad_(True)
ref_input = torch.randn(B, T, H, device=device, dtype=dtype, requires_grad=False) * scalar
target = torch.randint(
0,
V,
(
B,
T,
),
device=device,
dtype=torch.long,
)
_weight = torch.randn(V, H, device=device, dtype=dtype)
weight1 = _weight.detach().clone().requires_grad_(True)
weight2 = _weight.detach().clone().requires_grad_(True)
_ref_weight = torch.randn(V, H, device=device, dtype=dtype)
ref_weight1 = _ref_weight.detach().clone().requires_grad_(True)
ref_weight2 = _ref_weight.detach().clone().requires_grad_(True)
_bias = torch.randn(V, device=device, dtype=dtype) if bias else None
bias1 = _bias.detach().clone().requires_grad_(True) if bias else None
bias2 = _bias.detach().clone().requires_grad_(True) if bias else None
_ref_bias = torch.randn(V, device=device, dtype=dtype) if ref_bias else None
ref_bias1 = _ref_bias.detach().clone().requires_grad_(True) if ref_bias else None
ref_bias2 = _ref_bias.detach().clone().requires_grad_(True) if ref_bias else None
# Call with loss_type parameter for LigerFusedLinearDPOFunction
loss1, aggregated_aux_outputs1 = LigerFusedLinearDPOFunction.apply(
input1,
weight1,
target,
bias1,
ref_input,
ref_weight1,
ref_bias1,
-100,
0.1,
compute_nll_loss,
True, # compiled
True, # use_ref_model
False, # average_log_prob
1, # chunk_size
loss_type, # loss_type
)
# For comparison, create a LigerFusedLinearDPOLoss with the loss_type
dpo_loss_fn = LigerFusedLinearDPOLoss(
ignore_index=-100,
beta=0.1,
compute_nll_loss=compute_nll_loss,
loss_type=loss_type,
)
loss2, aggregated_aux_outputs2 = dpo_loss_fn(
weight2,
input2,
target,
bias2,
ref_input,
ref_weight2,
ref_bias2,
)
assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol)
loss1.backward()
loss2.backward()
assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol)
assert_verbose_allclose(weight1.grad, weight2.grad, atol=atol, rtol=rtol)
if bias:
assert_verbose_allclose(bias1.grad, bias2.grad, atol=atol, rtol=rtol)
def test_invalid_loss_type():
"""Test that invalid loss types raise ValueError"""
with pytest.raises(ValueError, match="Unsupported loss_type"):
LigerFusedLinearDPOLoss(loss_type="invalid_loss_type")
# Test that valid loss types don't raise errors
valid_loss_types = ["sigmoid", "apo_zero", "apo_down", "sppo_hard", "nca_pair"]
for loss_type in valid_loss_types:
# Should not raise an exception
loss_fn = LigerFusedLinearDPOLoss(loss_type=loss_type)
assert loss_fn.loss_type == loss_type
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment