Commit a380ace2 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [2024-] [Unsloth AI, Daniel Han-Chen & Michael Han-Chen]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
<div align="center">
<a href="https://unsloth.ai"><picture>
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20logo%20white%20text.png">
<source media="(prefers-color-scheme: light)" srcset="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20logo%20black%20text.png">
<img alt="unsloth logo" src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20logo%20black%20text.png" height="110" style="max-width: 100%;">
</picture></a>
<a href="https://colab.research.google.com/drive/1Ys44kVvmeZtnICzWz0xgpRnrIOjZAuxp?usp=sharing"><img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/start free finetune button.png" height="48"></a>
<a href="https://discord.gg/unsloth"><img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/Discord button.png" height="48"></a>
<a href="https://ko-fi.com/unsloth"><img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/buy me a coffee button.png" height="48"></a>
### Finetune Llama 3.1, Mistral, Phi-3 & Gemma 2-5x faster with 80% less memory!
![](https://i.ibb.co/sJ7RhGG/image-41.png)
</div>
## ✨ Finetune for Free
All notebooks are **beginner friendly**! Add your dataset, click "Run All", and you'll get a 2x faster finetuned model which can be exported to GGUF, Ollama, vLLM or uploaded to Hugging Face.
| Unsloth supports | Free Notebooks | Performance | Memory use |
|-----------|---------|--------|----------|
| **Llama 3.1 (8B)** | [▶️ Start for free](https://colab.research.google.com/drive/1Ys44kVvmeZtnICzWz0xgpRnrIOjZAuxp?usp=sharing) | 2x faster | 60% less |
| **Mistral Nemo (12B)** | [▶️ Start for free](https://colab.research.google.com/drive/17d3U-CAIwzmbDRqbZ9NnpHxCkmXB6LZ0?usp=sharing) | 2x faster | 60% less |
| **Gemma 2 (9B)** | [▶️ Start for free](https://colab.research.google.com/drive/1vIrqH5uYDQwsJ4-OO3DErvuv4pBgVwk4?usp=sharing) | 2x faster | 63% less |
| **Phi-3 (mini)** | [▶️ Start for free](https://colab.research.google.com/drive/1lN6hPQveB_mHSnTOYifygFcrO8C1bxq4?usp=sharing) | 2x faster | 50% less |
| **Ollama** | [▶️ Start for free](https://colab.research.google.com/drive/1WZDi7APtQ9VsvOrQSSC5DDtxq159j8iZ?usp=sharing) | 1.9x faster | 43% less |
| **Mistral v0.3 (7B)** | [▶️ Start for free](https://colab.research.google.com/drive/1_yNCks4BTD5zOnjozppphh5GzMFaMKq_?usp=sharing) | 2.2x faster | 73% less |
| **ORPO** | [▶️ Start for free](https://colab.research.google.com/drive/11t4njE3c4Lxl-07OD8lJSMKkfyJml3Tn?usp=sharing) | 1.9x faster | 43% less |
| **DPO Zephyr** | [▶️ Start for free](https://colab.research.google.com/drive/15vttTpzzVXv_tJwEk-hIcQ0S9FcEWvwP?usp=sharing) | 1.9x faster | 43% less |
| **TinyLlama** | [▶️ Start for free](https://colab.research.google.com/drive/1AZghoNBQaMDgWJpi4RbffGM1h6raLUj9?usp=sharing) | 3.9x faster | 74% less |
- **Kaggle Notebooks** for [Llama 3.1 (8B)](https://www.kaggle.com/danielhanchen/kaggle-llama-3-1-8b-unsloth-notebook), [Gemma 2 (9B)](https://www.kaggle.com/code/danielhanchen/kaggle-gemma-7b-unsloth-notebook/), [Mistral (7B)](https://www.kaggle.com/code/danielhanchen/kaggle-mistral-7b-unsloth-notebook)
- Run [Llama 3 conversational notebook](https://colab.research.google.com/drive/1XamvWYinY6FOSX9GLvnqSjjsNflxdhNc?usp=sharing) and [Mistral v0.3 ChatML](https://colab.research.google.com/drive/15F1xyn8497_dUbxZP4zWmPZ3PJx1Oymv?usp=sharing)
- This [text completion notebook](https://colab.research.google.com/drive/1ef-tab5bhkvWmBOObepl1WgJvfvSzn5Q?usp=sharing) is for continued pretraining / raw text
- This [continued pretraining notebook](https://colab.research.google.com/drive/1tEd1FrOXWMnCU9UIvdYhs61tkxdMuKZu?usp=sharing) is for learning another language
- Click [here](https://github.com/unslothai/unsloth/wiki) for detailed documentation for Unsloth.
## 🦥 Unsloth.ai News
- 📣 NEW! [Llama 3.1 8b, 70b](https://colab.research.google.com/drive/1Ys44kVvmeZtnICzWz0xgpRnrIOjZAuxp?usp=sharing) both Base and Instruct now supported
- 📣 NEW! [Mistral Nemo-12b](https://colab.research.google.com/drive/17d3U-CAIwzmbDRqbZ9NnpHxCkmXB6LZ0?usp=sharing) both Base and Instruct now supported
- 📣 NEW! [Gemma-2-9b](https://colab.research.google.com/drive/1vIrqH5uYDQwsJ4-OO3DErvuv4pBgVwk4?usp=sharing) and Gemma-2-27b now supported
- 📣 UPDATE! [Phi-3 mini](https://colab.research.google.com/drive/1hhdhBa1j_hsymiW9m-WzxQtgqTH_NHqi?usp=sharing) model updated. [Phi-3 Medium](https://colab.research.google.com/drive/1hhdhBa1j_hsymiW9m-WzxQtgqTH_NHqi?usp=sharing) 2x faster finetuning.
- 📣 NEW! Continued Pretraining [notebook](https://colab.research.google.com/drive/1tEd1FrOXWMnCU9UIvdYhs61tkxdMuKZu?usp=sharing) for other languages like Korean!
- 📣 NEW! Qwen2 now works
- 📣 [Mistral v0.3 Base](https://colab.research.google.com/drive/1_yNCks4BTD5zOnjozppphh5GzMFaMKq_?usp=sharing) and [Mistral v0.3 Instruct]
- 📣 [ORPO support](https://colab.research.google.com/drive/11t4njE3c4Lxl-07OD8lJSMKkfyJml3Tn?usp=sharing) is here + [2x faster inference](https://colab.research.google.com/drive/1aqlNQi7MMJbynFDyOQteD2t0yVfjb9Zh?usp=sharing) added for all our models
- 📣 We cut memory usage by a [further 30%](https://unsloth.ai/blog/long-context) and now support [4x longer context windows](https://unsloth.ai/blog/long-context)!
## 🔗 Links and Resources
| Type | Links |
| ------------------------------- | --------------------------------------- |
| 📚 **Documentation & Wiki** | [Read Our Wiki](https://github.com/unslothai/unsloth/wiki) |
| <img height="14" src="https://upload.wikimedia.org/wikipedia/commons/6/6f/Logo_of_Twitter.svg" />&nbsp; **Twitter (aka X)** | [Follow us on X](https://twitter.com/unslothai)|
| 💾 **Installation** | [unsloth/README.md](https://github.com/unslothai/unsloth/tree/main#installation-instructions)|
| 🥇 **Benchmarking** | [Performance Tables](https://github.com/unslothai/unsloth/tree/main#-performance-benchmarking)
| 🌐 **Released Models** | [Unsloth Releases](https://huggingface.co/unsloth)|
| ✍️ **Blog** | [Read our Blogs](https://unsloth.ai/blog)|
## ⭐ Key Features
- All kernels written in [OpenAI's Triton](https://openai.com/research/triton) language. **Manual backprop engine**.
- **0% loss in accuracy** - no approximation methods - all exact.
- No change of hardware. Supports NVIDIA GPUs since 2018+. Minimum CUDA Capability 7.0 (V100, T4, Titan V, RTX 20, 30, 40x, A100, H100, L40 etc) [Check your GPU!](https://developer.nvidia.com/cuda-gpus) GTX 1070, 1080 works, but is slow.
- Works on **Linux** and **Windows** via WSL.
- Supports 4bit and 16bit QLoRA / LoRA finetuning via [bitsandbytes](https://github.com/TimDettmers/bitsandbytes).
- Open source trains 5x faster - see [Unsloth Pro](https://unsloth.ai/) for up to **30x faster training**!
- If you trained a model with 🦥Unsloth, you can use this cool sticker! &nbsp; <img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/made with unsloth.png" height="50" align="center" />
## 🥇 Performance Benchmarking
- For the full list of **reproducible** benchmarking tables, [go to our website](https://unsloth.ai/blog/mistral-benchmark#Benchmark%20tables)
| 1 A100 40GB | 🤗Hugging Face | Flash Attention | 🦥Unsloth Open Source | 🦥[Unsloth Pro](https://unsloth.ai/pricing) |
|--------------|--------------|-----------------|---------------------|-----------------|
| Alpaca | 1x | 1.04x | 1.98x | **15.64x** |
| LAION Chip2 | 1x | 0.92x | 1.61x | **20.73x** |
| OASST | 1x | 1.19x | 2.17x | **14.83x** |
| Slim Orca | 1x | 1.18x | 2.22x | **14.82x** |
- Benchmarking table below was conducted by [🤗Hugging Face](https://huggingface.co/blog/unsloth-trl).
| Free Colab T4 | Dataset | 🤗Hugging Face | Pytorch 2.1.1 | 🦥Unsloth | 🦥 VRAM reduction |
| --- | --- | --- | --- | --- | --- |
| Llama-2 7b | OASST | 1x | 1.19x | 1.95x | -43.3% |
| Mistral 7b | Alpaca | 1x | 1.07x | 1.56x | -13.7% |
| Tiny Llama 1.1b | Alpaca | 1x | 2.06x | 3.87x | -73.8% |
| DPO with Zephyr | Ultra Chat | 1x | 1.09x | 1.55x | -18.6% |
![](https://i.ibb.co/sJ7RhGG/image-41.png)
## 💾 Installation Instructions
### Conda Installation
Select either `pytorch-cuda=11.8` for CUDA 11.8 or `pytorch-cuda=12.1` for CUDA 12.1. If you have `mamba`, use `mamba` instead of `conda` for faster solving. See this [Github issue](https://github.com/unslothai/unsloth/issues/73) for help on debugging Conda installs.
```bash
conda create --name unsloth_env \
python=3.10 \
pytorch-cuda=<11.8/12.1> \
pytorch cudatoolkit xformers -c pytorch -c nvidia -c xformers \
-y
conda activate unsloth_env
pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
pip install --no-deps "trl<0.9.0" peft accelerate bitsandbytes
```
### Pip Installation
Do **NOT** use this if you have Anaconda. You must use the Conda install method, or else stuff will BREAK.
For DCU(dtk24.04.1、Pytorch 2.1.0):
cd unsloth
pip install .
```
# modified unsloth
vim unsloth/kernels/cross_entropy_loss.py:
MAX_FUSED_SIZE = 65536 -> MAX_FUSED_SIZE = 16384
num_warps = 32 -> num_warps = 8 # 位于Fast_CrossEntropyLoss类的_chunked_cross_entropy_forward[(n_rows, n_chunks,)]下面
vim unsloth/kernels/utils.py
if BLOCK_SIZE >= 32768: num_warps = 32 -> if BLOCK_SIZE >= 32768: num_warps = 8
elif BLOCK_SIZE >= 8192: num_warps = 16 -> elif BLOCK_SIZE >= 8192: num_warps = 8
# 位于函数calculate_settings下面
vim unsloth/models/_utils.py
model_architectures = ["llama", "mistral", "gemma", "gemma2", "qwen2",] -> model_architectures = ["llama", "mistral", "qwen2",]
vim unsloth/models/llama.py
Q = Q.transpose(1, 2) -> Q = Q.transpose(1, 2).half()
K = K.transpose(1, 2) -> K = K.transpose(1, 2).half()
V = V.transpose(1, 2) -> V = V.transpose(1, 2).half()
# 位于函数LlamaAttention_fast_forward的elif HAS_FLASH_ATTENTION and attention_mask is None下面
```
For CUDA:
1. Find your CUDA version via
```python
import torch; torch.version.cuda
```
2. For Pytorch 2.1.0: You can update Pytorch via Pip (interchange `cu121` / `cu118`). Go to https://pytorch.org/ to learn more. Select either `cu118` for CUDA 11.8 or `cu121` for CUDA 12.1. If you have a RTX 3060 or higher (A100, H100 etc), use the `"ampere"` path. For Pytorch 2.1.1: go to step 3. For Pytorch 2.2.0: go to step 4.
```bash
pip install --upgrade --force-reinstall --no-cache-dir torch==2.1.0 triton \
--index-url https://download.pytorch.org/whl/cu121
```
```bash
pip install "unsloth[cu118] @ git+https://github.com/unslothai/unsloth.git"
pip install "unsloth[cu121] @ git+https://github.com/unslothai/unsloth.git"
pip install "unsloth[cu118-ampere] @ git+https://github.com/unslothai/unsloth.git"
pip install "unsloth[cu121-ampere] @ git+https://github.com/unslothai/unsloth.git"
```
3. For Pytorch 2.1.1: Use the `"ampere"` path for newer RTX 30xx GPUs or higher.
```bash
pip install --upgrade --force-reinstall --no-cache-dir torch==2.1.1 triton \
--index-url https://download.pytorch.org/whl/cu121
```
```bash
pip install "unsloth[cu118-torch211] @ git+https://github.com/unslothai/unsloth.git"
pip install "unsloth[cu121-torch211] @ git+https://github.com/unslothai/unsloth.git"
pip install "unsloth[cu118-ampere-torch211] @ git+https://github.com/unslothai/unsloth.git"
pip install "unsloth[cu121-ampere-torch211] @ git+https://github.com/unslothai/unsloth.git"
```
4. For Pytorch 2.2.0: Use the `"ampere"` path for newer RTX 30xx GPUs or higher.
```bash
pip install --upgrade --force-reinstall --no-cache-dir torch==2.2.0 triton \
--index-url https://download.pytorch.org/whl/cu121
```
```bash
pip install "unsloth[cu118-torch220] @ git+https://github.com/unslothai/unsloth.git"
pip install "unsloth[cu121-torch220] @ git+https://github.com/unslothai/unsloth.git"
pip install "unsloth[cu118-ampere-torch220] @ git+https://github.com/unslothai/unsloth.git"
pip install "unsloth[cu121-ampere-torch220] @ git+https://github.com/unslothai/unsloth.git"
```
5. If you get errors, try the below first, then go back to step 1:
```bash
pip install --upgrade pip
```
6. For Pytorch 2.2.1:
```bash
# RTX 3090, 4090 Ampere GPUs:
pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
pip install --no-deps packaging ninja einops flash-attn xformers trl peft accelerate bitsandbytes
# Pre Ampere RTX 2080, T4, GTX 1080 GPUs:
pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
pip install --no-deps xformers "trl<0.9.0" peft accelerate bitsandbytes
```
7. For Pytorch 2.3.0: Use the `"ampere"` path for newer RTX 30xx GPUs or higher.
```bash
pip install "unsloth[cu118-torch230] @ git+https://github.com/unslothai/unsloth.git"
pip install "unsloth[cu121-torch230] @ git+https://github.com/unslothai/unsloth.git"
pip install "unsloth[cu118-ampere-torch230] @ git+https://github.com/unslothai/unsloth.git"
pip install "unsloth[cu121-ampere-torch230] @ git+https://github.com/unslothai/unsloth.git"
```
8. To troubleshoot installs try the below (all must succeed). Xformers should mostly all be available.
```bash
nvcc
python -m xformers.info
python -m bitsandbytes
```
## 📜 Documentation
- Go to our [Wiki page](https://github.com/unslothai/unsloth/wiki) for saving to GGUF, checkpointing, evaluation and more!
- We support Huggingface's TRL, Trainer, Seq2SeqTrainer or even Pytorch code!
- We're in 🤗Hugging Face's official docs! Check out the [SFT docs](https://huggingface.co/docs/trl/main/en/sft_trainer#accelerate-fine-tuning-2x-using-unsloth) and [DPO docs](https://huggingface.co/docs/trl/main/en/dpo_trainer#accelerate-dpo-fine-tuning-using-unsloth)!
```python
from unsloth import FastLanguageModel
from unsloth import is_bfloat16_supported
import torch
from trl import SFTTrainer
from transformers import TrainingArguments
from datasets import load_dataset
max_seq_length = 2048 # Supports RoPE Scaling interally, so choose any!
# Get LAION dataset
url = "https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl"
dataset = load_dataset("json", data_files = {"train" : url}, split = "train")
# 4bit pre quantized models we support for 4x faster downloading + no OOMs.
fourbit_models = [
"unsloth/mistral-7b-v0.3-bnb-4bit", # New Mistral v3 2x faster!
"unsloth/mistral-7b-instruct-v0.3-bnb-4bit",
"unsloth/llama-3-8b-bnb-4bit", # Llama-3 15 trillion tokens model 2x faster!
"unsloth/llama-3-8b-Instruct-bnb-4bit",
"unsloth/llama-3-70b-bnb-4bit",
"unsloth/Phi-3-mini-4k-instruct", # Phi-3 2x faster!
"unsloth/Phi-3-medium-4k-instruct",
"unsloth/mistral-7b-bnb-4bit",
"unsloth/gemma-7b-bnb-4bit", # Gemma 2.2x faster!
] # More models at https://huggingface.co/unsloth
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/llama-3-8b-bnb-4bit",
max_seq_length = max_seq_length,
dtype = None,
load_in_4bit = True,
)
# Do model patching and add fast LoRA weights
model = FastLanguageModel.get_peft_model(
model,
r = 16,
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",],
lora_alpha = 16,
lora_dropout = 0, # Supports any, but = 0 is optimized
bias = "none", # Supports any, but = "none" is optimized
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
random_state = 3407,
max_seq_length = max_seq_length,
use_rslora = False, # We support rank stabilized LoRA
loftq_config = None, # And LoftQ
)
trainer = SFTTrainer(
model = model,
train_dataset = dataset,
dataset_text_field = "text",
max_seq_length = max_seq_length,
tokenizer = tokenizer,
args = TrainingArguments(
per_device_train_batch_size = 2,
gradient_accumulation_steps = 4,
warmup_steps = 10,
max_steps = 60,
fp16 = not is_bfloat16_supported(),
bf16 = is_bfloat16_supported(),
logging_steps = 1,
output_dir = "outputs",
optim = "adamw_8bit",
seed = 3407,
),
)
trainer.train()
# Go to https://github.com/unslothai/unsloth/wiki for advanced tips like
# (1) Saving to GGUF / merging to 16bit for vLLM
# (2) Continued training from a saved LoRA adapter
# (3) Adding an evaluation loop / OOMs
# (4) Customized chat templates
```
<a name="DPO"></a>
## DPO Support
DPO (Direct Preference Optimization), PPO, Reward Modelling all seem to work as per 3rd party independent testing from [Llama-Factory](https://github.com/hiyouga/LLaMA-Factory). We have a preliminary Google Colab notebook for reproducing Zephyr on Tesla T4 here: [notebook](https://colab.research.google.com/drive/15vttTpzzVXv_tJwEk-hIcQ0S9FcEWvwP?usp=sharing).
We're in 🤗Hugging Face's official docs! We're on the [SFT docs](https://huggingface.co/docs/trl/main/en/sft_trainer#accelerate-fine-tuning-2x-using-unsloth) and the [DPO docs](https://huggingface.co/docs/trl/main/en/dpo_trainer#accelerate-dpo-fine-tuning-using-unsloth)!
```python
from unsloth import FastLanguageModel, PatchDPOTrainer
from unsloth import is_bfloat16_supported
PatchDPOTrainer()
import torch
from transformers import TrainingArguments
from trl import DPOTrainer
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/zephyr-sft-bnb-4bit",
max_seq_length = max_seq_length,
dtype = None,
load_in_4bit = True,
)
# Do model patching and add fast LoRA weights
model = FastLanguageModel.get_peft_model(
model,
r = 64,
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",],
lora_alpha = 64,
lora_dropout = 0, # Supports any, but = 0 is optimized
bias = "none", # Supports any, but = "none" is optimized
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
random_state = 3407,
max_seq_length = max_seq_length,
)
dpo_trainer = DPOTrainer(
model = model,
ref_model = None,
args = TrainingArguments(
per_device_train_batch_size = 4,
gradient_accumulation_steps = 8,
warmup_ratio = 0.1,
num_train_epochs = 3,
fp16 = not is_bfloat16_supported(),
bf16 = is_bfloat16_supported(),
logging_steps = 1,
optim = "adamw_8bit",
seed = 42,
output_dir = "outputs",
),
beta = 0.1,
train_dataset = YOUR_DATASET_HERE,
# eval_dataset = YOUR_DATASET_HERE,
tokenizer = tokenizer,
max_length = 1024,
max_prompt_length = 512,
)
dpo_trainer.train()
```
## 🥇 Detailed Benchmarking Tables
- Click "Code" for fully reproducible examples
- "Unsloth Equal" is a preview of our PRO version, with code stripped out. All settings and the loss curve remains identical.
- For the full list of benchmarking tables, [go to our website](https://unsloth.ai/blog/mistral-benchmark#Benchmark%20tables)
| 1 A100 40GB | 🤗Hugging Face | Flash Attention 2 | 🦥Unsloth Open | Unsloth Equal | Unsloth Pro | Unsloth Max |
|--------------|-------------|-------------|-----------------|--------------|---------------|-------------|
| Alpaca | 1x | 1.04x | 1.98x | 2.48x | 5.32x | **15.64x** |
| code | [Code](https://colab.research.google.com/drive/1u4dBeM-0vGNVmmO6X7cScAut-Hyt4KDF?usp=sharing) | [Code](https://colab.research.google.com/drive/1fgTOxpMbVjloQBvZyz4lF4BacKSZOB2A?usp=sharing) | [Code](https://colab.research.google.com/drive/1YIPY_18xm-K0iJDgvNkRoJsgkPMPAO3G?usp=sharing) | [Code](https://colab.research.google.com/drive/1ANW8EFL3LVyTD7Gq4TkheC1Z7Rxw-rHp?usp=sharing) | | |
| seconds| 1040 | 1001 | 525 | 419 | 196 | 67 |
| memory MB| 18235 | 15365 | 9631 | 8525 | | |
| % saved| | 15.74 | 47.18 | 53.25 | | | |
### Llama-Factory 3rd party benchmarking
- [Link to performance table.](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-Comparison) TGS: tokens per GPU per second. Model: LLaMA2-7B. GPU: NVIDIA A100 * 1. Batch size: 4. Gradient accumulation: 2. LoRA rank: 8. Max length: 1024.
| Method | Bits | TGS | GRAM | Speed |
| --- | --- | --- | --- | --- |
| HF | 16 | 2392 | 18GB | 100% |
| HF+FA2 | 16 | 2954 | 17GB | 123% |
| Unsloth+FA2 | 16 | 4007 | 16GB | **168%** |
| HF | 4 | 2415 | 9GB | 101% |
| Unsloth+FA2 | 4 | 3726 | 7GB | **160%** |
### Performance comparisons between popular models
<details>
<summary>Click for specific model benchmarking tables (Mistral 7b, CodeLlama 34b etc.)</summary>
### Mistral 7b
| 1 A100 40GB | Hugging Face | Flash Attention 2 | Unsloth Open | Unsloth Equal | Unsloth Pro | Unsloth Max |
|--------------|-------------|-------------|-----------------|--------------|---------------|-------------|
| Mistral 7B Slim Orca | 1x | 1.15x | 2.15x | 2.53x | 4.61x | **13.69x** |
| code | [Code](https://colab.research.google.com/drive/1mePk3KzwTD81hr5mcNcs_AX3Kbg_Ha0x?usp=sharing) | [Code](https://colab.research.google.com/drive/1dgHxjvTmX6hb0bPcLp26RXSE6_n9DKj7?usp=sharing) | [Code](https://colab.research.google.com/drive/1SKrKGV-BZoU4kv5q3g0jtE_OhRgPtrrQ?usp=sharing) | [Code](https://colab.research.google.com/drive/18yOiyX0T81mTwZqOALFSCX_tSAqju6aD?usp=sharing) | |
| seconds | 1813 | 1571 | 842 | 718 | 393 | 132 |
| memory MB | 32853 | 19385 | 12465 | 10271 | | |
| % saved| | 40.99 | 62.06 | 68.74 | | |
### CodeLlama 34b
| 1 A100 40GB | Hugging Face | Flash Attention 2 | Unsloth Open | Unsloth Equal | Unsloth Pro | Unsloth Max |
|--------------|-------------|-------------|-----------------|--------------|---------------|-------------|
| Code Llama 34B | OOM ❌ | 0.99x | 1.87x | 2.61x | 4.27x | 12.82x |
| code | [▶️ Code](https://colab.research.google.com/drive/1ykfz3BqrtC_AUFegCzUQjjfUNlxp6Otc?usp=sharing) | [Code](https://colab.research.google.com/drive/12ZypxQh7OC6kBXvWZI-5d05I4m-B_hoR?usp=sharing) | [Code](https://colab.research.google.com/drive/1gdHyAx8XJsz2yNV-DHvbHjR1iCef5Qmh?usp=sharing) | [Code](https://colab.research.google.com/drive/1fm7wqx9MJ0kRrwKOfmLkK1Rmw-pySahB?usp=sharing) | |
| seconds | 1953 | 1982 | 1043 | 748 | 458 | 152 |
| memory MB | 40000 | 33217 | 27413 | 22161 | | |
| % saved| | 16.96| 31.47 | 44.60 | | | |
### 1 Tesla T4
| 1 T4 16GB | Hugging Face | Flash Attention | Unsloth Open | Unsloth Pro Equal | Unsloth Pro | Unsloth Max |
|--------------|-------------|-----------------|-----------------|---------------|---------------|-------------|
| Alpaca | 1x | 1.09x | 1.69x | 1.79x | 2.93x | **8.3x** |
| code | [▶️ Code](https://colab.research.google.com/drive/1XpLIV4s8Bj5uryB-X2gqM88oRGHEGdaB?usp=sharing) | [Code](https://colab.research.google.com/drive/1LyXu6CjuymQg6ddHX8g1dpUvrMa1nn4L?usp=sharing) | [Code](https://colab.research.google.com/drive/1gsv4LpY7C32otl1rgRo5wXTk4HIitXoM?usp=sharing) | [Code](https://colab.research.google.com/drive/1VtULwRQwhEnVdNryjm27zXfdSM1tNfFK?usp=sharing) | | |
| seconds | 1599 | 1468 | 942 | 894 | 545 | 193 |
| memory MB | 7199 | 7059 | 6459 | 5443 | | |
| % saved | | 1.94 | 10.28 | 24.39 | | |
### 2 Tesla T4s via DDP
| 2 T4 DDP | Hugging Face | Flash Attention | Unsloth Open | Unsloth Equal | Unsloth Pro | Unsloth Max |
|--------------|----------|-------------|-----------------|--------------|---------------|-------------|
| Alpaca | 1x | 0.99x | 4.95x | 4.44x | 7.28x | **20.61x** |
| code | [▶️ Code](https://www.kaggle.com/danielhanchen/hf-original-alpaca-t4-ddp) | [Code](https://www.kaggle.com/danielhanchen/hf-sdpa-alpaca-t4-ddp) | [Code](https://www.kaggle.com/danielhanchen/unsloth-alpaca-t4-ddp) | | |
| seconds | 9882 | 9946 | 1996 | 2227 | 1357 | 480 |
| memory MB| 9176 | 9128 | 6904 | 6782 | | |
| % saved | | 0.52 | 24.76 | 26.09 | | | |
</details>
### Performance comparisons on 1 Tesla T4 GPU:
<details>
<summary>Click for Time taken for 1 epoch</summary>
One Tesla T4 on Google Colab
`bsz = 2, ga = 4, max_grad_norm = 0.3, num_train_epochs = 1, seed = 3047, lr = 2e-4, wd = 0.01, optim = "adamw_8bit", schedule = "linear", schedule_steps = 10`
| System | GPU | Alpaca (52K) | LAION OIG (210K) | Open Assistant (10K) | SlimOrca (518K) |
| --- | --- | --- | --- | --- | --- |
| Huggingface | 1 T4 | 23h 15m | 56h 28m | 8h 38m | 391h 41m |
| Unsloth Open | 1 T4 | 13h 7m (1.8x) | 31h 47m (1.8x) | 4h 27m (1.9x) | 240h 4m (1.6x) |
| Unsloth Pro | 1 T4 | 3h 6m (7.5x) | 5h 17m (10.7x) | 1h 7m (7.7x) | 59h 53m (6.5x) |
| Unsloth Max | 1 T4 | 2h 39m (8.8x) | 4h 31m (12.5x) | 0h 58m (8.9x) | 51h 30m (7.6x) |
**Peak Memory Usage**
| System | GPU | Alpaca (52K) | LAION OIG (210K) | Open Assistant (10K) | SlimOrca (518K) |
| --- | --- | --- | --- | --- | --- |
| Huggingface | 1 T4 | 7.3GB | 5.9GB | 14.0GB | 13.3GB |
| Unsloth Open | 1 T4 | 6.8GB | 5.7GB | 7.8GB | 7.7GB |
| Unsloth Pro | 1 T4 | 6.4GB | 6.4GB | 6.4GB | 6.4GB |
| Unsloth Max | 1 T4 | 11.4GB | 12.4GB | 11.9GB | 14.4GB |
</details>
<details>
<summary>Click for Performance Comparisons on 2 Tesla T4 GPUs via DDP:</summary>
**Time taken for 1 epoch**
Two Tesla T4s on Kaggle
`bsz = 2, ga = 4, max_grad_norm = 0.3, num_train_epochs = 1, seed = 3047, lr = 2e-4, wd = 0.01, optim = "adamw_8bit", schedule = "linear", schedule_steps = 10`
| System | GPU | Alpaca (52K) | LAION OIG (210K) | Open Assistant (10K) | SlimOrca (518K) * |
| --- | --- | --- | --- | --- | --- |
| Huggingface | 2 T4 | 84h 47m | 163h 48m | 30h 51m | 1301h 24m * |
| Unsloth Pro | 2 T4 | 3h 20m (25.4x) | 5h 43m (28.7x) | 1h 12m (25.7x) | 71h 40m (18.1x) * |
| Unsloth Max | 2 T4 | 3h 4m (27.6x) | 5h 14m (31.3x) | 1h 6m (28.1x) | 54h 20m (23.9x) * |
**Peak Memory Usage on a Multi GPU System (2 GPUs)**
| System | GPU | Alpaca (52K) | LAION OIG (210K) | Open Assistant (10K) | SlimOrca (518K) * |
| --- | --- | --- | --- | --- | --- |
| Huggingface | 2 T4 | 8.4GB \| 6GB | 7.2GB \| 5.3GB | 14.3GB \| 6.6GB | 10.9GB \| 5.9GB * |
| Unsloth Pro | 2 T4 | 7.7GB \| 4.9GB | 7.5GB \| 4.9GB | 8.5GB \| 4.9GB | 6.2GB \| 4.7GB * |
| Unsloth Max | 2 T4 | 10.5GB \| 5GB | 10.6GB \| 5GB | 10.6GB \| 5GB | 10.5GB \| 5GB * |
* Slim Orca `bsz=1` for all benchmarks since `bsz=2` OOMs. We can handle `bsz=2`, but we benchmark it with `bsz=1` for consistency.
</details>
![](https://i.ibb.co/sJ7RhGG/image-41.png)
<br>
### Thank You to
- [HuyNguyen-hust](https://github.com/HuyNguyen-hust) for making [RoPE Embeddings 28% faster](https://github.com/unslothai/unsloth/pull/238)
- [RandomInternetPreson](https://github.com/RandomInternetPreson) for confirming WSL support
- [152334H](https://github.com/152334H) for experimental DPO support
- [atgctg](https://github.com/atgctg) for syntax highlighting
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
import importlib
import sys
from packaging.version import Version
# # Define a list of modules to check
# MODULES_TO_CHECK = ["bitsandbytes"]
# # Check if any of the modules in the list have been imported
# for module in MODULES_TO_CHECK:
# if module in sys.modules:
# raise ImportError(f"Unsloth: Please import Unsloth before {module}.")
# pass
# pass
# Unsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so
# enabling it will require much more work, so we have to prioritize. Please understand!
# We do have a beta version, which you can contact us about!
# Thank you for your understanding and we appreciate it immensely!
if "CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
devices = os.environ["CUDA_VISIBLE_DEVICES"]
# Check if there are multiple cuda devices set in env
if not devices.isdigit():
first_id = devices.split(",")[0]
warnings.warn(
f"Unsloth: 'CUDA_VISIBLE_DEVICES' is currently {devices} \n"\
"Unsloth currently does not support multi GPU setups - but we are working on it!\n"\
"Multiple CUDA devices detected but we require a single device.\n"\
f"We will override CUDA_VISIBLE_DEVICES to first device: {first_id}."
)
os.environ["CUDA_VISIBLE_DEVICES"] = str(first_id)
else:
# warnings.warn("Unsloth: 'CUDA_VISIBLE_DEVICES' is not set. We shall set it ourselves.")
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
pass
# Reduce VRAM usage by reducing fragmentation
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
try:
import torch
except:
raise ImportError("Pytorch is not installed. Go to https://pytorch.org/.\n"\
"We have some installation instructions on our Github page.")
pass
# Hugging Face Hub faster downloads (only enable during Colab and Kaggle sessions)
keynames = "\n" + "\n".join(os.environ.keys())
if "\nCOLAB_" in keynames or "\nKAGGLE_" in keynames:
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
pass
# We support Pytorch 2
# Fixes https://github.com/unslothai/unsloth/issues/38
torch_version = torch.__version__.split(".")
major_torch, minor_torch = torch_version[0], torch_version[1]
major_torch, minor_torch = int(major_torch), int(minor_torch)
if (major_torch < 2):
raise ImportError("Unsloth only supports Pytorch 2 for now. Please update your Pytorch to 2.1.\n"\
"We have some installation instructions on our Github page.")
elif (major_torch == 2) and (minor_torch < 2):
# Disable expandable_segments
del os.environ["PYTORCH_CUDA_ALLOC_CONF"]
pass
# Torch 2.5 has including_emulation
major_version, minor_version = torch.cuda.get_device_capability()
SUPPORTS_BFLOAT16 = (major_version >= 8)
if (major_torch == 2) and (minor_torch >= 5):
old_is_bf16_supported = torch.cuda.is_bf16_supported
def is_bf16_supported(including_emulation = False):
return old_is_bf16_supported(including_emulation)
torch.cuda.is_bf16_supported = is_bf16_supported
else:
def is_bf16_supported(): return SUPPORTS_BFLOAT16
torch.cuda.is_bf16_supported = is_bf16_supported
pass
# Try loading bitsandbytes and triton
import bitsandbytes as bnb
import triton
libcuda_dirs = lambda: None
if Version(triton.__version__) >= Version("3.0.0"):
try: from triton.backends.nvidia.driver import libcuda_dirs
except: pass
else: from triton.common.build import libcuda_dirs
import os
import re
import numpy as np
import subprocess
try:
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
libcuda_dirs()
except:
warnings.warn(
"Unsloth: Running `ldconfig /usr/lib64-nvidia` to link CUDA."\
)
if os.path.exists("/usr/lib64-nvidia"):
os.system("ldconfig /usr/lib64-nvidia")
elif os.path.exists("/usr/local"):
# Sometimes bitsandbytes cannot be linked properly in Runpod for example
possible_cudas = subprocess.check_output(["ls", "-al", "/usr/local"]).decode("utf-8").split("\n")
find_cuda = re.compile(r"[\s](cuda\-[\d\.]{2,})$")
possible_cudas = [find_cuda.search(x) for x in possible_cudas]
possible_cudas = [x.group(1) for x in possible_cudas if x is not None]
# Try linking cuda folder, or everything in local
if len(possible_cudas) == 0:
os.system(f"ldconfig /usr/local/")
else:
find_number = re.compile(r"([\d\.]{2,})")
latest_cuda = np.argsort([float(find_number.search(x).group(1)) for x in possible_cudas])[::-1][0]
latest_cuda = possible_cudas[latest_cuda]
os.system(f"ldconfig /usr/local/{latest_cuda}")
pass
importlib.reload(bnb)
importlib.reload(triton)
try:
libcuda_dirs = lambda: None
if Version(triton.__version__) >= Version("3.0.0"):
try: from triton.backends.nvidia.driver import libcuda_dirs
except: pass
else: from triton.common.build import libcuda_dirs
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
libcuda_dirs()
except:
warnings.warn(
"Unsloth: CUDA is not linked properly.\n"\
"Try running `python -m bitsandbytes` then `python -m xformers.info`\n"\
"We tried running `ldconfig /usr/lib64-nvidia` ourselves, but it didn't work.\n"\
"You need to run in your terminal `sudo ldconfig /usr/lib64-nvidia` yourself, then import Unsloth.\n"\
"Also try `sudo ldconfig /usr/local/cuda-xx.x` - find the latest cuda version.\n"\
"Unsloth will still run for now, but maybe it might crash - let's hope it works!"
)
pass
from .models import *
from .save import *
from .chat_templates import *
from .tokenizer_utils import *
from .trainer import *
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = [
"get_chat_template",
"test_chat_templates",
"test_hf_gguf_equivalence",
"remove_special_tokens",
"to_sharegpt",
"standardize_sharegpt",
"apply_chat_template",
"train_on_responses_only",
"test_construct_chat_template",
]
from transformers import StoppingCriteria, StoppingCriteriaList
from torch import LongTensor, FloatTensor
from transformers.models.llama.modeling_llama import logger
from .save import patch_saving_functions
import os
import shutil
from .tokenizer_utils import *
from .models._utils import patch_tokenizer
import re
CHAT_TEMPLATES = {}
# =========================================== Unsloth
# Unsloth efficient template leverages from Zephyr
unsloth_template = \
"{{ bos_token }}"\
"{% if messages[0]['role'] == 'system' %}"\
"{{ messages[0]['content'] + '\n' }}"\
"{% set loop_messages = messages[1:] %}"\
"{% else %}"\
"{{ 'You are a helpful assistant to the user\n' }}"\
"{% set loop_messages = messages %}"\
"{% endif %}"\
"{% for message in loop_messages %}"\
"{% if message['role'] == 'user' %}"\
"{{ '>>> User: ' + message['content'] + '\n' }}"\
"{% elif message['role'] == 'assistant' %}"\
"{{ '>>> Assistant: ' + message['content'] + eos_token + '\n' }}"\
"{% else %}"\
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
"{% endif %}"\
"{% endfor %}"\
"{% if add_generation_prompt %}"\
"{{ '>>> Assistant: ' }}"\
"{% endif %}"
pass
unsloth_ollama = \
'''
FROM {__FILE_LOCATION__}
TEMPLATE """{{ if .System }}{{ .System }}
{{ end }}{{ if .Prompt }}>>> User: {{ .Prompt }}
{{ end }}>>> Assistant: {{ .Response }}{__EOS_TOKEN__}
"""
PARAMETER stop "{__EOS_TOKEN__}"
SYSTEM """You are a helpful assistant to the user"""
'''
unsloth_eos_token = "eos_token"
CHAT_TEMPLATES["unsloth"] = (unsloth_template, unsloth_eos_token, False, unsloth_ollama,)
pass
# =========================================== Zephyr
# Zephyr has no BOS!
zephyr_template = \
"{% for message in messages %}"\
"{% if message['role'] == 'user' %}"\
"{{ '<|user|>\n' + message['content'] + eos_token + '\n' }}"\
"{% elif message['role'] == 'assistant' %}"\
"{{ '<|assistant|>\n' + message['content'] + eos_token + '\n' }}"\
"{% else %}"\
"{{ '<|system|>\n' + message['content'] + eos_token + '\n' }}"\
"{% endif %}"\
"{% endfor %}"\
"{% if add_generation_prompt %}"\
"{{ '<|assistant|>\n' }}"\
"{% endif %}"
pass
zephyr_ollama = \
'''
FROM {__FILE_LOCATION__}
TEMPLATE """{{ if .System }}<|system|>
{{ .System }}{__EOS_TOKEN__}
{{ end }}{{ if .Prompt }}<|user|>
{{ .Prompt }}{__EOS_TOKEN__}
{{ end }}<|assistant|>
{{ .Response }}{__EOS_TOKEN__}
"""
PARAMETER stop "{__EOS_TOKEN__}"
'''
zephyr_eos_token = "eos_token"
CHAT_TEMPLATES["zephyr"] = (zephyr_template, zephyr_eos_token, False, zephyr_ollama,)
pass
# =========================================== ChatML
# ChatML has no BOS and not EOS! Rather <|im_start|> and <|im_end|> acts as BOS / EOS.
chatml_template = \
"{% for message in messages %}"\
"{% if message['role'] == 'user' %}"\
"{{'<|im_start|>user\n' + message['content'] + '<|im_end|>\n'}}"\
"{% elif message['role'] == 'assistant' %}"\
"{{'<|im_start|>assistant\n' + message['content'] + '<|im_end|>\n' }}"\
"{% else %}"\
"{{ '<|im_start|>system\n' + message['content'] + '<|im_end|>\n' }}"\
"{% endif %}"\
"{% endfor %}"\
"{% if add_generation_prompt %}"\
"{{ '<|im_start|>assistant\n' }}"\
"{% endif %}"
pass
chatml_ollama = \
'''
FROM {__FILE_LOCATION__}
TEMPLATE """{{ if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}{{ if .Prompt }}<|im_start|>user
{{ .Prompt }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ .Response }}<|im_end|>
"""
PARAMETER stop "<|im_start|>"
PARAMETER stop "<|im_end|>"
'''
chatml_eos_token = "<|im_end|>"
CHAT_TEMPLATES["chatml"] = (chatml_template, chatml_eos_token, True, chatml_ollama,)
pass
# =========================================== Mistral-1
# Mistral Instruct doesn't allow system prompts, so we append it to the user message.
mistral_template = \
"{{ bos_token }}"\
"{% if messages[0]['role'] == 'system' %}"\
"{% if messages[1]['role'] == 'user' %}"\
"{{ '[INST] ' + messages[0]['content'] + ' ' + messages[1]['content'] + ' [/INST]' }}"\
"{% set loop_messages = messages[2:] %}"\
"{% else %}"\
"{{ '[INST] ' + messages[0]['content'] + ' [/INST]' }}"\
"{% set loop_messages = messages[1:] %}"\
"{% endif %}"\
"{% else %}"\
"{% set loop_messages = messages %}"\
"{% endif %}"\
"{% for message in loop_messages %}"\
"{% if message['role'] == 'user' %}"\
"{{ '[INST] ' + message['content'] + ' [/INST]' }}"\
"{% elif message['role'] == 'assistant' %}"\
"{{ message['content'] + eos_token }}"\
"{% else %}"\
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
"{% endif %}"\
"{% endfor %}"
pass
# Ollama from https://www.ollama.com/library/mistral
mistral_ollama = \
'''
FROM {__FILE_LOCATION__}
TEMPLATE """[INST] {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }} [/INST]"""
PARAMETER stop "{__EOS_TOKEN__}"
'''
mistral_eos_token = "eos_token"
CHAT_TEMPLATES["mistral"] = (mistral_template, mistral_eos_token, False, mistral_ollama,)
pass
# =========================================== Llama-2
# Adds BOS to every convo! And weird <<SYS>> system messages.
llama_template = \
"{% if messages[0]['role'] == 'system' %}"\
"{% if messages[1]['role'] == 'user' %}"\
"{{ bos_token + '[INST] <<SYS>>\n' + messages[0]['content'] + '\n<</SYS>>\n\n' + messages[1]['content'] + ' [/INST]' }}"\
"{% set loop_messages = messages[2:] %}"\
"{% else %}"\
"{{ bos_token + '[INST] ' + messages[0]['content'] + ' [/INST]' }}"\
"{% set loop_messages = messages[1:] %}"\
"{% endif %}"\
"{% else %}"\
"{% set loop_messages = messages %}"\
"{% endif %}"\
"{% for message in loop_messages %}"\
"{% if message['role'] == 'user' %}"\
"{{ bos_token + '[INST] ' + message['content'].strip() + ' [/INST]' }}"\
"{% elif message['role'] == 'assistant' %}"\
"{{ ' ' + message['content'].strip() + ' ' + eos_token }}"\
"{% else %}"\
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
"{% endif %}"\
"{% endfor %}"
pass
# Ollama from https://www.ollama.com/library/llama3
llama_ollama = \
'''
FROM {__FILE_LOCATION__}
TEMPLATE """[INST] <<SYS>>{{ .System }}<</SYS>>
{{ .Prompt }} [/INST]"""
PARAMETER stop "{__EOS_TOKEN__}"
'''
llama_eos_token = "eos_token"
CHAT_TEMPLATES["llama"] = (llama_template, llama_eos_token, False, llama_ollama,)
pass
# =========================================== Vicuna
# https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#prompt-template
vicuna_template = \
"{{ bos_token }}"\
"{% if messages[0]['role'] == 'system' %}"\
"{{ messages[0]['content'] + ' ' }}"\
"{% set loop_messages = messages[1:] %}"\
"{% else %}"\
"{{ 'A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user\\'s questions.' + ' ' }}"\
"{% set loop_messages = messages %}"\
"{% endif %}"\
"{% for message in loop_messages %}"\
"{% if message['role'] == 'user' %}"\
"{{ 'USER: ' + message['content'] + ' ' }}"\
"{% elif message['role'] == 'assistant' %}"\
"{{ 'ASSISTANT: ' + message['content'] + eos_token }}"\
"{% else %}"\
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
"{% endif %}"\
"{% endfor %}"\
"{% if add_generation_prompt %}"\
"{{ 'ASSISTANT:' }}"\
"{% endif %}"
pass
# Ollama from https://www.ollama.com/library/vicuna
vicuna_ollama = \
'''
FROM {__FILE_LOCATION__}
TEMPLATE """{{ if .System }}{{ .System }} {{ end }}{{ if .Prompt }}USER: {{ .Prompt }} {{ end }}ASSISTANT: {{ .Response }} {__EOS_TOKEN__}"""
PARAMETER stop "{__EOS_TOKEN__}"
'''
vicuna_eos_token = "eos_token"
CHAT_TEMPLATES["vicuna"] = (vicuna_template, vicuna_eos_token, False, vicuna_ollama,)
pass
# =========================================== Vicuna Old
# https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#prompt-template
vicuna_old_template = \
"{{ bos_token }}"\
"{% if messages[0]['role'] == 'system' %}"\
"{{ messages[0]['content'] + '\n' }}"\
"{% set loop_messages = messages[1:] %}"\
"{% else %}"\
"{{ 'A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human\\'s questions.' + '\n' }}"\
"{% set loop_messages = messages %}"\
"{% endif %}"\
"{% for message in loop_messages %}"\
"{% if message['role'] == 'user' %}"\
"{{ '### Human: ' + message['content'] + '\n' }}"\
"{% elif message['role'] == 'assistant' %}"\
"{{ '### Assistant: ' + message['content'] + eos_token + '\n' }}"\
"{% else %}"\
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
"{% endif %}"\
"{% endfor %}"\
"{% if add_generation_prompt %}"\
"{{ '### Assistant:' }}"\
"{% endif %}"
pass
vicuna_old_ollama = \
'''
FROM {__FILE_LOCATION__}
TEMPLATE """{{ if .System }}{{ .System }}
{{ end }}{{ if .Prompt }}### Human: {{ .Prompt }}
{{ end }}### Assistant: {{ .Response }}{__EOS_TOKEN__}
"""
PARAMETER stop "{__EOS_TOKEN__}"
SYSTEM """A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."""
'''
vicuna_old_eos_token = "eos_token"
CHAT_TEMPLATES["vicuna_old"] = (vicuna_old_template, vicuna_old_eos_token, False, vicuna_old_ollama,)
pass
# =========================================== Alpaca multi turn
# https://github.com/tatsu-lab/stanford_alpaca Changed for multi-turn convos
alpaca_template = \
"{{ bos_token }}"\
"{% if messages[0]['role'] == 'system' %}"\
"{{ messages[0]['content'] + '\n\n' }}"\
"{% set loop_messages = messages[1:] %}"\
"{% else %}"\
"{{ 'Below are some instructions that describe some tasks. Write responses that appropriately complete each request.\n\n' }}"\
"{% set loop_messages = messages %}"\
"{% endif %}"\
"{% for message in loop_messages %}"\
"{% if message['role'] == 'user' %}"\
"{{ '### Instruction:\n' + message['content'] + '\n\n' }}"\
"{% elif message['role'] == 'assistant' %}"\
"{{ '### Response:\n' + message['content'] + eos_token + '\n\n' }}"\
"{% else %}"\
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
"{% endif %}"\
"{% endfor %}"\
"{% if add_generation_prompt %}"\
"{{ '### Response:\n' }}"\
"{% endif %}"
pass
alpaca_ollama = \
'''
FROM {__FILE_LOCATION__}
TEMPLATE """{{ if .System }}{{ .System }}
{{ end }}{{ if .Prompt }}### Instruction:
{{ .Prompt }}{{ end }}
### Response:
{{ .Response }}{__EOS_TOKEN__}
"""
PARAMETER stop "{__EOS_TOKEN__}"
SYSTEM """Below are some instructions that describe some tasks. Write responses that appropriately complete each request."""
'''
alpaca_eos_token = "eos_token"
CHAT_TEMPLATES["alpaca"] = (alpaca_template, alpaca_eos_token, False, alpaca_ollama,)
pass
# =========================================== Gemma
# https://huggingface.co/google/gemma-7b-it
# Notice we must use |trim for lstrip and rstrip. <start_of_turn> maps to 106.
# <end_of_turn> maps to 107. user and model are normal 1 word tokens.
gemma_template = \
"{{ bos_token }}"\
"{% if messages[0]['role'] == 'system' %}"\
"{{'<start_of_turn>user\n' + messages[0]['content'] | trim + ' ' + messages[1]['content'] | trim + '<end_of_turn>\n'}}"\
"{% set loop_messages = messages[2:] %}"\
"{% endif %}"\
"{% for message in messages %}"\
"{% if message['role'] == 'user' %}"\
"{{'<start_of_turn>user\n' + message['content'] | trim + '<end_of_turn>\n'}}"\
"{% elif message['role'] == 'assistant' %}"\
"{{'<start_of_turn>model\n' + message['content'] | trim + '<end_of_turn>\n' }}"\
"{% else %}"\
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
"{% endif %}"\
"{% endfor %}"\
"{% if add_generation_prompt %}"\
"{{ '<start_of_turn>model\n' }}"\
"{% endif %}"
pass
# Ollama from https://www.ollama.com/library/gemma
gemma_ollama = \
'''
FROM {__FILE_LOCATION__}
TEMPLATE """<start_of_turn>user
{{ if .System }}{{ .System }} {{ end }}{{ .Prompt }}<end_of_turn>
<start_of_turn>model
{{ .Response }}<end_of_turn>
"""
PARAMETER repeat_penalty 1
PARAMETER stop "<start_of_turn>"
PARAMETER stop "<end_of_turn>"
PARAMETER penalize_newline false
'''
gemma_eos_token = "<end_of_turn>"
CHAT_TEMPLATES["gemma"] = (gemma_template, gemma_eos_token, True, gemma_ollama,)
pass
# =========================================== Gemma with ChatML instead
# We find using <eos> is still more appropriate!
gemma_chatml_template = "{{ bos_token }}" + chatml_template
pass
gemma_chatml_ollama = \
'''
FROM {__FILE_LOCATION__}
TEMPLATE """{{ if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}{{ if .Prompt }}<|im_start|>user
{{ .Prompt }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ .Response }}<|im_end|>
"""
PARAMETER repeat_penalty 1
PARAMETER stop "<|im_start|>"
PARAMETER stop "<|im_end|>"
PARAMETER penalize_newline false
'''
gemma_chatml_eos_token = (
{"<start_of_turn>" : "<|im_start|>", "<eos>" : "<|im_end|>"},
"<|im_end|>",
)
CHAT_TEMPLATES["gemma_chatml"] = (gemma_chatml_template, gemma_chatml_eos_token, True, gemma_chatml_ollama,)
pass
# =========================================== Gemma 2
# Same as Gemma 1, but with sliding window attention!
# https://ollama.com/library/gemma2/blobs/6522ca797f47
gemma2_template = gemma_template
gemma2_ollama = gemma_ollama + "PARAMETER num_ctx 4096\n"
gemma2_eos_token = "<end_of_turn>"
CHAT_TEMPLATES["gemma2"] = (gemma2_template, gemma2_eos_token, True, gemma2_ollama,)
# =========================================== Gemma 2 with ChatML instead
gemma2_chatml_template = gemma_chatml_template
gemma2_chatml_ollama = gemma_chatml_ollama + "PARAMETER num_ctx 4096\n"
gemma2_chatml_eos_token = gemma_chatml_eos_token
CHAT_TEMPLATES["gemma2_chatml"] = (gemma2_chatml_template, gemma2_chatml_eos_token, True, gemma2_chatml_ollama,)
pass
# =========================================== Llama-3
# Weirdly \n\n is needed?
llama3_template = \
"{{ bos_token }}"\
"{% for message in messages %}"\
"{% if message['role'] == 'user' %}"\
"{{ '<|start_header_id|>user<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }}"\
"{% elif message['role'] == 'assistant' %}"\
"{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }}"\
"{% else %}"\
"{{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }}"\
"{% endif %}"\
"{% endfor %}"\
"{% if add_generation_prompt %}"\
"{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}"\
"{% endif %}"
pass
# Ollama from https://www.ollama.com/library/llama3
llama3_ollama = \
'''
FROM {__FILE_LOCATION__}
TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
{{ .Response }}<|eot_id|>"""
PARAMETER stop "<|start_header_id|>"
PARAMETER stop "<|end_header_id|>"
PARAMETER stop "<|eot_id|>"
'''
llama3_template_eos_token = "eos_token"
CHAT_TEMPLATES["llama-3"] = (llama3_template, llama3_template_eos_token, False, llama3_ollama,)
pass
# =========================================== Phi-3
phi3_template = \
"{{ bos_token }}"\
"{% for message in messages %}"\
"{% if message['role'] == 'user' %}"\
"{{'<|user|>\n' + message['content'] + '<|end|>\n'}}"\
"{% elif message['role'] == 'assistant' %}"\
"{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}"\
"{% else %}"\
"{{'<|' + message['role'] + '|>\n' + message['content'] + '<|end|>\n'}}"\
"{% endif %}"\
"{% endfor %}"\
"{% if add_generation_prompt %}"\
"{{ '<|assistant|>\n' }}"\
"{% endif %}"
pass
# Ollama from https://www.ollama.com/library/phi3
phi3_ollama = \
'''
FROM {__FILE_LOCATION__}
TEMPLATE """{{ if .System }}<|system|>
{{ .System }}<|end|>
{{ end }}{{ if .Prompt }}<|user|>
{{ .Prompt }}<|end|>
{{ end }}<|assistant|>
{{ .Response }}<|end|>
"""
PARAMETER stop "<|end|>"
PARAMETER stop "<|user|>"
PARAMETER stop "<|assistant|>"
'''
phi3_template_eos_token = "<|end|>"
CHAT_TEMPLATES["phi-3"] = (phi3_template, phi3_template_eos_token, False, phi3_ollama,)
pass
def get_chat_template(
tokenizer,
chat_template = "chatml",
mapping = {"role" : "role", "content" : "content", "user" : "user", "assistant" : "assistant"},
map_eos_token = True,
system_message = None,
):
assert(type(map_eos_token) is bool)
old_tokenizer = tokenizer
IS_GEMMA = False
if tokenizer.__class__.__name__.startswith("Gemma"):
if chat_template == "chatml": chat_template = "gemma_chatml"
IS_GEMMA = True
pass
# We add a check for Llama-3
# if chat_template == "llama-3":
# tokenizer._using_llama3_template = True
# else:
# llama3_tokens = set(["<|end_header_id|>", "<|eot_id|>", "<|start_header_id|>"])
# check_llama3_tokens = llama3_tokens & set(str(x) for x in tokenizer.added_tokens_decoder.values())
# if len(check_llama3_tokens) == len(llama3_tokens):
# tokenizer._using_llama3_template = True
# pass
# pass
# We first check if the tokenizer is a fast one. If not, we cannot convert this!
is_fast_tokenizer = getattr(tokenizer, "is_fast", False)
old_padding_side = tokenizer.padding_side
same_padding_token = False
if type(chat_template) in (list, tuple,):
chat_template, stop_word = chat_template
assert(type(chat_template) is str)
assert(type(stop_word) is str)
ollama_modelfile = None
elif type(chat_template) is str:
chat_template, stop_word, yes_map_eos_token, ollama_modelfile = CHAT_TEMPLATES[chat_template]
# Check mapping to eos_token
if not map_eos_token and yes_map_eos_token: map_eos_token = True
if not yes_map_eos_token and map_eos_token: map_eos_token = False
if type(stop_word) in (list, tuple,):
token_mapping, stop_word = stop_word
assert(type(token_mapping) is dict)
else:
token_mapping = None
assert(type(stop_word) is str)
# Check fast tokenizer
if not is_fast_tokenizer:
print(
f"Unsloth: Not a fast tokenizer, so can't process it as of yet :(\n"\
"Please log a Github issue if you want this as a new feature!\n"\
"Your chat template will still work, but it won't add or edit tokens."
)
elif token_mapping is not None:
# token_mapping = {"<start_of_turn>" : "<|im_start|>", "<end_of_turn>" : "<|im_end|>"}
# For Gemma :)
string_vocab = tokenizer._tokenizer.to_str()
skipped = 0
for old_token, new_token in token_mapping.items():
old_count = string_vocab.count(f'"{old_token}"')
new_count = string_vocab.count(f'"{new_token}"')
if new_count != 0:
print(f"{new_token} is already a token. Skipping.")
skipped += 1
elif old_count == 0:
raise RuntimeError(f"{old_token} was not part of the tokenizer!")
else:
string_vocab = string_vocab.replace(f'"{old_token}"', f'"{new_token}"')
pass
pass
if map_eos_token and (not stop_word in token_mapping.values()):
# Do not map 107 = <|im_end|> and 1 = <|im_end|>. This will reduce the vocab size by 1
logger.warning_once(f"Unsloth: Will map {stop_word} to EOS = {tokenizer.eos_token}.")
string_vocab = string_vocab.replace(tokenizer.eos_token, stop_word)
pass
if skipped != len(token_mapping):
new_tokenizer = tokenizer._tokenizer.from_str(string_vocab)
# Careful on pad_token
old_pad_token = tokenizer.pad_token
if old_pad_token == tokenizer.eos_token:
old_pad_token = stop_word
same_padding_token = True
pass
if map_eos_token:
new_tokenizer = tokenizer.__class__(
tokenizer_object = new_tokenizer,
eos_token = stop_word,
pad_token = old_pad_token,
)
else:
new_tokenizer = tokenizer.__class__(
tokenizer_object = new_tokenizer,
pad_token = old_pad_token,
)
pass
# Must fix the sentence piece tokenizer since there's no tokenizer.model file!
tokenizer = fix_sentencepiece_tokenizer(tokenizer, new_tokenizer, token_mapping,)
else:
pass
elif map_eos_token and (stop_word != "eos_token"):
logger.warning_once(f"Unsloth: Will map {stop_word} to EOS = {tokenizer.eos_token}.")
# Replaces the old EOS token with a new one.
# Useful for ChatML <|im_end|> for example.
# Usually we train 2 more tokens <|im_start|> and <|im_end|>
# But training the lm_head and embeddings are slow!
# This is a HACK!
# Idea from https://huggingface.co/cognitivecomputations/dolphin-2.6-mistral-7b-dpo-laser
old_bos_token = getattr(tokenizer, "bos_token", None)
old_eos_token = getattr(tokenizer, "eos_token", None)
old_pad_token = getattr(tokenizer, "pad_token", None)
old_unk_token = getattr(tokenizer, "unk_token", None)
string_vocab = tokenizer._tokenizer.to_str()
# First check if new stop_word is in the tokenizer
if stop_word in string_vocab:
# We shall swap them around
temporary_stop_token = "<|:__TEMP//STOP//TOKEN__:|>"
string_vocab = string_vocab.replace(old_eos_token, temporary_stop_token)
string_vocab = string_vocab.replace(stop_word, old_eos_token)
string_vocab = string_vocab.replace(temporary_stop_token, stop_word)
else:
string_vocab = string_vocab.replace(old_eos_token, stop_word)
pass
new_tokenizer = tokenizer._tokenizer.from_str(string_vocab)
# Careful on pad_token
if old_pad_token == old_eos_token:
old_pad_token = stop_word
same_padding_token = True
pass
new_tokenizer = tokenizer.__class__(
tokenizer_object = new_tokenizer,
bos_token = old_bos_token,
eos_token = stop_word,
unk_token = old_unk_token,
pad_token = old_pad_token,
)
# Must fix the sentence piece tokenizer since there's no tokenizer.model file!
token_mapping = { old_eos_token : stop_word, }
tokenizer = fix_sentencepiece_tokenizer(tokenizer, new_tokenizer, token_mapping,)
pass
else:
raise TypeError(
f"Unsloth: `chat_template` must be a tuple of (your_template, eos_token,) or one of\n"\
f"{CHAT_TEMPLATES.keys()}"
)
pass
# For ShareGPT role -> from and content -> value
chat_template = chat_template\
.replace("'role'", "'" + mapping["role"] + "'")\
.replace("'content'", "'" + mapping["content"] + "'")\
.replace("'user'", "'" + mapping["user"] + "'")\
.replace("'assistant'", "'" + mapping["assistant"] + "'")
# Careful on Gemma
# bos_token is a must or else losses become too high
if IS_GEMMA and not chat_template.startswith("{{ bos_token }}"):
chat_template = "{{ bos_token }}" + chat_template
pass
_, tokenizer = patch_tokenizer(model = None, tokenizer = tokenizer)
tokenizer.padding_side = old_padding_side
tokenizer.chat_template = chat_template
# Also fix up other tokens
old_pad_token = getattr(old_tokenizer, "pad_token", None)
old_bos_token = getattr(old_tokenizer, "bos_token", None)
old_unk_token = getattr(old_tokenizer, "unk_token", None)
new_pad_token = getattr(tokenizer, "pad_token", None)
new_bos_token = getattr(tokenizer, "bos_token", None)
new_unk_token = getattr(tokenizer, "unk_token", None)
if old_bos_token != new_bos_token: tokenizer.bos_token = old_bos_token
if old_unk_token != new_unk_token: tokenizer.unk_token = old_unk_token
if not same_padding_token:
if old_pad_token != new_pad_token: tokenizer.pad_token = old_pad_token
pass
# stopping_criteria = create_stopping_criteria(tokenizer, stop_word)
# Patch saving functions
tokenizer = patch_saving_functions(tokenizer)
# Add Ollama
tokenizer._ollama_modelfile = ollama_modelfile
tokenizer._system_message = system_message
return tokenizer#, stopping_criteria
pass
def remove_special_tokens(tokenizer, prompt):
# Removes double BOS token
if prompt.startswith(tokenizer.bos_token):
prompt = prompt[len(tokenizer.bos_token):]
pass
return prompt
pass
def _parse_combined_prompt(combined_prompt, dataset):
# Find {...}
possible_columns = re.findall(r"\{(.+?)\}", combined_prompt)
dataset_columns = set(dataset.column_names)
for column in possible_columns:
if column not in dataset_columns:
raise KeyError(
f"Unsloth: Your prompt includes '{column}' but this does not exist in the dataset. "\
f"Only allowed columns are {list(dataset_columns)}"
)
pass
pass
# Find [[...]]
optional_prompts = list(re.finditer(r"\[\[.+?\]\]", combined_prompt, flags = re.DOTALL | re.MULTILINE))
optional_prompts = [(x.span(), x.group(0)) for x in optional_prompts]
final_optional_prompts = []
if len(optional_prompts) != 0:
# Add left
left = optional_prompts[0]
l = left[0][0]
if l != 0: final_optional_prompts.append(combined_prompt[:l])
# Add in between
for left, right in zip(optional_prompts[:-1], optional_prompts[1:]):
l, r = left[0][-1], right[0][0]
final_optional_prompts.append(left)
if l != r: final_optional_prompts.append(combined_prompt[l : r])
pass
final_optional_prompts.append(optional_prompts[-1])
# Add right
right = optional_prompts[-1]
r = right[0][1]
if r != len(combined_prompt): final_optional_prompts.append(combined_prompt[r:])
else:
# Just add in the entire string
final_optional_prompts.append(combined_prompt)
pass
check_combined = "".join(x if type(x) is str else x[1] for x in final_optional_prompts)
assert(combined_prompt == check_combined)
return possible_columns, final_optional_prompts
pass
def _create_formatter(possible_columns, final_optional_prompts, user_column_name):
# Start final prompt!
function = ["def __combined_prompt_processor__(examples):"]
columns = list(set(possible_columns))
for column in columns:
function.append(f"{' '*4}{column}__ = examples['{column}']")
function.append(f"{' '*4}texts = []")
function.append(f"{' '*4}for ({', '.join(columns)}) in zip({', '.join(f'{x}__' for x in columns)}):")
# Add optional tags as well!
final_prompt = ""
formatter = []
for j, optional_prompt in enumerate(final_optional_prompts):
if type(optional_prompt) is str:
columns = re.findall(r"\{(.+?)\}", optional_prompt)
formatter += columns
# Must escape \n \r
final_prompt += optional_prompt.encode("unicode-escape").decode("utf-8").replace("'", "\\'").replace('"', '\\"')
else:
where, prompt = optional_prompt
# Strip [[...]]
# Must escape \n \r
prompt = prompt[2:-2].encode("unicode-escape").decode("utf-8").replace("'", "\\'").replace('"', '\\"')
columns = re.findall(r"\{(.+?)\}", prompt)
x = f"__optional_{j}__"
prompt = f"{' '*8}{x} = '{prompt}'.format({', '.join(f'{x} = {x}' for x in columns)}) if {columns[0]} else ''"
function.append(prompt)
formatter.append(x)
final_prompt += "{" + x + "}"
pass
pass
function.insert(1, f"{' '*4}__combined_prompt__ = '{final_prompt}'")
function.append(f"{' '*8}texts.append("\
f"__combined_prompt__.format({', '.join(f'{x} = {x}' for x in formatter)}))")
function.append(f"{' '*4}return " + "{ " + f"'{user_column_name}' : texts" + " }")
return "\n".join(function)
pass
def to_sharegpt(
dataset,
merged_prompt = "",
merged_column_name = "instruction",
output_column_name = "output",
remove_unsued_columns = True,
conversation_extension = 1,
random_state = 3407,
):
"""
Converts a dataset to ShareGPT style.
ShareGPT requires only 1 input and 1 output field.
This means one has to merge multiple columns into 1 for 1 input field.
Use `conversation_extension` to increase the length of each conversation by randomnly
selecting a few and packing them into 1.
merged_prompt = "", Prompt to merge columns into 1 input
merged_column_name = "instruction", Final column name for the input field
output_column_name = "output", Final column name for the output field
remove_unsued_columns = True,
conversation_extension = 1, Automatically combines `conversation_extension` convos into 1
random_state = 3407,
"""
if "conversations" in dataset.column_names:
convo = dataset[0]["conversations"]
if type(convo) is list:
raise TypeError("Unsloth: Your dataset is probably already in ShareGPT format!")
pass
pass
possible_columns, final_optional_prompts = _parse_combined_prompt(merged_prompt, dataset)
function = _create_formatter(possible_columns, final_optional_prompts, merged_column_name)
exec(function, globals())
dataset = dataset.map(__combined_prompt_processor__, batched = True, desc = "Merging columns")
def __convert_to_sharegpt__(examples):
users = examples[merged_column_name]
assistants = examples[output_column_name]
texts = [
[
{"from" : "user", "content" : str(user) },
{"from" : "assistant", "content" : str(assistant)},
] \
for user, assistant in zip(users, assistants)
]
return { "conversations" : texts, }
pass
dataset = dataset.map(
__convert_to_sharegpt__,
batched = True,
desc = "Converting to ShareGPT",
# Remove unsued columns!
remove_columns = dataset.column_names if remove_unsued_columns else None,
)
# Randomnly concat conversations to create a long stream!
from datasets import concatenate_datasets
n_extensions = max(conversation_extension-1, 0)
if n_extensions == 0: return dataset
dataset = dataset.rename_columns({"conversations" : f"conversations0"})
all_shuffled = [dataset]
for j in range(1, n_extensions+1):
shuffled = dataset.shuffle(seed = random_state+j).rename_columns({"conversations0" : f"conversations{j}"})
all_shuffled.append(shuffled)
pass
dataset = concatenate_datasets(all_shuffled, axis = 1)
# Combine them into 1
function = "def __combine_conversations__(examples):\n"
n_extensions += 1
for j in range(n_extensions):
function += f"{' '*4}conversations{j}__ = examples['conversations{j}']\n"
function += f"{' '*4}convos = []\n"
function += f"{' '*4}for ({', '.join(f'conversations{j}' for j in range(n_extensions))}) "\
f"in zip({', '.join(f'conversations{j}__' for j in range(n_extensions))}):\n"
function += f"{' '*8}convos.append("\
f"{'+'.join(f'conversations{j}' for j in range(n_extensions))})\n"
function += f"{' '*4}return " + "{ " + f"'conversations' : convos" + " }"
# Map function
exec(function, globals())
dataset = dataset.map(
__combine_conversations__,
batched = True,
desc = "Extending conversations",
# Remove unsued columns!
remove_columns = dataset.column_names if remove_unsued_columns else None,
)
return dataset
pass
def standardize_sharegpt(
dataset,
aliases_for_system = ["system",],
aliases_for_user = ["user", "human", "input",],
aliases_for_assistant = ["gpt", "assistant", "output",],
):
"""
Standardizes ShareGPT and other formats to user/assistant Hugging Face format.
Get aliases for the system, user and assistant roles.
These shall map to "system", "user" and "assistant" respectively.
aliases_for_system = ["system",],
aliases_for_user = ["user", "human", "input",],
aliases_for_assistant = ["gpt", "assistant", "output",],
"""
import collections
import itertools
convos = dataset[:10]["conversations"]
uniques = collections.defaultdict(list)
for convo in convos:
for message in convo:
for key, value in message.items():
uniques[key].append(value)
pass
# Must be only 2 entries
assert(len(uniques.keys()) == 2)
keys = list(uniques.keys())
length_first = len(set(uniques[keys[0]]))
length_second = len(set(uniques[keys[1]]))
if length_first < length_second:
# Role is assigned to the first element
role_key = keys[0]
content_key = keys[1]
else:
role_key = keys[1]
content_key = keys[0]
pass
# Check roles are in aliases
all_aliases = set(aliases_for_system + aliases_for_user + aliases_for_assistant)
roles = set(uniques[role_key])
leftover_aliases = (all_aliases | roles) - all_aliases
if len(leftover_aliases) != 0:
raise TypeError(
f"Unsloth: {list(leftover_aliases)} are not in aliases. Please update aliases."
)
pass
# Mapping for aliases
aliases_mapping = {}
for x in aliases_for_system: aliases_mapping[x] = "system"
for x in aliases_for_user: aliases_mapping[x] = "user"
for x in aliases_for_assistant: aliases_mapping[x] = "assistant"
def _standardize_dataset(examples):
convos = examples["conversations"]
all_convos = []
for convo in convos:
new_convo = [
{ "role" : aliases_mapping[message[role_key]], "content" : message[content_key], }
for message in convo
]
all_convos.append(new_convo)
pass
return { "conversations" : all_convos, }
pass
return dataset.map(_standardize_dataset, batched = True, desc = "Standardizing format")
pass
def get_ollama_eos_tokens(tokenizer, extra_eos_tokens = []):
added_tokens_decoder = tokenizer.added_tokens_decoder.values()
added_tokens_decoder = [str(x) for x in added_tokens_decoder]
# Remove added_tokens_decoder duplicates
added_tokens_decoder = list(set(added_tokens_decoder) - set(extra_eos_tokens))
# Remove BOS
if getattr(tokenizer, "bos_token", None) is not None:
added_tokens_decoder = [x for x in added_tokens_decoder if x != tokenizer.bos_token]
pass
repeatted_tokens = []
# Join all vocab
joined_text = "\x01\x00".join(added_tokens_decoder)
for token in added_tokens_decoder:
n = len(token)
repeatted_counts = joined_text.count(token[:n//2])
# Try finding longer than 1/2 of the token in the rest
# For eg <|reserved_special_token_0|>, <|reserved_special_token_1|>
if repeatted_counts > 2:
for j in range(n//2+1, n):
if joined_text.count(token[:j]) < repeatted_counts:
j -= 1
# Remove repeatted tokens to reduce search space
joined_text = joined_text.replace(token[:j], "")
repeatted_tokens.append(token[:j])
break
pass
pass
pass
# Remove duplicates
splitted = joined_text.split("\x01\x00")
final_eos_tokens = []
for old, new in zip(added_tokens_decoder, splitted):
if old == new: final_eos_tokens.append(old)
pass
final_eos_tokens += extra_eos_tokens
final_eos_tokens += repeatted_tokens
# Remove new lines, spaces and HTML tags
filtered_eos_tokens = []
for token in final_eos_tokens:
if token.count("\n") == len(token): continue
elif token.count("▁") == len(token): continue
elif token.startswith("<") and len(token) <= 2: continue
elif token.startswith("</") and len(token) == 3: continue
filtered_eos_tokens.append(token)
pass
return filtered_eos_tokens
pass
def construct_chat_template( \
tokenizer = None,
chat_template = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
{SYSTEM}<|eot_id|><|start_header_id|>user<|end_header_id|>
{INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{OUTPUT}<|eot_id|><|start_header_id|>user<|end_header_id|>
{INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{OUTPUT}<|eot_id|>""",
default_system_message = \
"Below are some instructions that describe some tasks. Write responses that appropriately complete each request.",
extra_eos_tokens = None,
):
"""
Creates a Ollama modelfile and a HF Jinja template from a custom
template. You must provide 2x examples of an input & output.
There is an optional system message as well.
You must use {INPUT}, {OUTPUT} twice, and {SYSTEM} is optional.
"""
# Strip only the left
chat_template = chat_template.lstrip()
assert(tokenizer is not None)
if extra_eos_tokens is None: extra_eos_tokens = []
elif type(extra_eos_tokens) is str: extra_eos_tokens = [extra_eos_tokens,]
vocab = tokenizer.get_vocab()
for extra_eos in extra_eos_tokens:
assert(type(extra_eos) is str)
if extra_eos not in vocab:
raise ValueError(f"Unsloth: `{extra_eos}` is not a singular token in the tokenizer.")
pass
pass
error_msg = \
"Unsloth: Your prompt template must have 2 examples showing the user input {INPUT} "\
"and the assistant output {OUTPUT}\n\n"\
"For example what is not allowed is just:\n"\
"### Input:\\n{INPUT}\\n\\n### Response:\\n{OUTPUT}\\n\n\n"\
"What is required is 2x of this:\n"\
"### Input:\\n{INPUT}\\n\\n### Response:\\n{OUTPUT}\\n"\
"### Input:\\n{INPUT}\\n\\n### Response:\\n{OUTPUT}\\n"
# Check for EOS after {OUTPUT}
if tokenizer.eos_token is not None:
extra_eos_tokens.insert(0, tokenizer.eos_token)
if len(extra_eos_tokens) == 0:
raise RuntimeError(
"Unsloth: Your tokenizer does not have an EOS token? Please provide one via extra_eos_tokens!"
)
pass
# Check tokenizer types
tokenizer_name = tokenizer.name_or_path.lower()
if tokenizer_name.startswith(("unsloth/llama-3-8b-instruct", "unsloth/llama-3-70b-instruct")):
# Add <|eot_id|>
extra_eos_tokens.append("<|eot_id|>")
elif ("<|eot_id|>" in extra_eos_tokens or "<|eot_id|>" in chat_template) and \
tokenizer_name.startswith(("unsloth/llama-3-8b", "unsloth/llama-3-70b")):
# Warn
logger.warning(
"Unsloth: Base llama-3 models did not train <|eot_id|>.\n"\
"Please use the instruct version or use <|end_of_text|>"
)
pass
extra_eos_tokens = list(set(extra_eos_tokens))
count_eos = 0
for eos in extra_eos_tokens:
count_eos += len(re.findall(r"{OUTPUT}" + re.escape(eos), chat_template))
pass
# This forces you to provide 2 input and outputs
final_combined_check = False
try:
# O(N^2) search finding 2 repeatted pieces of text
j = len(chat_template)-1
at_least_one = False
while j > 0:
found = chat_template.rfind(chat_template[j:], 0, j)
if found == -1: break
j -= 1
at_least_one = True
pass
if j > 0: j += 1
else: raise RuntimeError(error_msg)
if not at_least_one: raise RuntimeError(error_msg)
# Must be equivalent to left
final_combined_check = True
# Repeatted text
instruction_response = chat_template[j:]
if instruction_response.count("{INPUT}") != 1 or instruction_response.count("{OUTPUT}") != 1:
raise RuntimeError(error_msg)
pass
# 1st System, Instruction, Output pair
left = chat_template[:j]
# 2nd Instruction, Output pair
right = chat_template[j:]
final_combined_check = left if final_combined_check else chat_template
# Isolate input
extra_eos_tokens_regex = "|".join(f"(?:{re.escape(x)})" for x in extra_eos_tokens)
if len(extra_eos_tokens_regex) != 0:
find_end = f"(?:{extra_eos_tokens_regex})?"
else:
find_end = ""
find_end = r"\{INPUT\}[\s\n]{0,}" + find_end
input_end = list(re.finditer(find_end, right))
assert(len(input_end) == 1)
input_end = input_end[0]
input_end = input_end.span(0)[1]
input_part = right[:input_end]
# Isolate output
output_part = right[input_end:]
# Isolate system
where_system = left.find(input_part)
system_part = left[:where_system if where_system != -1 else len(left)]
# Check if the user provided a correct prompt
combined = system_part + input_part + output_part
if combined != final_combined_check:
combined_changed = combined .replace('\n', '\\n')
left_changed = final_combined_check.replace('\n', '\\n')
raise RuntimeError(
"Unsloth: The prompt template you provided isn't correct. You gave:\n"\
f"{combined_changed}\n\n"\
"But we require the following:\n"\
f"{left_changed}"
)
pass
except:
ending = chat_template[chat_template.find("{OUTPUT}") + len("{OUTPUT}"):]
ending = re.escape(ending)
find_text = "{INPUT}" + ending + "(.+?{OUTPUT}" + ending + ")"
response_part = re.findall(find_text, chat_template, flags = re.DOTALL | re.MULTILINE)
response_part = response_part[0]
for j in range(1, len(response_part)):
try_find = re.escape(response_part[:j])
try: found = next(re.finditer("(" + try_find + ").+?\{INPUT\}", chat_template, flags = re.DOTALL | re.MULTILINE))
except: break
pass
separator = found.group(1)
response_start = chat_template.find(response_part)
start_instruction = chat_template[:response_start].rfind(separator)
if start_instruction == -1: start_instruction = 0
instruction_part = chat_template[start_instruction:response_start]
combined = instruction_part + response_part
where = chat_template.find(combined)
system_part = chat_template[:where]
system_part, input_part, output_part = system_part, instruction_part, response_part
pass
if count_eos == 0:
logger.warning("Unsloth: We automatically added an EOS token to stop endless generations.")
eos = extra_eos_tokens[0]
output_part = output_part + eos
pass
# Ollama modelfile parts
# Check bos_token is in system prompt
ollama_system = system_part
has_bos_token = False
always_bos_token = False
if tokenizer("A").input_ids[0] == getattr(tokenizer, "bos_token_id", None):
always_bos_token = True
if ollama_system.startswith(tokenizer.bos_token):
has_bos_token = True
ollama_system = ollama_system[len(tokenizer.bos_token):]
pass
pass
# Check system
if "{SYSTEM}" in ollama_system:
system_modelfile = "{{ if .System }}" + ollama_system.replace("{SYSTEM}", "{{ .System }}") + "{{ end }}"
else:
system_modelfile = ollama_system
pass
input_modelfile = "{{ if .Prompt }}" + input_part .replace("{INPUT}", "{{ .Prompt }}") + "{{ end }}"
output_modelfile = output_part.replace("{OUTPUT}", "{{ .Response }}")
# Ollama EOS
ollama_eos = get_ollama_eos_tokens(tokenizer, extra_eos_tokens)
ollama_eos = '\n'.join(f'PARAMETER stop "{eos}"' for eos in ollama_eos)
# Ollama modelfile
modelfile = 'FROM {__FILE_LOCATION__}\n\n'\
'TEMPLATE """' + system_modelfile + input_modelfile + output_modelfile + \
'"""\n\n' + ollama_eos
# HF Jinja Chat template
def process(part, which, content = "message['content']"):
if part.endswith(which):
part = "'" + part[:part.find(which)] + f"' + {content}"
elif part.startswith(which):
part = f"{content} + '" + part[part.find(which):] + "'"
else:
part = "'" + part.replace(which, f"' + {content} + '") + "'"
if part.startswith("'' + "): part = part[5:]
return part
pass
input_jinja = process(input_part, "{INPUT}")
output_jinja = process(output_part, "{OUTPUT}")
pass
jinja_template = \
"{% for message in loop_messages %}"\
"{% if message['role'] == 'user' %}"\
"{{ " + input_jinja + " }}"\
"{% elif message['role'] == 'assistant' %}"\
"{{ " + output_jinja + " }}"\
"{% else %}"\
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
"{% endif %}"\
"{% endfor %}"\
"{% if add_generation_prompt %}"\
"{{ '" + output_part[:output_part.find("{OUTPUT}")] + "' }}"\
"{% endif %}"
pass
# Now add system prompt to jinja
if len(system_part) != 0:
partial_system = process(system_part, "{SYSTEM}", "messages[0]['content']")
partial_system = partial_system.replace("{SYSTEM}", "")
if "{SYSTEM}" in partial_system:
if default_system_message is None:
raise RuntimeError("Unsloth: Please specify a default system message!")
pass
# Separate the BOS
if has_bos_token:
partial_system = partial_system.replace(tokenizer.bos_token, "", 1)
system_part = system_part .replace(tokenizer.bos_token, "", 1)
pass
partial_system = \
"{% if messages[0]['role'] == 'system' %}"\
"{{ " + partial_system + " }}"\
"{% set loop_messages = messages[1:] %}"
if default_system_message is not None:
full_system = system_part.replace("{SYSTEM}", default_system_message)
if "{SYSTEM}" in system_part:
modelfile += '\nSYSTEM "' + default_system_message + '"'
pass
partial_system += "{% else %}"\
"{{ '" + full_system + "' }}"\
"{% set loop_messages = messages %}"\
"{% endif %}"
else:
partial_system += "{% endif %}"
pass
jinja_template = partial_system + jinja_template
if has_bos_token:
jinja_template = "{{ bos_token }}" + jinja_template
pass
# Fix missing loop_messages
if "{% set loop_messages = messages %}" not in jinja_template:
jinja_template = jinja_template.replace(
"{% for message in loop_messages %}",
"{% for message in messages %}",
1, # Only replace the first one
)
pass
# Check if system part is the same!
jinja_template = re.sub(
r"\{\% if messages\[0\]\['role'\] \=\= 'system' \%\}\{\{ '(.+?)' \}\}"\
r"\{\% set loop\_messages \= messages\[1\:\] \%\}"\
r"\{\% else \%\}\{\{ '\1' \}\}\{\% set loop\_messages \= messages \%\}\{\% endif \%\}"\
r"\{\% for message in loop\_messages \%\}",
r"{{ '\1' }}{% for message in messages %}",
jinja_template, flags = re.MULTILINE | re.DOTALL,
)
# Check jinja tempate for bos
if always_bos_token:
if not jinja_template.startswith("{{ bos_token }}"):
jinja_template = "{{ bos_token }}" + jinja_template
pass
# Get instruction and output parts for train_on_inputs = False
input_part = input_part [:input_part .find("{INPUT}")]
output_part = output_part[:output_part.find("{OUTPUT}")]
return modelfile, jinja_template, input_part, output_part
pass
def test_construct_chat_template():
token = "hf_"
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", token = token)
chat_template = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
{SYSTEM}<|eot_id|><|start_header_id|>user<|end_header_id|>
{INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{OUTPUT}<|eot_id|><|start_header_id|>user<|end_header_id|>
{INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{OUTPUT}<|eot_id|>"""
default_system_message = \
"Below are some instructions that describe some tasks. Write responses that appropriately complete each request."
extra_eos_tokens = None
modelfile, jinja_template, _, _ = construct_chat_template(
tokenizer = tokenizer,
chat_template = chat_template,
extra_eos_tokens = extra_eos_tokens,
)
messages = [
{"role": "system", "content": "You are an assistant"},
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "It's 4."},
{"role": "user", "content": "Ok!"},
{"role": "assistant", "content": "Anything else?"},
{"role": "user", "content": "What's 2x2?"},
]
correct_output = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
tokenizer.chat_template = jinja_template
new_output = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
assert(correct_output == new_output)
pass
pass
def apply_chat_template( \
dataset,
tokenizer = None,
chat_template = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
{SYSTEM}<|eot_id|><|start_header_id|>user<|end_header_id|>
{INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{OUTPUT}<|eot_id|><|start_header_id|>user<|end_header_id|>
{INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{OUTPUT}<|eot_id|>""",
default_system_message = \
"Below are some instructions that describe some tasks. Write responses that appropriately complete each request.",
extra_eos_tokens = None,
):
"""
Creates a Ollama modelfile and a HF Jinja template from a custom
template. You must provide 2x examples of an input & output.
There is an optional system message as well.
You must use {INPUT}, {OUTPUT} twice, and {SYSTEM} is optional.
"""
modelfile, jinja_template, input_part, output_part = construct_chat_template(
tokenizer = tokenizer,
chat_template = chat_template,
default_system_message = default_system_message,
extra_eos_tokens = extra_eos_tokens,
)
def formatting_prompts_func(examples):
convos = examples["conversations"]
texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
return { "text" : texts, }
pass
tokenizer.chat_template = jinja_template
tokenizer._ollama_modelfile = modelfile
tokenizer._unsloth_input_part = input_part
tokenizer._unsloth_output_part = output_part
return dataset.map(formatting_prompts_func, batched = True,)
pass
def train_on_responses_only(
trainer,
instruction_part = None,
response_part = None,
):
"""
Trains only on responses and not on the instruction by masking out
the labels with -100 for the instruction part.
"""
tokenizer = trainer.tokenizer
if not hasattr(tokenizer, "_unsloth_input_part") or \
not hasattr(tokenizer, "_unsloth_output_part"):
if instruction_part is None or response_part is None:
raise ValueError("Unsloth: instruction_part and response_part must be given!")
pass
elif (instruction_part is not None or response_part is not None) and \
(hasattr(tokenizer, "_unsloth_input_part") or hasattr(tokenizer, "_unsloth_output_part")):
raise ValueError("Unsloth: Your tokenizer already has instruction and response parts set - do not give custom ones!")
else:
instruction_part = tokenizer._unsloth_input_part
response_part = tokenizer._unsloth_output_part
pass
instruction_ids = tokenizer(instruction_part, add_special_tokens = False).input_ids
response_ids = tokenizer(response_part, add_special_tokens = False).input_ids
instruction_length = len(instruction_ids)
response_length = len(response_ids)
max_length = max(instruction_length, response_length)
def _train_on_responses_only(examples):
input_ids_ = examples["input_ids"]
all_labels = []
for input_ids in input_ids_:
labels = [-100] * len(input_ids)
m = len(input_ids) - max_length
first_response = response_ids[0]
first_instruction = instruction_ids[0]
j = 0
while j < m:
if input_ids[j] == first_response:
if input_ids[j : j+response_length] == response_ids:
j = j + response_length
start = j
while j < m:
if input_ids[j] == first_instruction and input_ids[j : j+instruction_length] == instruction_ids:
j = j + instruction_length
labels[start : j] = input_ids[start : j]
break
elif j == (m-1):
j = m
labels[start:] = input_ids[start:]
break
pass
j += 1
pass
pass
pass
j += 1
pass
all_labels.append(labels)
pass
return { "labels" : all_labels }
pass
trainer.train_dataset = trainer.train_dataset.map(_train_on_responses_only, batched = True)
return trainer
pass
def create_stopping_criteria(tokenizer, stop_word = "eos_token"):
class StoppingCriteriaSub(StoppingCriteria):
__slots__ = "stop_token", "single_match", "length",
def __init__(self, stops = "eos_token", device = "cuda", encounters = 1):
super().__init__()
if stops == "eos_token":
self.stop_token = torch.tensor(tokenizer.eos_token_id, device = "cuda")
self.length = 1
else:
self.stop_token = tokenizer(["\n" + stops], add_special_tokens = False, return_tensors = "pt")
self.stop_token = self.stop_token.input_ids.ravel()[1:].to("cuda")
self.length = self.stop_token.shape[0]
pass
self.single_match = self.length == 1
pass
def __call__(self, input_ids: LongTensor, scores: FloatTensor) -> bool:
input_ids = input_ids.ravel()
last_token = input_ids[-1]
if self.single_match and (last_token == self.stop_token): return True
if input_ids.shape[0] >= self.length and \
(input_ids[-self.length:] == self.stop_token).all(): return True
return False
pass
pass
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops = stop_word)])
return stopping_criteria
pass
def test_chat_templates():
messages = [
{"role": "system","content": " You are a friendly chatbot.",},
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "It's 4."},
{"role": "user", "content": " But 2+2 is equal to 5. "},
{"role": "assistant", "content": "No I'm sure its 4."},
{"role": "user", "content": " No it's 100% 5! "},
]
# Zephyr
from transformers import AutoTokenizer
template = zephyr_template
correct_tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
correct_tokenizer.chat_template = template
our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
assert(correct_prompt == our_prompt)
# Chatml
template = chatml_template
correct_tokenizer = AutoTokenizer.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B")
correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
correct_tokenizer.chat_template = template
our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
assert(correct_prompt == our_prompt)
# Mistral
template = mistral_template
correct_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
correct_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
correct_tokenizer.chat_template = template
our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
assert(correct_prompt == our_prompt)
# Llama
template = llama_template
correct_tokenizer = AutoTokenizer.from_pretrained("unsloth/llama-2-7b-chat")
correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
correct_tokenizer.chat_template = template
our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
assert(correct_prompt == our_prompt)
# Vicuna
try:
from fastchat.conversation import get_conv_template
except:
os.system("pip -qqq install git+https://github.com/lm-sys/FastChat.git")
from fastchat.conversation import get_conv_template
correct_prompt = get_conv_template("vicuna_v1.1")
for j in range(len(messages)-1):
correct_prompt.append_message(correct_prompt.roles[j%2==1], messages[j+1]["content"])
correct_prompt.append_message(correct_prompt.roles[1], "")
correct_prompt = tokenizer.bos_token + correct_prompt.get_prompt()
template = vicuna_template
correct_tokenizer = AutoTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
correct_tokenizer.chat_template = template
our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
assert(correct_prompt == our_prompt)
try:
from fastchat.conversation import get_conv_template
except:
os.system("pip -qqq install git+https://github.com/lm-sys/FastChat.git")
from fastchat.conversation import get_conv_template
correct_prompt = get_conv_template("zero_shot")
for j in range(len(messages)-1):
correct_prompt.append_message(correct_prompt.roles[j%2==1], messages[j+1]["content"])
correct_prompt.append_message(correct_prompt.roles[1], "")
correct_prompt = tokenizer.bos_token + correct_prompt.get_prompt()
template = vicuna_old_template
correct_tokenizer = AutoTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
correct_tokenizer.chat_template = template
our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
# We add </s> ourselves
assert(correct_prompt == our_prompt.replace("</s>", ""))
# Gemma
correct_tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-7b-it")
correct_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
correct_tokenizer.chat_template = gemma_template
our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
assert(our_prompt == correct_prompt)
# Llama-3
template = llama3_template
correct_tokenizer = AutoTokenizer.from_pretrained("unsloth/llama-3-8b-Instruct")
correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
correct_tokenizer.chat_template = template
our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
assert(correct_prompt == our_prompt)
# Phi-3
template = phi3_template
correct_tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
correct_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
correct_tokenizer.chat_template = template
our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
assert(correct_prompt == our_prompt)
pass
def test_hf_gguf_equivalence(tokenizer, gguf_model = "./model-unsloth.F16.gguf"):
"""
Carefully checks the output of GGUF's tokenization and HF.
Can catch all tokenization bugs.
"""
import subprocess
import re
messages = [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "It's 4."},
{"role": "user", "content": " But 2+2 is equal to 5. "},
{"role": "assistant", "content": "No I'm sure its 4."},
{"role": "user", "content": " No it's 100% 5! "},
]
prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{}
### Input:
{}
### Response:
{}""".format(
"Describe the city given eloquently.", # instruction
"The lost city of Atlantis.", # input
"", # output - leave this blank for generation!
)
prompts = [ prompt, ]
if tokenizer.chat_template is not None:
prompt = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
prompt = prompt.replace("'", "") # Subprocess does not like ''
prompt = remove_special_tokens(tokenizer, prompt)
prompts.append(prompt)
pass
for prompt in prompts:
command = f"./llama.cpp/llama-cli -m {gguf_model} -n 0 --temp 0.0 --verbose-prompt "\
f"--check-tensors -p '{prompt}'"
datas = []
with subprocess.Popen(command, shell = True, stdout = subprocess.PIPE, stderr = subprocess.STDOUT, bufsize = 1) as sp:
for line in sp.stdout:
datas.append(line.decode("utf-8", errors = "replace"))
pass
gguf_tokens = "".join(datas)
# Now extract GGUF tokenization attempt
gguf_tokenized = re.findall("([\d]{1,}) \-\> \'([^\']{1,})\'", gguf_tokens, flags = re.MULTILINE)
gguf_tokenized = [(int(x[0]), x[1],) for x in gguf_tokenized]
input_ids = tokenizer(prompt).input_ids
tokens = tokenizer.batch_decode(input_ids)
hf_tokenized = list(zip(input_ids, tokens))
# Compare to Huggingface
for j, (hf_token, gguf_token) in enumerate(zip(hf_tokenized, gguf_tokenized)):
if (hf_token[0] != gguf_token[0]):
print("Failed GGUF != HF at", j)
print("HF =", hf_token)
print("GGUF =", gguf_token)
print(hf_tokenized)
print()
print(gguf_tokenized)
print()
raise RuntimeError("Failed comparing GGUF to HF.")
pass
pass
return True
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .cross_entropy_loss import fast_cross_entropy_loss
from .rms_layernorm import fast_rms_layernorm
from .rope_embedding import fast_rope_embedding, inplace_rope_embedding
from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel
from .geglu import (
geglu_exact_forward_kernel,
geglu_exact_backward_kernel,
geglu_approx_forward_kernel,
geglu_approx_backward_kernel,
)
from .fast_lora import (
get_lora_parameters,
get_lora_parameters_bias,
apply_lora_mlp_swiglu,
apply_lora_mlp_geglu_exact,
apply_lora_mlp_geglu_approx,
apply_lora_qkv,
apply_lora_o,
)
from .utils import fast_dequantize, fast_gemv, QUANT_STATE, fast_linear_forward, matmul_lora
from .flex_attention import HAS_FLEX_ATTENTION, slow_attention_softcapping
if HAS_FLEX_ATTENTION:
from .flex_attention import (
FLEX_ATTENTION_PADDING,
)
pass
try:
print("🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.")
except:
print("Unsloth: Will patch your computer to enable 2x faster free finetuning.")
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import triton
import triton.language as tl
import torch
from .utils import calculate_settings, MAX_FUSED_SIZE, triton_tanh
from transformers.models.llama.modeling_llama import logger
@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],})
@triton.jit
def _cross_entropy_forward(
logits_ptr, logits_row_stride,
loss_ptr,
logsumexp_ptr,
labels_ptr,
VOCAB_SIZE : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
DO_SOFTCAPPING : tl.constexpr,
SOFTCAP : tl.constexpr,
):
"""
Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
Pi = exp(xi) / sum(exp(xi))
CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ]
= -y [ x - log[sum(exp(x))] ]
= y * (log[sum(exp(x))] - x)
If y == 0: CE_i = 0
If y == 1: CE_i = logsumexp - x
logsumexp is also stable
Take y = log[sum(exp(x))]
exp(y) = sum(exp(x))
exp(y) = sum(exp(x - c)*exp(c)) Since e^(x-c)*e^c = e^x
exp(y) = exp(c)*sum(exp(x - c))
y = log(exp(c)*sum(exp(x - c)))
y = c + log[sum(exp(x - c))]
This means we can set c = max(x) to make sure
exp(x - c) always is exp(x - max(x)).
This ensures exp(x - max(x))'s maximum is 1 as exp(0) = 1.
"""
row_idx = tl.program_id(0)
logits_ptr += row_idx * logits_row_stride.to(tl.int64)
loss_ptr += row_idx
logsumexp_ptr += row_idx
labels_ptr += row_idx
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < VOCAB_SIZE
label_idx = tl.load(labels_ptr).to(tl.int32)
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP)
logits = logits.to(tl.float32)
c = tl.max(logits, 0)
logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
if label_idx != -100:
x = tl.load(logits_ptr + label_idx)
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP)
loss = logsumexp - x.to(tl.float32)
else:
loss = 0.0
tl.store(logsumexp_ptr, logsumexp)
tl.store(loss_ptr, loss)
pass
@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],})
@triton.jit
def _chunked_cross_entropy_forward(
logits_ptr, logits_row_stride,
loss_ptr,
logsumexp_ptr,
labels_ptr,
VOCAB_SIZE : tl.constexpr,
N_CHUNKS : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
DO_SOFTCAPPING : tl.constexpr,
SOFTCAP : tl.constexpr,
):
"""
256K vocab divided in 4 chunks
|-65536-| |-65536-| |-65536-| |-65536-|
|-------| |-------| |-------| |-------|
|-------| |-------| |-------| |-------|
If y == 0: CE_i = 0
If y == 1: CE_i = logsumexp - x
Notice we can do logsumexp for each chunk and then
logsumexp[chunk_sum(logsumexp)] == logsumexp
chunk_sum = log[chunk_sum(logsumexp)]
= log[exp(logsumexp(a)) + ... + exp(logsumexp(z))]
= log[exp(log[sum(exp(a))]) + ... + exp(log[sum(exp(z))])]
= log[sum(exp(a)) + ... + sum(exp(z))]
= logsumexp(x)
This means we can perform a logsumexp for each chunk, then do a
final logsumexp reduction!
Ie do: logsumexp(chunked_logsumexp) - x
"""
row_idx = tl.program_id(0)
chunk_idx = tl.program_id(1)
logits_ptr += row_idx * logits_row_stride.to(tl.int64)
loss_ptr += row_idx
logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx
labels_ptr += row_idx
col_offsets = chunk_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = col_offsets < VOCAB_SIZE
label_idx = tl.load(labels_ptr).to(tl.int32)
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP)
logits = logits.to(tl.float32)
c = tl.max(logits, 0)
logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
if chunk_idx == 0:
# logsumexp(chunked_logsumexp) - x
# Do the -x separately
if label_idx != -100:
x = tl.load(logits_ptr + label_idx).to(tl.float32)
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP)
loss = -1.0 * x.to(tl.float32)
else:
loss = 0.0
tl.store(loss_ptr, loss)
pass
tl.store(logsumexp_ptr, logsumexp)
pass
@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],})
@triton.jit
def _cross_entropy_backward(
logits_ptr, logits_row_stride,
dloss_ptr, dloss_row_stride,
logsumexp_ptr,
labels_ptr,
VOCAB_SIZE : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
DO_SOFTCAPPING : tl.constexpr,
SOFTCAP : tl.constexpr,
):
"""
CE_i = -y log(P) = y * (log[sum(exp(x))] - x)
dC/dx = d/dx (y * log[sum(exp(x))] - x * y)
From https://en.wikipedia.org/wiki/LogSumExp
d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x)
dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y)
dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick
dC/dx = y * exp[x - logsumexp] - d/dx (x * y)
If y == 0: dC/dx = 0
If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1
If y == 1 and x != label: dC/dx = exp[x - logsumexp]
"""
row_idx = tl.program_id(0)
block_idx = tl.program_id(1)
logits_ptr += row_idx * logits_row_stride.to(tl.int64)
dloss_ptr += row_idx * dloss_row_stride
col_offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = col_offsets < VOCAB_SIZE
label_idx = tl.load(labels_ptr + row_idx).to(tl.int32)
if label_idx != -100:
dloss = tl.load(dloss_ptr)
else:
dloss = 0.0
x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING:
# d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
partial = triton_tanh(x / SOFTCAP)
x = SOFTCAP * partial
pass
logsumexp = tl.load(logsumexp_ptr + row_idx)
y = tl.exp(x.to(tl.float32) - logsumexp)
y = tl.where(
col_offsets == label_idx,
y - 1.0, # exp(x - logsumexp) - 1
y, # exp(x - logsumexp)
)
if DO_SOFTCAPPING:
# d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
y = y * (1.0 - partial*partial)
pass
# If y == 0: dC/dx = 0 ==> we already masked it to be = 0, so dloss = 0.
tl.store(logits_ptr + col_offsets, dloss * y, mask = mask)
pass
MAX_FUSED_SIZE = 65536 # 2**16
class Fast_CrossEntropyLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, logits, labels, logit_softcapping = 0):
n_rows, vocab_size = logits.shape
div, mod = divmod(vocab_size, MAX_FUSED_SIZE)
n_chunks = div + (mod != 0)
losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
DO_SOFTCAPPING = (logit_softcapping != 0)
if n_chunks == 1:
# For small vocabs <= 65336 like Llama, Mistral
BLOCK_SIZE, num_warps = calculate_settings(vocab_size)
logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
_cross_entropy_forward[(n_rows,)](
logits, logits.stride(0),
losses,
logsumexp,
labels,
VOCAB_SIZE = vocab_size,
BLOCK_SIZE = BLOCK_SIZE,
DO_SOFTCAPPING = DO_SOFTCAPPING,
SOFTCAP = logit_softcapping,
num_warps = num_warps,
)
else:
# For large vocabs > 65336 like Gemma 256K
logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = "cuda:0")
_chunked_cross_entropy_forward[(n_rows, n_chunks,)](
logits, logits.stride(0),
losses,
logsumexp,
labels,
VOCAB_SIZE = vocab_size,
N_CHUNKS = n_chunks,
BLOCK_SIZE = MAX_FUSED_SIZE,
DO_SOFTCAPPING = DO_SOFTCAPPING,
SOFTCAP = logit_softcapping,
num_warps = 32,
)
# logsumexp(chunked_logsumexp) - x
# Do the -x separately
logsumexp = torch.logsumexp(logsumexp, dim = 1) # Row sum
losses += logsumexp
losses.masked_fill_(labels == -100, 0) # Don't forget to mask padding out!
pass
ctx.save_for_backward(logits, logsumexp, labels)
ctx.DO_SOFTCAPPING = DO_SOFTCAPPING
ctx.logit_softcapping = logit_softcapping
return losses
pass
@staticmethod
def backward(ctx, dlosses):
logits, logsumexp, labels = ctx.saved_tensors
n_rows, vocab_size = logits.shape
BLOCK_SIZE = 4096
div, mod = divmod(vocab_size, BLOCK_SIZE)
n_blocks = div + (mod != 0)
_cross_entropy_backward[(n_rows, n_blocks,)](
logits, logits.stride(0),
dlosses, dlosses.stride(0),
logsumexp,
labels,
VOCAB_SIZE = vocab_size,
BLOCK_SIZE = BLOCK_SIZE,
DO_SOFTCAPPING = ctx.DO_SOFTCAPPING,
SOFTCAP = ctx.logit_softcapping,
num_warps = 8,
)
return logits, None, None,
pass
pass
def fast_cross_entropy_loss(logits, labels, logit_softcapping = 0):
"""
Arguments:
logits: (batch, seq_len, vocab_size)
labels: (batch, seq_len,)
Returns:
losses: float
"""
batch, seq_len, d = logits.shape
assert(labels.shape == (batch, seq_len))
loss = Fast_CrossEntropyLoss.apply(
logits.view(batch*seq_len, d),
labels.view(-1),
logit_softcapping,
)
n_items = torch.count_nonzero(labels != -100)
return loss.sum() / n_items
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from .utils import (
fast_dequantize,
QUANT_STATE,
get_lora_parameters,
get_lora_parameters_bias,
matmul_lora,
torch_amp_custom_fwd,
torch_amp_custom_bwd,
)
class LoRA_MLP(torch.autograd.Function):
"""
### LoRA weights
G = G + Ag @ Bg
U = U + Au @ Bu
W = W + Aw @ Bw
### SwiGLU(X)
e = X @ G
f = e * sigmoid(e)
g = X @ U
h = f * g
i = h @ W
### Backpropagation chain rule
See our blog post for more details
df = sigmoid(e) * (1 - f) + f
dC/dW = h.T @ dY
dC/dU = X.T @ (D @ W.T * f)
dC/dG = X.T @ (D @ W.T * df * g)
### Down projection LoRA weights
dC/dAw = dC/dW @ B.T
dC/dBw = A.T @ dC/dW
dC/dAw = h.T @ dY @ B.T
dC/dBw = A.T @ h.T @ dY
### Up projection LoRA weights
dC/dAu = X.T @ (D @ W.T * f) @ B.T
dC/dBu = A.T @ X.T @ (D @ W.T * f)
### Gate projection LoRA weights
dC/dAg = X.T @ (D @ W.T * df * g) @ B.T
dC/dBg = A.T @ X.T @ (D @ W.T * df * g)
Don't forget to see our blog post for more details!
"""
@staticmethod
@torch_amp_custom_fwd
def forward(ctx, X : torch.Tensor,
gateW, gateW_quant, gateA, gateB, gateS,
upW, upW_quant, upA, upB, upS,
downW, downW_quant, downA, downB, downS,
_forward_function, _backward_function,):
dtype = X.dtype
e = matmul_lora(X, gateW, gateW_quant, gateA, gateB, gateS)
g = matmul_lora(X, upW, upW_quant, upA, upB, upS)
h = _forward_function(e, g)
i = matmul_lora(h, downW, downW_quant, downA, downB, downS)
ctx.custom_saved_tensors = (
gateW, gateW_quant, gateS,
upW, upW_quant, upS,
downW, downW_quant, downS,
_backward_function,
)
ctx.save_for_backward(gateA, gateB, upA, upB, downA, downB,
X, e, g)
return i
pass
@staticmethod
@torch_amp_custom_bwd
def backward(ctx, dY : torch.Tensor):
gateW, gateW_quant, gateS, upW, upW_quant, upS, downW, downW_quant, downS, \
_backward_function = ctx.custom_saved_tensors
gateA, gateB, upA, upB, downA, downB, \
X, e, g = ctx.saved_tensors
gateA, gateB, upA, upB, downA, downB = \
gateA.t(), gateB.t(), upA.t(), upB.t(), downA.t(), downB.t()
batch, seq_len, hd = X.shape
dY = dY.view(-1, dY.shape[-1])
X = X .view(-1, X .shape[-1])
e = e .view(-1, e .shape[-1])
g = g .view(-1, g .shape[-1])
dtype = X.dtype
DW = matmul_lora(dY, downW.t(), downW_quant, downB, downA, downS)
DW, e, g = _backward_function(DW, e, g)
h, df, de = DW, e, g
# Down projection LoRA weights
d_downA = h.t() @ (dY @ downB.t())
d_downB = (downA.t() @ h.t()) @ dY
d_downA *= downS
d_downB *= downS
# Up projection LoRA weights
d_upA = X.t() @ (df @ upB.t())
d_upB = (upA.t() @ X.t()) @ df
d_upA *= upS
d_upB *= upS
# Gate projection LoRA weights
d_gateA = X.t() @ (de @ gateB.t())
d_gateB = (gateA.t() @ X.t()) @ de
d_gateA *= gateS
d_gateB *= gateS
# dX = matmul_lora(df, upW.t(), upW_quant, upB, upA, upS)
# dX += matmul_lora(de, gateW.t(), gateW_quant, gateB, gateA, gateS)
upW = fast_dequantize(upW.t(), upW_quant)
dX = torch.matmul(df, upW.t(), out = X)
del upW
dX += df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t())
gateW = fast_dequantize(gateW.t(), gateW_quant)
dX += de @ gateW.t()
del gateW
dX += de @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t())
# gateW, gateW_quant, gateA, gateB, gateS,
# upW, upW_quant, upA, upB, upS,
# downW, downW_quant, downA, downB, downS,
return dX.view(batch, seq_len, hd), \
None, None, d_gateA.t(), d_gateB.t(), None, \
None, None, d_upA.t(), d_upB.t(), None, \
None, None, d_downA.t(), d_downB.t(), None, \
None, None, # _backward and _forward
pass
pass
from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel
def apply_lora_mlp_swiglu(self, X):
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
out = LoRA_MLP.apply(X,
gateW, gateW_quant, gateA, gateB, gateS,
upW, upW_quant, upA, upB, upS,
downW, downW_quant, downA, downB, downS,
swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel,)
return out
pass
from .geglu import geglu_exact_forward_kernel, geglu_exact_backward_kernel
def apply_lora_mlp_geglu_exact(self, X):
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
out = LoRA_MLP.apply(X,
gateW, gateW_quant, gateA, gateB, gateS,
upW, upW_quant, upA, upB, upS,
downW, downW_quant, downA, downB, downS,
geglu_exact_forward_kernel, geglu_exact_backward_kernel,)
return out
pass
from .geglu import geglu_approx_forward_kernel, geglu_approx_backward_kernel
def apply_lora_mlp_geglu_approx(self, X):
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
out = LoRA_MLP.apply(X,
gateW, gateW_quant, gateA, gateB, gateS,
upW, upW_quant, upA, upB, upS,
downW, downW_quant, downA, downB, downS,
geglu_approx_forward_kernel, geglu_approx_backward_kernel,)
return out
pass
class LoRA_QKV(torch.autograd.Function):
"""
### LoRA weights
Wq = Wq + Aq @ Bq
Wk = Wk + Ak @ Bk
Wv = Wv + Av @ Bv
Q = X @ Wq = X @ Wq + X @ Aq @ Bq
K = X @ Wk = X @ Wk + X @ Ak @ Bk
V = X @ Wv = X @ Wv + X @ Av @ Bv
### Backpropagation chain rule
See our blogpost for more details.
dC/dWq = X.T @ D(Wq)
dC/dWk = X.T @ D(Wk)
dC/dWv = X.T @ D(Wv)
We then sum them all find dC/dX
### Q projection LoRA weights
dC/dAq = X.T @ D(Wq) @ B.T
dC/dBq = A.T @ X.T @ D(Wq)
### K projection LoRA weights
dC/dAk = X.T @ D(Wk) @ B.T
dC/dBk = A.T @ X.T @ D(Wk)
### V projection LoRA weights
dC/dAv = X.T @ D(Wv) @ B.T
dC/dBv = A.T @ X.T @ D(Wv)
"""
@staticmethod
@torch_amp_custom_fwd
def forward(ctx, X : torch.Tensor,
QW, QW_quant, QA, QB, QS,
KW, KW_quant, KA, KB, KS,
VW, VW_quant, VA, VB, VS,):
dtype = X.dtype
Q = matmul_lora(X, QW, QW_quant, QA, QB, QS)
K = matmul_lora(X, KW, KW_quant, KA, KB, KS)
V = matmul_lora(X, VW, VW_quant, VA, VB, VS)
ctx.custom_saved_tensors = (
QW, QW_quant, QS,
KW, KW_quant, KS,
VW, VW_quant, VS,
)
ctx.save_for_backward(X, QA, QB, KA, KB, VA, VB,)
return Q, K, V
pass
@staticmethod
@torch_amp_custom_bwd
def backward(ctx, dQ, dK, dV):
QW, QW_quant, QS, KW, KW_quant, KS, VW, VW_quant, VS = \
ctx.custom_saved_tensors
X, QA, QB, KA, KB, VA, VB, = ctx.saved_tensors
QA, QB, KA, KB, VA, VB = \
QA.t(), QB.t(), KA.t(), KB.t(), VA.t(), VB.t()
batch, seq_len, hd = X.shape
dQ = dQ.view(-1, dQ.shape[-1])
dK = dK.reshape(-1, dK.shape[-1]) # view doesn't work on K.T
dV = dV.view(-1, dV.shape[-1])
X = X .view(-1, X .shape[-1])
dtype = X.dtype
### Weight projection LoRA weights
# See our blogpost for more details.
# Q Projection
d_QA = X.t() @ (dQ @ QB.t())
d_QB = (QA.t() @ X.t()) @ dQ
d_QA *= QS
d_QB *= QS
# K Projection
d_KA = X.t() @ (dK @ KB.t())
d_KB = (KA.t() @ X.t()) @ dK
d_KA *= KS
d_KB *= KS
# V Projection
d_VA = X.t() @ (dV @ VB.t())
d_VB = (VA.t() @ X.t()) @ dV
d_VA *= VS
d_VB *= VS
# Combine derivatives to find dX
# dQ
QW = fast_dequantize(QW.t(), QW_quant)
dX = torch.matmul(dQ, QW.t(), out = X)
del QW
dX += (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t()))
# dK
KW = fast_dequantize(KW.t(), KW_quant)
dX += dK @ KW.t()
del KW
dX += dK @ KB.to(dtype).t() @ (KS * KA.to(dtype).t())
# dV
VW = fast_dequantize(VW.t(), VW_quant)
dX += dV @ VW.t()
del VW
dX += dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t())
# QW, QW_quant, QA, QB, QS,
# KW, KW_quant, KA, KB, KS,
# VW, VW_quant, VA, VB, VS,
return dX.view(batch, seq_len, hd), \
None, None, d_QA.t(), d_QB.t(), None, \
None, None, d_KA.t(), d_KB.t(), None, \
None, None, d_VA.t(), d_VB.t(), None
pass
pass
def apply_lora_qkv(self, X):
QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
Q, K, V = LoRA_QKV.apply(X,
QW, QW_quant, QA, QB, QS,
KW, KW_quant, KA, KB, KS,
VW, VW_quant, VA, VB, VS,
)
return Q, K, V
pass
class LoRA_W(torch.autograd.Function):
"""
### LoRA weights
Wq = Wq + Aq @ Bq
Wk = Wk + Ak @ Bk
Wv = Wv + Av @ Bv
Q = X @ Wq = X @ Wq + X @ Aq @ Bq
K = X @ Wk = X @ Wk + X @ Ak @ Bk
V = X @ Wv = X @ Wv + X @ Av @ Bv
### Backpropagation chain rule
dC/dWq = X.T @ D(Wq)
dC/dWk = X.T @ D(Wk)
dC/dWv = X.T @ D(Wv)
### Q projection LoRA weights
dC/dAq = X.T @ D(Wq) @ B.T
dC/dBq = A.T @ X.T @ D(Wq)
### K projection LoRA weights
dC/dAk = X.T @ D(Wk) @ B.T
dC/dBk = A.T @ X.T @ D(Wk)
### V projection LoRA weights
dC/dAv = X.T @ D(Wv) @ B.T
dC/dBv = A.T @ X.T @ D(Wv)
"""
@staticmethod
@torch_amp_custom_fwd
def forward(ctx, X : torch.Tensor,
W, W_quant, A, B, S):
dtype = X.dtype
XW = matmul_lora(X, W, W_quant, A, B, S)
ctx.custom_saved_tensors = (W, W_quant, S,)
ctx.save_for_backward(A, B, X)
return XW
pass
@staticmethod
@torch_amp_custom_bwd
def backward(ctx, dY : torch.Tensor):
W, W_quant, S = ctx.custom_saved_tensors
A, B, X = ctx.saved_tensors
A, B = A.t(), B.t()
batch, seq_len, hd = X.shape
dY = dY.reshape(-1, dY.shape[-1]) # Must be reshape
X = X .reshape(-1, X .shape[-1]) # Must be reshape
dtype = X.dtype
### Weight projection LoRA weights
# Weight projection
d_A = X.t() @ (dY @ B.t())
d_B = (A.t() @ X.t()) @ dY
d_A *= S
d_B *= S
# Get derivative for dX
W = fast_dequantize(W.t(), W_quant)
dX = dY @ W.t()
del W
dX += dY @ B.to(dtype).t() @ (S * A.to(dtype).t())
# W, W_quant, A, B, S
return dX.view(batch, seq_len, hd), \
None, None, d_A.t(), d_B.t(), None
pass
pass
def apply_lora_o(self, X):
OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS)
return O
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from functools import lru_cache
from transformers.models.llama.modeling_llama import logger
torch_compile_options = {
"epilogue_fusion" : True,
"max_autotune" : True,
"shape_padding" : True,
"trace.enabled" : False, # Output Triton kernel outputs!
"triton.cudagraphs" : False,
}
# Flex Attention supported from torch 2.5 onwards only
import torch.nn
if hasattr(torch.nn, "attention"):
import torch.nn.attention
if hasattr(torch.nn.attention, "flex_attention"):
import torch.nn.attention.flex_attention
from torch.nn.attention.flex_attention import flex_attention
from torch.nn.attention.flex_attention import create_block_mask
FLEX_ATTENTION_PADDING = getattr(
torch.nn.attention.flex_attention,
"_DEFAULT_SPARSE_BLOCK_SIZE",
1,
)
flex_attention = torch.compile(flex_attention, dynamic = False)
HAS_FLEX_ATTENTION = True
else:
HAS_FLEX_ATTENTION = False
pass
else:
HAS_FLEX_ATTENTION = False
pass
# Logit softcapping
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
n_heads = self.num_heads
head_dim = self.head_dim
n_kv_heads = self.num_key_value_heads
n_groups = self.num_key_value_groups
# Grouped query attention
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
K = K.reshape(bsz, n_heads, q_len, head_dim)
V = V.reshape(bsz, n_heads, q_len, head_dim)
# See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
# Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below
# We default to using the config file itself
# s = self.config.hidden_size // self.config.num_attention_heads
s = self.config.query_pre_attn_scalar
t = self.config.attn_logit_softcapping
Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly
A = torch.matmul(Q, K.transpose(2, 3))
A = t * torch.tanh(A / t) # Logit softcapping
A += causal_mask[:q_len, :q_len]
# Much slower in torch compile!
# A.masked_fill_(causal_mask[:q_len, :q_len], -float("inf"))
A = torch.nn.functional.softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype)
A = torch.matmul(A, V)
A = A.transpose(1, 2).contiguous()
A = A.reshape(bsz, q_len, n_heads*head_dim)
return A
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import triton
import triton.language as tl
import torch
from .utils import calculate_settings, triton_tanh
@triton.jit
def _exact_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
block_idx = tl.program_id(0)
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
# f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
# h = f * up
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
f_row = 0.5 * e_row * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0)
f_row = f_row.to(g_row.dtype) # Exact copy from HF
h_row = f_row * g_row
# Store h
tl.store(h + offsets, h_row, mask = mask)
pass
def geglu_exact_forward_kernel(gate, up):
batch, seq_len, hd = gate.shape
n_elements = gate.numel()
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda:0")
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_exact_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
return out
pass
@triton.jit
def _exact_backward_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
"""
f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
h = f * up
df/de (with help of Wolfram :)
df/de = 1/2 * (1 + erf(1/sqrt(2) * e)) + 1/sqrt(2*pi) * e * exp(-1/2 * e^2)
Reuse via
f = 1/2 * (1 + erf(1/sqrt(2) * e)) * e
"""
block_idx = tl.program_id(0)
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
# Break e_row away for re-use
# f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
f_partial_row = 0.5 * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0)
f_row = f_partial_row * e_row
f_row = f_row.to(DW_row.dtype)
# h = f * g
h_row = f_row * g_row
# df = DW * f
df_row = DW_row * f_row
# dg = DW * g
dg_row = DW_row * g_row
# df/de = 1/2 * (1 + erf(1/sqrt(2) * e)) + 1/sqrt(2*pi) * e * exp(-1/2 * e^2)
t = 0.3989422804014327 # 1/sqrt(2*pi)
df_de = f_partial_row + t * e_row * tl.exp(-0.5 * e_row * e_row)
de_row = dg_row.to(tl.float32) * df_de
de_row = de_row.to(DW_row.dtype)
# Store derivatives in buffers
tl.store(DW + offsets, h_row, mask = mask) # h = f * g
tl.store(e + offsets, df_row, mask = mask) # df = DW * f
tl.store(g + offsets, de_row, mask = mask) # de
pass
def geglu_exact_backward_kernel(DW, e, g):
batch_seq_len, hd = e.shape
n_elements = e.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_exact_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
return DW, e, g
pass
@triton.jit
def _approx_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
block_idx = tl.program_id(0)
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
# f = 1/2 * e * (1 + tanh( sqrt(2/pi) * (x + 0.044715 * x^3 ) ))
# f = 1/2 * e * (1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ))
# h = f * up
s = 0.7978845608028654 # math.sqrt(2 / math.pi)
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
f_row = 0.5 * e_row * (
triton_tanh(s * e_row * (1.0 + 0.044715 * e_row * e_row)) \
+ 1.0
)
f_row = f_row.to(g_row.dtype) # Exact copy from HF
h_row = f_row * g_row
# Store h
tl.store(h + offsets, h_row, mask = mask)
pass
def geglu_approx_forward_kernel(gate, up):
batch, seq_len, hd = gate.shape
n_elements = gate.numel()
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda:0")
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_approx_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
return out
pass
@triton.jit
def _approx_backward_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
"""
f = 1/2 * e * (1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ))
h = f * up
df/de (with help from https://arxiv.org/pdf/2305.12073.pdf :))
df/de = 1/2 * [1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) )] +
1/2 * sech^2 [ sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ] * \
( sqrt(2/pi) * x * (1 + 0.044715 * x^2 * 3 ) )
Notice sech^2(x) = 1 - tanh^2(x)
So reuse tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) )
See https://www.desmos.com/calculator/nqprfoni6x
"""
block_idx = tl.program_id(0)
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
# See https://www.desmos.com/calculator/nqprfoni6x
s = 0.7978845608028654 # math.sqrt(2 / math.pi)
a = s * e_row # a = sqrt(2 / pi) * x
b = a * 0.044715 * e_row * e_row # b = a * 0.044715 * x^2
T = 1.0 + triton_tanh(a + b)
T2 = 0.5 * T
# Q = 0.5 * -T * (T - 2.0) * (a + 3.0 * b)
Q2 = -T2 * (T - 2.0) * (a + 3.0 * b)
df_de = T2 + Q2 # 1/2 * (T + Q)
# f = 1/2 * e * (1 + tanh( sqrt(2/pi) * (x + 0.044715 * x^3 ) ))
f_row = T2 * e_row
f_row = f_row.to(DW_row.dtype)
# h = f * g
h_row = f_row * g_row
# df = DW * f
df_row = DW_row * f_row
# dg = DW * g
dg_row = DW_row * g_row
de_row = dg_row.to(tl.float32) * df_de
de_row = de_row.to(DW_row.dtype)
# Store derivatives in buffers
tl.store(DW + offsets, h_row, mask = mask) # h = f * g
tl.store(e + offsets, df_row, mask = mask) # df = DW * f
tl.store(g + offsets, de_row, mask = mask) # de
pass
def geglu_approx_backward_kernel(DW, e, g):
batch_seq_len, hd = e.shape
n_elements = e.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_approx_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
return DW, e, g
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import triton
import triton.language as tl
import torch
from .utils import calculate_settings
@triton.jit
def _rms_layernorm_forward(
Y, Y_row_stride,
X, X_row_stride,
W, W_row_stride,
r, r_row_stride,
n_cols, eps,
BLOCK_SIZE : tl.constexpr
):
"""
Fast RMS Layernorm kernel
Inspiration from a Triton tutorial:
https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
"""
row_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
Y += row_idx * Y_row_stride
X += row_idx * X_row_stride
r += row_idx * r_row_stride
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
W_row = tl.load(W + col_offsets, mask = mask, other = 0)#.to(tl.float32)
row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
inv_var = tl.math.rsqrt(row_var + eps)
tl.store(r, inv_var)
normed = X_row * inv_var
normed = normed.to(W_row.dtype) # Exact copy from HF
output = normed * W_row
tl.store(Y + col_offsets, output, mask = mask)
pass
@triton.heuristics({"GEMMA": lambda args: args["GEMMA"],})
@triton.jit
def _rms_layernorm_backward(
dY, dY_row_stride,
X, X_row_stride,
W, W_row_stride,
r, r_row_stride,
dW, dW_row_stride,
n_cols, eps,
GEMMA : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
):
"""
Fast RMS Layernorm kernel for the backward pass
Inspiration from a Triton tutorial:
https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
"""
row_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
dY += row_idx * dY_row_stride
X += row_idx * X_row_stride
r += row_idx * r_row_stride
dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
# Get saved row variance
inv_var = tl.load(r).to(tl.float32)
normed = X_row * inv_var
if GEMMA: dY_W = dY_row * (W_row + 1.0)
else: dY_W = dY_row * W_row
rowsum_dY_normed = tl.sum(dY_W * normed, axis = 0)
output = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed)
tl.store(dY + col_offsets, output, mask = mask)
pass
@triton.jit
def _gemma_rms_layernorm_forward(
Y, Y_row_stride,
X, X_row_stride,
W, W_row_stride,
r, r_row_stride,
n_cols, eps,
BLOCK_SIZE : tl.constexpr,
):
# Copies https://github.com/google-deepmind/gemma/blob/main/gemma/layers.py#L31
# and https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L33
# exactly. Essentially all in float32!
row_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
Y += row_idx * Y_row_stride
X += row_idx * X_row_stride
r += row_idx * r_row_stride
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
inv_var = tl.math.rsqrt(row_var + eps)
tl.store(r, inv_var)
normed = X_row * inv_var
output = normed * (W_row + 1.0)
tl.store(Y + col_offsets, output, mask = mask)
pass
class Fast_RMS_Layernorm(torch.autograd.Function):
@staticmethod
def forward(ctx, X, W, eps, gemma = False):
shape = X.shape
dim = shape[-1]
X = X.view(-1, dim)
n_rows, n_cols = X.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda:0")
r = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward
fx[(n_rows,)](
Y, Y.stride(0),
X, X.stride(0),
W, W.stride(0),
r, r.stride(0),
n_cols, eps,
BLOCK_SIZE = BLOCK_SIZE,
num_warps = num_warps,
)
ctx.eps = eps
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.GEMMA = gemma
ctx.save_for_backward(X, W, r)
return Y.view(*shape)
pass
@staticmethod
def backward(ctx, dY):
shape = dY.shape
dim = shape[-1]
dY = dY.view(-1, dim)
X, W, r = ctx.saved_tensors
n_rows, n_cols = dY.shape
dW = X
_rms_layernorm_backward[(n_rows,)](
dY, dY.stride(0),
X, X .stride(0),
W, W .stride(0),
r, r .stride(0),
dW, dW.stride(0),
n_cols, ctx.eps,
GEMMA = ctx.GEMMA,
BLOCK_SIZE = ctx.BLOCK_SIZE,
num_warps = ctx.num_warps,
)
dX = dY.view(*shape)
return dX, None, None, None
pass
pass
def fast_rms_layernorm(layernorm, X, gemma = False):
W = layernorm.weight
eps = layernorm.variance_epsilon
out = Fast_RMS_Layernorm.apply(X, W, eps, gemma)
return out
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import triton
import triton.language as tl
import torch
from .utils import calculate_settings
ROPE_GROUP_SIZE = 4
@triton.heuristics({"BACKWARD_PASS": lambda args: args["BACKWARD_PASS"],})
@triton.jit
def _rope_embedding(
Q, Q_row_stride,
cos, cos_row_stride,
sin, sin_row_stride,
seqlen,
head_dim : tl.constexpr,
n_heads : tl.constexpr,
BACKWARD_PASS : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
):
"""
Calculates the RoPE Embedding quickly
RoPE is Q * cos + rotate_half(Q) * sin
See our blog post for more info
"""
ROPE_GROUP_SIZE = 4
row_position = tl.program_id(0)
group_head_position = tl.program_id(1)
col_offsets = tl.arange(0, BLOCK_SIZE)
half_head_dim = head_dim // 2
mask = col_offsets < half_head_dim
sin1 = tl.load(sin + (row_position % seqlen)*sin_row_stride + \
half_head_dim*0 + col_offsets, mask = mask, other = 0)
cos1 = tl.load(cos + (row_position % seqlen)*cos_row_stride + \
half_head_dim*0 + col_offsets, mask = mask, other = 0)
if BACKWARD_PASS:
# See our blog post for more info.
sin1 = -sin1
pass
# [TODO] Autotune ROPE_GROUP_SIZE to be 1, 2, 4, 8
head_start = group_head_position * ROPE_GROUP_SIZE
head_end = min((head_start + ROPE_GROUP_SIZE), n_heads)
# 10% Faster kernel from [HuyNguyen-hust](https://github.com/unslothai/unsloth/pull/238)
for k in range(head_start, head_end):
offs_q1 = row_position * Q_row_stride + k * head_dim + col_offsets
offs_q2 = row_position * Q_row_stride + k * head_dim + col_offsets + half_head_dim
# For Gemma - sometimes RoPE must be done in float32 and not bfloat16
Q1 = tl.load(Q + offs_q1, mask = mask, other = 0).to(sin1.dtype)
Q2 = tl.load(Q + offs_q2, mask = mask, other = 0).to(sin1.dtype)
tl.store(Q + offs_q1, Q1*cos1 - Q2*sin1, mask = mask)
tl.store(Q + offs_q2, Q2*cos1 + Q1*sin1, mask = mask)
pass
pass
class Fast_RoPE_Embedding(torch.autograd.Function):
@staticmethod
def forward(ctx, Q, cos, sin):
cos, sin = cos.squeeze(), sin.squeeze()
batch, seq_len, n_heads, head_dim = Q.shape
Q = Q.view(batch*seq_len, n_heads*head_dim)
n_rows, n_cols = Q.shape
assert(seq_len <= cos.shape[0])
# [TODO] Changing blocksize to head_dim//2 seems to have
# some concurrency / un-deterministic issues.
BLOCK_SIZE, num_warps = calculate_settings(head_dim//2) # (head_dim//2)
# group_size = 4 # 4 or 8, too large group_size can hurt performance.
div, mod = divmod(n_heads, ROPE_GROUP_SIZE)
n_groups = div + (mod != 0)
_rope_embedding[(n_rows, n_groups, )](
Q, Q.stride(0),
cos, cos.stride(0),
sin, sin.stride(0),
seq_len,
head_dim, n_heads,
BACKWARD_PASS = False,
BLOCK_SIZE = BLOCK_SIZE,
num_warps = num_warps,
)
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.n_groups = n_groups
ctx.cos = cos
ctx.sin = sin
return Q.view(batch, seq_len, n_heads, head_dim)
pass
@staticmethod
def backward(ctx, dY):
batch, seq_len, n_heads, head_dim = dY.shape
dY = dY.reshape(batch*seq_len, n_heads*head_dim)
# Must be reshape not view
n_rows, n_cols = dY.shape
cos = ctx.cos
sin = ctx.sin
_rope_embedding[(n_rows, ctx.n_groups, )](
dY, dY .stride(0),
cos, cos.stride(0),
sin, sin.stride(0),
seq_len, head_dim, n_heads,
BACKWARD_PASS = True,
BLOCK_SIZE = ctx.BLOCK_SIZE,
num_warps = ctx.num_warps,
)
dY = dY.view(batch, seq_len, n_heads, head_dim)
return dY, None, None,
pass
pass
def fast_rope_embedding(Q, K, cos, sin):
Q = Fast_RoPE_Embedding.apply(Q.transpose(1, 2), cos, sin).transpose(1, 2)
K = Fast_RoPE_Embedding.apply(K.transpose(1, 2), cos, sin).transpose(1, 2)
return Q, K
pass
class Slow_RoPE_Embedding(torch.autograd.Function):
@staticmethod
def forward(ctx, Q, cos, sin, position_ids):
if position_ids is not None:
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
# Q * cos + rotate_half(Q) * sin
half = Q.shape[-1]//2
RH_Q = torch.cat((-Q[..., half:], Q[..., :half]), dim = -1)
Q *= cos
Q.addcmul_(RH_Q, sin)
# RH_Q *= sin
# Q += RH_Q
ctx.save_for_backward(cos, sin)
return Q
pass
@staticmethod
def backward(ctx, dY):
cos, sin = ctx.saved_tensors
# Q * cos + rotate_half.T(Q) * sin
half = dY.shape[-1]//2
RH_dY = torch.cat((dY[..., half:], -dY[..., :half]), dim = -1)
dY *= cos
dY.addcmul_(RH_dY, sin)
# RH_dY *= sin
# dY += RH_dY
return dY, None, None, None
pass
pass
def inplace_rope_embedding(Q, K, cos, sin, position_ids):
Q = Slow_RoPE_Embedding.apply(Q, cos, sin, position_ids)
K = Slow_RoPE_Embedding.apply(K, cos, sin, position_ids)
return Q, K
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import triton
import triton.language as tl
import torch
from .utils import calculate_settings
@triton.jit
def _fg_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
block_idx = tl.program_id(0)
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
# f = e * sigmoid(e)
f_row = e_row * tl.sigmoid(e_row) # e_row / (1 + tl.exp(-e_row))
f_row = f_row.to(g_row.dtype) # Exact copy from HF
# h = f * g
h_row = f_row * g_row
# Store h
tl.store(h + offsets, h_row, mask = mask)
pass
def swiglu_fg_kernel(e, g):
batch, seq_len, hd = e.shape
n_elements = e.numel()
h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = "cuda:0")
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_fg_kernel[grid](e, g, h, n_elements, BLOCK_SIZE = 1024,)
return h
pass
@triton.jit
def _DWf_DW_dfg_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
"""
e = e.float()
se = 1.0 / (1.0 + torch.exp(-e))
f = (se * e).to(dtype)
h = f * g
df = DW * f
dg = DW * g
de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)
"""
block_idx = tl.program_id(0)
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
# e = e.float()
# se = 1.0 / (1.0 + torch.exp(-e))
se_row = tl.sigmoid(e_row) # 1.0 / (1.0 + tl.exp(-e_row))
# f = (se * e).to(dtype)
f_row = se_row * e_row
f_row = f_row.to(DW_row.dtype)
# h = f * g
h_row = f_row * g_row
# df = DW * f
df_row = DW_row * f_row
# dg = DW * g
dg_row = DW_row * g_row
# de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)
de_row = dg_row.to(tl.float32) * se_row * (1.0 + e_row * (1.0 - se_row))
de_row = de_row.to(DW_row.dtype)
# Store derivatives in buffers
tl.store(DW + offsets, h_row, mask = mask) # h = f * g
tl.store(e + offsets, df_row, mask = mask) # df = DW * f
tl.store(g + offsets, de_row, mask = mask) # de
pass
def swiglu_DWf_DW_dfg_kernel(DW, e, g):
batch_seq_len, hd = e.shape
n_elements = e.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_DWf_DW_dfg_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
return DW, e, g
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import triton
MAX_FUSED_SIZE = 65536
next_power_of_2 = triton.next_power_of_2
# torch.cuda.amp.custom_fwd is deprecated >= 2.4
import torch
from packaging.version import Version
if Version(torch.__version__) < Version("2.4.0"):
torch_amp_custom_fwd = torch.cuda.amp.custom_fwd
torch_amp_custom_bwd = torch.cuda.amp.custom_bwd
else:
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda")
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda")
pass
# tl.math.tanh now is libdevice.tanh
from packaging.version import Version
import triton
if Version(triton.__version__) >= Version("3.0.0"):
from triton.language.extra import libdevice
triton_tanh = libdevice.tanh
else:
import triton.language as tl
triton_tanh = tl.math.tanh
pass
def calculate_settings(n):
BLOCK_SIZE = next_power_of_2(n)
if BLOCK_SIZE > MAX_FUSED_SIZE:
raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
num_warps = 4
if BLOCK_SIZE >= 32768: num_warps = 32
elif BLOCK_SIZE >= 8192: num_warps = 16
elif BLOCK_SIZE >= 2048: num_warps = 8
return BLOCK_SIZE, num_warps
pass
import bitsandbytes as bnb
get_ptr = bnb.functional.get_ptr
import ctypes
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4
cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4
cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16
cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemm_4bit_inference_naive_bf16
def QUANT_STATE(W):
return getattr(W, "quant_state", None)
pass
def get_lora_parameters(proj):
# For DPO or disabled adapters
base_layer = (proj.base_layer if hasattr(proj, "base_layer") else proj)
W = base_layer.weight
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
return W, QUANT_STATE(W), None, None, None
pass
active_adapter = proj.active_adapters[0] if \
hasattr(proj, "active_adapters") else proj.active_adapter
A = proj.lora_A [active_adapter].weight
B = proj.lora_B [active_adapter].weight
s = proj.scaling[active_adapter]
return W, QUANT_STATE(W), A, B, s
pass
def get_lora_parameters_bias(proj):
# For DPO or disabled adapters
base_layer = (proj.base_layer if hasattr(proj, "base_layer") else proj)
W = base_layer.weight
bias = base_layer.bias
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
return W, QUANT_STATE(W), None, None, None, bias
pass
active_adapter = proj.active_adapters[0] if \
hasattr(proj, "active_adapters") else proj.active_adapter
A = proj.lora_A [active_adapter].weight
B = proj.lora_B [active_adapter].weight
s = proj.scaling[active_adapter]
return W, QUANT_STATE(W), A, B, s, bias
pass
def fast_dequantize(W, quant_state = None, out = None):
if quant_state is None: return W
if type(quant_state) is not list:
# New quant_state as a class
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
absmax = quant_state.absmax
shape = quant_state.shape
dtype = quant_state.dtype
blocksize = quant_state.blocksize
offset = quant_state.offset
state2 = quant_state.state2
absmax2 = state2.absmax
code2 = state2.code
blocksize2 = state2.blocksize
else:
# Old quant_state as a list of lists
absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
offset, state2 = compressed_stats
absmax2, code2, blocksize2, _, _, _, _ = state2
pass
# Create weight matrix
if out is None:
out = torch.empty(shape, dtype = dtype, device = "cuda:0")
else:
assert(out.shape == shape)
assert(out.dtype == dtype)
# NF4 dequantization of statistics
n_elements_absmax = absmax.numel()
out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0")
# Do dequantization
ptr_out_absmax = get_ptr(out_absmax)
cdequantize_blockwise_fp32(
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax,
ctypes.c_int(blocksize2), ctypes.c_int(n_elements_absmax)
)
out_absmax += offset
fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \
cdequantize_blockwise_bf16_nf4
fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out),
ctypes.c_int(blocksize), ctypes.c_int(out.numel()))
# Careful returning transposed data
is_transposed = (True if W.shape[0] == 1 else False)
return out.t() if is_transposed else out
pass
def fast_gemv(X, W, quant_state, out = None):
if quant_state is None: return torch.matmul(X, W, out = out)
# For fast X @ W where seq_len == 1
# From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
_, q_len, hd = X.shape
# assert(q_len == 1)
if type(quant_state) is not list:
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
absmax = quant_state.absmax
shape = quant_state.shape
dtype = quant_state.dtype
blocksize = quant_state.blocksize
stats = quant_state.code
offset = quant_state.offset
state2 = quant_state.state2
absmax2 = state2.absmax
code2 = state2.code
blocksize2 = state2.blocksize
else:
absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = quant_state
offset, state2 = compressed_stats
absmax2, code2, blocksize2, _, _, _, _ = state2
pass
# assert(dtype == X.dtype)
bout = shape[0]
if out is None:
out = torch.empty((1, 1, bout,), dtype = dtype, device = "cuda:0")
# else:
# assert(out.shape == (1, 1, bout,))
# pass
n = 1
m = shape[0]
k = shape[1]
lda = shape[0]
ldc = shape[0]
ldb = (hd+1)//2
m = ctypes.c_int32(m)
n = ctypes.c_int32(n)
k = ctypes.c_int32(k)
lda = ctypes.c_int32(lda)
ldb = ctypes.c_int32(ldb)
ldc = ctypes.c_int32(ldc)
df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda:0")
cdequantize_blockwise_fp32(
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df),
ctypes.c_int(blocksize2), ctypes.c_int(df.numel()),
)
df += offset
absmax = df
fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else \
cgemm_4bit_inference_naive_bf16
blocksize = ctypes.c_int32(blocksize)
fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out),
lda, ldb, ldc, blocksize)
return out
pass
def fast_linear_forward(proj, X, temp_lora = None, out = None):
W, W_quant, lora_A, lora_B, lora_S, bias = get_lora_parameters_bias(proj)
bsz, q_len, in_dim = X.shape
if q_len != 1: return matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S)
if W_quant is None:
out = torch.matmul(X, W.t(), out = out)
elif bsz == 1 and q_len == 1:
out = fast_gemv(X, W, W_quant, out = out)
else:
W = fast_dequantize(W.t(), W_quant)
out = torch.matmul(X, W, out = out)
pass
# Add in LoRA weights
if lora_A is not None:
out_dim = out.shape[2]
dtype = X.dtype
if not hasattr(lora_A, "_fast_lora"):
lora_A._fast_lora = lora_A.to(dtype)
lora_B._fast_lora = lora_B.to(dtype)
pass
if bsz == 1:
out = out.view(out_dim)
temp_lora = torch.mv(lora_A._fast_lora, X.ravel(), out = temp_lora)
out.addmv_(lora_B._fast_lora, temp_lora, alpha = lora_S)
else:
out = out.view(bsz, out_dim)
temp_lora = torch.mm(X.view(bsz, in_dim), lora_A._fast_lora.t(), out = temp_lora)
out.addmm_(temp_lora, lora_B._fast_lora.t(), alpha = lora_S)
pass
out = out.view(bsz, 1, out_dim)
pass
if bias is not None: out += bias
return out
pass
def matmul_lora(X, W, W_quant, A, B, s, out = None):
dtype = X.dtype
W = fast_dequantize(W.t(), W_quant)
if X.dim() == 3:
batch, seq_len, d = X.shape
X = X.view(-1, X.shape[-1])
reshape = True
else:
reshape = False
pass
out = torch.matmul(X, W, out = out)
if W_quant is not None: del W
if A is not None:
# LoRA is enabled
A, B = A.t(), B.t()
out += (X @ A.to(dtype)) @ (s * B.to(dtype))
pass
return out.view(batch, seq_len, -1) if reshape else out
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .loader import FastLanguageModel
from .llama import FastLlamaModel
from .mistral import FastMistralModel
from .qwen2 import FastQwen2Model
from .dpo import PatchDPOTrainer
from ._utils import is_bfloat16_supported
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "2024.8"
__all__ = [
"prepare_model_for_kbit_training",
"xformers",
"xformers_attention",
"xformers_version",
"__version__",
"HAS_FLASH_ATTENTION",
"PRE_CHECK",
"platform_system",
"patch_tokenizer",
"get_statistics",
"Unsloth_Offloaded_Gradient_Checkpointer",
"offload_to_disk",
"offload_input_embeddings",
"offload_output_embeddings",
"is_bfloat16_supported",
"unsloth_offloaded_gradient_checkpoint",
"torch_compile_options",
"patch_linear_scaling",
"patch_llama_rope_scaling",
"check_nvidia",
"create_boolean_mask",
"torch_amp_custom_fwd",
"torch_amp_custom_bwd",
]
import torch
from typing import Union, Optional, List, Any, Callable, Tuple
from platform import system as platform_system
platform_system = platform_system()
import numpy as np
import warnings, subprocess, re, inspect, psutil, os, math
from packaging.version import Version
# =============================================
# Disable some warnings which can get annoying
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "torch")
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "huggingface_hub")
warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "huggingface_hub")
warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "subprocess")
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "transformers")
warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "accelerate")
warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "multiprocessing")
warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "multiprocess")
# Stop "Special tokens have been added in the vocabulary, ..."
import logging
logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.CRITICAL+1)
# =============================================
# =============================================
# Edits all Config files to enable RoPE Scaling for all models
# Transformers had to update for Mistral Nemo 12b since Attention is (5120, 4096) now.
def patch_mistral_nemo_config(config):
if "head_dim (" not in config:
add_head_dim = "If it is not specified, will default to `8`.\n"\
" head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):\n"\
" The attention head dimension."
config = config.replace("If it is not specified, will default to `8`.", add_head_dim)
add_head_dim = "num_key_value_heads=8,\n head_dim=None,"
config = config.replace("num_key_value_heads=8,", add_head_dim)
add_head_dim = "self.sliding_window = sliding_window\n self.head_dim = head_dim or hidden_size // num_attention_heads\n"
config = config.replace("self.sliding_window = sliding_window", add_head_dim)
pass
return config
pass
from transformers import __version__ as transformers_version
from transformers import PretrainedConfig
model_architectures = ["llama", "mistral", "gemma", "gemma2", "qwen2",]
for model_name in model_architectures:
config_filepath = f"transformers.models.{model_name}.configuration_{model_name}"
model_filepath = f"transformers.models.{model_name}.modeling_{model_name}"
config_filename = f"{model_name.title()}Config"
exec(f"from {config_filepath} import {config_filename}", globals())
try:
config = inspect.getsource(eval(config_filename))
except:
continue
if "rope_scaling" in config: continue
config = re.sub(
r"(\*\*kwargs)[\s]{0,}\,[\s]{0,}\)[\s]{0,}\:",
r"rope_scaling=None,"\
r"\n **kwargs):\n"\
r"\n self.rope_scaling = rope_scaling\n",
config,
)
# Just for Mistral Nemo
if model_name == "mistral":
if Version(transformers_version) <= Version("4.42.4"):
config = patch_mistral_nemo_config(config)
pass
exec(config, globals())
exec(f"import {config_filepath}", globals())
exec(f"{config_filepath}.{config_filename} = {config_filename}", globals())
pass
# =============================================
# =============================================
# torch.cuda.amp.custom_fwd is deprecated >= 2.4
import torch
if Version(torch.__version__) < Version("2.4.0"):
torch_amp_custom_fwd = torch.cuda.amp.custom_fwd
torch_amp_custom_bwd = torch.cuda.amp.custom_bwd
else:
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda")
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda")
pass
# =============================================
# =============================================
# Get Flash Attention v2 if Ampere (RTX 30xx, A100)
import bitsandbytes as bnb
from transformers.models.llama.modeling_llama import logger
from transformers import AutoTokenizer
major_version, minor_version = torch.cuda.get_device_capability()
SUPPORTS_BFLOAT16 = False
if major_version >= 8:
SUPPORTS_BFLOAT16 = True
try:
from flash_attn import flash_attn_func
# Check for CUDA linking errors "undefined symbol: _ZNK3c106SymIntltEl"
try:
from flash_attn.flash_attn_interface import flash_attn_cuda
HAS_FLASH_ATTENTION = True
except:
logger.warning_once(
"Unsloth: Your Flash Attention 2 installation seems to be broken?\n"\
"A possible explanation is you have a new CUDA version which isn't\n"\
"yet compatible with FA2? Please file a ticket to Unsloth or FA2.\n"\
"We shall now use Xformers instead, which gets a 0.01% performance hit.\n"\
"We found this negligible impact by benchmarking on 1x A100."
)
HAS_FLASH_ATTENTION = False
except:
HAS_FLASH_ATTENTION = False
else:
# Tri Dao's benchmark shows xformers is faster for now.
HAS_FLASH_ATTENTION = False
pass
import xformers.ops.fmha as xformers
xformers_attention = xformers.memory_efficient_attention
from xformers import __version__ as xformers_version
# Temporarily disable 0.0.27 and higher - inference issues
if Version(xformers_version) >= Version("0.0.27"):
raise ImportError(
"Unsloth: If you are in Colab, we updated the top cell install instructions - please change it to below "\
"then press Disconnect Runtime and then Restart it.\n"\
"\n"\
"%%capture\n"
"# Installs Unsloth, Xformers (Flash Attention) and all other packages!\n"
'!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"\n'
'!pip install --no-deps "xformers<0.0.27" "trl<0.9.0" peft accelerate bitsandbytes\n'\
'\n'\
f"Otherwise in local machines, your xformers version of {xformers_version} is too new.\n"\
'Please downgrade xformers via `pip install --force-reinstall "xformers<0.0.27"'
)
pass
# Check TRL version
from trl import __version__ as trl_version
if Version(xformers_version) >= Version("0.9.0"):
raise ImportError(
"Unsloth: If you are in Colab, we updated the top cell install instructions - please change it to below "\
"then press Disconnect Runtime and then Restart it.\n"\
"\n"\
"%%capture\n"
"# Installs Unsloth, Xformers (Flash Attention) and all other packages!\n"
'!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"\n'
'!pip install --no-deps "xformers<0.0.27" "trl<0.9.0" peft accelerate bitsandbytes\n'\
'\n'\
f"Otherwise in local machines, your TRL version of {trl_version} is too new.\n"\
'Please downgrade TRL via `pip install --force-reinstall "trl<0.9.0"'
)
pass
# =============================================
# =============================================
# Torch compile settings
# Just remove max_autotune_gemm warning
import functools
@functools.lru_cache(None)
def is_big_gpu(index):
sms = torch.cuda.get_device_properties(index).multi_processor_count
if sms < 80: # V100
# log.warning("not enough SMs to use max_autotune_gemm mode")
return False
return True
import torch._inductor.utils
torch._inductor.utils.is_big_gpu = is_big_gpu
# Torch compile arguments
torch_compile_arguments = [
"config.dce = True",
"config.memory_planning = True",
"config.memory_pool = 'combined'",
"config.coordinate_descent_tuning = True",
"config.max_autotune_gemm = False", # GEMM is unnecessary
"config.autotune_multi_device = False",
"config.max_autotune_gemm_backends = 'ATEN'", # Not much faster
"config.aggressive_fusion = False", # Careful changes results!
"config.cuda.enable_cuda_lto = True",
"config.cuda.use_fast_math = True",
"config.cuda.compile_opt_level = '-O2'",
]
# Torch dynamo arguments
torch_dynamo_arguments = [
"config.accumulated_cache_size_limit = 512", # Bump up a bit from 256
"config.suppress_errors = True", # Supress errors for now
"config.do_not_emit_runtime_asserts = True",
]
import torch._inductor.config as config
for _try_compile_argument in torch_compile_arguments:
try: exec(_try_compile_argument)
except: pass
pass
import torch._dynamo.config as config
for _try_dynamo_argument in torch_dynamo_arguments:
try: exec(_try_dynamo_argument)
except: pass
pass
torch_compile_options = {
"epilogue_fusion" : True,
"max_autotune" : True,
"shape_padding" : True,
"trace.enabled" : False, # Output Triton kernel outputs!
"triton.cudagraphs" : False,
}
# =============================================
def prepare_model_for_kbit_training(
model : Any,
use_gradient_checkpointing : Optional = True,
use_reentrant : Optional[bool] = True,
) -> Any:
"""
Calculates where to place the gradient checkpoints given n_layers.
We also freeze all other layers's gradients
Args:
model: Any LlamaModel with layers.
use_gradient_checkpointing (`bool`, *optional*):
Default enabled. Provides memory savings by not saving all activations,
but only some.
use_reentrant (`bool`, *optional*):
https://github.com/pytorch/pytorch/blob/main/torch/utils/checkpoint.py#L354
Optimal gradient checkpointing algorithm which will be the default in
future Pytorch versions.
"""
# Freeze all parameters except LoRA
import re
with torch.no_grad():
for name, param in model.named_parameters():
if ".lora_A." in name or ".lora_B." in name or ".lora_magnitude_vector" in name:
param.requires_grad_(True)
# Also must be in float32!
if param.dtype != torch.float32:
name = name.replace("base_model", "model", 1)
layer_number = re.search(r"\.[\d]{1,}\.", name).group(0)
name = name.replace(layer_number, f"[{layer_number[1:-1]}].")
name = name.replace(".weight", "", 1)
exec(f"{name}.to(torch.float32)")
pass
else:
param.requires_grad_(False)
pass
pass
# Gradient checkpointing!
if use_gradient_checkpointing == "unsloth":
# Saves VRAM!
original_model = model
while hasattr(original_model, "model"):
original_model._offloaded_gradient_checkpointing = True
original_model = original_model.model
pass
original_model._offloaded_gradient_checkpointing = True
model.gradient_checkpointing_enable()
elif use_gradient_checkpointing == True:
model.gradient_checkpointing_enable()
pass
# If use_reentrant = True which is the Pytorch default, we just make the input requires_grad.
if use_reentrant:
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
return model
pass
def patch_tokenizer(model, tokenizer):
"""
Phi3's pad_token isn't set. We set it to <|placeholder...
Llama-3 is <|reserved...
Llama-2 is <unk>
Check if pad_token is not the same as eos_token otherwise the loss will ignore it!!
Fixes https://github.com/unslothai/unsloth/issues/5
"""
possible_reserved_tokens = (
"<|reserved", # Llama-3
"<|placeholder", # Phi-3
"[control", # Mistral type models
"<pad>", # Mistral Nemo
"<|finetune_right_pad_id|>", # Llama-3.1
)
if model is not None:
model.config.update({"unsloth_version" : __version__})
bad_pad_token = False
if hasattr(tokenizer, "pad_token") and tokenizer.pad_token is not None:
# Check if pad_token is not the same as eos_token otherwise the loss will ignore it!!
bad_pad_token = tokenizer.eos_token == tokenizer.pad_token
elif hasattr(tokenizer, "pad_token") and tokenizer.pad_token is None:
bad_pad_token = True
else:
bad_pad_token = False
pass
if bad_pad_token:
# Find a better pad token
added_tokens = [str(x) for x in tokenizer.added_tokens_decoder.values()]
possible_pad_token = None
n_possible_pad_tokens = 0
for added_token in added_tokens[::-1]:
if added_token.startswith(possible_reserved_tokens):
if possible_pad_token is None: possible_pad_token = added_token
n_possible_pad_tokens += 1
# We must see at least 3 of the reserved tokens
if n_possible_pad_tokens >= 3: break
pass
pass
if n_possible_pad_tokens < 3: possible_pad_token = None
if possible_pad_token is None:
# Try unk_token
possible_pad_token = tokenizer.unk_token
pass
if possible_pad_token is None:
# Failure to find a good replacement!! We shall manually add one!
new_pad_token = "<|PAD_TOKEN|>"
while new_pad_token in tokenizer.get_vocab():
new_pad_token += "#"
pass
possible_pad_token = new_pad_token
pass
name = model.config._name_or_path if model is not None else "Model"
logger.warning_once(
f"{name} does not have a padding token! Will use pad_token = {possible_pad_token}."
)
# Edit pad_token
tokenizer.add_special_tokens({"pad_token" : possible_pad_token})
tokenizer.pad_token = possible_pad_token
if model is not None:
config = model.config.update({"pad_token_id" : tokenizer.pad_token_id})
else:
if model is not None:
if model.config.pad_token_id is None:
config = model.config.update({"pad_token_id" : tokenizer.pad_token_id})
return model, tokenizer
pass
# =============================================
# Weirdly LoraLayer.update_layer downcasts PEFT layers to float16??
# For mixed precision, we need it to be in float32 not float16.
from peft import __version__ as peft_version
if Version(peft_version) < Version("0.12.0"):
from peft.tuners.lora.layer import LoraLayer
import inspect, re
try:
source = inspect.getsource(LoraLayer.update_layer)
text = "if weight is not None:\n"
start = source.find(text) + len(text)
end = source.find("self.to(weight.device)", start)
spaces = re.findall(r"^([ ]{1,})break", source, flags = re.MULTILINE)[0]
source = source.replace(source[start : end], spaces)
spaces = len(re.match(r"[\s]{1,}", source).group(0))
lines = source.split("\n")
source = "\n".join(x[spaces:] for x in lines)
source = re.sub("([^\.])nn\.", r"\1torch.nn.", source)
source = source.replace("def update_layer", "def LoraLayer_update_layer")
exec(source, globals())
# Fix up incorrect downcasting of LoRA weights
from peft.tuners.lora.layer import LoraLayer
LoraLayer.update_layer = LoraLayer_update_layer
from peft.tuners.lora import LoraLayer
LoraLayer.update_layer = LoraLayer_update_layer
except:
logger.warning_once(
"Unsloth unsuccessfully patched LoraLayer.update_layer. Please file a bug report.\n"\
"Luckily, your training run will still work in the meantime!"
)
pass
pass
# =============================================
import psutil
def _get_statistics(statistics = None, force_download = True):
# We log some basic stats about which environment is being used.
# We simply download a README.md file from HF - all data is made public.
# This is simply so we can check if some envs are broken or not.
# You can disable this by commenting the below out
try:
n_cpus = psutil.cpu_count(logical = False)
keynames = "\n" + "\n".join(os.environ.keys())
if statistics is not None: pass
elif "\nCOLAB_" in keynames and n_cpus == 1: statistics = "colab"
elif "\nCOLAB_" in keynames: statistics = "colabpro"
elif "\nKAGGLE_" in keynames: statistics = "kaggle"
elif "\nRUNPOD_" in keynames: statistics = "runpod"
elif "\nAWS_" in keynames: statistics = "aws"
elif "\nAZURE_" in keynames: statistics = "azure"
elif "\nK_" in keynames or "\nFUNCTION_" in keynames: statistics = "gcp"
elif "\nINVOCATION_ID" in keynames: statistics = "lambda"
else: statistics = "other"
if statistics is not None:
from transformers import AutoModelForCausalLM
stats_model = AutoModelForCausalLM.from_pretrained(
f"unslothai/{statistics}",
force_download = force_download,
)
del stats_model
pass
except:
pass
pass
def get_statistics():
# We log some basic stats about which environment is being used.
# We simply download a README.md file from HF - all data is made public.
# This is simply so we can check if some envs are broken or not.
# You can disable this by commenting the below out
from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled
disabled = False
if not are_progress_bars_disabled():
disable_progress_bars()
disabled = True
pass
_get_statistics(None)
_get_statistics("repeat", force_download = False)
try:
vram = torch.cuda.get_device_properties(0).total_memory / 1024 / 1024 / 1024
if vram <= 8 : vram = 8
elif vram <= 16: vram = 16
elif vram <= 20: vram = 20
elif vram <= 24: vram = 24
elif vram <= 40: vram = 40
elif vram <= 48: vram = 48
elif vram <= 80: vram = 80
else: vram = 96
_get_statistics(f"vram-{vram}")
except:
pass
pass
try:
devices = torch.cuda.device_count()
_get_statistics(f"{devices if devices <= 8 else 9}")
except:
pass
if disabled: enable_progress_bars()
pass
def _calculate_n_gradient_checkpoints(
n_layers : int,
method : Optional[Union[str, int]] = "sqrt",
) -> List[int]:
assert(type(n_layers) is int and n_layers > 0)
if method is None: method = "sqrt"
if method == "sqrt":
n_checkpoints = int(n_layers**0.5)
elif type(method) is int and method > 0:
n_checkpoints = int(np.ceil(n_layers / method))
else:
raise ValueError("method must be 'sqrt' or an int >0 and <= n_layers.")
size = n_layers // n_checkpoints
sizes = np.full(n_checkpoints, size, dtype = int)
leftovers = n_layers % n_checkpoints
# We append leftovers from the right
for k in range(leftovers):
sizes[n_checkpoints-1-k] += 1
boundaries = np.hstack((0, np.cumsum(sizes)))
boundaries = boundaries.tolist()
return boundaries
pass
def calculate_n_gradient_checkpoints(
n_layers : int,
layers_per_checkpoint : Optional[Union[str, int]] = "sqrt",
) -> List[int]:
assert(type(n_layers) is int and n_layers > 0)
if layers_per_checkpoint is None or layers_per_checkpoint == 1:
return None
boundaries = _calculate_n_gradient_checkpoints(n_layers, layers_per_checkpoint)
assert(boundaries[0] == 0 and boundaries[-1] == n_layers)
assert(min(boundaries) == 0 and max(boundaries) == n_layers)
assert(np.diff(boundaries).min() >= 0)
return boundaries
pass
def prepare_n_gradient_checkpoints(
model : Any,
layers_per_checkpoint : Optional[Union[str, int]] = "sqrt",
use_reentrant : Optional[bool] = True,
) -> None:
"""
Calculates where to place the gradient checkpoints given n_layers.
Args:
model: Any LlamaModel with layers.
layers_per_checkpoint (`Union[str, int]`, *optional*):
Can either be `sqrt` or an integer for how many layers per checkpoint you want.
The more, the less memory usage, but can be slower. Default is `sqrt`.
Choose 1 for Pytorch gradient checkpointing. 2 to wrap 2 layers in 1 module etc.
use_reentrant (`bool`, *optional*):
https://github.com/pytorch/pytorch/blob/main/torch/utils/checkpoint.py#L354
Optimal gradient checkpointing algorithm `use_reentrant=False` which will
be the default in future Pytorch versions doesn't seem to work??
"""
_model = None
if hasattr(model, "layers"):
_model = model
elif hasattr(model, "model"):
if hasattr(model.model, "layers"):
_model = model.model
if _model is None:
raise TypeError("`model` or `model.model` does not have attribute `layers`. Are you sure this is a model?")
pass
if use_reentrant is False:
use_reentrant = True
pass
n_layers = len(_model.layers)
boundaries = calculate_n_gradient_checkpoints(n_layers, layers_per_checkpoint)
_model._gradient_checkpointing_boundaries = boundaries
_model._gradient_checkpointing_use_reentrant = use_reentrant
pass
class Unsloth_Offloaded_Gradient_Checkpointer(torch.autograd.Function):
"""
Saves VRAM by smartly offloading to RAM.
Tiny hit to performance, since we mask the movement via non blocking calls.
"""
@staticmethod
@torch_amp_custom_fwd
def forward(ctx, forward_function, hidden_states, *args):
saved_hidden_states = hidden_states.to("cpu", non_blocking = True)
with torch.no_grad():
output = forward_function(hidden_states, *args)
ctx.save_for_backward(saved_hidden_states)
ctx.forward_function = forward_function
ctx.args = args
return output
pass
@staticmethod
@torch_amp_custom_bwd
def backward(ctx, dY):
(hidden_states,) = ctx.saved_tensors
hidden_states = hidden_states.to("cuda:0", non_blocking = True).detach()
hidden_states.requires_grad = True
with torch.enable_grad():
(output,) = ctx.forward_function(hidden_states, *ctx.args)
torch.autograd.backward(output, dY)
return (None, hidden_states.grad,) + (None,)*len(ctx.args)
pass
pass
@torch._disable_dynamo
def unsloth_offloaded_gradient_checkpoint(function, *args, use_reentrant = None, **kwargs):
return Unsloth_Offloaded_Gradient_Checkpointer.apply(function, *args)
pass
# =============================================
# Fixes Bitsandbytes to remove missing warnings
from transformers.utils.quantization_config import BitsAndBytesConfig, QuantizationMethod
from inspect import getsource
from accelerate.utils.dataclasses import DistributedType
import re
BitsAndBytesConfig__init__ = getsource(BitsAndBytesConfig.__init__)
BitsAndBytesConfig__init__ = re.sub(
r"if[\s]{1,}kwargs\:[\s]{1,}.+?\n",
"",
BitsAndBytesConfig__init__,
flags = re.MULTILINE,
)
BitsAndBytesConfig__init__ = BitsAndBytesConfig__init__.split("\n")
length_spaces = len(re.match(r"[\s]{1,}", BitsAndBytesConfig__init__[0]).group(0))
BitsAndBytesConfig__init__ = "\n".join(x[length_spaces:] for x in BitsAndBytesConfig__init__)
BitsAndBytesConfig__init__ = BitsAndBytesConfig__init__.replace(
"__init__",
"_BitsAndBytesConfig__init__",
)
def _prepare_backend(
self, cpu: bool = False, sagemaker_dp = False, backend: str = None,
) -> tuple[str, DistributedType]:
return None, DistributedType.NO
pass
import accelerate.state
accelerate.state.PartialState._prepare_backend = _prepare_backend
import accelerate.accelerator
prepare = inspect.getsource(accelerate.accelerator.Accelerator.prepare)
prepare = prepare.split("\n")
spaces = prepare[0].find("def")
prepare = "\n".join(x[spaces:] for x in prepare)
x = "for obj in args:"
s = " "*spaces
prepare = prepare.replace(x, f'self.state.distributed_type = DistributedType.NO\n{s}{x}', 1)
exec(prepare, globals())
accelerate.accelerator.Accelerator.prepare = prepare
exec(BitsAndBytesConfig__init__, globals())
import transformers.utils.quantization_config
transformers.utils.quantization_config.BitsAndBytesConfig.__init__ = _BitsAndBytesConfig__init__
# =============================================
# Offloading to disk for modules (lm_head, embed_tokens)
import pickle
def offload_to_disk(W, model, name, temporary_location : str = "_unsloth_temporary_saved_buffers"):
file_location = os.path.join(temporary_location, model.config._name_or_path)
if not os.path.exists(file_location):
os.makedirs(file_location)
pass
filename = os.path.join(file_location, f"{name}.pt")
W = W.weight if hasattr(W, "weight") else W
torch.save(W, filename, pickle_module = pickle, pickle_protocol = pickle.HIGHEST_PROTOCOL,)
offloaded_W = torch.load(filename, map_location = "cpu", mmap = True)
offloaded_W._offloaded_file_location = filename
return offloaded_W
pass
def offload_input_embeddings(model, temporary_location : str = "_unsloth_temporary_saved_buffers"):
offloaded_W = offload_to_disk(model.get_input_embeddings(), model, "input_embeddings", temporary_location)
new_input_embeddings = torch.nn.Embedding.from_pretrained(offloaded_W)
new_input_embeddings._offloaded_file_location = offloaded_W._offloaded_file_location
model.set_input_embeddings(new_input_embeddings)
return
pass
def offload_output_embeddings(model, temporary_location : str = "_unsloth_temporary_saved_buffers"):
offloaded_W = offload_to_disk(model.get_output_embeddings(), model, "output_embeddings", temporary_location)
new_output_embeddings = torch.nn.Linear(1, 1, bias = None)
del new_output_embeddings.weight
new_output_embeddings.weight = offloaded_W
new_output_embeddings.in_features = offloaded_W.shape[1]
new_output_embeddings.out_features = offloaded_W.shape[0]
new_output_embeddings._offloaded_file_location = offloaded_W._offloaded_file_location
model.set_output_embeddings(new_output_embeddings)
return
pass
# Fixes a weird Torch 2.3 bug which says T4s have bfloat16
def is_bfloat16_supported():
return SUPPORTS_BFLOAT16
pass
# Patches models to add RoPE Scaling
def patch_linear_scaling(
model_name = "gemma2",
rope_module = None,
scaled_rope_module = None,
attention_module = None,
):
assert(rope_module is not None and scaled_rope_module is not None)
assert(attention_module is not None)
rope_name = rope_module.__name__
scaled_rope_name = scaled_rope_module.__name__
model_filepath = f"transformers.models.{model_name}.modeling_{model_name}"
exec_code = \
f"import torch.nn as nn\n"\
f"from typing import Union, Optional, List, Any, Callable, Tuple\n"\
f"from {model_filepath} import logger, "\
f"{model_name.title()}Attention, {model_name.title()}Config"
try:
function = inspect.getsource(attention_module.__init__)
except:
# Most likely already patched!
return None, None
where = function.find("def")
function = function.split("\n")
function = "\n".join(x[where:] for x in function)
init_name = f"{model_name.title()}Attention__init__"
function = function.replace("def __init__", f"def {init_name}")
function = function.replace(
"super().__init__()",
f"super({model_name.title()}Attention, self).__init__()",
)
fix_rope_function = """
if getattr(self.config, "rope_scaling", None) is None:
self.rotary_emb = {rope_function}(
dim = self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = {scaled_rope_function}(
dim = self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
else:
raise ValueError(f"Unknown RoPE scaling type {{scaling_type}}")
pass
"""
fix_rope_function = fix_rope_function.format(
rope_function = rope_module.__name__,
scaled_rope_function = scaled_rope_module.__name__,
)
rotary_emb = re.findall(
"self.rotary_emb = .+?\)", function,
flags = re.DOTALL | re.MULTILINE,
)
if len(rotary_emb) == 0: return None, function
rotary_emb = rotary_emb[0]
function = function.replace(rotary_emb, fix_rope_function, 1)
function = exec_code + "\n\n" + function
return init_name, function
pass
# Patches for Llama-3 LlamaExtendedRotaryEmbedding
def patch_llama_rope_scaling(
model_name = "llama",
rope_module = None,
scaled_rope_module = None,
extended_rope_module = None,
attention_module = None,
):
assert(\
rope_module is not None and \
scaled_rope_module is not None and \
extended_rope_module is not None
)
assert(attention_module is not None)
rope_name = rope_module.__name__
scaled_rope_name = scaled_rope_module.__name__
model_filepath = f"transformers.models.{model_name}.modeling_{model_name}"
exec_code = \
f"import torch.nn as nn\n"\
f"from typing import Union, Optional, List, Any, Callable, Tuple\n"\
f"from {model_filepath} import logger, "\
f"{model_name.title()}Attention, {model_name.title()}Config"
try:
function = inspect.getsource(attention_module.__init__)
except:
# Most likely already patched!
return None, None
where = function.find("def")
function = function.split("\n")
function = "\n".join(x[where:] for x in function)
init_name = f"{model_name.title()}Attention__init__"
function = function.replace("def __init__", f"def {init_name}")
function = function.replace(
"super().__init__()",
f"super({model_name.title()}Attention, self).__init__()",
)
fix_rope_function = """
if getattr(self.config, "rope_scaling", None) is None:
self.rotary_emb = {rope_function}(
dim = self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
scaling_type1 = self.config.rope_scaling.get("type", None)
scaling_type2 = self.config.rope_scaling.get("rope_type", None)
scaling_type = scaling_type1 if scaling_type1 is not None else scaling_type2
scaling_factor = self.config.rope_scaling.get("factor")
if scaling_type == "linear":
self.rotary_emb = {scaled_rope_function}(
dim = self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
elif scaling_type == "llama3":
self.rotary_emb = {extended_rope_function}(
dim = self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
raise ValueError(f"Unknown RoPE scaling type {{scaling_type}}")
pass
"""
fix_rope_function = fix_rope_function.format(
rope_function = rope_module.__name__,
scaled_rope_function = scaled_rope_module.__name__,
extended_rope_function = extended_rope_module.__name__,
)
rotary_emb = re.findall(
"self.rotary_emb = .+?\)", function,
flags = re.DOTALL | re.MULTILINE,
)
if len(rotary_emb) == 0: return None, function
rotary_emb = rotary_emb[0]
function = function.replace(rotary_emb, fix_rope_function, 1)
function = exec_code + "\n\n" + function
return init_name, function
pass
def check_nvidia():
# Unsloth doesn't work yet on AMD devices - we're working on it!
output = np.array([0,])
try:
output = subprocess.check_output("nvidia-smi --query-gpu=memory.used --format=csv", shell = True)
output = re.findall(rb'([\d]{1,})[\s]{1,}M', output)
output = np.array([int(x.decode('utf-8'))/1024 for x in output])
except:
if not torch.cuda.is_available():
raise RuntimeError("Unsloth: We do not support AMD / Intel machines yet - it is a work in progress!")
return output
pass
PRE_CHECK = check_nvidia()
def create_boolean_mask(n = 4096, sliding_window = 2048):
# Creates a boolean mask for attention
mask = torch.ones(n, n, dtype = torch.bool)
if sliding_window == 0:
return torch.triu(mask, diagonal = 1, out = mask)
pass
torch.triu(mask, diagonal = 0, out = mask)
torch.triu(mask.T, diagonal = -sliding_window, out = mask.T)
mask = mask.T
torch.logical_not(mask, out = mask)
return mask
pass
def test_mask_creation():
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
for n in range(2, 23):
for s in range(1, 23):
correct_mask = AttentionMaskConverter(
is_causal = True,
sliding_window = s,
).to_causal_4d(1, n, n, dtype = torch.float16,).squeeze(0).squeeze(0)
correct_mask = (correct_mask == correct_mask.min())
our_mask = create_boolean_mask(n = n, sliding_window = s)
assert(torch.all(correct_mask == our_mask))
pass
correct_mask = AttentionMaskConverter(
is_causal = True,
sliding_window = None,
).to_causal_4d(1, n, n, dtype = torch.float16,).squeeze(0).squeeze(0)
correct_mask = (correct_mask == correct_mask.min())
our_mask = create_boolean_mask(n = n, sliding_window = 0)
assert(torch.all(correct_mask == our_mask))
pass
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
try:
from transformers.utils.notebook import (
IntervalStrategy,
NotebookTrainingTracker,
NotebookProgressCallback,
)
HAS_NOTEBOOK = True
except:
HAS_NOTEBOOK = False
pass
DPOTrainer_metrics = [
"rewards/chosen",
"rewards/rejected",
"rewards/accuracies",
"rewards/margins",
"logps/rejected",
"logps/chosen",
"logits/rejected",
"logits/chosen",
]
set_DPOTrainer_metrics = frozenset(DPOTrainer_metrics)
def NotebookProgressCallback_on_train_begin(self, args, state, control, **kwargs):
self.first_column = "Epoch" if args.evaluation_strategy == IntervalStrategy.EPOCH else "Step"
self.training_loss = 0
self.last_log = 0
column_names = [self.first_column] + ["Training Loss"]
if args.evaluation_strategy != IntervalStrategy.NO:
column_names.append("Validation Loss")
column_names += [x.replace("/", " / ") for x in DPOTrainer_metrics]
self.training_tracker = NotebookTrainingTracker(state.max_steps, column_names)
pass
def NotebookProgressCallback_on_log(self, args, state, control, logs=None, **kwargs):
# Only for when there is no evaluation
if args.evaluation_strategy == IntervalStrategy.NO and "loss" in logs:
values = {"Training Loss": logs["loss"]}
for metric in DPOTrainer_metrics:
values[metric.replace("/", " / ")] = logs[metric]
pass
# First column is necessarily Step since we're not in epoch eval strategy
values["Step"] = state.global_step
self.training_tracker.write_line(values)
pass
pass
def NotebookTrainingTracker_write_line(self, values):
"""
Write the values in the inner table.
Args:
values (`Dict[str, float]`): The values to display.
"""
if self.inner_table is None:
self.inner_table = [list(values.keys()), list(values.values())]
else:
columns = self.inner_table[0]
new_values = {}
for key, value in values.items():
lowered = key.lower()
if lowered in set_DPOTrainer_metrics:
new_values[lowered.replace("/", " / ")] = value
else:
new_values[key] = value
pass
values = new_values
self.inner_table[0] = columns
if len(self.inner_table) > 1:
last_values = self.inner_table[-1]
first_column = self.inner_table[0][0]
if last_values[0] != values[first_column]:
# write new line
self.inner_table.append([values[c] if c in values else "No Log" for c in columns])
else:
# update last line
new_values = values
for c in columns:
if c not in new_values.keys():
new_values[c] = last_values[columns.index(c)]
self.inner_table[-1] = [new_values[c] for c in columns]
else:
# Edit for evaluation purposes
self.inner_table.append([values[c] if c in values else 0 for c in columns])
pass
pass
pass
def PatchDPOTrainer():
if HAS_NOTEBOOK:
from transformers.trainer import is_in_notebook
if is_in_notebook():
# Patch DPO notebook printing
NotebookTrainingTracker.write_line = NotebookTrainingTracker_write_line
from transformers.trainer import DEFAULT_PROGRESS_CALLBACK
DEFAULT_PROGRESS_CALLBACK.on_train_begin = NotebookProgressCallback_on_train_begin
DEFAULT_PROGRESS_CALLBACK.on_log = NotebookProgressCallback_on_log
pass
pass
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .llama import *
from ._utils import __version__
try:
from transformers.models.gemma.modeling_gemma import (
GemmaAttention,
GemmaDecoderLayer,
GemmaModel,
GemmaForCausalLM,
GemmaRotaryEmbedding,
apply_rotary_pos_emb,
repeat_kv,
)
except:
from packaging.version import Version
transformers_version = Version(transformers_version)
if not transformers_version >= Version("4.38"):
raise ImportError(
f"Unsloth: Your transformers version of {transformers_version} does not support Gemma.\n"\
f"The minimum required version is 4.38.\n"\
f'Try `pip install --upgrade "transformers>=4.38"`\n'\
f"to obtain the latest transformers build, then restart this session."\
)
pass
pass
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask_for_sdpa,
)
# For Pytorch 2.1.1
try:
from transformers.models.gemma.modeling_gemma import (
GemmaSdpaAttention,
GemmaFlashAttention2,
)
except:
GemmaSdpaAttention = GemmaAttention
GemmaFlashAttention2 = GemmaAttention
pass
torch_nn_functional_gelu = torch.nn.functional.gelu
def fast_geglu_inference(self, X):
# gate = self.gate_proj(X)
# up = self.up_proj(X)
bsz, _, hd = X.shape
# mlp_size = self.config.intermediate_size
# temp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0")
gate = fast_linear_forward(self.gate_proj, X)#, out = temp[0])
up = fast_linear_forward(self. up_proj, X)#, out = temp[1])
gate = torch_nn_functional_gelu(gate, approximate = "tanh")
gate *= up
# X = self.down_proj(gate)
down = fast_linear_forward(self.down_proj, gate, out = up[:,:,:hd])
return down
pass
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L590
def GemmaDecoderLayer_fast_forward(
self,
hidden_states: torch.Tensor,
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
*args, **kwargs,
):
if use_cache and hasattr(self, "_flag_for_generation"): #past_key_value is not None:
out_weight = torch.empty(self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda:0")
# Self Attention
residual = hidden_states
hidden_states = fast_rms_layernorm_inference_gemma(self.input_layernorm, hidden_states, out_weight)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)
hidden_states += residual
# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm_inference_gemma(self.post_attention_layernorm, hidden_states, out_weight)
hidden_states = fast_geglu_inference(self.mlp, hidden_states)
hidden_states += residual
else:
residual = hidden_states
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states, gemma = True)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states, gemma = True)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
pass
outputs = (hidden_states,)
if output_attentions: outputs += (self_attn_weights,)
if use_cache: outputs += (present_key_value,)
return outputs
pass
from math import sqrt as math_sqrt
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825
# @torch.inference_mode
def GemmaModel_fast_forward_inference(
self,
input_ids,
past_key_values,
position_ids,
attention_mask = None,
):
out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda:0")
input_ids = input_ids[:,:self.max_seq_length]
hidden_states = self.model.embed_tokens(input_ids)
hidden_states = hidden_states.to(self.config.torch_dtype)
# 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32
# 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32
hidden_states *= torch.tensor(math_sqrt(self.config.hidden_size), dtype = hidden_states.dtype)
bsz, q_len, hd = hidden_states.shape
seq_len = past_key_values[0][0].shape[-2]
if bsz != 1:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(bsz, q_len),
hidden_states,
seq_len,
)
pass
next_decoder_cache = []
for idx, decoder_layer in enumerate(self.model.layers):
residual = hidden_states
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.input_layernorm, hidden_states, out_weight)
hidden_states, present_key_value = LlamaAttention_fast_forward_inference(
decoder_layer.self_attn,
hidden_states = hidden_states,
past_key_value = past_key_values[idx],
position_ids = position_ids,
attention_mask = attention_mask,
do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
)
hidden_states += residual
residual = hidden_states
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_attention_layernorm, hidden_states, out_weight)
hidden_states = fast_geglu_inference(decoder_layer.mlp, hidden_states)
hidden_states += residual
next_decoder_cache.append(present_key_value)
pass
hidden_states = fast_rms_layernorm_inference_gemma(self.model.norm, hidden_states, out_weight)
return BaseModelOutputWithPast(
last_hidden_state = hidden_states,
past_key_values = next_decoder_cache,
hidden_states = [],
attentions = [],
)
pass
# Follows line by line https://github.com/google-deepmind/gemma/blob/main/gemma/positional_embeddings.py#L45
# Formulates cos and sin differently from Llama!
class GemmaFixedRotaryEmbedding(torch.nn.Module):
# Fixes https://github.com/huggingface/transformers/pull/28837
# https://github.com/microsoft/DeepSpeed/issues/4932
# The precision of RoPE buffers is not correct, so we cast to int64.
def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None,
config = None, # [TODO] Hack to pass in config - need to remove later
):
super().__init__()
if config is not None: return # [TODO] Hack to pass in config - need to remove later
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
# Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this
self.current_rope_size = min(4 * 8192, self.max_position_embeddings)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype())
pass
def _set_cos_sin_cache(self, seq_len, device, dtype):
# Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
# in FP32. They are applied (multiplied) in FP32 as well.
self.current_rope_size = seq_len
# The difference is we do division explicity instead of t * (1/x) ie we do t/x.
freq_exponents = (2.0 / self.dim) * (
torch.arange(self.dim // 2, dtype = torch.int64, device = "cpu").float()
)
timescale = self.base**freq_exponents
positions = torch.arange(self.current_rope_size, device = "cpu", dtype = torch.int64).float()
radians_new = positions[..., None] / timescale[None, None, :]
radians_new = radians_new.squeeze(0)
emb = torch.cat((radians_new, radians_new), dim = -1)
# We must do RoPE in float32!
cos = emb.cos().to(device = "cuda:0", non_blocking = True)#, dtype = dtype)
sin = emb.sin().to(device = "cuda:0", non_blocking = True)#, dtype = dtype)
self.register_buffer("cos_cached", cos, persistent = False)
self.register_buffer("sin_cached", sin, persistent = False)
pass
def forward(self, x, position_ids=None, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.current_rope_size:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
pass
def extend_rope_embedding(self, x, seq_len):
if seq_len <= self.current_rope_size: return
# Iteratively grow by increments of 8192
self.current_rope_size = int(round(seq_len / 8192)) * 8192
self._set_cos_sin_cache(self.current_rope_size, device = "cuda:0", dtype = x.dtype)
pass
pass
class GemmaFixedLinearScalingRotaryEmbedding(GemmaFixedRotaryEmbedding):
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
# Fixes https://github.com/huggingface/transformers/pull/28837
# https://github.com/microsoft/DeepSpeed/issues/4932
# The precision of RoPE buffers is not correct, so we cast to int64.
def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0,
config = None, # [TODO] Hack to pass in config - need to remove later
):
self.scaling_factor = scaling_factor
super().__init__(dim = dim, max_position_embeddings = max_position_embeddings, base = base, device = device, config = config)
pass
def _set_cos_sin_cache(self, seq_len, device, dtype):
# Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
# in FP32. They are applied (multiplied) in FP32 as well.
self.current_rope_size = seq_len
# The difference is we do division explicity instead of t * (1/x) ie we do t/x.
freq_exponents = (2.0 / self.dim) * (
torch.arange(self.dim // 2, dtype = torch.int64, device = "cpu").float()
)
timescale = self.base**freq_exponents
positions = torch.arange(self.current_rope_size, device = "cpu", dtype = torch.int64).float()
positions = positions / self.scaling_factor
radians_new = positions[..., None] / timescale[None, None, :]
radians_new = radians_new.squeeze(0)
emb = torch.cat((radians_new, radians_new), dim = -1)
# We must do RoPE in float32!
cos = emb.cos().to(device = "cuda:0", non_blocking = True)#, dtype = dtype)
sin = emb.sin().to(device = "cuda:0", non_blocking = True)#, dtype = dtype)
self.register_buffer("cos_cached", cos, persistent = False)
self.register_buffer("sin_cached", sin, persistent = False)
pass
pass
class FastGemmaModel(FastLlamaModel):
@staticmethod
def pre_patch():
init_name, function = patch_linear_scaling(
model_name = "gemma",
rope_module = GemmaFixedRotaryEmbedding,
scaled_rope_module = GemmaFixedLinearScalingRotaryEmbedding,
attention_module = GemmaAttention,
)
if init_name is not None:
exec(function, globals())
GemmaAttention.__init__ = eval(init_name)
pass
GemmaAttention .forward = LlamaAttention_fast_forward
GemmaSdpaAttention .forward = LlamaAttention_fast_forward
GemmaFlashAttention2.forward = LlamaAttention_fast_forward
GemmaDecoderLayer .forward = GemmaDecoderLayer_fast_forward
GemmaModel .forward = LlamaModel_fast_forward
GemmaForCausalLM .forward = CausalLM_fast_forward(GemmaModel_fast_forward_inference)
PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward
fix_prepare_inputs_for_generation(GemmaForCausalLM)
# Solves https://github.com/unslothai/unsloth/issues/168
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
# Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
# https://github.com/huggingface/transformers/pull/27931
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
import transformers.models.gemma.modeling_gemma
transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding = GemmaFixedRotaryEmbedding
return
pass
@staticmethod
def post_patch(model):
# Patch model for Gemma
layers = model.model.layers
# Torch.compile fails on embedding matrix??
# Workaround randomnly fixes it for torch versions < 2.2
model.model.embed_tokens = torch.nn.Embedding.from_pretrained(model.model.embed_tokens.weight)
model.config.update({"unsloth_version" : __version__})
# We also do this for the lm_head
lm_head = torch.nn.Linear(1, 1, bias = None)
del lm_head.weight
lm_head.weight = model.lm_head.weight
lm_head.in_features = lm_head.weight.shape[1]
lm_head.out_features = lm_head.weight.shape[0]
model.lm_head = lm_head
# Gemma has tied weights! This means lm_head == embed_tokens
if model.model.embed_tokens.weight.data_ptr() != model.lm_head.weight.data_ptr():
lm_head = torch.nn.Linear(1, 1, bias = None)
del lm_head.weight
lm_head.weight = model.model.embed_tokens.weight
lm_head.in_features = lm_head.weight.shape[1]
lm_head.out_features = lm_head.weight.shape[0]
model.lm_head = lm_head
pass
# Also patch all dtypes - BnB seems to not allocate the correct type?
# BnB default dtype seems to be float16!
correct_dtype = lm_head.weight.dtype
for name, module in model.named_modules():
if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)):
weight = module.weight
quant_state = weight.quant_state
if type(quant_state) is list:
# BnB seems to have float16 as default!
module.weight.quant_state[2] = correct_dtype # Cast to correct dtype
else:
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
quant_state.dtype = correct_dtype
pass
pass
# Downcast RoPE embedding to correct data type
# RoPE must be done in float32 for Gemma
# if (name.endswith("rotary_emb") or hasattr(module, "cos_cached")) \
# and (module.cos_cached.dtype != correct_dtype):
# module.cos_cached = module.cos_cached.to(correct_dtype)
# module.sin_cached = module.sin_cached.to(correct_dtype)
# pass
# pass
pass
# Add 1 to weight
# return output * (1 + self.weight)
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py#L89
from transformers.models.gemma.modeling_gemma import GemmaRMSNorm
# Freeze all parameters except LoRA
# We do this first since += 1 seems to not be liked by requires_grad = True
for name, param in model.named_parameters():
if ".lora_A." in name or ".lora_B." in name:
param.requires_grad_(True)
else:
param.requires_grad_(False)
pass
# Patch RMS Layernorm
for name, module in model.named_modules():
if isinstance(module, GemmaRMSNorm):
# Must be in float32
# https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L36
# module = module.to(torch.float32)
# Leave + 1 to Triton kernel itself
# module.weight += 1.0 # return output * (1 + self.weight)
if not hasattr(module, "variance_epsilon"):
module.variance_epsilon = module.eps # Gemma doesn't use variance_epsilon
pass
# Clear deleted GPU items
import gc
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
return model
pass
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .llama import *
from ._utils import __version__
from .gemma import (
GemmaFixedRotaryEmbedding,
GemmaFixedLinearScalingRotaryEmbedding,
fast_geglu_inference,
)
try:
from transformers.models.gemma2.modeling_gemma2 import (
Gemma2Attention,
Gemma2DecoderLayer,
Gemma2Model,
Gemma2ForCausalLM,
Gemma2RotaryEmbedding,
apply_rotary_pos_emb,
repeat_kv,
)
except:
from packaging.version import Version
transformers_version = Version(transformers_version)
if not transformers_version >= Version("4.42"):
raise ImportError(
f"Unsloth: Your transformers version of {transformers_version} does not support Gemma2.\n"\
f"The minimum required version is 4.42.3.\n"\
f'Try `pip install --upgrade "transformers>=4.42.3"`\n'\
f"to obtain the latest transformers build, then restart this session."\
)
pass
pass
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask_for_sdpa,
)
# For Pytorch 2.1.1
try:
from transformers.models.gemma2.modeling_gemma2 import (
Gemma2SdpaAttention,
Gemma2FlashAttention2,
)
except:
Gemma2SdpaAttention = Gemma2Attention
Gemma2FlashAttention2 = Gemma2Attention
pass
# [TODO] We must randomnly use torch.compile?
# I checked the gradients and formulas and I'm sure it's correct.
# I'm stumped :(
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
def fast_rms_layernorm_gemma2_compiled(layernorm, X, gemma = True):
old_dtype = X.dtype
X = X.float()
X = X * torch.rsqrt(X.square().mean(-1, keepdim = True) + layernorm.eps) * \
(1.0 + layernorm.weight.float())
return X.to(old_dtype)
pass
# Logit softcapping
def Gemma2Attention_fast_forward(
self,
hidden_states: torch.Tensor,
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
*args, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# Clear inference
if hasattr(self, "paged_attention"):
del self.paged_attention_K
del self.paged_attention_V
del self.paged_attention
del self.temp_QA
del self.temp_KV
del self.RH_Q
del self.attention
pass
bsz, q_len, _ = hidden_states.size()
n_heads = self.num_heads
n_groups = self.num_key_value_groups
n_kv_heads = self.num_key_value_heads
head_dim = self.head_dim
assert(n_kv_heads * n_groups == n_heads)
Q, K, V = self.apply_qkv(self, hidden_states)
Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
kv_seq_len = K.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
if position_ids is None:
cos = self.rotary_emb.cos_cached
sin = self.rotary_emb.sin_cached
Q, K = fast_rope_embedding(Q, K, cos, sin)
else:
cos, sin = self.rotary_emb(V, seq_len = kv_seq_len)
Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
pass
if past_key_value is not None:
K = torch.cat([past_key_value[0], K], dim = 2)
V = torch.cat([past_key_value[1], V], dim = 2)
pass
past_key_value = (K, V) if use_cache else None
A = slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, kv_seq_len)
A = self.apply_o(self, A)
return A, None, past_key_value
pass
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L590
def Gemma2DecoderLayer_fast_forward(
self,
hidden_states: torch.Tensor,
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
*args, **kwargs,
):
if use_cache and hasattr(self, "_flag_for_generation"): #past_key_value is not None:
out_weight = torch.empty(self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda:0")
# Self Attention
residual = hidden_states
hidden_states = fast_rms_layernorm_inference_gemma(self.input_layernorm, hidden_states, out_weight)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)
hidden_states = fast_rms_layernorm_inference_gemma(self.post_attention_layernorm, hidden_states, out_weight)
hidden_states += residual
# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm_inference_gemma(self. pre_feedforward_layernorm, hidden_states, out_weight)
hidden_states = fast_geglu_inference(self.mlp, hidden_states)
hidden_states = fast_rms_layernorm_inference_gemma(self.post_feedforward_layernorm, hidden_states, out_weight)
hidden_states += residual
else:
residual = hidden_states
hidden_states = fast_rms_layernorm_gemma2_compiled(self.input_layernorm, hidden_states, gemma = True)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)
hidden_states = fast_rms_layernorm_gemma2_compiled(self.post_attention_layernorm, hidden_states, gemma = True)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm_gemma2_compiled(self. pre_feedforward_layernorm, hidden_states, gemma = True)
hidden_states = self.mlp(hidden_states)
hidden_states = fast_rms_layernorm_gemma2_compiled(self.post_feedforward_layernorm, hidden_states, gemma = True)
hidden_states = residual + hidden_states
pass
outputs = (hidden_states,)
if output_attentions: outputs += (self_attn_weights,)
if use_cache: outputs += (present_key_value,)
return outputs
pass
from math import sqrt as math_sqrt
KV_CACHE_INCREMENT = 256 # KV Cache update size
torch_nn_functional_softmax = torch.nn.functional.softmax
def Gemma2Attention_fast_forward_inference(
self,
hidden_states: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]],
position_ids,
do_prefill = False,
attention_mask = None,
use_sliding_window = False,
):
Xn = hidden_states
bsz, _, hd = hidden_states.size()
K1, V1 = past_key_value
dtype = Xn.dtype
n_heads = self.num_heads
n_groups = self.num_key_value_groups
n_kv_heads = self.num_key_value_heads
head_dim = self.head_dim
attention_size = n_heads*head_dim
# assert(n_kv_heads * n_groups == n_heads)
seq_len = K1.shape[-2]
kv_seq_len = seq_len + 1
# Prefill phase
# if not hasattr(self, "paged_attention"):
if do_prefill:
self.paged_attention = torch.empty((KV_CACHE_INCREMENT+seq_len+1, 2, bsz, n_kv_heads, head_dim), dtype = dtype, device = "cuda:0")
self.paged_attention_K = self.paged_attention[:,0]
self.paged_attention_V = self.paged_attention[:,1]
self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = "cuda:0")
self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0")
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
# Only for Gemma2
self.temp_O = torch.empty((1, bsz, self.hidden_size), dtype = dtype, device = "cuda:0")
self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0")
# See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
# Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below
# We default to using the config file itself
# s = self.config.hidden_size // self.config.num_attention_heads
self.scalar = 1.0 / math_sqrt(self.config.query_pre_attn_scalar)
# self.scalar = 1.0 / math_sqrt(self.config.hidden_size // self.config.num_attention_heads)
self.half_head_dim = head_dim // 2
self. t = self.config.attn_logit_softcapping
self.reciprocal_t = 1.0 / self.config.attn_logit_softcapping
elif kv_seq_len >= self.paged_attention.shape[0]:
self.paged_attention.resize_((self.paged_attention.shape[0]+KV_CACHE_INCREMENT, 2, bsz, n_kv_heads, head_dim))
self.paged_attention_K = self.paged_attention[:,0]
self.paged_attention_V = self.paged_attention[:,1]
self.attention.resize_((bsz, n_heads, 1, self.attention.shape[-1]+KV_CACHE_INCREMENT))
pass
Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
# cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len)
# Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids)
cos = self.rotary_emb.cos_cached[position_ids].unsqueeze(1)
sin = self.rotary_emb.sin_cached[position_ids].unsqueeze(1)
h = self.half_head_dim
RH_Q = self.RH_Q
RH_Q[:,:,:,:h] = Qn[:,:,:,h:]
RH_Q[:,:,:,h:] = Qn[:,:,:,:h]
torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h])
Qn *= cos
Qn.addcmul_(RH_Q, sin)
RH_K = RH_Q[:,:n_kv_heads,:,:] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
RH_K[:,:,:,:h] = Kn[:,:,:,h:]
RH_K[:,:,:,h:] = Kn[:,:,:,:h]
torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])
Kn *= cos
Kn.addcmul_(RH_K, sin)
# New KV cache
# Kn = torch.cat([K1, Kn], dim = 2)
# Vn = torch.cat([V1, Vn], dim = 2)
self.paged_attention_K[seq_len] = Kn.permute(2, 0, 1, 3)
self.paged_attention_V[seq_len] = Vn.permute(2, 0, 1, 3)
Kn = self.paged_attention_K[:kv_seq_len].permute(1, 2, 0, 3)
Vn = self.paged_attention_V[:kv_seq_len].permute(1, 2, 0, 3)
# Handle sliding windows
sliding_window = self.config.sliding_window
if use_sliding_window and kv_seq_len > sliding_window:
# From https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L193
slicing_tokens = 1 - sliding_window
Knn = Kn[:, :, slicing_tokens:, :]#.contiguous()
Vnn = Vn[:, :, slicing_tokens:, :]#.contiguous()
else:
Knn, Vnn = Kn, Vn
pass
# Grouped query attention
_, _, cached_len, _ = Knn.shape
if n_groups != 1:
Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)
Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim)
pass
# else:
# Knn, Vnn = Knn, Vnn
# pass
# Attention
# if bsz == 1:
Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
# It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows
A = torch.matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len])
# if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
A *= self.reciprocal_t; torch.tanh(A, out = A); A *= self.t; # Logit softcapping
A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype)
A = torch.matmul(A, Vnn, out = Qn)
# else:
# A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False)
# pass
A = A.transpose(1, 2)
A = A.reshape(bsz, 1, attention_size)
A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
return A, (Kn, Vn)
pass
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825
# @torch.inference_mode
def Gemma2Model_fast_forward_inference(
self,
input_ids,
past_key_values,
position_ids,
attention_mask = None,
):
out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda:0")
input_ids = input_ids[:,:self.max_seq_length]
hidden_states = self.model.embed_tokens(input_ids)
hidden_states = hidden_states.to(self.config.torch_dtype)
# 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32
# 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32
hidden_states *= torch.tensor(math_sqrt(self.config.hidden_size), dtype = hidden_states.dtype)
bsz, q_len, hd = hidden_states.shape
seq_len = past_key_values[0][0].shape[-2]
if bsz != 1:
SWA = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(bsz, q_len),
hidden_states,
seq_len,
sliding_window = self.config.sliding_window,
)
GA = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(bsz, q_len),
hidden_states,
seq_len,
)
else:
SWA = attention_mask
GA = attention_mask
pass
next_decoder_cache = []
for idx, decoder_layer in enumerate(self.model.layers):
use_sliding_window = idx % 2 == 0
residual = hidden_states
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.input_layernorm, hidden_states, out_weight)
hidden_states, present_key_value = Gemma2Attention_fast_forward_inference(
decoder_layer.self_attn,
hidden_states = hidden_states,
past_key_value = past_key_values[idx],
position_ids = position_ids,
attention_mask = SWA if use_sliding_window else GA,
do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
use_sliding_window = use_sliding_window,
)
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_attention_layernorm, hidden_states, out_weight)
hidden_states += residual
residual = hidden_states
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer. pre_feedforward_layernorm, hidden_states, out_weight)
hidden_states = fast_geglu_inference(decoder_layer.mlp, hidden_states)
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_feedforward_layernorm, hidden_states, out_weight)
hidden_states += residual
next_decoder_cache.append(present_key_value)
pass
hidden_states = fast_rms_layernorm_inference_gemma(self.model.norm, hidden_states, out_weight)
return BaseModelOutputWithPast(
last_hidden_state = hidden_states,
past_key_values = next_decoder_cache,
hidden_states = [],
attentions = [],
)
pass
class FastGemma2Model(FastLlamaModel):
@staticmethod
def pre_patch():
init_name, function = patch_linear_scaling(
model_name = "gemma2",
rope_module = GemmaFixedRotaryEmbedding,
scaled_rope_module = GemmaFixedLinearScalingRotaryEmbedding,
attention_module = Gemma2Attention,
)
if init_name is not None:
exec(function, globals())
Gemma2Attention.__init__ = eval(init_name)
pass
Gemma2Attention .forward = Gemma2Attention_fast_forward
Gemma2SdpaAttention .forward = Gemma2Attention_fast_forward
Gemma2FlashAttention2.forward = Gemma2Attention_fast_forward
Gemma2DecoderLayer .forward = Gemma2DecoderLayer_fast_forward
Gemma2Model .forward = LlamaModel_fast_forward
Gemma2ForCausalLM .forward = CausalLM_fast_forward(Gemma2Model_fast_forward_inference)
PeftModelForCausalLM .forward = PeftModelForCausalLM_fast_forward
fix_prepare_inputs_for_generation(Gemma2ForCausalLM)
# Solves https://github.com/unslothai/unsloth/issues/168
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
# Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
# https://github.com/huggingface/transformers/pull/27931
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
import transformers.models.gemma2.modeling_gemma2
transformers.models.gemma2.modeling_gemma2.Gemma2RotaryEmbedding = GemmaFixedRotaryEmbedding
return
pass
@staticmethod
def post_patch(model):
# Patch model for Gemma
layers = model.model.layers
# Torch.compile fails on embedding matrix??
# Workaround randomnly fixes it for torch versions < 2.2
model.model.embed_tokens = torch.nn.Embedding.from_pretrained(model.model.embed_tokens.weight)
model.config.update({"unsloth_version" : __version__})
# We also do this for the lm_head
lm_head = torch.nn.Linear(1, 1, bias = None)
del lm_head.weight
lm_head.weight = model.lm_head.weight
lm_head.in_features = lm_head.weight.shape[1]
lm_head.out_features = lm_head.weight.shape[0]
model.lm_head = lm_head
# Gemma has tied weights! This means lm_head == embed_tokens
if model.model.embed_tokens.weight.data_ptr() != model.lm_head.weight.data_ptr():
lm_head = torch.nn.Linear(1, 1, bias = None)
del lm_head.weight
lm_head.weight = model.model.embed_tokens.weight
lm_head.in_features = lm_head.weight.shape[1]
lm_head.out_features = lm_head.weight.shape[0]
model.lm_head = lm_head
pass
# Also patch all dtypes - BnB seems to not allocate the correct type?
# BnB default dtype seems to be float16!
correct_dtype = lm_head.weight.dtype
for name, module in model.named_modules():
if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)):
weight = module.weight
quant_state = weight.quant_state
if type(quant_state) is list:
# BnB seems to have float16 as default!
module.weight.quant_state[2] = correct_dtype # Cast to correct dtype
else:
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
quant_state.dtype = correct_dtype
pass
pass
# Downcast RoPE embedding to correct data type
# RoPE must be done in float32 for Gemma
# if (name.endswith("rotary_emb") or hasattr(module, "cos_cached")) \
# and (module.cos_cached.dtype != correct_dtype):
# module.cos_cached = module.cos_cached.to(correct_dtype)
# module.sin_cached = module.sin_cached.to(correct_dtype)
# pass
# pass
pass
# Add 1 to weight
# return output * (1 + self.weight)
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py#L89
from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm
# Freeze all parameters except LoRA
# We do this first since += 1 seems to not be liked by requires_grad = True
for name, param in model.named_parameters():
if ".lora_A." in name or ".lora_B." in name:
param.requires_grad_(True)
else:
param.requires_grad_(False)
pass
# Patch RMS Layernorm
for name, module in model.named_modules():
if isinstance(module, Gemma2RMSNorm):
# Must be in float32
# https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L36
# module = module.to(torch.float32)
# Leave + 1 to Triton kernel itself
# module.weight += 1.0 # return output * (1 + self.weight)
if not hasattr(module, "variance_epsilon"):
module.variance_epsilon = module.eps # Gemma doesn't use variance_epsilon
pass
# Clear deleted GPU items
import gc
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
return model
pass
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import gc
from typing import Optional, Tuple, List, Union
from ._utils import *
from ._utils import __version__
from torch.nn.functional import scaled_dot_product_attention
from transformers.models.llama.modeling_llama import (
logger,
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask_for_sdpa,
)
from ..kernels import *
from ..tokenizer_utils import *
if HAS_FLASH_ATTENTION:
from flash_attn import flash_attn_func
# Final patching code
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaModel,
LlamaForCausalLM,
)
# For Pytorch 2.1.1
try:
from transformers.models.llama.modeling_llama import (
LlamaSdpaAttention,
LlamaFlashAttention2,
)
except:
LlamaSdpaAttention = LlamaAttention
LlamaFlashAttention2 = LlamaAttention
pass
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoConfig
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING
from transformers import set_seed as transformers_set_seed
from peft import LoraConfig, TaskType, get_peft_model as _get_peft_model
from peft import PeftModelForCausalLM
from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit
from peft.tuners.lora import Linear4bit as Peft_Linear4bit
from ..save import patch_saving_functions
import re, os, inspect, math, sys
def original_apply_qkv(self, X):
Q = self.q_proj(X)
K = self.k_proj(X)
V = self.v_proj(X)
return Q, K, V
pass
def original_apply_o(self, X):
O = self.o_proj(X)
return O
pass
from math import sqrt as math_sqrt
KV_CACHE_INCREMENT = 256 # KV Cache update size
torch_nn_functional_softmax = torch.nn.functional.softmax
# Fix new HF's inference code
def _fast_prepare_inputs_for_generation(self, input_ids, **kwargs,):
if "past_key_values" in kwargs:
input_ids = input_ids[:,[-1]]
kwargs["attention_mask"] = kwargs["attention_mask"][:,[-1]]
if "cache_position" in kwargs:
kwargs["position_ids"] = kwargs["cache_position"]
return { "input_ids" : input_ids, **kwargs, }
pass
def fix_prepare_inputs_for_generation(module):
# Fix prepare_inputs_for_generation
if hasattr(module, "prepare_inputs_for_generation"):
module.prepare_inputs_for_generation = _fast_prepare_inputs_for_generation
pass
pass
torch_matmul = torch.matmul
def LlamaAttention_fast_forward_inference(
self,
hidden_states: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]],
position_ids,
do_prefill = False,
attention_mask = None,
):
"""
https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L406
Fast inference using KV cache.
QK^T can be computed in 4 chunks
[Q, q] @ [K, k].T where q, k are the new tokens.
[QK^T, Qk^T]
[qK^T, qk^T]
Since the attention mask wipes Qk^T, we just get
[QK^T, 0]
[qK^T, qk^T]
Since softmax is row-wise, we get
softmax([QK^T, 0])
softmax([qK^T, qk^T])
We then multiply by [V]
[v]
softmax([QK^T, 0]) [softmax(QK^T)V] *
softmax([qK^T, qk^T]) [softmax([qK^T, qk^T]) @ [V, v]]
But notice * [softmax(QK^T)V] is just the last attention.
We just need to compute the last final row.
This means we can pass in a row of Q, but we need to
remember K and V, which are called the KV cache.
"""
Xn = hidden_states
bsz, _, hd = hidden_states.size()
K1, V1 = past_key_value
dtype = Xn.dtype
n_heads = self.num_heads
n_groups = self.num_key_value_groups
n_kv_heads = self.num_key_value_heads
head_dim = self.head_dim
attention_size = n_heads*head_dim
# assert(n_kv_heads * n_groups == n_heads)
seq_len = K1.shape[-2]
kv_seq_len = seq_len + 1
# Prefill phase
# if not hasattr(self, "paged_attention"):
if do_prefill:
self.paged_attention = torch.empty((KV_CACHE_INCREMENT+seq_len+1, 2, bsz, n_kv_heads, head_dim), dtype = dtype, device = "cuda:0")
self.paged_attention_K = self.paged_attention[:,0]
self.paged_attention_V = self.paged_attention[:,1]
self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = "cuda:0")
self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0")
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
# Mistral Nemo 12b has weird dimensions
if attention_size != self.hidden_size:
self.temp_O = torch.empty((1, bsz, self.hidden_size), dtype = dtype, device = "cuda:0")
else:
self.temp_O = self.temp_QA[1][:,:,:self.hidden_size]
pass
self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0")
self.scalar = 1.0 / math_sqrt(self.head_dim)
self.half_head_dim = head_dim // 2
elif kv_seq_len >= self.paged_attention.shape[0]:
self.paged_attention.resize_((self.paged_attention.shape[0]+KV_CACHE_INCREMENT, 2, bsz, n_kv_heads, head_dim))
self.paged_attention_K = self.paged_attention[:,0]
self.paged_attention_V = self.paged_attention[:,1]
self.attention.resize_((bsz, n_heads, 1, self.attention.shape[-1]+KV_CACHE_INCREMENT))
pass
Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
# cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len)
# Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids)
cos = self.rotary_emb.cos_cached[position_ids].unsqueeze(1)
sin = self.rotary_emb.sin_cached[position_ids].unsqueeze(1)
h = self.half_head_dim
RH_Q = self.RH_Q
RH_Q[:,:,:,:h] = Qn[:,:,:,h:]
RH_Q[:,:,:,h:] = Qn[:,:,:,:h]
torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h])
Qn *= cos
Qn.addcmul_(RH_Q, sin)
RH_K = RH_Q[:,:n_kv_heads,:,:] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
RH_K[:,:,:,:h] = Kn[:,:,:,h:]
RH_K[:,:,:,h:] = Kn[:,:,:,:h]
torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])
Kn *= cos
Kn.addcmul_(RH_K, sin)
# New KV cache
# Kn = torch.cat([K1, Kn], dim = 2)
# Vn = torch.cat([V1, Vn], dim = 2)
self.paged_attention_K[seq_len] = Kn.permute(2, 0, 1, 3)
self.paged_attention_V[seq_len] = Vn.permute(2, 0, 1, 3)
Kn = self.paged_attention_K[:kv_seq_len].permute(1, 2, 0, 3)
Vn = self.paged_attention_V[:kv_seq_len].permute(1, 2, 0, 3)
# Handle sliding windows
sliding_window = getattr(self.config, "sliding_window", None)
if sliding_window is not None and kv_seq_len > sliding_window:
# From https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L193
slicing_tokens = 1 - sliding_window
Knn = Kn[:, :, slicing_tokens:, :]#.contiguous()
Vnn = Vn[:, :, slicing_tokens:, :]#.contiguous()
else:
Knn, Vnn = Kn, Vn
pass
# Grouped query attention
_, _, cached_len, _ = Knn.shape
if n_groups != 1:
Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)
Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim)
pass
# else:
# Knn, Vnn = Knn, Vnn
# pass
# Attention
if bsz == 1:
Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
# It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows
A = torch_matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len])
# if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype)
A = torch_matmul(A, Vnn, out = Qn)
else:
A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False)
pass
A = A.transpose(1, 2)
A = A.reshape(bsz, 1, attention_size)
A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
return A, (Kn, Vn)
pass
torch_nn_functional_silu = torch.nn.functional.silu
def fast_swiglu_inference(self, X):
# gate = self.gate_proj(X)
# up = self.up_proj(X)
bsz, _, hd = X.shape
# mlp_size = self.config.intermediate_size
# temp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0")
gate = fast_linear_forward(self.gate_proj, X)#, out = temp[0])
up = fast_linear_forward(self. up_proj, X)#, out = temp[1])
gate = torch_nn_functional_silu(gate, inplace = True)
gate *= up
# X = self.down_proj(gate)
down = fast_linear_forward(self.down_proj, gate, out = up[:,:,:hd])
return down
pass
def fast_rms_layernorm_inference(self, X):
old_dtype = X.dtype
XX = X.to(torch.float32)
variance = XX.square().mean(-1, keepdim = True)
variance += self.variance_epsilon
XX *= variance.rsqrt_()
X = XX.to(old_dtype) # Must preserve due to residual
X *= self.weight
return X
pass
def fast_rms_layernorm_inference_gemma(self, X, out_weight = None):
XX = X.to(torch.float32)
variance = XX.square().mean(-1, keepdim = True)
variance += self.variance_epsilon
XX *= variance.rsqrt_()
if out_weight is None:
out_weight = self.weight + 1.0
else:
out_weight[:] = self.weight
out_weight += 1.0
pass
XX *= out_weight
return XX.to(X.dtype)
pass
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L320
def LlamaAttention_fast_forward(
self,
hidden_states: torch.Tensor,
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
*args, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# Clear inference
if hasattr(self, "paged_attention"):
del self.paged_attention_K
del self.paged_attention_V
del self.paged_attention
del self.temp_QA
del self.temp_KV
del self.RH_Q
del self.attention
pass
bsz, q_len, _ = hidden_states.size()
n_heads = self.num_heads
n_groups = self.num_key_value_groups
n_kv_heads = self.num_key_value_heads
head_dim = self.head_dim
assert(n_kv_heads * n_groups == n_heads)
Q, K, V = self.apply_qkv(self, hidden_states)
Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
kv_seq_len = K.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
# Extend RoPE dynamically to fit in VRAM
self.rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len)
if position_ids is None:
cos = self.rotary_emb.cos_cached
sin = self.rotary_emb.sin_cached
Q, K = fast_rope_embedding(Q, K, cos, sin)
else:
cos, sin = self.rotary_emb(V, seq_len = kv_seq_len)
Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
pass
if past_key_value is not None:
K = torch.cat([past_key_value[0], K], dim = 2)
V = torch.cat([past_key_value[1], V], dim = 2)
pass
past_key_value = (K, V) if use_cache else None
# Attention module
if (not HAS_FLASH_ATTENTION and attention_mask is None):
# Xformers memory efficient attention
# Also has Flash Attention v2 dispatching
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
# Group query attention
if n_groups != 1:
K = K .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
V = V .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
if hidden_states.requires_grad:
K = K.reshape(bsz, kv_seq_len, n_heads, head_dim)
V = V.reshape(bsz, kv_seq_len, n_heads, head_dim)
else:
Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
pass
A = xformers_attention(Q, K, V, attn_bias = causal_mask)
A = A.view(bsz, q_len, n_heads, head_dim)
elif HAS_FLASH_ATTENTION and attention_mask is None:
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
A = flash_attn_func(Q, K, V, causal = True)
else:
# Grouped query attention
if n_groups != 1:
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
K = K.reshape(bsz, n_heads, kv_seq_len, head_dim)
V = V.reshape(bsz, n_heads, kv_seq_len, head_dim)
pass
# Must be contiguous or else results are False!
# https://github.com/pytorch/pytorch/issues/112577
Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous()
# Needs (batch_size, n_heads, seq_len, head_dim)
# is_casual and attention_mask must not be both set!
A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False)
# Go back to (batch_size, seq_len, n_heads, head_dim)
A = A.transpose(1, 2).contiguous()
pass
attn_output = A.reshape(bsz, q_len, n_heads*head_dim)
attn_output = self.apply_o(self, attn_output)
attn_weights = None
return attn_output, attn_weights, past_key_value
pass
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L590
def LlamaDecoderLayer_fast_forward(
self,
hidden_states: torch.Tensor,
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
*args, **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
if use_cache and hasattr(self, "_flag_for_generation"):
residual = hidden_states
hidden_states = fast_rms_layernorm_inference(self.input_layernorm, hidden_states)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)
hidden_states += residual
# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm_inference(self.post_attention_layernorm, hidden_states)
hidden_states = fast_swiglu_inference(self.mlp, hidden_states)
hidden_states += residual
else:
residual = hidden_states
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
pass
outputs = (hidden_states,)
if output_attentions: outputs += (self_attn_weights,)
if use_cache: outputs += (present_key_value,)
return outputs
pass
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825
def LlamaModel_fast_forward(
self,
input_ids: torch.LongTensor,
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
*args, **kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
assert(output_attentions is False)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("Unsloth: You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("Unsloth: You have to specify either decoder_input_ids or decoder_inputs_embeds")
seq_length_with_past = seq_length
# Fix out of bounds tokenization
if hasattr(self, "max_seq_length"):
if seq_length > self.max_seq_length:
logger.warning_once(
f"Unsloth: Input IDs of length {seq_length} > the model's max sequence length of {self.max_seq_length}.\n"\
"We shall truncate it ourselves. It's imperative if you correct this issue first."
)
if input_ids is not None:
input_ids = input_ids[:,:self.max_seq_length]
elif inputs_embeds is not None:
inputs_embeds = inputs_embeds[:,:self.max_seq_length,:]
pass
pass
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
pass
# We already handle KV cache position_ids ourselves.
if False:#(past_key_values_length != 0):
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length,
dtype = torch.int32,
device = "cuda:0",
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
elif position_ids is not None:
position_ids = position_ids.view(-1, seq_length).to(torch.int32)#.long()
else:
position_ids = None
pass
if position_ids is not None:
if position_ids.shape[0] != batch_size:
position_ids = position_ids.repeat((batch_size, 1))
pass
# Embed positions
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
inputs_embeds = inputs_embeds.to(self.config.torch_dtype)
# Normalized from Gemma
IS_GEMMA = self.config.model_type.startswith("gemma")
IS_GEMMA2 = self.config.model_type.startswith("gemma2")
train_embed_tokens = self.embed_tokens.weight.requires_grad
if IS_GEMMA:
# Match Gemma exactly by casting to bfloat16 / float16
# inputs_embeds *= math_sqrt(self.config.hidden_size)
# Ie 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32
# & 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32
normalizer = torch.tensor(math_sqrt(self.config.hidden_size), dtype = inputs_embeds.dtype)
if train_embed_tokens:
# Careful we must not do an inplace op!
inputs_embeds = inputs_embeds * normalizer
else:
inputs_requires_grad = inputs_embeds.requires_grad
if not inputs_embeds.is_leaf:
inputs_embeds = inputs_embeds.detach()
inputs_requires_grad = True
elif inputs_requires_grad:
inputs_embeds.requires_grad_(False)
pass
inputs_embeds *= normalizer
# inputs_embeds *= math_sqrt(self.config.hidden_size)
if inputs_requires_grad: inputs_embeds.requires_grad_(True)
pass
pass
# Fix up attention mask by setting elements to 0
# Specifically for DPO
if self._has_no_labels and (attention_mask is not None) and (past_key_values is None) and \
(not train_embed_tokens):
# Careful for inference the attention_mask is size (1, kv_seq_len)
# Whilst the input_embeds is size (1, 1, 4096)
inputs_requires_grad = inputs_embeds.requires_grad
if not inputs_embeds.is_leaf:
inputs_embeds = inputs_embeds.detach()
inputs_requires_grad = True
elif inputs_requires_grad:
inputs_embeds.requires_grad_(False)
pass
inputs_embeds *= attention_mask.unsqueeze(0).transpose(0, 1).transpose(1, 2)
if inputs_requires_grad: inputs_embeds.requires_grad_(True)
pass
# Ignore attention_mask
if attention_mask is None:
padding_mask = None
elif self.training:
attention_mask = None
padding_mask = None
else:
# if 0 in attention_mask:
# padding_mask = attention_mask
# else:
padding_mask = None
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window = getattr(self.config, "sliding_window", None),
)
pass
hidden_states = inputs_embeds
if past_key_values is None and self.training:
use_cache = False
# if use_cache:
# logger.warning_once(
# "Unsloth: `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`"
# )
# use_cache = False
pass
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
# Gradient checkpointing methods (ie sqrt)
if hasattr(self, "_gradient_checkpointing_boundaries"):
boundaries = self._gradient_checkpointing_boundaries
else:
boundaries = None
pass
# Check checkpointing method
gradient_checkpointing = False
offloaded_gradient_checkpointing = False
if (self.gradient_checkpointing and self.training and not use_cache):
gradient_checkpointing = True
if output_attentions is False and hasattr(self, "_offloaded_gradient_checkpointing"):
offloaded_gradient_checkpointing = True
pass
# Check for Flex Attention
# if IS_GEMMA2 and HAS_FLEX_ATTENTION:
# if not (seq_length % FLEX_ATTENTION_PADDING == 0):
# USE_FLEX_ATTENTION = True
# Gemma2 has alternating SWA and global attn
if IS_GEMMA2 and not hasattr(self, "SWA_mask"):
n = self.config.max_position_embeddings
# masked_fill is making stuff slower!
# self. GA_mask = create_boolean_mask(n = n, sliding_window = 0)
# self.SWA_mask = create_boolean_mask(n = n, sliding_window = self.config.sliding_window)
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
self.SWA_mask = AttentionMaskConverter(
is_causal = True,
sliding_window = self.config.sliding_window,
)\
.to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda:0",)\
.squeeze(0).squeeze(0)
self.GA_mask = AttentionMaskConverter(
is_causal = True,
)\
.to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda:0",)\
.squeeze(0).squeeze(0)
pass
# Go through every layer!
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states: all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
mask = causal_mask
if IS_GEMMA2: mask = self.SWA_mask if (idx % 2 == 0) else self.GA_mask
if offloaded_gradient_checkpointing:
hidden_states = Unsloth_Offloaded_Gradient_Checkpointer.apply(
decoder_layer,
hidden_states,
mask,
attention_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
)[0]
elif gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions, padding_mask = padding_mask)
return custom_forward
pass
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
mask,
attention_mask,
position_ids,
use_reentrant = True,
preserve_rng_state = False,
)
hidden_states = layer_outputs[0]
else:
layer_outputs = decoder_layer(
hidden_states,
causal_mask=mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)
hidden_states = layer_outputs[0]
pass
if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions: all_self_attns += (layer_outputs[1],)
pass
# Final layernorm
if use_cache:
hidden_states = (fast_rms_layernorm_inference_gemma if IS_GEMMA else fast_rms_layernorm_inference)\
(self.norm, hidden_states)
else:
hidden_states = fast_rms_layernorm(self.norm, hidden_states, gemma = IS_GEMMA)
pass
if output_hidden_states: all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
pass
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825
def LlamaModel_fast_forward_inference(
self,
input_ids,
past_key_values,
position_ids,
attention_mask = None,
):
input_ids = input_ids[:,:self.max_seq_length]
hidden_states = self.model.embed_tokens(input_ids)
hidden_states = hidden_states.to(self.config.torch_dtype)
bsz, q_len, hd = hidden_states.shape
seq_len = past_key_values[0][0].shape[-2]
if bsz != 1:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(bsz, q_len),
hidden_states,
seq_len,
sliding_window = getattr(self.config, "sliding_window", None),
)
else:
attention_mask = None
pass
next_decoder_cache = []
for idx, decoder_layer in enumerate(self.model.layers):
residual = hidden_states
hidden_states = fast_rms_layernorm_inference(decoder_layer.input_layernorm, hidden_states)
hidden_states, present_key_value = LlamaAttention_fast_forward_inference(
decoder_layer.self_attn,
hidden_states = hidden_states,
past_key_value = past_key_values[idx],
position_ids = position_ids,
attention_mask = attention_mask,
do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
)
hidden_states += residual
residual = hidden_states
hidden_states = fast_rms_layernorm_inference(decoder_layer.post_attention_layernorm, hidden_states)
hidden_states = fast_swiglu_inference(decoder_layer.mlp, hidden_states)
hidden_states += residual
next_decoder_cache.append(present_key_value)
pass
hidden_states = fast_rms_layernorm_inference(self.model.norm, hidden_states)
return BaseModelOutputWithPast(
last_hidden_state = hidden_states,
past_key_values = next_decoder_cache,
hidden_states = [],
attentions = [],
)
pass
def CausalLM_fast_forward(fast_forward_inference):
def _CausalLM_fast_forward(
self,
input_ids: torch.LongTensor = None,
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
*args, **kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
if past_key_values is not None:
outputs = fast_forward_inference(
self,
input_ids,
past_key_values,
position_ids = position_ids,
attention_mask = attention_mask,
)
else:
causal_mask = xformers.attn_bias.LowerTriangularMask()
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
self.model._has_no_labels = labels is None
outputs = self.model(
input_ids=input_ids,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pass
hidden_states = outputs[0]
bsz, q_len, hd = hidden_states.shape
lm_head = self.lm_head.weight
if bsz == 1 and q_len == 1:
logits = torch.mv(lm_head, hidden_states.ravel().to(lm_head.dtype))
logits = logits.unsqueeze(0).unsqueeze(0)
else:
logits = self.lm_head(hidden_states.to(lm_head.dtype))
pass
logits = logits.to(self.config.torch_dtype)
loss = None
logit_softcapping = getattr(self.config, "final_logit_softcapping", 0)
if labels is not None:
shift_logits = logits
if not hasattr(self, "extra_ignored_labels"):
# Fixes https://github.com/unslothai/unsloth/issues/10
self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = "cuda:0")
pass
shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
loss = fast_cross_entropy_loss(
logits = shift_logits,
labels = shift_labels,
logit_softcapping = logit_softcapping,
)
elif logit_softcapping != 0:
if logits.requires_grad:
logits = (1.0 / logit_softcapping) * logits
logits = torch.tanh(logits)
logits = logit_softcapping * logits
else:
logits *= (1.0 / logit_softcapping)
torch.tanh(logits, out = logits)
logits *= logit_softcapping
pass
pass
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
pass
return _CausalLM_fast_forward
pass
def PeftModelForCausalLM_fast_forward(
self,
input_ids=None,
causal_mask=None,
attention_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
task_ids=None,
**kwargs,
):
return self.base_model(
input_ids=input_ids,
causal_mask=causal_mask,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
pass
# Solves https://github.com/unslothai/unsloth/issues/168
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
# Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
# https://github.com/huggingface/transformers/pull/27931
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
class LlamaRotaryEmbedding(torch.nn.Module):
# Fixes https://github.com/huggingface/transformers/pull/28837
# https://github.com/microsoft/DeepSpeed/issues/4932
# The precision of RoPE buffers is not correct, so we cast to int64.
def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None,
config = None, # [TODO] Hack to pass in config - need to remove later
):
super().__init__()
if config is not None:
# [TODO] Hack to pass in config - need to remove later
base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = int((config.hidden_size // config.num_attention_heads))
device = "cuda"
max_position_embeddings = config.max_position_embeddings
pass
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
# Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this
self.current_rope_size = min(4 * 8192, self.max_position_embeddings)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype())
pass
def _set_cos_sin_cache(self, seq_len, device, dtype):
# Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
# in FP32. They are applied (multiplied) in FP32 as well.
self.current_rope_size = seq_len
inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() / self.dim)
)
t = torch.arange(self.current_rope_size, device="cpu", dtype=torch.int64).float()
freqs = torch.outer(t, inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype=dtype, device=device, non_blocking=True), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype=dtype, device=device, non_blocking=True), persistent=False)
pass
def forward(self, x, position_ids=None, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.current_rope_size:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:seq_len].to(dtype = x.dtype),
self.sin_cached[:seq_len].to(dtype = x.dtype),
)
pass
def extend_rope_embedding(self, x, seq_len):
if seq_len <= self.current_rope_size: return
# Iteratively grow by increments of 8192
self.current_rope_size = int(round(seq_len / 8192)) * 8192
self._set_cos_sin_cache(self.current_rope_size, device = "cuda:0", dtype = x.dtype)
pass
pass
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
# Fixes https://github.com/huggingface/transformers/pull/28837
# https://github.com/microsoft/DeepSpeed/issues/4932
# The precision of RoPE buffers is not correct, so we cast to int64.
def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0,
config = None, # [TODO] Hack to pass in config - need to remove later
):
self.scaling_factor = scaling_factor
super().__init__(dim = dim, max_position_embeddings = max_position_embeddings, base = base, device = device, config = config)
pass
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.current_rope_size = seq_len
inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() / self.dim)
)
t = torch.arange(self.current_rope_size, device="cpu", dtype=torch.int64).float()
t = t / self.scaling_factor
freqs = torch.outer(t, inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype=dtype, device=device, non_blocking=True), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype=dtype, device=device, non_blocking=True), persistent=False)
pass
pass
# See https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py#L736
# For Llama 3.1
class LlamaExtendedRotaryEmbedding(torch.nn.Module):
def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None,
config = None, # [TODO] Hack to pass in config - need to remove later
):
super().__init__()
if config is not None:
# [TODO] Hack to pass in config - need to remove later
base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = int((config.hidden_size // config.num_attention_heads))
device = "cuda"
max_position_embeddings = config.max_position_embeddings
pass
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
# Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this
self.current_rope_size = min(4 * 8192, self.max_position_embeddings)
# Normal Llama-3 RoPE
inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() / self.dim)
)
inv_freq = self.apply_scaling(inv_freq)
self.register_buffer("inv_freq", inv_freq, persistent = False)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype())
pass
def _set_cos_sin_cache(self, seq_len, device, dtype):
# Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
# in FP32. They are applied (multiplied) in FP32 as well.
self.current_rope_size = seq_len
t = torch.arange(self.current_rope_size, device="cpu", dtype=torch.int64).float()
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype=dtype, device=device, non_blocking=True), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype=dtype, device=device, non_blocking=True), persistent=False)
pass
# From https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py#L41
def apply_scaling(self, freqs: torch.Tensor):
# Values obtained from grid search
scale_factor = 8
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192 # original llama3 length
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scale_factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
pass
def forward(self, x, position_ids=None, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.current_rope_size:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:seq_len].to(dtype = x.dtype),
self.sin_cached[:seq_len].to(dtype = x.dtype),
)
pass
def extend_rope_embedding(self, x, seq_len):
if seq_len <= self.current_rope_size: return
# Iteratively grow by increments of 8192
self.current_rope_size = int(round(seq_len / 8192)) * 8192
self._set_cos_sin_cache(self.current_rope_size, device = "cuda:0", dtype = x.dtype)
pass
pass
def _wrap_fast_inference(generate, device_type, dtype, model):
# Wraps inference with bfloat16 / float16
@torch.inference_mode
def _fast_generate(*args, **kwargs):
# Set a flag for generation!
internal_model = model
while hasattr(internal_model, "model"):
internal_model._flag_for_generation = True
internal_model = internal_model.model
pass
internal_model._flag_for_generation = True
# For newer HF
kwargs["cache_implementation"] = "dynamic"
# Remove token_type_ids
kwargs.pop("token_type_ids", None)
# Check pad_token
model_eos_token_id = getattr(model.config, "eos_token_id", None)
if model_eos_token_id is not None and hasattr(model_eos_token_id, "__iter__"):
model_eos_token_id = model_eos_token_id[0]
kwargs["pad_token_id"] = kwargs.pop("pad_token_id", model_eos_token_id)
# Set pad token
# old_pad_token_id = getattr(model.config, "pad_token_id", None)
# old_eos_token_id = getattr(model.config, "eos_token_id", None)
# model.config.pad_token_id = old_eos_token_id
# Autocasted
with torch.autocast(device_type = device_type, dtype = dtype):
output = generate(*args, **kwargs)
pass
# Revert
# model.config.pad_token_id = old_pad_token_id
# Unset a flag for generation!
internal_model = model
while hasattr(internal_model, "model"):
if hasattr(internal_model, "_flag_for_generation"): del internal_model._flag_for_generation
internal_model = internal_model.model
pass
if hasattr(internal_model, "_flag_for_generation"): del internal_model._flag_for_generation
return output
pass
return _fast_generate
pass
class FastLlamaModel:
@staticmethod
def pre_patch():
init_name, function = patch_llama_rope_scaling(
model_name = "llama",
rope_module = LlamaRotaryEmbedding,
scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
extended_rope_module = LlamaExtendedRotaryEmbedding,
attention_module = LlamaAttention,
)
if init_name is not None:
exec(function, globals())
LlamaAttention.__init__ = eval(init_name)
pass
LlamaAttention .forward = LlamaAttention_fast_forward
LlamaSdpaAttention .forward = LlamaAttention_fast_forward
LlamaFlashAttention2.forward = LlamaAttention_fast_forward
LlamaDecoderLayer .forward = LlamaDecoderLayer_fast_forward
LlamaModel .forward = LlamaModel_fast_forward
LlamaForCausalLM .forward = CausalLM_fast_forward(LlamaModel_fast_forward_inference)
PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward
fix_prepare_inputs_for_generation(LlamaForCausalLM)
# Solves https://github.com/unslothai/unsloth/issues/168
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
# Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
# https://github.com/huggingface/transformers/pull/27931
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
import transformers.models.llama.modeling_llama
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = LlamaRotaryEmbedding
transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding = LlamaLinearScalingRotaryEmbedding
return
pass
@staticmethod
def from_pretrained(
model_name = "unsloth/llama-3-8b-bnb-4bit",
max_seq_length = None,
dtype = None,
load_in_4bit = True,
token = None,
device_map = "sequential",
rope_scaling = None,
fix_tokenizer = True,
model_patcher = None,
tokenizer_name = None,
trust_remote_code = False,
**kwargs,
):
if trust_remote_code:
print(
"Unsloth: WARNING `trust_remote_code` is True.\n"\
"Are you certain you want to do remote code execution?"
)
pass
if token is None and "HF_TOKEN" in os.environ:
token = os.environ["HF_TOKEN"]
if token is None and "HUGGINGFACE_TOKEN" in os.environ:
token = os.environ["HUGGINGFACE_TOKEN"]
if model_patcher is None: model_patcher = FastLlamaModel
SUPPORTS_BFLOAT16 = is_bfloat16_supported()
gpu_stats = torch.cuda.get_device_properties(0)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
statistics = \
f"==((====))== Unsloth: Fast {model_patcher.__name__[4:-5]} patching release {__version__}\n"\
f" \\\ /| GPU: {gpu_stats.name}. Max memory: {max_memory} GB. Platform = {platform_system}.\n"\
f"O^O/ \_/ \\ Pytorch: {torch.__version__}. CUDA = {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit = {torch.version.cuda}.\n"\
f"\ / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\
f' "-____-" Free Apache license: http://github.com/unslothai/unsloth'
print(statistics)
# Warn about fast transfers
old_hf_transfer = os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "0")
if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "0") == "1":
print("Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!")
pass
# Return old flag
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_hf_transfer
model_patcher.pre_patch()
get_statistics() # For debugging - we use a download counter to see if environments are not breaking
if dtype is None:
dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16
elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16:
logger.warning_once("Device does not support bfloat16. Will change to float16.")
dtype = torch.float16
assert(dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32)
# RoPE Scaling
model_config = AutoConfig.from_pretrained(model_name, token = token)
model_max_seq_length = model_config.max_position_embeddings
# Check if RoPE Scaling is even allowed
model_function = MODEL_FOR_CAUSAL_LM_MAPPING[model_config.__class__]
has_rope_scaling = False
try:
with open(inspect.getfile(model_function), "r") as file:
has_rope_scaling = "self.config.rope_scaling" in file.read()
except: pass
has_rope_scaling = True
# If max_seq_length is not specified, use maximum fron config
if max_seq_length is None:
max_seq_length = model_max_seq_length
pass
if (rope_scaling is None) and (max_seq_length > model_max_seq_length):
rope_scaling = max_seq_length / model_max_seq_length
logger.warning_once(
f"Unsloth: {model_name} can only handle sequence lengths of at most "\
f"{model_max_seq_length}.\nBut with kaiokendev's RoPE scaling of "\
f"{round(rope_scaling, 3)}, it can be magically be extended to "\
f"{max_seq_length}!"
)
# Warn RoPE scaling isn't allowed
if not has_rope_scaling:
raise RuntimeError(
"However, {model_name} doesn't support RoPE Scaling!\n"\
"Please file a feature request at https://github.com/unslothai/unsloth."
)
pass
rope_scaling = {"type": "linear", "factor": rope_scaling,}
# Add to kwargs
kwargs["rope_scaling"] = rope_scaling
pass
# We currently only support NVIDIA GPUs - AMD / Intel is a work in progress!
pre_check = check_nvidia()
bnb_config = None
if load_in_4bit:
bnb_config = BitsAndBytesConfig(
load_in_4bit = True,
bnb_4bit_use_double_quant = True,
bnb_4bit_quant_type = "nf4",
bnb_4bit_compute_dtype = dtype,
)
pass
# https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/discussions/12
# RoPE Scaling's max_position_embeddings must be updated
max_position_embeddings = max(max_seq_length, model_max_seq_length)
kwargs.pop("attn_implementation", None); # No need since we auto call it
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map = device_map,
torch_dtype = dtype,
quantization_config = bnb_config,
token = token,
max_position_embeddings = max_position_embeddings,
trust_remote_code = trust_remote_code,
attn_implementation = "eager",
**kwargs,
)
# Return old flag
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_hf_transfer
# We currently only support NVIDIA GPUs - AMD / Intel is a work in progress!
post_check = check_nvidia()
# Counteract saved tokenizers
tokenizer_name = model_name if tokenizer_name is None else tokenizer_name
tokenizer = load_correct_tokenizer(
tokenizer_name = tokenizer_name,
model_max_length = max_position_embeddings,
padding_side = "right",
token = token,
trust_remote_code = trust_remote_code,
)
model, tokenizer = patch_tokenizer(model, tokenizer)
model = model_patcher.post_patch(model)
# Patch up QKV / O and MLP
for idx, layer in enumerate(model.model.layers):
layer.self_attn.apply_qkv = original_apply_qkv
layer.self_attn.apply_o = original_apply_o
pass
# Patch Trainer
from transformers.trainer import Trainer
try:
if Trainer._inner_training_loop.__name__ != "_fast_inner_training_loop":
inner_training_loop = inspect.getsource(Trainer._inner_training_loop)
Trainer._original_training_loop = inner_training_loop
else:
inner_training_loop = Trainer._original_training_loop
except:
raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!')
pass
if ((post_check - pre_check) >= 1).sum() > 1:
raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!')
import transformers.trainer
items_in_trainer = dir(transformers.trainer)
good_items = []
for item in items_in_trainer:
# TODO: Support Deepspeed
if item.startswith(("deepspeed", "xm", "met", "smp")): continue
if item in inner_training_loop: good_items.append(item)
pass
exec("from transformers.trainer import (" + ", ".join(x for x in good_items) + ")", globals())
start = re.search('logger\.info\([\"\'].+?Running training', inner_training_loop).span(0)[0]
end = inner_training_loop.find("\n\n", start)
original_debug = inner_training_loop[start:end]
spaces = re.search('\n([\s\t]{1,})', original_debug).group(0)[1:]
front_spaces = re.match('([\s\t]{1,})', inner_training_loop).group(0)
debug_info = """debug_info = \\
f"==((====))== Unsloth - 2x faster free finetuning | Num GPUs = {args.world_size}\\n"\\
f" \\\\\\ /| Num examples = {num_examples:,} | Num Epochs = {num_train_epochs:,}\\n"\\
f"O^O/ \\_/ \\ Batch size per device = {self._train_batch_size:,} | Gradient Accumulation steps = {args.gradient_accumulation_steps}\\n"\\
f"\\ / Total batch size = {total_train_batch_size:,} | Total steps = {max_steps:,}\\n"\\
f' "-____-" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}'
logger.warning(debug_info)
import subprocess, re, gc, numpy as np
a = np.array([0,])
try:
a = subprocess.check_output('nvidia-smi --query-gpu=memory.used --format=csv', shell = True)
a = re.findall(rb'([\\d]{1,})[\\s]{1,}M', a)
a = np.array([int(x.decode('utf-8'))/1024 for x in a])
except:
if not torch.cuda.is_available():
raise RuntimeError('Unsloth: We do not support AMD / Intel machines yet - it is a work in progress!')
if ((a - PRE_CHECK) >= 1).sum() > 1:
raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!')
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()"""
debug_info = debug_info.split('\n')
debug_info = "\n".join([debug_info[0]] + [spaces + x[8:] for x in debug_info[1:]])
inner_training_loop = inner_training_loop.replace(original_debug, debug_info)
debug_info = """n_total_devices = total_train_batch_size // \\
args.gradient_accumulation_steps // self._train_batch_size
if n_total_devices > 1:
logger.warning_once('Unsloth currently does not support multi GPU setups - but we are working on it!')
debug_info ="""
debug_info = debug_info.split('\n')
debug_info = "\n".join([debug_info[0]] + [spaces + x[8:] for x in debug_info[1:]])
inner_training_loop = inner_training_loop.replace("debug_info =", debug_info, 1)
front_spaces = re.match(r"[\t\s]{1,}", inner_training_loop).group(0)
inner_training_loop = re.sub(r"^" + front_spaces, "", inner_training_loop, flags = re.MULTILINE)
inner_training_loop = inner_training_loop.replace(
"train_dataloader = tpu_spmd_dataloader(train_dataloader)",
"raise RuntimeError('Unsloth: TPUs are not yet supported!')"
)
inner_training_loop = inner_training_loop.replace(
"self.accelerator.free_memory()",
"self.accelerator.free_memory()\n" + \
front_spaces + "if self.is_deepspeed_enabled:"\
"raise RuntimeError('Unsloth: Deepspeed is not yet supported!')\n", 1,
)
check_batches = """train_dataloader = self.get_train_dataloader()
ga = args.gradient_accumulation_steps
bsz = self._train_batch_size
total_batches = bsz * ga * args.world_size
n_total_devices = total_batches // ga // bsz
if n_total_devices > 1:
logger.warning_once('Unsloth currently does not support multi GPU setups - but we are working on it!')
divisor = n_total_devices / 1
bsz = self._train_batch_size = max(int(bsz / divisor), 1)
if total_batches // ga // bsz > 1:
divisor = n_total_devices / 1
ga = args.gradient_accumulation_steps = max(int(ga / divisor), 1)"""
check_batches = check_batches.split('\n')
check_batches = "\n".join([check_batches[0]] + [front_spaces + x[8:] for x in check_batches[1:]])
inner_training_loop = inner_training_loop.replace(
"train_dataloader = self.get_train_dataloader()",
check_batches, 1,
)
inner_training_loop = inner_training_loop.replace(
"_inner_training_loop",
"_fast_inner_training_loop", 1,
)
exec(inner_training_loop, globals())
Trainer._inner_training_loop = _fast_inner_training_loop
inner_training_loop = inner_training_loop.replace(
"is_torch_tpu_available()",
"False",
)
if "n_total_devices >" not in inner_training_loop:
raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!')
pass
inner_training_loop = inner_training_loop.replace(
"is_sagemaker_mp_enabled()",
"False",
)
exec(inner_training_loop, globals())
Trainer._inner_training_loop = _fast_inner_training_loop
# Save max_seq_length
model.max_seq_length = max_position_embeddings
internal_model = model
while hasattr(internal_model, "model"):
internal_model.max_seq_length = max_position_embeddings
internal_model = internal_model.model
pass
internal_model.max_seq_length = max_position_embeddings
# We check the tokenizer first for errors
if fix_tokenizer:
tokenizer = check_tokenizer(
model = model,
tokenizer = tokenizer,
model_name = model_name,
model_max_length = max_position_embeddings,
padding_side = "right",
token = token,
)
pass
patch_saving_functions(tokenizer)
# Fix up config for transformers uploading PEFT
# Not necessary anymore since we require transformers>=4.37!
if False:
name = model.config._name_or_path
if name.startswith("unsloth/") and name.endswith("-bnb-4bit"):
name = name[:len(name) - len("-bnb-4bit")]
model.config.update({"_name_or_path" : name})
pass
pass
# Log Unsloth version for future fastpaths for inference
model.config.update({"unsloth_version" : __version__})
# Add save modules
patch_saving_functions(model)
Trainer._inner_training_loop = _fast_inner_training_loop
# Save tokenizer for inference purposes
tokenizer.padding_side = "left" # Force inference
internal_model = model
while hasattr(internal_model, "model"):
internal_model._saved_temp_tokenizer = tokenizer
internal_model = internal_model.model
pass
internal_model._saved_temp_tokenizer = tokenizer
return model, tokenizer
pass
@staticmethod
def post_patch(model):
# Patch model
layers = model.model.layers
# Torch.compile fails on embedding matrix??
# Workaround randomnly fixes it for torch versions < 2.
model.set_input_embeddings(torch.nn.Embedding.from_pretrained(model.get_input_embeddings().weight))
model.config.update({"unsloth_version" : __version__})
# We also do this for the lm_head
lm_head = torch.nn.Linear(1, 1, bias = None)
del lm_head.weight
lm_head.weight = model.get_output_embeddings().weight
lm_head.in_features = lm_head.weight.shape[1]
lm_head.out_features = lm_head.weight.shape[0]
model.lm_head = lm_head
# Also patch all dtypes - BnB seems to not allocate the correct type?
# BnB default dtype seems to be float16!
correct_dtype = lm_head.weight.dtype
for name, module in model.named_modules():
if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)):
weight = module.weight
quant_state = weight.quant_state
if type(quant_state) is list:
# BnB seems to have float16 as default!
module.weight.quant_state[2] = correct_dtype # Cast to correct dtype
else:
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
quant_state.dtype = correct_dtype
pass
pass
# Downcast RoPE embedding to correct data type
if (name.endswith("rotary_emb") or hasattr(module, "cos_cached")) \
and (module.cos_cached.dtype != correct_dtype):
module.cos_cached = module.cos_cached.to(correct_dtype)
module.sin_cached = module.sin_cached.to(correct_dtype)
pass
pass
pass
# Clear deleted GPU items
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
return model
pass
@staticmethod
def get_peft_model(
model,
r = 16,
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_alpha = 16,
lora_dropout = 0,
bias = "none",
layers_to_transform = None,
layers_pattern = None,
use_gradient_checkpointing = True,
random_state = 3407,
max_seq_length = 2048, # not used anymore
use_rslora = False,
modules_to_save = None,
init_lora_weights = True,
loftq_config = {},
temporary_location = "_unsloth_temporary_saved_buffers",
**kwargs,
):
transformers_set_seed(random_state)
if isinstance(model, PeftModelForCausalLM):
# Check if exactly the same and then pass through!
assert(hasattr(model, "peft_config"))
peft_config = model.peft_config["default"].to_dict()
check_parameters = [
"r", "lora_alpha", "lora_dropout",
"bias", "layers_to_transform", "layers_pattern",
"use_rslora", "init_lora_weights",
]
check_all = True
for param in check_parameters:
check_all = check_all and (peft_config[param] == eval(param))
pass
# Check save_modules
old_target_modules = list(peft_config["target_modules"])
modules_to_save = peft_config["modules_to_save"]
if modules_to_save is None: modules_to_save = {}
modules_to_save = list(modules_to_save)
old_target_modules += modules_to_save
# Combine all
new_target_modules = list(target_modules) + \
list(modules_to_save if modules_to_save is not None else [])
# Now check!
new_target_modules = set(new_target_modules)
check_all = check_all and (
len(set(old_target_modules) ^ new_target_modules) == 0
)
check_all = check_all and (
(loftq_config == {} or loftq_config is None) and \
(peft_config["loftq_config"] == {} or peft_config["loftq_config"] is None)
)
if check_all:
# Simply pass through!
logger.warning(
"Unsloth: Already have LoRA adapters! We shall skip this step."
)
# Offload!
# [TODO] First offload lm_head and embed_tokens to CPU (should be disk!!)
if "embed_tokens" in new_target_modules:
print("Unsloth: Casting embed_tokens to float32")
model.model.model.embed_tokens.modules_to_save.default\
.to(device = "cuda:0", dtype = torch.float32, non_blocking = True)
model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True)
# [TODO] Move old embed_tokens to CPU - should be disk!
model.model.model.embed_tokens.original_module\
.to(device = "cpu", non_blocking = True)
model.model.model.embed_tokens.original_module.requires_grad_(False)
pass
if "lm_head" in new_target_modules:
print("Unsloth: Casting lm_head to float32")
model.model.lm_head.modules_to_save.default\
.to(device = "cuda:0", dtype = torch.float32, non_blocking = True)
model.model.lm_head.modules_to_save.default.requires_grad_(True)
# [TODO] Move old lm_head to CPU - should be disk!
model.model.lm_head.original_module\
.to(device = "cpu", non_blocking = True)
model.model.lm_head.original_module.requires_grad_(False)
pass
return model
else:
raise TypeError(
"Unsloth: Your model already has LoRA adapters. Your new parameters are different."
)
pass
pass
if loftq_config is None: loftq_config = {}
signature = str(inspect.signature(LoraConfig))
SUPPORTS_LOFTQ = "loftq_config" in signature
SUPPORTS_RSLORA = "use_rslora" in signature
assert(max_seq_length <= model.max_seq_length)
if lora_dropout != 0:
logger.warning_once(
f"Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = {lora_dropout}.\n"\
f"Unsloth will patch all other layers, except LoRA matrices, causing a performance hit."
)
pass
if bias != "none":
logger.warning_once(
f"Unsloth: bias = `none` is supported for fast patching. You are using bias = {bias}.\n"\
f"Unsloth will patch all other layers, except LoRA matrices, causing a performance hit."
)
pass
if not (type(init_lora_weights) is bool or \
init_lora_weights == "gaussian" or init_lora_weights == "loftq"):
raise ValueError(
'Unsloth: `init_lora_weights` must be either [True, False, "gaussian", "loftq"].'
)
pass
if init_lora_weights == "loftq":
if not SUPPORTS_LOFTQ:
import peft
raise RuntimeError(
f"Unsloth: Your PEFT version of {peft.__version__} does not support LoftQ init.\n"\
"Please install PEFT 0.7.2 or higher.\n"\
"You can also install from source: `pip install git+https://github.com/huggingface/peft.git"
)
pass
if loftq_config == {}:
from peft import LoftQConfig
logger.warning_once(
f"Unsloth: init_lora_weights = `loftq` is set, but `loftq_config` is None.\n"\
f"We shall use `loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)`."
)
loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)
pass
if hasattr(model.config, "quantization_config"):
raise ValueError(
"Unsloth: You are using `loftq` init, yet `load_in_4bit = True` was set.\n"\
"Reload your model without any quantization by setting `load_in_4bit = False`."
)
pass
pass
assert(type(use_rslora) is bool)
if use_rslora:
if not SUPPORTS_RSLORA:
# We manually check for PEFT
import peft
raise RuntimeError(
f"Unsloth: Your PEFT version of {peft.__version__} does not support `use_rslora`.\n"\
"Please install PEFT 0.7.2 or higher.\n"\
"You can also install from source: `pip install git+https://github.com/huggingface/peft.git"
)
pass
pass
accepted_modules = frozenset(("q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",),)
model.config.update({"unsloth_version" : __version__})
if type(modules_to_save) is tuple:
modules_to_save = list(modules_to_save)
pass
train_lm_head = False
train_embed_tokens = False
final_modules = []
for module in target_modules:
if module == "lm_head":
# logger.warning_once(
# "Unsloth: `lm_head` should be placed in `modules_to_save` and not `target_modules`. "\
# "Luckily, we shall do it for you!"
# )
train_lm_head = True
if modules_to_save is None: modules_to_save = ["lm_head"]
else: modules_to_save.append("lm_head")
elif module == "embed_tokens":
# logger.warning_once(
# "Unsloth: `embed_tokens` should be placed in `modules_to_save` and not `target_modules`. "\
# "Luckily, we shall do it for you!"
# )
train_embed_tokens = True
if modules_to_save is None: modules_to_save = ["embed_tokens"]
else: modules_to_save.append("embed_tokens")
else:
assert(module in accepted_modules)
final_modules.append(module)
pass
# Check if we added new tokens!
if hasattr(model, "_need_to_train_embeddings"):
if not train_lm_head or not train_embed_tokens:
print(
"Unsloth: You added new tokens but did not specify if you wanted to "\
"train the lm_head and embed_tokens.\nWe must turn it on for you."
)
train_lm_head = True
train_embed_tokens = True
if modules_to_save is None: modules_to_save = ["embed_tokens"]
else: modules_to_save.append("embed_tokens")
if modules_to_save is None: modules_to_save = ["lm_head"]
else: modules_to_save.append("lm_head")
pass
pass
# Check for Llama-3
# if hasattr(model._saved_temp_tokenizer, "_using_llama3_template"):
# if not train_embed_tokens and not train_lm_head:
# raise RuntimeError("")
# First fix untrained tokens
# Wrong - can cause reserved tokens to pop out!!
# if train_embed_tokens or train_lm_head:
# fix_untrained_tokens(model, eps = 1e-16)
# pass
# Check modules_to_save
if modules_to_save is not None:
for module in modules_to_save:
if module == "lm_head":
train_lm_head = True
elif module == "embed_tokens":
train_embed_tokens = True
else:
raise TypeError(
f"Unsloth: Module = {module} is not allowed. Only 'lm_head' and 'embed_tokens' is allowed."
)
pass
pass
if isinstance(modules_to_save, (tuple, list)):
modules_to_save = list(set(modules_to_save))
pass
# Get LoRA
arguments = dict(
r = r,
lora_alpha = lora_alpha,
target_modules = final_modules,
lora_dropout = lora_dropout,
bias = bias,
task_type = TaskType.CAUSAL_LM,
layers_to_transform = layers_to_transform,
init_lora_weights = init_lora_weights,
loftq_config = loftq_config,
use_rslora = use_rslora,
modules_to_save = modules_to_save,
**kwargs,
)
if not SUPPORTS_LOFTQ: del arguments["loftq_config"]
if not SUPPORTS_RSLORA: del arguments["use_rslora"]
_saved_temp_tokenizer = model._saved_temp_tokenizer
lora_config = LoraConfig(**arguments)
# First offload lm_head and embed_tokens to disk
input_embeddings_device = model. get_input_embeddings().weight.device
output_embeddings_device = model.get_output_embeddings().weight.device
if use_gradient_checkpointing == "unsloth":
if train_embed_tokens:
print("Unsloth: Offloading input_embeddings to disk to save VRAM")
offload_input_embeddings(model, temporary_location)
pass
# Remove old items to save VRAM
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
pass
if train_lm_head:
print("Unsloth: Offloading output_embeddings to disk to save VRAM")
offload_output_embeddings(model, temporary_location)
pass
# Remove old items to save VRAM
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
pass
pass
model = _get_peft_model(model, lora_config)
model._saved_temp_tokenizer = _saved_temp_tokenizer
model = FastLlamaModel.patch_peft_model(model, use_gradient_checkpointing)
# Now patch lm_head and embed_tokens
if train_embed_tokens:
print("Unsloth: Casting embed_tokens to float32")
assert(hasattr(model.model.model.embed_tokens, "modules_to_save"))
model.model.model.embed_tokens.modules_to_save.default\
.to(device = "cuda:0", dtype = torch.float32, non_blocking = True)
model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True)
pass
if train_lm_head:
print("Unsloth: Casting lm_head to float32")
assert(hasattr(model.model.lm_head, "modules_to_save"))
model.model.lm_head.modules_to_save.default\
.to(device = "cuda:0", dtype = torch.float32, non_blocking = True)
model.model.lm_head.modules_to_save.default.requires_grad_(True)
pass
# Patch tokenizer to pad to the right
internal_model = model
while hasattr(internal_model, "model"):
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.padding_side = "right"
pass
internal_model = internal_model.model
pass
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.padding_side = "right"
pass
# Clear deleted GPU items
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
pass
return model
pass
@staticmethod
def patch_peft_model(
model,
use_gradient_checkpointing = True,
):
if not isinstance(model, PeftModelForCausalLM):
raise TypeError(
"Unsloth: Your model needs to call `.get_peft_model` first!"
)
pass
# Get activation function
model_type = model.config.model_type
if model_type == "llama": apply_lora_mlp = apply_lora_mlp_swiglu
elif model_type == "mistral": apply_lora_mlp = apply_lora_mlp_swiglu
elif model_type == "qwen2": apply_lora_mlp = apply_lora_mlp_swiglu
elif model_type == "gemma": apply_lora_mlp = apply_lora_mlp_geglu_approx
elif model_type == "gemma2": apply_lora_mlp = apply_lora_mlp_geglu_approx
else:
raise NotImplementedError(f"Unsloth: {model_type} is not yet implemented!")
pass
model = prepare_model_for_kbit_training(
model,
use_gradient_checkpointing = use_gradient_checkpointing,
use_reentrant = True,
)
# Fix up config for transformers uploading PEFT
for active_adapter in model.peft_config.keys():
# Not necessary since we requires transformers >= 4.37
if False:
name = model.peft_config[active_adapter].base_model_name_or_path
if name.startswith("unsloth/") and name.endswith("-bnb-4bit"):
name = name[:len(name) - len("-bnb-4bit")]
model.peft_config[active_adapter].base_model_name_or_path = name
pass
# Add revision to enable future fast inference paths
model.peft_config[active_adapter].revision = f"unsloth"
pass
from transformers.trainer import Trainer
if Trainer._inner_training_loop.__name__ != "_fast_inner_training_loop":
raise RuntimeError(
'Unsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so '\
'enabling it will require much more work, so we have to prioritize. Please understand!\n'\
'We do have a separate beta version, which you can contact us about!\n'\
'Thank you for your understanding and we appreciate it immensely!'
)
pass
# Fix loftq issues
# loftq_config must not = None, but rather {}
all_configs = model.peft_config
for key, current_config in all_configs.items():
if hasattr(current_config, "loftq_config") and current_config.loftq_config is None:
new_args = current_config.__dict__
new_args["loftq_config"] = {}
current_config = current_config.__class__(**new_args)
all_configs[key] = current_config
pass
pass
# Do patching
n_mlp = 0
n_qkv = 0
n_o = 0
import types
active_adapter = model.active_adapters[0] if \
hasattr(model, "active_adapters") else model.active_adapter
# Get dropout and bias
lora_dropout = model.peft_config[active_adapter].lora_dropout
bias = model.peft_config[active_adapter].bias
if lora_dropout == 0 and bias == "none":
for idx, layer in enumerate(model.model.model.layers):
# MLP patching
gate_proj = layer.mlp.gate_proj
up_proj = layer.mlp. up_proj
down_proj = layer.mlp.down_proj
if hasattr(gate_proj, "lora_A") and \
hasattr( up_proj, "lora_A") and \
hasattr(down_proj, "lora_A") and \
(getattr(gate_proj, "base_layer", gate_proj).bias is None) and \
(getattr( up_proj, "base_layer", up_proj).bias is None) and \
(getattr(down_proj, "base_layer", down_proj).bias is None) and \
(len(getattr(gate_proj, "lora_magnitude_vector", []) or []) == 0) and \
(len(getattr( up_proj, "lora_magnitude_vector", []) or []) == 0) and \
(len(getattr(down_proj, "lora_magnitude_vector", []) or []) == 0):
# https://stackoverflow.com/questions/50599045/python-replacing-a-function-within-a-class-of-a-module
layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp)
n_mlp += 1
else:
logger.warning_once(
"Not an error, but Unsloth cannot patch MLP layers with our manual autograd engine since either LoRA adapters\n"\
"are not enabled or a bias term (like in Qwen) is used."
)
pass
# QKV attention patching
q_proj = layer.self_attn.q_proj
k_proj = layer.self_attn.k_proj
v_proj = layer.self_attn.v_proj
if hasattr(q_proj, "lora_A") and \
hasattr(k_proj, "lora_A") and \
hasattr(v_proj, "lora_A") and \
(getattr(q_proj, "base_layer", q_proj).bias is None) and \
(getattr(k_proj, "base_layer", k_proj).bias is None) and \
(getattr(v_proj, "base_layer", v_proj).bias is None) and \
(len(getattr(q_proj, "lora_magnitude_vector", []) or []) == 0) and \
(len(getattr(k_proj, "lora_magnitude_vector", []) or []) == 0) and \
(len(getattr(v_proj, "lora_magnitude_vector", []) or []) == 0):
layer.self_attn.apply_qkv = apply_lora_qkv
n_qkv += 1
else:
if model_type != "qwen2":
logger.warning_once(
"Not an error, but Unsloth cannot patch Attention layers with our manual autograd engine since either LoRA adapters\n"\
"are not enabled or a bias term (like in Qwen) is used."
)
pass
pass
# O attention patching
o_proj = layer.self_attn.o_proj
if hasattr(o_proj, "lora_A") and \
(getattr(o_proj, "base_layer", o_proj).bias is None) and \
(len(getattr(o_proj, "lora_magnitude_vector", []) or []) == 0):
layer.self_attn.apply_o = apply_lora_o
n_o += 1
else:
logger.warning_once(
"Not an error, but Unsloth cannot patch O projection layer with our manual autograd engine since either LoRA adapters\n"\
"are not enabled or a bias term (like in Qwen) is used."
)
pass
pass
pass
logger.warning_once(
f"Unsloth {__version__} patched {len(model.model.model.layers)} layers with "\
f"{n_qkv} QKV layers, {n_o} O layers and {n_mlp} MLP layers.",
)
patch_saving_functions(model)
# Patch cross entropy loss labels
# Fixes https://github.com/unslothai/unsloth/issues/10
max_seq_length = model.max_seq_length
extra_ignored_labels = torch.full((max_seq_length, 1), -100, device = "cuda:0")
model.model.extra_ignored_labels = extra_ignored_labels
internal_model = model
while hasattr(internal_model, "model"):
internal_model.max_seq_length = max_seq_length
internal_model = internal_model.model
pass
internal_model.max_seq_length = max_seq_length
# Patch tokenizer to pad to the right
internal_model = model
while hasattr(internal_model, "model"):
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.padding_side = "right"
pass
internal_model = internal_model.model
pass
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.padding_side = "right"
pass
# Clear deleted GPU items
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
pass
return model
pass
@staticmethod
def for_inference(model):
# if model.config.model_type == "qwen2":
# FastLlamaModel.for_training(model)
# return
# pass
internal_model = model
internal_model.gradient_checkpointing = False
internal_model.training = False
while hasattr(internal_model, "model"):
internal_model = internal_model.model
internal_model.gradient_checkpointing = False
internal_model.training = False
pass
# Also check if lm_head / embeddings are trained
internal_model = model
while not hasattr(internal_model, "lm_head"):
internal_model = internal_model.model
pass
lm_head = internal_model.lm_head.weight
device_type = lm_head.device.type
dtype = model.config.torch_dtype
if type(dtype) is str:
if dtype == "float16": dtype = torch.float16
elif dtype == "bfloat16": dtype = torch.bfloat16
pass
# Wrap model.generate
if model.generate.__name__ != "_fast_generate":
model._unwrapped_old_generate = model.generate
model.generate = _wrap_fast_inference(model.generate, device_type, dtype, model)
pass
# Patch tokenizer to pad to the left
internal_model = model
while hasattr(internal_model, "model"):
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.padding_side = "left"
pass
internal_model = internal_model.model
pass
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.padding_side = "left"
pass
pass
@staticmethod
def for_training(model, use_gradient_checkpointing = True):
internal_model = model
internal_model.gradient_checkpointing = use_gradient_checkpointing
internal_model.training = True
# Delete all fast inference loras
for param in model.parameters():
if hasattr(param, "_fast_lora"):
del param._fast_lora
pass
while hasattr(internal_model, "model"):
internal_model = internal_model.model
internal_model.gradient_checkpointing = use_gradient_checkpointing
internal_model.training = True
pass
# Also revert model.generate
if hasattr(model, "_unwrapped_old_generate"):
model.generate = model._unwrapped_old_generate
del model._unwrapped_old_generate
pass
# Patch tokenizer to pad to the right
internal_model = model
while hasattr(internal_model, "model"):
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.padding_side = "right"
pass
internal_model = internal_model.model
pass
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.padding_side = "right"
pass
pass
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .llama import FastLlamaModel, logger
from .mistral import FastMistralModel
from .qwen2 import FastQwen2Model
from transformers import AutoConfig
from transformers import __version__ as transformers_version
from peft import PeftConfig, PeftModel
from .mapper import INT_TO_FLOAT_MAPPER, FLOAT_TO_INT_MAPPER
import os
# https://github.com/huggingface/transformers/pull/26037 allows 4 bit loading!
from packaging.version import Version
transformers_version = Version(transformers_version)
SUPPORTS_FOURBIT = transformers_version >= Version("4.37")
SUPPORTS_GEMMA = transformers_version >= Version("4.38")
SUPPORTS_GEMMA2 = transformers_version >= Version("4.42")
SUPPORTS_LLAMA31 = transformers_version >= Version("4.43.1")
if SUPPORTS_GEMMA:
from .gemma import FastGemmaModel
if SUPPORTS_GEMMA2:
from .gemma2 import FastGemma2Model
pass
def _get_model_name(model_name, load_in_4bit = True):
if not SUPPORTS_FOURBIT and model_name in INT_TO_FLOAT_MAPPER:
model_name = INT_TO_FLOAT_MAPPER[model_name.lower()]
logger.warning_once(
f"Unsloth: Your transformers version of {transformers_version} does not support native "\
f"4bit loading.\nThe minimum required version is 4.37.\n"\
f'Try `pip install --upgrade "transformers>=4.37"`\n'\
f"to obtain the latest transformers build, then restart this session.\n"\
f"For now, we shall load `{model_name}` instead (still 4bit, just slower downloading)."
)
elif not load_in_4bit and model_name in INT_TO_FLOAT_MAPPER:
new_model_name = INT_TO_FLOAT_MAPPER[model_name.lower()]
# logger.warning_once(
# f"Unsloth: You passed in `{model_name}` which is a 4bit model, yet you set\n"\
# f"`load_in_4bit = False`. We shall load `{new_model_name}` instead."
# )
model_name = new_model_name
elif load_in_4bit and SUPPORTS_FOURBIT and model_name in FLOAT_TO_INT_MAPPER:
new_model_name = FLOAT_TO_INT_MAPPER[model_name.lower()]
# logger.warning_once(
# f"Unsloth: You passed in `{model_name}` and `load_in_4bit = True`.\n"\
# f"We shall load `{new_model_name}` for 4x faster loading."
# )
model_name = new_model_name
pass
return model_name
pass
class FastLanguageModel(FastLlamaModel):
@staticmethod
def from_pretrained(
model_name = "unsloth/llama-3-8b-bnb-4bit",
max_seq_length = None,
dtype = None,
load_in_4bit = True,
token = None,
device_map = "sequential",
rope_scaling = None,
fix_tokenizer = True,
trust_remote_code = False,
use_gradient_checkpointing = "unsloth",
resize_model_vocab = None,
revision = None,
*args, **kwargs,
):
if token is None and "HF_TOKEN" in os.environ:
token = os.environ["HF_TOKEN"]
if token is None and "HUGGINGFACE_TOKEN" in os.environ:
token = os.environ["HUGGINGFACE_TOKEN"]
old_model_name = model_name
model_name = _get_model_name(model_name, load_in_4bit)
# First check if it's a normal model via AutoConfig
from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled
was_disabled = are_progress_bars_disabled()
disable_progress_bars()
try:
model_config = AutoConfig.from_pretrained(model_name, token = token, revision = revision)
is_model = True
except:
is_model = False
try:
peft_config = PeftConfig .from_pretrained(model_name, token = token, revision = revision)
is_peft = True
except:
is_peft = False
# Cannot be both!
if is_model and is_peft:
raise RuntimeError(
"Unsloth: Your repo has a LoRA adapter and a base model.\n"\
"You have 2 files `config.json` and `adapter_config.json`.\n"\
"We must only allow one config file.\n"\
"Please separate the LoRA and base models to 2 repos."
)
elif not is_model and not is_peft:
raise RuntimeError(
f"Unsloth: `{model_name}` is not a base model or a PEFT model.\n"\
"We could not locate a `config.json` or `adapter_config.json` file.\n"\
"Are you certain the model name is correct? Does it actually exist?"
)
pass
# Get base model for PEFT:
if is_peft:
# Check base model again for PEFT
model_name = _get_model_name(peft_config.base_model_name_or_path, load_in_4bit)
model_config = AutoConfig.from_pretrained(model_name, token = token, revision = revision)
pass
if not was_disabled: enable_progress_bars()
model_type = model_config.model_type
if model_type == "llama":
scaling_type = None
if getattr(model_config, "rope_scaling", None) is not None:
scaling_type1 = model_config.rope_scaling.get("type", None)
scaling_type2 = model_config.rope_scaling.get("rope_type", None)
scaling_type = scaling_type1 if scaling_type1 is not None else scaling_type2
pass
if scaling_type == "llama3" and not SUPPORTS_LLAMA31:
raise ImportError(
f"Unsloth: Your transformers version of {transformers_version} does not support Llama 3.1.\n"\
f"The minimum required version is 4.43.1\n"\
f'Try `pip install --upgrade "transformers>=4.43.1"`\n'\
f"to obtain the latest transformers build, then restart this session."\
)
dispatch_model = FastLlamaModel
elif model_type == "mistral": dispatch_model = FastMistralModel
elif model_type == "gemma":
if not SUPPORTS_GEMMA:
raise ImportError(
f"Unsloth: Your transformers version of {transformers_version} does not support Gemma.\n"\
f"The minimum required version is 4.38.\n"\
f'Try `pip install --upgrade "transformers>=4.38"`\n'\
f"to obtain the latest transformers build, then restart this session."\
)
dispatch_model = FastGemmaModel
elif model_type == "gemma2":
if not SUPPORTS_GEMMA2:
raise ImportError(
f"Unsloth: Your transformers version of {transformers_version} does not support Gemma2.\n"\
f"The minimum required version is 4.42.3.\n"\
f'Try `pip install --upgrade "transformers>=4.42.3"`\n'\
f"to obtain the latest transformers build, then restart this session."\
)
dispatch_model = FastGemma2Model
elif model_type == "qwen2":
dispatch_model = FastQwen2Model
else:
raise NotImplementedError(
f"Unsloth: {model_name} not supported yet!\n"\
"Make an issue to https://github.com/unslothai/unsloth!",
)
pass
# Check if this is local model since the tokenizer gets overwritten
if os.path.exists(os.path.join(old_model_name, "tokenizer_config.json")) and \
os.path.exists(os.path.join(old_model_name, "tokenizer.json")) and \
os.path.exists(os.path.join(old_model_name, "special_tokens_map.json")):
tokenizer_name = old_model_name
else:
tokenizer_name = None
pass
model, tokenizer = dispatch_model.from_pretrained(
model_name = model_name,
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
token = token,
device_map = device_map,
rope_scaling = rope_scaling,
fix_tokenizer = fix_tokenizer,
model_patcher = dispatch_model,
tokenizer_name = tokenizer_name,
trust_remote_code = trust_remote_code,
revision = revision if not is_peft else None,
*args, **kwargs,
)
if resize_model_vocab is not None:
model.resize_token_embeddings(resize_model_vocab)
pass
# In case the model supports tagging, add the unsloth tag.
if hasattr(model, "add_model_tags"):
model.add_model_tags(["unsloth",])
pass
if hasattr(tokenizer, "add_model_tags"):
tokenizer.add_model_tags(["unsloth",])
pass
if load_in_4bit:
# Fix up bitsandbytes config
quantization_config = \
{
# Sometimes torch_dtype is not a string!!
"bnb_4bit_compute_dtype" : model.config.to_dict()["torch_dtype"],
"bnb_4bit_quant_type" : "nf4",
"bnb_4bit_use_double_quant" : True,
"llm_int8_enable_fp32_cpu_offload" : False,
"llm_int8_has_fp16_weight" : False,
"llm_int8_skip_modules" : None,
"llm_int8_threshold" : 6.0,
"load_in_4bit" : True,
"load_in_8bit" : False,
"quant_method" : "bitsandbytes",
}
model.config.update({"quantization_config" : quantization_config})
pass
if is_peft:
# From https://github.com/huggingface/peft/issues/184
# Now add PEFT adapters
model.enable_input_require_grads()
model = PeftModel.from_pretrained(
model,
old_model_name,
token = token,
revision = revision,
is_trainable = True,
)
# Patch it as well!
model = dispatch_model.patch_peft_model(model, use_gradient_checkpointing)
pass
return model, tokenizer
pass
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment