Commit 9b0e3a30 authored by cmx's avatar cmx
Browse files

first commit

parent fe5cd1fc
Pipeline #3450 failed with stages
in 0 seconds
# This code is based on tatsu-lab/stanford_alpaca. Below is the original copyright:
#
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from: https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py
import json
import os
import pathlib
from dataclasses import dataclass
from dataclasses import field
from typing import Dict
from typing import Optional
import torch
import transformers
from callback import EfficiencyCallback
from medusa_util import add_medusa_heads
from safetensors.torch import save_file
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
from transformers import Trainer
from transformers.trainer_pt_utils import LabelSmoother
from liger_kernel.transformers import AutoLigerKernelForCausalLM
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="meta-llama/Meta-Llama-3-8B-Instruct")
@dataclass
class DataArguments:
data_path: str = field(
default="Aeala/ShareGPT_Vicuna_unfiltered",
metadata={"help": "Path to the training data."},
)
eval_data_path: str = field(default=None, metadata={"help": "Path to the evaluation data."})
lazy_preprocess: bool = True
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
report_to: Optional[str] = None
optim: str = field(default="adamw_torch")
model_max_length: int = field(
default=2048,
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
)
medusa_num_heads: int = field(
default=1,
metadata={"help": "Number of Medusa heads."},
)
medusa_num_layers: int = field(
default=1,
metadata={"help": "Number of layers for each Medusa head."},
)
medusa_heads_coefficient: float = field(
default=1.0,
metadata={"help": "Coefficient for the Medusa heads."},
)
medusa_decay_coefficient: float = field(
default=1.0,
metadata={"help": "Coefficient for the Medusa heads."},
)
medusa_scheduler: str = field(
default="constant",
metadata={"help": "Scheduler for the Medusa heads."},
)
medusa_lr_multiplier: float = field(
default=0.0,
metadata={"help": "Learning rate multiplier for the Medusa heads."},
)
medusa_return: bool = field(
default=False,
metadata={
"help": "If medusa is not applied, the default is False, and the regular lm_head will be used for single-token prediction."
},
)
medusa_only_heads: bool = field(
default=False,
metadata={"help": "If train medusa heads only, default is False, the whole model will be trained"},
)
use_liger: bool = field(
default=False,
metadata={"help": "If apply liger kernel to the model."},
)
local_rank = None
def rank0_print(*args):
if local_rank == 0:
print(*args)
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
"""
Save the model's state dictionary to a specified directory.
Args:
trainer (transformers.Trainer): The Hugging Face Trainer object.
output_dir (str): The directory where the model state dictionary will be saved.
"""
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
def preprocess(
sources,
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
"""
Preprocesses conversation data and tokenizes it for model input.
Args:
sources: A list of conversation sources.
tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for tokenization.
Returns:
Dict: A dictionary containing tokenized inputs, labels, and attention mask.
"""
# Apply prompt templates
conversations = []
prompts = []
# import pdb; pdb.set_trace()
for conversation in sources[:50]:
tokenizer_compatible_conv = [
{
"role": "user" if c["from"] == "human" else "assistant",
"content": c["value"],
}
for c in conversation["conversations"]
]
prompt = tokenizer.apply_chat_template(tokenizer_compatible_conv, tokenize=False)
prompts.append(prompt)
conversations.append(tokenizer_compatible_conv)
# Tokenize conversations
encoding = tokenizer(
prompts,
return_tensors="pt",
padding="max_length",
truncation=True,
return_offsets_mapping=True,
)
# Set everything to be ignored, except the assistant part
targets = torch.full_like(encoding.input_ids, IGNORE_TOKEN_ID)
input_ids = encoding.input_ids
# Mask targets. Only compute loss on the assistant outputs.
for conv_index, (conversation, target, prompt) in enumerate(zip(conversations, targets, prompts)):
# print(conv_index)
for turn in conversation:
if turn["role"] == "assistant":
content = turn["content"]
# Unfortunate strip() necessary because chat templates are doing the same.
start = prompt.index(content.strip())
# stop = start + len(content)
indices = []
for tok_index, (tok_start, tok_stop) in enumerate(encoding.offset_mapping[conv_index]):
if tok_stop >= start or tok_start < tok_stop:
indices.append(tok_index)
target[indices] = encoding.input_ids[conv_index][indices]
return dict(
input_ids=input_ids,
labels=targets,
attention_mask=input_ids.ne(tokenizer.pad_token_id),
)
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning.
Args:
raw_data (list): A list of raw data examples.
tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing.
"""
def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer):
super(SupervisedDataset, self).__init__()
rank0_print("Formatting inputs...")
sources = raw_data
data_dict = preprocess(sources, tokenizer)
self.input_ids = data_dict["input_ids"]
self.labels = data_dict["labels"]
self.attention_mask = data_dict["attention_mask"]
def __len__(self):
return len(self.input_ids)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
return dict(
input_ids=self.input_ids[i],
labels=self.labels[i],
attention_mask=self.attention_mask[i],
)
class LazySupervisedDataset(Dataset):
"""Lazy dataset for supervised fine-tuning.
This dataset loads data on-the-fly when requested, which can be memory-efficient but slower.
Args:
raw_data (list): A list of raw data examples.
tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing.
"""
def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer):
super(LazySupervisedDataset, self).__init__()
self.tokenizer = tokenizer
rank0_print("Formatting inputs...Skip in lazy mode")
self.tokenizer = tokenizer
self.raw_data = raw_data
self.cached_data_dict = {}
def __len__(self):
return len(self.raw_data)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
if i in self.cached_data_dict:
return self.cached_data_dict[i]
ret = preprocess([self.raw_data[i]], self.tokenizer)
ret = dict(
input_ids=ret["input_ids"][0],
labels=ret["labels"][0],
attention_mask=ret["attention_mask"][0],
)
self.cached_data_dict[i] = ret
return ret
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args, test_size=0.05) -> Dict:
"""Make dataset and collator for supervised fine-tuning.
Args:
tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing.
data_args: Data arguments.
test_size: evaluation data ratio (default: 0.05)
Returns:
dict: A dictionary containing train and eval datasets.
"""
dataset_cls = LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset
rank0_print("Loading data...")
# Load the entire dataset
train_json = json.load(open(data_args.data_path, "r"))
# Perform a train-test split based on test_size
train_data, eval_data = train_test_split(train_json, test_size=test_size, random_state=42)
# Create the train and eval datasets
train_dataset = dataset_cls(train_data, tokenizer=tokenizer)
eval_dataset = dataset_cls(eval_data, tokenizer=tokenizer)
return dict(train_dataset=train_dataset, eval_dataset=eval_dataset)
def train():
global local_rank
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
local_rank = training_args.local_rank
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=True,
)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.pad_token = tokenizer.eos_token
# Making sure the tokenizer works before loading the model.
print(tokenizer(["This is a test", "secondary"], padding=True))
print(tokenizer.apply_chat_template([{"role": "user", "content": "This is a test"}]))
def _model_loader():
# we use a customized model loader to inject medusa heads to FSDP-wrapped model variables properly.
# see https://github.com/linkedin/Liger-Kernel/issues/309#issuecomment-2455077623 for details.
# Load model
if training_args.use_liger:
model_builder = AutoLigerKernelForCausalLM.from_pretrained
else:
model_builder = transformers.AutoModelForCausalLM.from_pretrained
model = model_builder(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
dtype=torch.bfloat16,
)
# Freeze the base model
for param in model.base_model.parameters():
param.requires_grad = False
# Inject Medusa heads
add_medusa_heads(
model,
training_args.medusa_num_heads,
training_args.medusa_num_layers,
training_args.medusa_return,
training_args.medusa_only_heads,
training_args.use_liger,
)
return model
# Format output dir
training_args.output_dir = f"{training_args.output_dir}_medusa_mlp_{model_args.model_name_or_path.split('/')[-1]}_medusa_{training_args.medusa_num_heads}_lr_{training_args.learning_rate}_layers_{training_args.medusa_num_layers}"
# Load data
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
# Start trainner
trainer = Trainer(
model_init=_model_loader,
tokenizer=tokenizer,
args=training_args,
callbacks=[EfficiencyCallback()],
**data_module,
)
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
if training_args.medusa_return and training_args.medusa_only_heads:
# Save only the updated head without saving the backbone model
state_dict = {
k.replace("medusa_head.", ""): v.to(torch.bfloat16)
for k, v in trainer.accelerator.get_state_dict(trainer.model).items()
if "medusa_head" in k
}
# Save Medusa heads
if local_rank == 0:
save_file(
state_dict,
os.path.join(training_args.output_dir, "medusa_lm_head.safetensors"),
)
trainer.accelerator.wait_for_everyone()
else:
# Save the whole model weight
trainer.save_model(training_args.output_dir)
if __name__ == "__main__":
train()
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [2024-] [Unsloth AI, Daniel Han-Chen & Michael Han-Chen]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
\ No newline at end of file
MIT License
Copyright (c) 2023 MIT HAN Lab
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
\ No newline at end of file
MIT License
Copyright (c) 2024 mgmalek
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
\ No newline at end of file
MIT License
Copyright (c) 2024 Andrej Karpathy
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
\ No newline at end of file
/*
* Copyright 2018-2020 Philippe Tillet
* Copyright 2020-2022 OpenAI
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
\ No newline at end of file
site_name: Liger-Kernel Docs
site_url: https://linkedin.github.io/Liger-Kernel/
site_author: LinkedIn
site_description: Efficient Triton Kernels for LLM Training
theme:
name: material
font:
text: Merriweather Sans
code: Red Hat Mono
features:
- navigation.footer
- toc.follow
- navigation.top
- navigation.sections
palette:
# Dark Mode
- scheme: slate
toggle:
icon: material/weather-sunny
name: Dark mode
primary: green
accent: deep purple
# Light Mode
- scheme: default
toggle:
icon: material/weather-night
name: Light mode
primary: blue
accent: deep purple
nav:
- Home: index.md
- Examples: Examples.md
- Getting Started: Getting-Started.md
- High Level APIs: High-Level-APIs.md
- Low Level APIs: Low-Level-APIs.md
- Contributing: contributing.md
- Acknowledgment: acknowledgement.md
- License: license.md
markdown_extensions:
- attr_list
- toc:
permalink: true
- pymdownx.highlight:
anchor_linenums: true
line_spans: __span
pygments_lang_class: true
- pymdownx.inlinehilite
- pymdownx.snippets
- pymdownx.superfences:
custom_fences:
- name: mermaid
class: mermaid
format: !!python/name:pymdownx.superfences.fence_code_format
- pymdownx.tabbed:
alternate_style: true
- admonition
- pymdownx.details
plugins:
- search
- mkdocstrings:
handlers:
python:
paths: [src]
options:
show_root_heading: true
show_source: true
docstring_style: google
docstring_section_style: table
heading_level: 3
show_signature_annotations: false # Hides type annotations to save space
separate_signature: true # Separates signature from description
# Repository
repo_name: linkedin/Liger-Kernel
repo_url: https://github.com/linkedin/Liger-Kernel
edit_uri: edit/main/docs/
extra:
social:
- icon: simple/github
link: https://github.com/linkedin/Liger-Kernel
[build-system]
requires = ["setuptools>=42", "wheel", "setuptools-scm"]
build-backend = "setuptools.build_meta"
[project]
name = "liger_kernel"
version = "0.7.0"
description = "Efficient Triton kernels for LLM Training"
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
readme = { file = "README.md", content-type = "text/markdown" }
license = { file = "LICENSE" }
dynamic = ["dependencies", "optional-dependencies"]
[tool.setuptools]
package-dir = {"" = "src"}
[tool.setuptools.packages.find]
where = ["src"]
include = ["liger_kernel*"]
namespaces = false
[tool.pytest.ini_options]
pythonpath = ["src", "."]
asyncio_mode = "auto"
log_cli = true
log_cli_level = "INFO"
addopts = [
"--cov=src/liger_kernel",
"--cov-report=term-missing",
"--cov-report=html",
"--cov-config=pyproject.toml",
"--durations=0"
]
python_files = "test_*.py"
testpaths = ["test/"]
[tool.coverage.run]
branch = true
parallel = true
source = ["src/liger_kernel"]
# xdist uses subprocesses; "multiprocessing" is a safe concurrency choice
concurrency = ["multiprocessing"]
[tool.coverage.paths]
liger_kernel = [
"src/liger_kernel",
"*/site-packages/liger_kernel"
]
[tool.coverage.report]
omit = ["test/*"]
show_missing = true
skip_covered = false
[tool.ruff]
line-length = 120
target-version = "py310"
respect-gitignore = true
src = ["src"]
[tool.ruff.lint]
select = [
"E", # pycodestyle errors
"F", # pyflakes
"I", # isort
]
ignore = ["E501", "B006", "E731", "A002", "E203"]
exclude = [
".git",
"__pycache__",
"benchmark_internal/others",
".venv",
]
[tool.ruff.format]
quote-style = "double"
indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"
[tool.ruff.lint.isort]
known-first-party = ["liger_kernel"]
force-single-line = true
lines-between-types = 1
# setup.py
import subprocess
from typing import Literal
from setuptools import setup
def get_default_dependencies():
"""Determine the appropriate dependencies based on detected hardware."""
platform = get_platform()
if platform in ["cuda", "cpu"]:
return [
"torch>=2.1.2",
"triton>=2.3.1",
]
elif platform == "rocm":
return [
"triton>=3.0.0",
]
elif platform == "xpu":
return [
"torch>=2.6.0",
]
# TODO: Currently, triton-ascend is not compatible with torch 2.7.1. We will upgrade it later.
elif platform == "npu":
return ["torch==2.6.0", "torch_npu==2.6.0", "triton-ascend"]
def get_optional_dependencies():
"""Get optional dependency groups."""
return {
"dev": [
"transformers>=4.52.0",
"matplotlib>=3.7.2",
"ruff>=0.12.0",
"pytest>=7.1.2",
"pytest-xdist",
"pytest-cov",
"pytest-asyncio",
"pytest-rerunfailures",
"datasets>=2.19.2",
"seaborn",
"mkdocs-material",
"torchvision>=0.20",
"prek>=0.2.28",
]
}
def is_xpu_available():
"""
Check if Intel XPU is available.
xpu-smi is often missing right now.
"""
try:
subprocess.run(["xpu-smi"], check=True)
return True
except (subprocess.SubprocessError, FileNotFoundError):
pass
try:
result = subprocess.run("sycl-ls", check=True, capture_output=True, shell=True)
if "level_zero:gpu" in result.stdout.decode():
return True
except (subprocess.SubprocessError, FileNotFoundError):
pass
return False
def is_ascend_available() -> bool:
"""Best-effort Ascend detection.
Checks for common Ascend environment variables and a possible `npu-smi`
utility if present.
"""
try:
subprocess.run(["npu-smi", "info"], check=True)
return True
except (subprocess.SubprocessError, FileNotFoundError):
pass
return False
def get_platform() -> Literal["cuda", "rocm", "cpu", "xpu", "npu"]:
"""
Detect whether the system has NVIDIA or AMD GPU without torch dependency.
"""
# Try nvidia-smi first
try:
subprocess.run(["nvidia-smi"], check=True)
print("NVIDIA GPU detected")
return "cuda"
except (subprocess.SubprocessError, FileNotFoundError):
# If nvidia-smi fails, check for ROCm
try:
subprocess.run(["rocm-smi"], check=True)
print("ROCm GPU detected")
return "rocm"
except (subprocess.SubprocessError, FileNotFoundError):
if is_xpu_available():
print("Intel GPU detected")
return "xpu"
elif is_ascend_available():
print("Ascend NPU detected")
return "npu"
else:
print("No GPU detected")
return "cpu"
setup(
name="liger_kernel",
package_dir={"": "src"},
packages=["liger_kernel"],
install_requires=get_default_dependencies(),
extras_require=get_optional_dependencies(),
classifiers=[
"Development Status :: 5 - Production/Stable",
"Intended Audience :: Developers",
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"Programming Language :: Python :: 3",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries :: Python Modules",
"License :: OSI Approved :: BSD-2-Clause Software License",
"Operating System :: OS Independent",
],
)
# Liger FlexChunkLoss: Alignment and Distillation loss
Liger FlexChunkLoss offers a versatile interface, delivering up to 80% memory savings and a 10% throughput boost for post-training loss functions, including alignment (DPO, ORPO, CPO, KTO) and very soon, distillation. Its flexible design supports custom losses, ensuring efficiency gains across diverse use cases.
### User interface
FlexChunkLoss offers two flexible usage options:
1. **Via `Liger[Custom Loss]Trainer`**
For example, by simply replacing the HuggingFace `ORPOTrainer` with `LigerORPOTrainer` in your code, you can leverage our optimized ORPO implementation and immediately benefit from improved performance.
2. **Using `nn.Module` Implementations of Custom Loss Functions**
Explore the [LigerORPOTrainer implementation](https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/orpo_trainer.py) to see how the modular design integrates custom loss functions seamlessly.
### What's under the hood?
We employ chunking and fused kernel optimizations to enhance performance. By fusing the final linear layer with loss computation and calculating backward gradients during the forward pass, we significantly reduce the need for storing intermediate activations. All operations are implemented in PyTorch, leveraging `torch.compile` to streamline kernel execution without relying on extensive low-level optimizations. Additionally, we minimize `torch.compile` recompilations to reduce overhead and ensure consistent performance gains.
### Extending to custom loss functions
We provide two base classes: `LigerFusedLinearPreferenceBase` for alignment use cases and `LigerFusedLinearDistillationBase` for distillation use cases. These base classes manage chunking, kernel fusions, and Torch compilation.
To implement a custom loss function, you need to create a subclass that defines the custom preference or distillation loss function, capable of processing a given input chunk. The base class will take care of the optimizations, handling most of the heavy lifting for you.
For a working example, refer to the [ORPO loss implementation](https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/chunked_loss/orpo_loss.py).
\ No newline at end of file
from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityLoss # noqa:F401
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss # noqa: F401
from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDLoss # noqa: F401
from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOLoss # noqa: F401
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401
from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOLoss # noqa: F401
from typing import Tuple
from typing import Union
import torch
import torch.nn.functional as F
from liger_kernel.chunked_loss.fused_linear_distillation import LigerFusedLinearDistillationBase
class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase):
@staticmethod
def distillation_loss_fn(
student_logits,
teacher_logits,
target=None,
ignore_index=None,
beta=1.0,
):
"""
Compute Cosine loss (Cosine Similarity Loss).
Args:
student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,).
teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,).
beta: Coefficient beta of generalized Cosine Similarity in the interval [0, 1]. Default: `1.0` (float): .
Returns:
torch.Tensor: cosine similarity loss
"""
student_norm = F.normalize(student_logits, p=2, dim=-1)
teacher_norm = F.normalize(teacher_logits, p=2, dim=-1)
cosine_sim = F.cosine_similarity(student_norm, teacher_norm, dim=-1)
loss = beta * (1 - cosine_sim)
return loss.sum()
@classmethod
def forward(
cls,
ctx,
student_input: torch.Tensor,
student_weight: torch.Tensor,
teacher_input: torch.Tensor,
teacher_weight: torch.Tensor,
true_labels: torch.LongTensor,
student_bias: torch.Tensor,
teacher_bias: torch.Tensor,
weight_hard_loss: float = 0.5,
weight_soft_loss: float = 0.5,
beta: float = 0.5,
ignore_index: int = -100,
temperature: float = 1.0,
compiled: bool = True,
chunk_size: int = 1024,
return_soft_hard_loss: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
return super().forward(
cls=cls,
ctx=ctx,
student_input=student_input,
student_weight=student_weight,
teacher_input=teacher_input,
teacher_weight=teacher_weight,
target=true_labels,
student_bias=student_bias,
teacher_bias=teacher_bias,
chunk_size=chunk_size,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
beta=beta,
ignore_index=ignore_index,
temperature=temperature,
compiled=compiled,
return_soft_hard_loss=return_soft_hard_loss,
)
@staticmethod
def backward(ctx, grad_output, *args):
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output, *args)[:6]
return (
*grads,
None, # teacher_bias
None, # weight_hard_loss
None, # weight_soft_loss
None, # beta
None, # ignore_index
None, # temperature
None, # compiled
None, # chunk_size
None, # return_soft_hard_loss
)
class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
def __init__(
self,
weight_hard_loss: float = 0.5,
weight_soft_loss: float = 0.5,
beta: float = 0.5,
ignore_index: int = -100,
temperature: float = 1.0,
compiled: bool = True,
chunk_size: int = 1024,
return_soft_hard_loss: bool = False,
):
super().__init__()
assert temperature != 0, "Temperature cannot be 0."
self.weight_hard_loss = weight_hard_loss
self.weight_soft_loss = weight_soft_loss
self.ignore_index = ignore_index
self.temperature = temperature
self.compiled = compiled
self.beta = beta
self.chunk_size = chunk_size
self.return_soft_hard_loss = return_soft_hard_loss
def forward(
self,
student_input: torch.Tensor,
student_weight: torch.Tensor,
teacher_input: torch.Tensor,
teacher_weight: torch.Tensor,
true_labels: torch.LongTensor,
student_bias: torch.Tensor = None,
teacher_bias: torch.Tensor = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
return LigerFusedLinearCosineSimilarityFunction.apply(
student_input,
student_weight,
teacher_input,
teacher_weight,
true_labels,
student_bias,
teacher_bias,
self.weight_hard_loss,
self.weight_soft_loss,
self.beta,
self.ignore_index,
self.temperature,
self.compiled,
self.chunk_size,
self.return_soft_hard_loss,
)
import torch
import torch.nn.functional as F
from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
@staticmethod
def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1, label_smoothing=0.0):
"""
Paper: https://arxiv.org/pdf/2401.08417
Formula:
L(π_θ; U) = -E_(x,y_w,y_l)~D[log σ(β log π_θ(y_w|x) - β log π_θ(y_l|x))]
Where:
- π_θ(y|x): Policy (model) probability
- y_w: Chosen sequence
- y_l: Rejected sequence
- σ: Sigmoid function
- β: Temperature parameter
- E: Expected value over the dataset D
- D: Dataset of preferences
Args:
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
full_target (torch.Tensor): Non chunked full target tensor
beta (float): Weight for the CPO loss
label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0.
"""
logits = beta * (chosen_logps - rejected_logps)
loss = (-F.logsigmoid(logits) * (1 - label_smoothing) - F.logsigmoid(-logits) * label_smoothing).sum() / (
full_target.shape[0] // 2
)
chosen_rewards = beta * chosen_logps
rejected_rewards = beta * rejected_logps
return loss, chosen_rewards, rejected_rewards
@classmethod
def forward(
cls,
ctx,
_input,
weight,
target,
bias=None,
ignore_index=-100,
beta=0.1,
alpha=1.0,
label_smoothing=0.0,
compute_nll_loss=True,
compiled=True,
average_log_prob=False,
chunk_size=1,
):
"""
Fused linear layer with CPO loss.
Args:
_input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
ignore_index (int): Index to ignore in loss computation
beta (float): Weight for the odds ratio loss
alpha (float): Weight for the alpha parameter
label_smoothing (float): Label smoothing factor
compute_nll_loss (bool): Whether to compute the NLL loss
compiled (bool): Whether to use torch compile
average_log_prob (bool): Whether to average the log probability per non-masked token
chunk_size (int): Size of chunks for processing.
Returns:
torch.Tensor: Computed loss
"""
return super().forward(
cls=cls,
ctx=ctx,
_input=_input,
weight=weight,
target=target,
bias=bias,
ignore_index=ignore_index,
alpha=alpha,
beta=beta,
label_smoothing=label_smoothing,
compute_nll_loss=compute_nll_loss,
average_log_prob=average_log_prob,
compiled=compiled,
chunk_size=chunk_size,
)
@staticmethod
def backward(ctx, *grad_output):
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
return *grads, None, None, None, None, None, None, None, None
class LigerFusedLinearCPOLoss(torch.nn.Module):
"""
Fused linear layer with CPO loss.
"""
def __init__(
self,
ignore_index: int = -100,
beta: float = 0.1,
alpha: float = 1.0,
label_smoothing: float = 0.0,
compute_nll_loss: bool = True,
compiled: bool = True,
average_log_prob: bool = False,
chunk_size: int = 1,
):
"""
Args:
ignore_index (int): Index to ignore in the loss.
beta (float): Weight for the odds ratio loss.
alpha (float): Weight for the alpha parameter.
label_smoothing (float): Label smoothing factor.
compute_nll_loss (bool): Whether to compute the NLL loss.
compiled (bool): Whether to use the torch compiled kernel.
average_log_prob (bool): Whether to average the log probability per non-masked token.
chunk_size (int): Size of chunks for processing.
"""
super().__init__()
self.ignore_index = ignore_index
self.beta = beta
self.alpha = alpha
self.label_smoothing = label_smoothing
self.compute_nll_loss = compute_nll_loss
self.compiled = compiled
self.average_log_prob = average_log_prob
self.chunk_size = chunk_size
def forward(
self,
lin_weight,
_input,
target,
bias=None,
):
return LigerFusedLinearCPOFunction.apply(
_input,
lin_weight,
target,
bias,
self.ignore_index,
self.beta,
self.alpha,
self.label_smoothing,
self.compute_nll_loss,
self.compiled,
self.average_log_prob,
self.chunk_size,
)
import torch
import torch.nn.functional as F
from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
@staticmethod
def preference_loss_fn(
chosen_logps,
rejected_logps,
full_target,
ref_chosen_logps=None,
ref_rejected_logps=None,
beta=0.1,
loss_type="sigmoid",
):
"""
Paper: https://arxiv.org/pdf/2305.18290
Formula:
L_DPO = -E[ log_sigmoid( β * (log(π(y_w|x)/π_ref(y_w|x)) - log(π(y_l|x)/π_ref(y_l|x))) ) ]
Where:
- π(y|x): Policy (model) probability
- π_ref(y|x): Reference model probability
- y_w: Chosen sequence
- y_l: Rejected sequence
- β: Weight for the direct preference loss
- E: Expected value over the dataset
Args:
chosen_logps: Log probabilities of chosen tokens (batch_size,)
rejected_logps: Log probabilities of rejected tokens (batch_size,)
full_target: Non chunked full target tensor
ref_chosen_logps: Reference log probs of chosen tokens (batch_size,)
ref_rejected_logps: Reference log probs of rejected tokens (batch_size,)
beta: Weight for the direct preference loss
"""
if ref_chosen_logps is None:
ref_chosen_logps = torch.tensor(0.0, device=chosen_logps.device)
if ref_rejected_logps is None:
ref_rejected_logps = torch.tensor(0.0, device=rejected_logps.device)
chosen_logratios = chosen_logps - ref_chosen_logps
rejected_logratios = rejected_logps - ref_rejected_logps
chosen_rewards = beta * chosen_logratios
rejected_rewards = beta * rejected_logratios
if loss_type == "sigmoid":
logits_diff = beta * (chosen_logratios - rejected_logratios)
loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
elif loss_type == "apo_zero":
# Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266)
# Use this loss when you believe the chosen outputs are better than your model's default output
losses_chosen = 1 - F.sigmoid(beta * chosen_logratios) # Increase chosen likelihood
losses_rejected = F.sigmoid(beta * rejected_logratios)
losses = losses_chosen + losses_rejected
loss = losses.sum() / (full_target.shape[0] // 2)
elif loss_type == "apo_down":
# Eqn (8) of the APO paper (https://huggingface.co/papers/2408.06266)
# Use this loss when you believe the chosen outputs are worse than your model's default output.
# Decrease chosen likelihood and decrease rejected likelihood more
losses_chosen = F.sigmoid(beta * chosen_logratios)
losses_rejected = 1 - F.sigmoid(beta * (chosen_logratios - rejected_logratios))
losses = losses_chosen + losses_rejected
loss = losses.sum() / (full_target.shape[0] // 2)
elif loss_type == "sppo_hard":
# In the paper (https://huggingface.co/papers/2405.00675), SPPO employs a soft probability approach,
# estimated using the PairRM score. The probability calculation is conducted outside of the trainer class.
# The version described here is the hard probability version, where P in Equation (4.7) of Algorithm 1 is
# set to 1 for the winner and 0 for the loser.
a = chosen_logps - ref_chosen_logps
b = rejected_logps - ref_rejected_logps
losses = (a - 0.5 / beta) ** 2 + (b + 0.5 / beta) ** 2
loss = losses.sum() / (full_target.shape[0] // 2)
elif loss_type == "nca_pair":
losses = (
-F.logsigmoid(chosen_rewards)
- 0.5 * F.logsigmoid(-chosen_rewards)
- 0.5 * F.logsigmoid(-rejected_rewards)
)
loss = losses.sum() / (full_target.shape[0] // 2)
else:
raise ValueError(
f"Unsupported loss_type: {loss_type}. Supported types are: sigmoid, apo_zero, apo_down, sppo_hard, nca_pair"
)
return loss, chosen_rewards, rejected_rewards
@classmethod
def forward(
cls,
ctx,
_input,
weight,
target,
bias=None,
ref_input=None,
ref_weight=None,
ref_bias=None,
ignore_index=-100,
beta=0.1,
compute_nll_loss=False,
compiled=True,
use_ref_model=True,
average_log_prob=False,
chunk_size=1,
loss_type="sigmoid",
):
"""
Fused linear layer with DPO loss.
Args:
_input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size)
ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
ignore_index (int): Index to ignore in loss computation
beta (float): Weight for the odds ratio loss
compute_nll_loss (bool): Whether to compute the NLL loss
compiled (bool): Whether to use torch compile
use_ref_model (bool): Whether to use a reference model
average_log_prob (bool): Whether to average the log probability per non-masked token
chunk_size (int): Size of chunks for processing.
Returns:
torch.Tensor: Computed loss
"""
return super().forward(
cls=cls,
ctx=ctx,
_input=_input,
weight=weight,
target=target,
bias=bias,
ignore_index=ignore_index,
beta=beta,
compute_nll_loss=compute_nll_loss,
compiled=compiled,
use_ref_model=use_ref_model,
ref_input=ref_input,
ref_weight=ref_weight,
ref_bias=ref_bias,
average_log_prob=average_log_prob,
chunk_size=chunk_size,
loss_type=loss_type,
)
@staticmethod
def backward(ctx, *grad_output):
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
return *grads, None, None, None, None, None, None, None, None, None, None, None
class LigerFusedLinearDPOLoss(torch.nn.Module):
"""
Fused linear layer with DPO loss.
"""
def __init__(
self,
ignore_index: int = -100,
beta: float = 0.1,
compute_nll_loss: bool = False,
compiled: bool = True,
use_ref_model: bool = True,
average_log_prob: bool = False,
chunk_size: int = 1,
loss_type: str = "sigmoid",
):
"""
Args:
ignore_index (int): Index to ignore in the loss.
beta (float): Weight for the odds ratio loss.
compute_nll_loss (bool): Whether to compute the NLL loss.
compiled (bool): Whether to use the torch compiled kernel.
use_ref_model (bool): Whether to use a reference model for the DPO loss.
average_log_prob (bool): Whether to average the log probability per non-masked token.
chunk_size (int): Size of chunks for processing.
"""
super().__init__()
self.ignore_index = ignore_index
self.beta = beta
self.compute_nll_loss = compute_nll_loss
self.compiled = compiled
self.use_ref_model = use_ref_model
self.average_log_prob = average_log_prob
self.chunk_size = chunk_size
self.loss_type = loss_type
supported_loss_types = {"sigmoid", "apo_zero", "apo_down", "sppo_hard", "nca_pair"}
if self.loss_type not in supported_loss_types:
raise ValueError(f"Unsupported loss_type: {self.loss_type}. Supported types are: {supported_loss_types}")
def forward(
self,
lin_weight,
_input,
target,
bias=None,
ref_input=None,
ref_weight=None,
ref_bias=None,
):
return LigerFusedLinearDPOFunction.apply(
_input,
lin_weight,
target,
bias,
ref_input,
ref_weight,
ref_bias,
self.ignore_index,
self.beta,
self.compute_nll_loss,
self.compiled,
self.use_ref_model,
self.average_log_prob,
self.chunk_size,
self.loss_type,
)
from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityFunction
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction
from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction
from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOFunction
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction
liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply
liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply
liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
liger_fused_linear_cosine = LigerFusedLinearCosineSimilarityFunction.apply
liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
liger_fused_linear_grpo = LigerFusedLinearGRPOFunction.apply
from abc import abstractmethod
from functools import partial
from typing import Tuple
from typing import Union
import torch
from torch.nn import functional as F
class LigerFusedLinearDistillationBase(torch.autograd.Function):
@abstractmethod
def distillation_loss_fn(
student_logits,
teacher_logits,
target=None,
ignore_index=None,
):
"""
Compute distillation loss.
Args:
student_logits (torch.Tensor): Raw (temperature-scaled) logits of student tokens. Shape: (batch_size * seq_len, vocab_size).
teacher_logits (torch.Tensor): Raw (temperature-scaled) logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size).
Returns:
torch.Tensor: Sum of distillation losses for the chunk. The class will handle
converting this to mean loss by dividing by the full batch size * sequence length in _compute_loss.
"""
raise NotImplementedError("Distillation loss function must be implemented.")
@staticmethod
def chunk_forward(
student_input_chunk,
student_weight,
teacher_input_chunk,
teacher_weight,
target_chunk,
student_bias=None,
teacher_bias=None,
ignore_index=-100,
compute_ce_loss=True,
):
# Student
student_logits_chunk = student_input_chunk @ student_weight.t()
if student_bias is not None:
student_logits_chunk += student_bias
student_log_probs_chunk = F.log_softmax(student_logits_chunk.float(), dim=-1)
# Teacher
with torch.no_grad():
teacher_logits_chunk = teacher_input_chunk @ teacher_weight.t()
if teacher_bias is not None:
teacher_logits_chunk += teacher_bias
# The hard/task loss
ce_loss = 0.0
if compute_ce_loss:
ce_loss = F.nll_loss(
student_log_probs_chunk.view(-1, student_log_probs_chunk.shape[-1]),
target_chunk.view(-1),
reduction="sum",
ignore_index=ignore_index,
)
return student_logits_chunk, teacher_logits_chunk, ce_loss
@staticmethod
def _compute_loss(
student_input_chunk,
student_weight,
teacher_input_chunk,
teacher_weight,
target_chunk,
student_bias=None,
teacher_bias=None,
distillation_loss_fn=None,
full_target=None,
ignore_index=-100,
weight_hard_loss=0.5,
weight_soft_loss=0.5,
compute_ce_loss=True,
temperature=1,
**loss_kwargs,
):
"""
Compute the total loss for a chunk of input and target, while using an knowledge distillation loss function.
Args:
distillation_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
student_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, student_hidden_size).
student_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, student_hidden_size).
teacher_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, teacher_hidden_size).
teacher_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, teacher_hidden_size).
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (chunk_size,).
student_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
teacher_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
full_target (torch.Tensor): Full target tensor. Shape: (batch_size * sequence_length,).
ignore_index (int): Index to ignore for loss computation.
weight_hard_loss (float): Weight for hard loss.
weight_soft_loss (float): Weight for soft loss.
compute_ce_loss (bool): Whether to compute CE loss.
temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
loss_kwargs (dict): Additional arguments for the loss function.
"""
(
student_logits_chunk,
teacher_logits_chunk,
hard_loss,
) = LigerFusedLinearDistillationBase.chunk_forward(
student_input_chunk,
student_weight,
teacher_input_chunk,
teacher_weight,
target_chunk,
student_bias=student_bias,
teacher_bias=teacher_bias,
ignore_index=ignore_index,
compute_ce_loss=compute_ce_loss,
)
student_logits_chunk /= temperature
teacher_logits_chunk /= temperature
# If the teacher and student token size is different, pad student logits to match the teacher's.
# This only applies to cases where they share exactly the same vocab and tokenizer just
# that teacher logit is padded for some training efficiency such as
# https://huggingface.co/Qwen/Qwen1.5-72B-Chat/discussions/1#662883f568adf59b07b176d2
teacher_vocab_size = teacher_weight.shape[0]
student_vocab_size = student_weight.shape[0]
if teacher_vocab_size > student_vocab_size:
pad_size = teacher_vocab_size - student_vocab_size
pad_tensor = torch.zeros(
(*student_logits_chunk.shape[:-1], pad_size),
dtype=student_logits_chunk.dtype,
device=student_logits_chunk.device,
)
student_logits_chunk = torch.cat([student_logits_chunk, pad_tensor], dim=-1)
num_valid_tokens = (full_target != ignore_index).sum()
num_valid_tokens = num_valid_tokens.clamp_min(1) # to avoid division by zero
hard_loss /= num_valid_tokens
soft_loss = distillation_loss_fn(
student_logits_chunk, teacher_logits_chunk, target=target_chunk, ignore_index=ignore_index, **loss_kwargs
)
soft_loss /= num_valid_tokens
loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
return loss, (soft_loss, hard_loss, student_logits_chunk, teacher_logits_chunk)
@staticmethod
def forward(
cls,
ctx,
student_input,
student_weight,
teacher_input,
teacher_weight,
target,
student_bias=None,
teacher_bias=None,
chunk_size=1024,
ignore_index=-100,
weight_hard_loss=0.5,
weight_soft_loss=0.5,
beta=0.5,
compute_ce_loss=True,
temperature=1.0,
compiled=True,
return_soft_hard_loss=False,
**loss_kwargs,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""
Base class for fused linear layer with distillation loss.
Only need to compute gradients for student model.
Args:
student_input (torch.Tensor): Student input tensor. Shape: (batch_size * seq_len, student_hidden_size).
student_weight (torch.Tensor): Student weight tensor. Shape: (vocab_size, student_hidden_size).
teacher_input (torch.Tensor): Teacher input tensor. Shape: (batch_size * seq_len, teacher_hidden_size).
teacher_weight (torch.Tensor): Teacher weight tensor. Shape: (vocab_size, teacher_hidden_size).
target (torch.Tensor): Target truth label tensor. Shape: (batch_size * seq_len).
student_bias (torch.Tensor, optional): Student bias tensor. Shape: (vocab_size,).
teacher_bias (torch.Tensor, optional): Teacher bias tensor. Shape: (vocab_size,).
loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
chunk_size (int): Size of a chunk.
ignore_index (int): Index to ignore for loss computation.
weight_hard_loss (float): Weight for hard/task loss.
weight_soft_loss (float): Weight for soft/distillation loss.
beta (float): Interpolation coefficient between 0 and 1 (default: 0.5).
compute_ce_loss (bool): Whether to compute CE loss.
temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
compiled (bool): Whether to use torch compile for chunk accumulation.
return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
loss_kwargs (dict): Other possible arguments that a loss function might need
"""
CHUNK_SIZE = chunk_size
grad_weight = torch.zeros_like(student_weight)
grad_inputs = []
grad_bias = torch.zeros_like(student_bias) if student_bias is not None else None
loss_acc = torch.zeros((), device=student_input.device)
soft_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None
hard_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None
loss_func_to_call = partial(
LigerFusedLinearDistillationBase._compute_loss,
distillation_loss_fn=cls.distillation_loss_fn,
full_target=target,
ignore_index=ignore_index,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
compute_ce_loss=compute_ce_loss,
temperature=temperature,
beta=beta,
**loss_kwargs,
)
def accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk):
if student_bias is not None:
(
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias),
(
chunk_loss,
(
chunk_soft_loss,
chunk_hard_loss,
chunk_student_logits,
chunk_teacher_logits,
),
),
) = torch.func.grad_and_value(loss_func_to_call, argnums=(0, 1, 5), has_aux=True)(
student_input_chunk,
student_weight,
teacher_input_chunk,
teacher_weight,
target_chunk,
student_bias,
teacher_bias,
)
grad_bias.add_(chunk_grad_bias)
else:
(
(chunk_grad_input, chunk_grad_weight),
(
chunk_loss,
(
chunk_soft_loss,
chunk_hard_loss,
chunk_student_logits,
chunk_teacher_logits,
),
),
) = torch.func.grad_and_value(loss_func_to_call, argnums=(0, 1), has_aux=True)(
student_input_chunk,
student_weight,
teacher_input_chunk,
teacher_weight,
target_chunk,
student_bias,
teacher_bias,
)
grad_weight.add_(chunk_grad_weight)
loss_acc.add_(chunk_loss)
if return_soft_hard_loss:
soft_loss_acc.add_(chunk_soft_loss)
hard_loss_acc.add_(chunk_hard_loss)
return chunk_grad_input
if compiled:
accumulate_chunk = torch.compile(accumulate_chunk)
num_chunks = max(1, student_input.shape[0] // CHUNK_SIZE)
_student_input_chunks = torch.chunk(student_input, chunks=num_chunks, dim=0)
_teacher_input_chunks = torch.chunk(teacher_input, chunks=num_chunks, dim=0)
_target_chunks = torch.chunk(target, chunks=num_chunks, dim=0)
for student_input_chunk, teacher_input_chunk, target_chunk in zip(
_student_input_chunks, _teacher_input_chunks, _target_chunks
):
grad_input = accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk)
grad_inputs.append(grad_input)
ctx.save_for_backward(
torch.cat(grad_inputs, dim=0),
grad_weight,
grad_bias,
)
if return_soft_hard_loss:
return loss_acc, soft_loss_acc, hard_loss_acc
return loss_acc
@staticmethod
def backward(ctx, grad_output, *args):
grad_input, grad_weight, grad_bias = ctx.saved_tensors
if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
grad_input = grad_input * grad_output
grad_weight = grad_weight * grad_output
grad_bias = grad_bias * grad_output if grad_bias is not None else None
return grad_input, grad_weight, None, None, None, grad_bias
from abc import abstractmethod
from functools import partial
import torch
import torch._dynamo.config
import torch.nn.functional as F
class LigerFusedLinearPPOBase(torch.autograd.Function):
@abstractmethod
def ppo_loss_fn(*args, **kwargs):
"""
To be extended by subclasses.
"""
raise NotImplementedError("PPO loss function must be implemented.")
@staticmethod
def forward(
cls,
ctx,
_input,
weight,
selected_token_ids,
attention_mask,
advantages,
bias=None,
ref_per_token_logps=None,
old_per_token_logps=None,
ref_input=None,
ref_weight=None,
ref_bias=None,
epsilon_low=0.2,
epsilon_high=0.2,
beta=0.04,
loss_type="dapo",
max_completion_length=None,
importance_sampling_level="token",
temperature=1.0,
compiled=True,
use_ref_model=False,
chunk_size=1,
sapo_temperature_pos=1.0,
sapo_temperature_neg=1.05,
vllm_is_ratio=None,
delta=None,
use_bias_correction_kl=False,
):
# TODO: check torch compile matmul
"""Chunked forward pass for PPO loss computation.
Args:
cls: The class
ctx: Context for backward
_input: Input tensor
weight: Weight tensor
selected_token_ids: Selected token ids tensor
attention_mask: Attention mask tensor
advantages: Advantages tensor
bias: Bias tensor
ref_per_token_logps: Reference model log probs per token tensor
old_per_token_logps: Old per token log probabilities tensor
ref_input: Reference model input tensor
ref_weight: Reference model weight tensor
ref_bias: Reference model bias tensor
epsilon_low: Lower bound for clipping the importance sampling ratio
epsilon_high: Upper bound for clipping the importance sampling ratio
beta: Weight for the KL penalty
loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo", "cispo", "sapo")
max_completion_length: Maximum completion length required for "dr_grpo"
importance_sampling_level: Level of importance sampling ("token" or "sequence")
temperature: Temperature for the logits
compiled: Whether to use torch compile
use_ref_model: Whether to use a reference model
chunk_size: Size of chunks for processing in other loss modules
sapo_temperature_pos: Temperature for positive advantages in SAPO
sapo_temperature_neg: Temperature for negative advantages in SAPO
vllm_is_ratio: vLLM importance sampling ratio tensor (batch_size, seq_len) or (batch_size, 1) or None.
Used to correct for distribution mismatch when using vLLM for generation.
"""
if use_ref_model:
assert ref_per_token_logps is not None or ref_input is not None, (
"If use_ref_model is True, ref_per_token_logps or ref_input must be provided"
)
if ref_per_token_logps is not None and ref_input is not None:
raise Warning("Both ref_per_token_logps and ref_input are provided. Using ref_per_token_logps.")
if loss_type == "dr_grpo":
assert max_completion_length is not None, "max_completion_length must be provided for loss_type 'dr_grpo'"
if vllm_is_ratio is not None:
B, T = attention_mask.shape
assert vllm_is_ratio.dim() in (1, 2), (
f"vllm_is_ratio must be 1D (B,) or 2D (B, T) / (B, 1), got {vllm_is_ratio.dim()}D"
)
if vllm_is_ratio.dim() == 2:
assert vllm_is_ratio.shape[0] == B and vllm_is_ratio.shape[1] in (1, T), (
f"vllm_is_ratio shape must be ({B}, 1) or ({B}, {T}), got {tuple(vllm_is_ratio.shape)}"
)
else:
assert vllm_is_ratio.shape[0] == B, (
f"vllm_is_ratio shape must be ({B},), got {tuple(vllm_is_ratio.shape)}"
)
vllm_is_ratio = vllm_is_ratio.unsqueeze(-1) # (B,) -> (B, 1) for broadcasting
# Initialize accumulators
loss_acc = torch.zeros((), device=_input.device, dtype=torch.float32)
grad_weight = torch.zeros_like(weight) # [V, H]
grad_inputs = []
grad_bias = torch.zeros_like(bias) if bias is not None else None # [V]
aggregated_metrics = []
# Create a partial function with fixed arguments
compute_loss = partial(
LigerFusedLinearPPOBase._compute_chunk_loss,
ref_weight=ref_weight,
ref_bias=ref_bias,
full_attention_mask=attention_mask,
epsilon_low=epsilon_low,
epsilon_high=epsilon_high,
beta=beta,
loss_type=loss_type,
max_completion_length=max_completion_length,
importance_sampling_level=importance_sampling_level,
temperature=temperature,
use_ref_model=use_ref_model,
ppo_loss_fn=cls.ppo_loss_fn,
sapo_temperature_pos=sapo_temperature_pos,
sapo_temperature_neg=sapo_temperature_neg,
delta=delta,
use_bias_correction_kl=use_bias_correction_kl,
)
def fused_fwd_bwd(
input_chunk,
selected_token_ids_chunk,
attention_mask_chunk,
advantages_chunk,
ref_per_token_logps_chunk,
old_per_token_logps_chunk,
ref_input_chunk,
vllm_is_ratio_chunk,
):
"""Fused forward and backward for a chunk."""
argnums = (0, 1, 5) if bias is not None else (0, 1)
return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=True)(
input_chunk, # arg 0
weight, # arg 1
selected_token_ids_chunk, # arg 2
attention_mask_chunk, # arg 3
advantages_chunk, # arg 4
bias, # arg 5
ref_per_token_logps_chunk=ref_per_token_logps_chunk, # arg 6
old_per_token_logps_chunk=old_per_token_logps_chunk, # arg 7
ref_input_chunk=ref_input_chunk, # arg 8
vllm_is_ratio_chunk=vllm_is_ratio_chunk, # arg 9
)
def accumulate_chunk(
input_chunk,
selected_token_ids_chunk,
attention_mask_chunk,
advantages_chunk,
ref_per_token_logps_chunk=None,
old_per_token_logps_chunk=None,
ref_input_chunk=None,
vllm_is_ratio_chunk=None,
):
(chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
input_chunk,
selected_token_ids_chunk,
attention_mask_chunk,
advantages_chunk,
ref_per_token_logps_chunk,
old_per_token_logps_chunk,
ref_input_chunk,
vllm_is_ratio_chunk,
)
if bias is not None:
grad_bias.add_(chunk_grad_bias[0])
# Accumulate gradients and loss
grad_weight.add_(chunk_grad_weight)
grad_inputs.append(chunk_grad_input)
loss_acc.add_(chunk_loss)
# Initialize storage for metrics on first chunk
if len(aggregated_metrics) == 0:
for metric in chunk_metrics:
if metric.ndim == 0:
aggregated_metrics.append(torch.zeros((), device=metric.device))
else:
aggregated_metrics.append([])
# Accumulate metrics
for i, metric in enumerate(chunk_metrics):
if metric.ndim == 0:
aggregated_metrics[i].add_(metric)
else:
aggregated_metrics[i].append(metric)
if compiled:
# TODO: Figure out what is better to compile here
# accumulate_chunk = torch.compile(accumulate_chunk)
fused_fwd_bwd = torch.compile(fused_fwd_bwd)
# Process input in chunks based on chunk_size
chunks = max(1, _input.shape[0] // chunk_size)
_input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
_selected_token_ids_chunks = torch.chunk(selected_token_ids, chunks=chunks, dim=0)
_attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0)
_advantages_chunks = torch.chunk(advantages, chunks=chunks, dim=0)
_ref_per_token_logps_chunks = (
torch.chunk(ref_per_token_logps, chunks=chunks, dim=0)
if use_ref_model and ref_per_token_logps is not None
else [None] * chunks
)
_old_per_token_logps_chunks = (
torch.chunk(old_per_token_logps, chunks=chunks, dim=0)
if old_per_token_logps is not None
else [None] * chunks
)
# if ref_log_probs is not none, then we don't need ref_input to calculate the log probs
_ref_input_chunks = (
torch.chunk(ref_input, chunks=chunks, dim=0)
if use_ref_model and ref_per_token_logps is None
else [None] * chunks
)
_vllm_is_ratio_chunks = (
torch.chunk(vllm_is_ratio, chunks=chunks, dim=0) if vllm_is_ratio is not None else [None] * chunks
)
for (
input_chunk,
selected_token_ids_chunk,
attention_mask_chunk,
advantages_chunk,
ref_per_token_logps_chunk,
old_per_token_logps_chunk,
ref_input_chunk,
vllm_is_ratio_chunk,
) in zip(
_input_chunks,
_selected_token_ids_chunks,
_attention_mask_chunks,
_advantages_chunks,
_ref_per_token_logps_chunks,
_old_per_token_logps_chunks,
_ref_input_chunks,
_vllm_is_ratio_chunks,
):
# Mark dynamic dimensions
torch._dynamo.mark_dynamic(input_chunk, 1)
torch._dynamo.mark_dynamic(selected_token_ids_chunk, 1)
torch._dynamo.mark_dynamic(attention_mask_chunk, 1)
if ref_per_token_logps_chunk is not None:
torch._dynamo.mark_dynamic(ref_per_token_logps_chunk, 1)
if ref_input_chunk is not None:
torch._dynamo.mark_dynamic(ref_input_chunk, 1)
if old_per_token_logps_chunk is not None:
torch._dynamo.mark_dynamic(old_per_token_logps_chunk, 1)
if vllm_is_ratio_chunk is not None:
torch._dynamo.mark_dynamic(vllm_is_ratio_chunk, 1)
accumulate_chunk(
input_chunk,
selected_token_ids_chunk,
attention_mask_chunk,
advantages_chunk,
ref_per_token_logps_chunk,
old_per_token_logps_chunk,
ref_input_chunk,
vllm_is_ratio_chunk,
)
# Combine gradients
grad_input = torch.cat(grad_inputs, dim=0)
# Save for backward
ctx.save_for_backward(grad_input, grad_weight, grad_bias)
# Finalize metrics
final_metrics = []
for metric in aggregated_metrics:
if isinstance(metric, list):
final_metrics.append(torch.cat(metric, dim=0))
else:
final_metrics.append(metric)
return loss_acc, tuple(final_metrics)
@staticmethod
def _compute_dapo_normalizer(attention_mask):
"""Global active tokens averaged per process."""
normalizer = attention_mask.to(torch.float32).sum()
world_size = 1
if torch.distributed.is_available() and torch.distributed.is_initialized():
import torch.distributed as dist
normalizer = normalizer.clone()
dist.all_reduce(normalizer, op=dist.ReduceOp.SUM)
world_size = dist.get_world_size()
normalizer = normalizer / world_size
return torch.clamp(normalizer, min=1.0)
@staticmethod
def _compute_chunk_loss(
input_chunk,
weight,
selected_token_ids_chunk,
attention_mask_chunk,
advantages_chunk,
bias=None,
ref_per_token_logps_chunk=None,
old_per_token_logps_chunk=None,
ref_input_chunk=None,
vllm_is_ratio_chunk=None,
ref_weight=None,
ref_bias=None,
full_attention_mask=None,
epsilon_low=0.2,
epsilon_high=0.2,
beta=0.04,
loss_type="dapo",
max_completion_length=None,
importance_sampling_level="token",
temperature=1.0,
use_ref_model=False,
ppo_loss_fn=None,
sapo_temperature_pos=1.0,
sapo_temperature_neg=1.05,
delta=None,
use_bias_correction_kl=False,
):
"""Compute loss for a single chunk."""
# Get policy log probabilities using chunk_forward
log_probs, _ = LigerFusedLinearPPOBase.chunk_forward(input_chunk, weight, bias=bias, temperature=temperature)
# Get reference log probabilities if needed
ref_log_probs = None
if use_ref_model and ref_per_token_logps_chunk is None:
with torch.no_grad():
ref_log_probs, _ = LigerFusedLinearPPOBase.chunk_forward(
ref_input_chunk, ref_weight, bias=ref_bias, temperature=temperature
)
# Compute chunk loss and metrics using the provided loss function
chunk_loss, chunk_metrics = ppo_loss_fn(
log_probs=log_probs,
selected_token_ids=selected_token_ids_chunk,
attention_mask=attention_mask_chunk,
advantages=advantages_chunk,
full_attention_mask=full_attention_mask,
ref_per_token_logps=ref_per_token_logps_chunk.float() if ref_per_token_logps_chunk is not None else None,
old_per_token_logps=old_per_token_logps_chunk.float() if old_per_token_logps_chunk is not None else None,
ref_log_probs=ref_log_probs, # used when ref_per_token_logps is None
epsilon_low=epsilon_low,
epsilon_high=epsilon_high,
beta=beta,
loss_type=loss_type,
max_completion_length=max_completion_length,
importance_sampling_level=importance_sampling_level,
sapo_temperature_pos=sapo_temperature_pos,
sapo_temperature_neg=sapo_temperature_neg,
vllm_is_ratio=vllm_is_ratio_chunk,
delta=delta,
use_bias_correction_kl=use_bias_correction_kl,
)
return chunk_loss, chunk_metrics
@staticmethod
def chunk_forward(input_chunk, weight, bias=None, temperature=1.0):
"""Forward pass computation for a single chunk without explicit reshaping."""
# Directly compute logits via batched matrix multiplication: [B, T, H] @ [H, V] -> [B, T, V]
logits = torch.matmul(input_chunk, weight.t())
if bias is not None:
logits = logits + bias # Broadcasts bias to [B, T, V]
if temperature != 1.0:
logits = logits / temperature
# Compute log probabilities using softmax over the last dimension
log_probs = F.log_softmax(logits.float(), dim=-1)
return log_probs, logits
@staticmethod
def backward(ctx, grad_output, *grad_metrics):
"""Backward pass for PPO loss."""
grad_input, grad_weight, grad_bias = ctx.saved_tensors
if grad_output != 1.0:
grad_input = grad_input * grad_output
grad_weight = grad_weight * grad_output
if grad_bias is not None:
grad_bias = grad_bias * grad_output
return (
grad_input,
grad_weight,
None, # grad_selected_token_ids
None, # grad_attention_mask
None, # grad_advantages
grad_bias,
None, # grad_ref_per_token_logps
None, # grad_old_per_token_logps
None, # grad_ref_input
None, # grad_ref_weight
None, # grad_ref_bias
None, # grad_epsilon_low
None, # grad_epsilon_high
None, # grad_beta
None, # grad_loss_type
None, # grad_max_completion_length
None, # grad_importance_sampling_level
None, # grad_temperature
None, # grad_compiled
None, # grad_use_ref_model
None, # grad_chunk_size
None, # grad_sapo_temperature_pos
None, # grad_sapo_temperature_neg
None, # grad_vllm_is_ratio
None, # grad_delta
None, # grad_use_bias_correction_kl
)
from abc import abstractmethod
from functools import partial
import torch
from torch.nn import functional as F
class LigerFusedLinearPreferenceBase(torch.autograd.Function):
@abstractmethod
def preference_loss_fn(*args, **kwargs):
"""
To be extended by subclasses.
"""
raise NotImplementedError("Preference loss function must be implemented.")
@staticmethod
def forward(
cls,
ctx,
_input,
weight,
target,
bias=None,
chunk_size=1,
ignore_index=-100,
alpha=1.0,
beta=0.1,
compute_nll_loss=True,
nll_target=None,
compiled=True,
use_ref_model=False,
ref_input=None,
ref_weight=None,
ref_bias=None,
average_log_prob=True,
**loss_kwargs,
):
"""
Base class for fused linear layer with preference loss.
Expects _input to be stacked with chosen and rejected inputs on the batch dimension.
The mental model is:
forward()
├── Loop over chunks
└── compute_loss()
├── chunk_forward() # Compute logits and log probs
└── prefer_loss() # Calculate preference loss
Args:
_input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size).
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
target (torch.Tensor): Target tensor. Shape: (batch_size, seq_len).
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs).
ignore_index (int): Index to ignore for loss computation.
alpha (float): Weight for the NLL loss.
beta (float): Weight for the preference loss.
compute_nll_loss (bool): Whether to compute NLL loss.
nll_target (torch.Tensor, optional): Target tensor for NLL loss. Shape: (batch_size, seq_len). If not provided the target is used.
compiled (bool): Whether to use torch compile for chunk accumulation.
use_ref_model (bool): Whether to use a reference model for the alignment loss.
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
average_log_prob (bool): Whether to average log probabilities or to sum them over the completion.
loss_kwargs (dict): Other possible arguments that a loss function might need
"""
# TODO: Tune CHUNK_SIZE to fully utilize the GPU
CHUNK_SIZE = chunk_size
# Gradients to be accumulated
grad_weight = torch.zeros_like(weight)
grad_chosen_inputs = []
grad_rejected_inputs = []
grad_bias = torch.zeros_like(bias) if bias is not None else None
# Loss to be accumulated
loss_acc = torch.zeros((), device=_input.device)
# Metrics to be recorded
policy_chosen_logps = []
policy_rejected_logps = []
policy_chosen_logits_mean = torch.zeros((), device=_input.device)
policy_rejected_logits_mean = torch.zeros((), device=_input.device)
policy_nll_loss = torch.zeros((), device=_input.device)
aggregated_aux_outputs = [] # aggregated aux outputs from all chunks
compute_loss = partial(
LigerFusedLinearPreferenceBase._compute_loss,
preference_loss_fn=cls.preference_loss_fn,
ignore_index=ignore_index,
alpha=alpha,
beta=beta,
compute_nll_loss=compute_nll_loss,
full_target=target,
use_ref_model=use_ref_model,
ref_weight=ref_weight,
ref_bias=ref_bias,
full_nll_target=nll_target,
average_log_prob=average_log_prob,
**loss_kwargs,
)
def fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk):
"""
Fused forward and backward pass for a chunk of input and target.
"""
if bias is not None:
return torch.func.grad_and_value(compute_loss, argnums=(0, 1, 3), has_aux=True)(
input_chunk,
weight,
target_chunk,
bias,
ref_input_chunk=ref_input_chunk,
chosen_nll_target_chunk=chosen_nll_target_chunk,
)
else:
return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)(
input_chunk,
weight,
target_chunk,
ref_input_chunk=ref_input_chunk,
chosen_nll_target_chunk=chosen_nll_target_chunk,
)
def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None, chosen_nll_target_chunk=None):
if bias is not None:
(
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias),
(
chunk_loss,
(
chunk_chosen_logps,
chunk_rejected_logps,
chunk_chosen_logits_mean,
chunk_rejected_logits_mean,
chunk_nll_loss,
*aux_outputs,
),
),
) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)
grad_bias.add_(chunk_grad_bias) # accumulate bias gradient
else:
(
(chunk_grad_input, chunk_grad_weight),
(
chunk_loss,
(
chunk_chosen_logps,
chunk_rejected_logps,
chunk_chosen_logits_mean,
chunk_rejected_logits_mean,
chunk_nll_loss,
*aux_outputs,
),
),
) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)
# Accumulate gradients
grad_weight.add_(chunk_grad_weight)
grad_chosen_inputs.append(chunk_grad_input[: chosen_target_chunk.shape[0]])
grad_rejected_inputs.append(chunk_grad_input[chosen_target_chunk.shape[0] :])
# Accumulate loss
loss_acc.add_(chunk_loss)
# Accumulate metrics
policy_chosen_logps.append(chunk_chosen_logps)
policy_rejected_logps.append(chunk_rejected_logps)
policy_chosen_logits_mean.add_(chunk_chosen_logits_mean)
policy_rejected_logits_mean.add_(chunk_rejected_logits_mean)
policy_nll_loss.add_(chunk_nll_loss)
# aux_outputs
# Initialize storage for aux_outputs
if len(aggregated_aux_outputs) == 0:
for aux in aux_outputs:
if aux.ndim == 0:
aggregated_aux_outputs.append(torch.zeros((), device=aux.device))
else:
aggregated_aux_outputs.append([])
# Process each aux_output
for i, aux in enumerate(aux_outputs):
if aux.ndim == 0:
aggregated_aux_outputs[i].add_(aux)
else:
aggregated_aux_outputs[i].append(aux)
if compiled:
fused_fwd_bwd = torch.compile(fused_fwd_bwd)
len_chosen = target.shape[0] // 2
chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE))
_chosen_input_chunks = torch.chunk(_input[:len_chosen], chunks=chunks, dim=0)
_chosen_target_chunks = torch.chunk(target[:len_chosen], chunks=chunks, dim=0)
_rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0)
_rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0)
if nll_target is not None:
_chosen_nll_target_chunks = torch.chunk(nll_target[:len_chosen], chunks=chunks, dim=0)
if use_ref_model:
_ref_chosen_input_chunks = torch.chunk(ref_input[:len_chosen], chunks=chunks, dim=0)
_ref_rejected_input_chunks = torch.chunk(ref_input[len_chosen:], chunks=chunks, dim=0)
for (
chosen_input_chunk,
rejected_input_chunk,
chosen_target_chunk,
rejected_target_chunk,
ref_chosen_input_chunk,
ref_rejected_input_chunk,
chosen_nll_target_chunk,
) in zip(
_chosen_input_chunks,
_rejected_input_chunks,
_chosen_target_chunks,
_rejected_target_chunks,
(_ref_chosen_input_chunks if use_ref_model else [None] * len(_chosen_input_chunks)),
(_ref_rejected_input_chunks if use_ref_model else [None] * len(_rejected_input_chunks)),
(_chosen_nll_target_chunks if nll_target is not None else [None] * len(_chosen_input_chunks)),
):
input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0)
ref_input_chunk = (
torch.cat([ref_chosen_input_chunk, ref_rejected_input_chunk], dim=0) if use_ref_model else None
)
target_chunk = torch.cat([chosen_target_chunk, rejected_target_chunk], dim=0)
# mark input_chunk, target_chunk, and target dimension 1 as dynamic to prevent torch.compile recompilation
torch._dynamo.mark_dynamic(input_chunk, 1)
torch._dynamo.mark_dynamic(target_chunk, 1)
torch._dynamo.mark_dynamic(target, 1)
torch._dynamo.mark_dynamic(ref_input_chunk, 1) if use_ref_model else None
torch._dynamo.mark_dynamic(chosen_nll_target_chunk, 1) if nll_target is not None else None
# accumulate loss, gradients, and metrics
accumulate_chunk(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)
# combine grad_chosen_inputs and grad_rejected_inputs
grad_inputs = grad_chosen_inputs + grad_rejected_inputs
policy_chosen_logps = torch.cat(policy_chosen_logps, dim=0)
policy_rejected_logps = torch.cat(policy_rejected_logps, dim=0)
# Aggregate aux outputs lists into tensors
for i, aux in enumerate(aggregated_aux_outputs):
if isinstance(aux, list):
aggregated_aux_outputs[i] = torch.cat(aux, dim=0)
ctx.save_for_backward(
torch.cat(grad_inputs, dim=0),
grad_weight,
grad_bias,
)
return_vars = (
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits_mean,
policy_rejected_logits_mean,
policy_nll_loss,
)
return loss_acc, (*return_vars, *aggregated_aux_outputs)
@staticmethod
def backward(ctx, *grad_output):
grad_input, grad_weight, grad_bias = ctx.saved_tensors
if torch.ne(grad_output[0][0], torch.tensor(1.0, device=grad_output[0][0].device)):
grad_input = grad_input * grad_output[0][0]
grad_weight = grad_weight * grad_output[0][0]
grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None
return grad_input, grad_weight, None, grad_bias, None, None, None, None
@staticmethod
def chunk_forward(
input_chunk,
weight,
target_chunk,
bias=None,
ignore_index=-100,
compute_nll_loss=True,
chosen_nll_target_chunk=None,
average_log_prob=True,
):
len_chosen_chunk = target_chunk.shape[0] // 2
logits_chunk = input_chunk @ weight.t()
if bias is not None:
logits_chunk = logits_chunk + bias
log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)
chosen_nll_loss = 0.0
if compute_nll_loss:
nll_labels = (
chosen_nll_target_chunk if chosen_nll_target_chunk is not None else target_chunk[:len_chosen_chunk]
)
chosen_nll_loss = F.nll_loss(
log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
nll_labels.view(-1),
reduction="sum",
ignore_index=ignore_index,
)
loss_mask = target_chunk != ignore_index
label_chunk = torch.where(loss_mask, target_chunk, 0)
per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1)
if average_log_prob:
log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
else:
log_prob = (per_token_logps * loss_mask).sum(-1)
chosen_logps = log_prob[:len_chosen_chunk]
rejected_logps = log_prob[len_chosen_chunk:]
chosen_logits = logits_chunk[:len_chosen_chunk]
rejected_logits = logits_chunk[len_chosen_chunk:]
return (
chosen_logps,
rejected_logps,
chosen_logits,
rejected_logits,
chosen_nll_loss,
)
@staticmethod
def _compute_loss(
input_chunk,
weight,
target_chunk,
bias=None,
preference_loss_fn=None,
full_target=None,
ignore_index=-100,
alpha=1.0,
beta=0.1,
compute_nll_loss=True,
use_ref_model=False,
ref_input_chunk=None,
ref_weight=None,
ref_bias=None,
full_nll_target=None,
chosen_nll_target_chunk=None,
average_log_prob=True,
**loss_kwargs,
):
"""
Compute the total loss for a chunk of input and target, while using an alignment/preference loss function.
Args:
preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length).
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
ignore_index (int): Index to ignore for loss computation.
alpha (float): Weight for the NLL loss.
beta (float): Weight for the preference loss.
compute_nll_loss (bool): Whether to compute NLL loss.
use_ref_model (bool): Whether to use a reference model for the alignment loss.
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
full_nll_target (torch.Tensor, optional): Full target tensor for NLL loss. Shape: (batch_size, sequence_length).
chosen_nll_target_chunk (torch.Tensor, optional): Target tensor for NLL loss. Shape: (chunk_size, sequence_length) If not provided the target_chunk is used.
average_log_prob (bool): Whether to average log probabilities or the sum.
loss_kwargs (dict): Additional arguments for the loss function.
"""
(
chosen_logps,
rejected_logps,
chosen_logits,
rejected_logits,
chosen_nll_loss,
) = LigerFusedLinearPreferenceBase.chunk_forward(
input_chunk,
weight,
target_chunk,
bias=bias,
ignore_index=ignore_index,
compute_nll_loss=compute_nll_loss,
chosen_nll_target_chunk=chosen_nll_target_chunk,
average_log_prob=average_log_prob,
)
if full_nll_target is not None:
chosen_nll_loss = chosen_nll_loss / (full_nll_target[: full_nll_target.shape[0] // 2] != ignore_index).sum()
else:
chosen_nll_loss = chosen_nll_loss / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
chosen_logits_mean = chosen_logits.sum() / (full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0])
rejected_logits_mean = rejected_logits.sum() / (
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
)
if use_ref_model:
with torch.no_grad():
(
ref_chosen_logps,
ref_rejected_logps,
_,
_,
_,
) = LigerFusedLinearPreferenceBase.chunk_forward(
ref_input_chunk,
ref_weight,
target_chunk,
ref_bias,
ignore_index=ignore_index,
compute_nll_loss=False, # We don't need NLL loss for the reference model
chosen_nll_target_chunk=None,
average_log_prob=average_log_prob,
)
loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
loss_kwargs["ref_rejected_logps"] = ref_rejected_logps
preference_loss_outputs = preference_loss_fn(
chosen_logps, rejected_logps, full_target, beta=beta, **loss_kwargs
)
if isinstance(preference_loss_outputs, tuple):
preference_loss, *aux_outputs = preference_loss_outputs
else:
preference_loss, aux_outputs = preference_loss_outputs, []
loss = alpha * chosen_nll_loss + preference_loss
return_vars = (
chosen_logps,
rejected_logps,
chosen_logits_mean,
rejected_logits_mean,
chosen_nll_loss,
)
return loss, (*return_vars, *aux_outputs)
from abc import abstractmethod
from functools import partial
import torch
from torch.nn import functional as F
class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
@abstractmethod
def preference_loss_fn(*args, **kwargs):
"""
To be extended by subclasses.
"""
raise NotImplementedError("Preference loss function must be implemented.")
@staticmethod
def forward(
cls,
ctx,
_input,
weight,
target,
preference_labels,
bias=None,
chunk_size=1,
ignore_index=-100,
compiled=True,
use_ref_model=False,
ref_input=None,
ref_weight=None,
ref_bias=None,
average_log_prob=False,
**loss_kwargs,
):
"""
Base class for fused linear layer with unpaired preference loss like KTO
Expects _input to be stacked with chosen and rejected inputs on the batch dimension.
The mental model is:
forward()
├── Loop over chunks
└── compute_loss()
├── chunk_forward() # Compute logits and log probs
└── prefer_loss() # Calculate preference loss
Args:
_input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size).
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
target (torch.Tensor): Target tensor. Shape: (batch_size, seq_len).
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs).
ignore_index (int): Index to ignore for loss computation.
beta (float): Weight for the preference loss.
compiled (bool): Whether to use torch compile for chunk accumulation.
use_ref_model (bool): Whether to use a reference model for the alignment loss.
preference_labels (torch.Tensor): Boolean tensor indicating chosen (True) vs rejected (False) examples.
Shape: (batch_size,).
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
average_log_prob (bool): Whether to average the log probability per non-masked token.
loss_kwargs (dict): Other possible arguments that a loss function might need
"""
# TODO: Tune CHUNK_SIZE to fully utilize the GPU
CHUNK_SIZE = chunk_size
# Gradients to be accumulated
grad_inputs = []
grad_weight = torch.zeros_like(weight)
grad_bias = torch.zeros_like(bias) if bias is not None else None
# Loss to be accumulated
loss_acc = torch.zeros((), device=_input.device)
# Metrics to be recorded
chosen_logps_sum = torch.zeros((), device=_input.device)
rejected_logps_sum = torch.zeros((), device=_input.device)
chosen_logits_sum = torch.zeros((), device=_input.device)
rejected_logits_sum = torch.zeros((), device=_input.device)
aggregated_aux_outputs = []
compute_loss = partial(
LigerFusedLinearUnpairedPreferenceBase._compute_loss,
preference_loss_fn=cls.preference_loss_fn,
full_target=target,
ignore_index=ignore_index,
use_ref_model=use_ref_model,
ref_weight=ref_weight,
ref_bias=ref_bias,
average_log_prob=average_log_prob,
**loss_kwargs,
)
def fused_fwd_bwd(input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk):
"""
Fused forward and backward pass for a chunk of input and target.
"""
argnums = (0, 1, 4) if bias is not None else (0, 1)
return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=True)(
input_chunk,
weight,
target_chunk,
preference_labels_chunk,
bias,
ref_input_chunk=ref_input_chunk,
)
def accumulate_chunk(
input_chunk,
target_chunk,
preference_labels_chunk=None,
ref_input_chunk=None,
):
(
(chunk_grad_input, chunk_grad_weight, *chunk_grad_bias),
(
chunk_loss,
(
chunk_chosen_logps_sum,
chunk_rejected_logps_sum,
chunk_chosen_logits_sum,
chunk_rejected_logits_sum,
*aux_outputs,
),
),
) = fused_fwd_bwd(input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk)
if bias is not None:
grad_bias.add_(chunk_grad_bias[0]) # accumulate bias gradient
# Accumulate gradients
grad_weight.add_(chunk_grad_weight)
grad_inputs.append(chunk_grad_input)
# Accumulate loss
loss_acc.add_(chunk_loss)
# Accumulate metrics
chosen_logps_sum.add_(chunk_chosen_logps_sum)
rejected_logps_sum.add_(chunk_rejected_logps_sum)
chosen_logits_sum.add_(chunk_chosen_logits_sum)
rejected_logits_sum.add_(chunk_rejected_logits_sum)
# aux_outputs
# Initialize storage for aux_outputs
if len(aggregated_aux_outputs) == 0:
for aux in aux_outputs:
aggregated_aux_outputs.append(torch.zeros((), device=aux.device))
# Process each aux_output
for i, aux in enumerate(aux_outputs):
if aux.ndim == 0:
aggregated_aux_outputs[i].add_(aux)
if compiled:
fused_fwd_bwd = torch.compile(fused_fwd_bwd)
# When not paired, use labels to separate chosen and rejected
assert preference_labels is not None, "preference_labels must be provided for unpaired preference loss"
chunks = max(1, _input.shape[0] // CHUNK_SIZE)
_input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
_target_chunks = torch.chunk(target, chunks=chunks, dim=0)
_preference_labels_chunks = torch.chunk(preference_labels, chunks=chunks, dim=0)
if use_ref_model:
_ref_input_chunks = torch.chunk(ref_input, chunks=chunks, dim=0)
for (
input_chunk,
target_chunk,
ref_input_chunk,
preference_labels_chunk,
) in zip(
_input_chunks,
_target_chunks,
(_ref_input_chunks if use_ref_model else [None] * len(_input_chunks)),
_preference_labels_chunks,
):
# mark input_chunk, target_chunk, and target dimension 1 (sequence length) as dynamic to prevent torch.compile recompilation
torch._dynamo.mark_dynamic(input_chunk, 1)
torch._dynamo.mark_dynamic(target_chunk, 1)
torch._dynamo.mark_dynamic(target, 1)
torch._dynamo.mark_dynamic(ref_input_chunk, 1) if use_ref_model else None
torch._dynamo.mark_dynamic(preference_labels_chunk, 1)
# accumulate loss, gradients, and metrics
accumulate_chunk(input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk)
# Aggregate aux outputs lists into tensors
for i, aux in enumerate(aggregated_aux_outputs):
if isinstance(aux, list):
aggregated_aux_outputs[i] = torch.cat(aux, dim=0)
ctx.save_for_backward(
torch.cat(grad_inputs, dim=0),
grad_weight,
grad_bias,
)
return_vars = (
chosen_logps_sum,
rejected_logps_sum,
chosen_logits_sum,
rejected_logits_sum,
)
return loss_acc, (*return_vars, *aggregated_aux_outputs)
@staticmethod
def backward(ctx, *grad_output):
grad_input, grad_weight, grad_bias = ctx.saved_tensors
if torch.ne(grad_output[0][0], torch.tensor(1.0, device=grad_output[0][0].device)):
grad_input = grad_input * grad_output[0][0]
grad_weight = grad_weight * grad_output[0][0]
grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None
return grad_input, grad_weight, None, None, grad_bias
@staticmethod
def chunk_forward(
input_chunk,
weight,
target_chunk,
preference_labels_chunk,
bias=None,
ignore_index=-100,
average_log_prob=False,
):
logits_chunk = input_chunk @ weight.t()
if bias is not None:
logits_chunk = logits_chunk + bias
log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)
loss_mask_chunk = target_chunk != ignore_index
label_chunk = torch.where(loss_mask_chunk, target_chunk, 0)
per_token_logps_chunk = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1)
if average_log_prob:
log_probs = (per_token_logps_chunk * loss_mask_chunk).sum(-1) / loss_mask_chunk.sum(-1)
else:
log_probs = (per_token_logps_chunk * loss_mask_chunk).sum(-1)
chosen_logps_sum = (log_probs * preference_labels_chunk.unsqueeze(1)).sum()
rejected_logps_sum = (log_probs * (~preference_labels_chunk).unsqueeze(1)).sum()
chosen_logits_sum = (logits_chunk * preference_labels_chunk.unsqueeze(1)).sum()
rejected_logits_sum = (logits_chunk * (~preference_labels_chunk).unsqueeze(1)).sum()
return (
log_probs,
chosen_logps_sum,
rejected_logps_sum,
chosen_logits_sum,
rejected_logits_sum,
)
@staticmethod
def _compute_loss(
input_chunk,
weight,
target_chunk,
preference_labels_chunk,
bias=None,
preference_loss_fn=None,
full_target=None,
ignore_index=-100,
use_ref_model=False,
ref_input_chunk=None,
ref_weight=None,
ref_bias=None,
average_log_prob=False,
**loss_kwargs,
):
"""
Compute the total loss for a chunk of input and target, while using an alignment/preference loss function.
Args:
preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length).
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
ignore_index (int): Index to ignore for loss computation.
use_ref_model (bool): Whether to use a reference model for the alignment loss.
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
average_log_prob (bool): Whether to average the log probability per non-masked token.
loss_kwargs (dict): Additional arguments for the loss function.
"""
(
log_prob_chunk,
chosen_logps_sum,
rejected_logps_sum,
chosen_logits_sum,
rejected_logits_sum,
) = LigerFusedLinearUnpairedPreferenceBase.chunk_forward(
input_chunk,
weight,
target_chunk,
preference_labels_chunk,
bias=bias,
ignore_index=ignore_index,
average_log_prob=average_log_prob,
)
if use_ref_model:
with torch.no_grad():
(
ref_log_prob_chunk,
_,
_,
_,
_,
) = LigerFusedLinearUnpairedPreferenceBase.chunk_forward(
ref_input_chunk,
ref_weight,
target_chunk,
preference_labels_chunk,
ref_bias,
ignore_index=ignore_index,
average_log_prob=average_log_prob,
)
loss_kwargs["ref_log_prob_chunk"] = ref_log_prob_chunk
preference_loss_outputs = preference_loss_fn(
log_prob_chunk, preference_labels_chunk, full_target, **loss_kwargs
)
if isinstance(preference_loss_outputs, tuple):
preference_loss_chunk, *aux_outputs = preference_loss_outputs
else:
preference_loss_chunk, aux_outputs = preference_loss_outputs, []
return_vars = (
chosen_logps_sum,
rejected_logps_sum,
chosen_logits_sum,
rejected_logits_sum,
)
return preference_loss_chunk, (*return_vars, *aux_outputs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment