Commit d3e0fa63 authored by chenzk's avatar chenzk
Browse files

v1.0.3

parents
Pipeline #1452 canceled with stages
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
import importlib
import sys
from packaging.version import Version
# # Define a list of modules to check
# MODULES_TO_CHECK = ["bitsandbytes"]
# # Check if any of the modules in the list have been imported
# for module in MODULES_TO_CHECK:
# if module in sys.modules:
# raise ImportError(f"Unsloth: Please import Unsloth before {module}.")
# pass
# pass
# Unsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so
# enabling it will require much more work, so we have to prioritize. Please understand!
# We do have a beta version, which you can contact us about!
# Thank you for your understanding and we appreciate it immensely!
if "CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
devices = os.environ["CUDA_VISIBLE_DEVICES"]
# Check if there are multiple cuda devices set in env
if not devices.isdigit():
first_id = devices.split(",")[0]
warnings.warn(
f"Unsloth: 'CUDA_VISIBLE_DEVICES' is currently {devices} \n"\
"Unsloth currently does not support multi GPU setups - but we are working on it!\n"\
"Multiple CUDA devices detected but we require a single device.\n"\
f"We will override CUDA_VISIBLE_DEVICES to first device: {first_id}."
)
os.environ["CUDA_VISIBLE_DEVICES"] = str(first_id)
else:
# warnings.warn("Unsloth: 'CUDA_VISIBLE_DEVICES' is not set. We shall set it ourselves.")
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
pass
# Reduce VRAM usage by reducing fragmentation
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
try:
import torch
except:
raise ImportError("Pytorch is not installed. Go to https://pytorch.org/.\n"\
"We have some installation instructions on our Github page.")
pass
# Hugging Face Hub faster downloads (only enable during Colab and Kaggle sessions)
keynames = "\n" + "\n".join(os.environ.keys())
if "\nCOLAB_" in keynames or "\nKAGGLE_" in keynames:
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
pass
# We support Pytorch 2
# Fixes https://github.com/unslothai/unsloth/issues/38
torch_version = torch.__version__.split(".")
major_torch, minor_torch = torch_version[0], torch_version[1]
major_torch, minor_torch = int(major_torch), int(minor_torch)
if (major_torch < 2):
raise ImportError("Unsloth only supports Pytorch 2 for now. Please update your Pytorch to 2.1.\n"\
"We have some installation instructions on our Github page.")
elif (major_torch == 2) and (minor_torch < 2):
# Disable expandable_segments
del os.environ["PYTORCH_CUDA_ALLOC_CONF"]
pass
# Torch 2.5 has including_emulation
major_version, minor_version = torch.cuda.get_device_capability()
SUPPORTS_BFLOAT16 = (major_version >= 8)
if (major_torch == 2) and (minor_torch >= 5):
old_is_bf16_supported = torch.cuda.is_bf16_supported
def is_bf16_supported(including_emulation = False):
return old_is_bf16_supported(including_emulation)
torch.cuda.is_bf16_supported = is_bf16_supported
else:
def is_bf16_supported(): return SUPPORTS_BFLOAT16
torch.cuda.is_bf16_supported = is_bf16_supported
pass
# Try loading bitsandbytes and triton
import bitsandbytes as bnb
import triton
libcuda_dirs = lambda: None
if Version(triton.__version__) >= Version("3.0.0"):
try: from triton.backends.nvidia.driver import libcuda_dirs
except: pass
else: from triton.common.build import libcuda_dirs
import os
import re
import numpy as np
import subprocess
try:
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
libcuda_dirs()
except:
warnings.warn(
"Unsloth: Running `ldconfig /usr/lib64-nvidia` to link CUDA."\
)
if os.path.exists("/usr/lib64-nvidia"):
os.system("ldconfig /usr/lib64-nvidia")
elif os.path.exists("/usr/local"):
# Sometimes bitsandbytes cannot be linked properly in Runpod for example
possible_cudas = subprocess.check_output(["ls", "-al", "/usr/local"]).decode("utf-8").split("\n")
find_cuda = re.compile(r"[\s](cuda\-[\d\.]{2,})$")
possible_cudas = [find_cuda.search(x) for x in possible_cudas]
possible_cudas = [x.group(1) for x in possible_cudas if x is not None]
# Try linking cuda folder, or everything in local
if len(possible_cudas) == 0:
os.system(f"ldconfig /usr/local/")
else:
find_number = re.compile(r"([\d\.]{2,})")
latest_cuda = np.argsort([float(find_number.search(x).group(1)) for x in possible_cudas])[::-1][0]
latest_cuda = possible_cudas[latest_cuda]
os.system(f"ldconfig /usr/local/{latest_cuda}")
pass
importlib.reload(bnb)
importlib.reload(triton)
try:
libcuda_dirs = lambda: None
if Version(triton.__version__) >= Version("3.0.0"):
try: from triton.backends.nvidia.driver import libcuda_dirs
except: pass
else: from triton.common.build import libcuda_dirs
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
libcuda_dirs()
except:
warnings.warn(
"Unsloth: CUDA is not linked properly.\n"\
"Try running `python -m bitsandbytes` then `python -m xformers.info`\n"\
"We tried running `ldconfig /usr/lib64-nvidia` ourselves, but it didn't work.\n"\
"You need to run in your terminal `sudo ldconfig /usr/lib64-nvidia` yourself, then import Unsloth.\n"\
"Also try `sudo ldconfig /usr/local/cuda-xx.x` - find the latest cuda version.\n"\
"Unsloth will still run for now, but maybe it might crash - let's hope it works!"
)
pass
from .models import *
from .save import *
from .chat_templates import *
from .tokenizer_utils import *
from .trainer import *
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = [
"get_chat_template",
"test_chat_templates",
"test_hf_gguf_equivalence",
"remove_special_tokens",
"to_sharegpt",
"standardize_sharegpt",
"apply_chat_template",
"train_on_responses_only",
"test_construct_chat_template",
]
from transformers import StoppingCriteria, StoppingCriteriaList
from torch import LongTensor, FloatTensor
from transformers.models.llama.modeling_llama import logger
from .save import patch_saving_functions
import os
import shutil
from .tokenizer_utils import *
from .models._utils import patch_tokenizer
import re
CHAT_TEMPLATES = {}
# =========================================== Unsloth
# Unsloth efficient template leverages from Zephyr
unsloth_template = \
"{{ bos_token }}"\
"{% if messages[0]['role'] == 'system' %}"\
"{{ messages[0]['content'] + '\n' }}"\
"{% set loop_messages = messages[1:] %}"\
"{% else %}"\
"{{ 'You are a helpful assistant to the user\n' }}"\
"{% set loop_messages = messages %}"\
"{% endif %}"\
"{% for message in loop_messages %}"\
"{% if message['role'] == 'user' %}"\
"{{ '>>> User: ' + message['content'] + '\n' }}"\
"{% elif message['role'] == 'assistant' %}"\
"{{ '>>> Assistant: ' + message['content'] + eos_token + '\n' }}"\
"{% else %}"\
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
"{% endif %}"\
"{% endfor %}"\
"{% if add_generation_prompt %}"\
"{{ '>>> Assistant: ' }}"\
"{% endif %}"
pass
unsloth_ollama = \
'''
FROM {__FILE_LOCATION__}
TEMPLATE """{{ if .System }}{{ .System }}
{{ end }}{{ if .Prompt }}>>> User: {{ .Prompt }}
{{ end }}>>> Assistant: {{ .Response }}{__EOS_TOKEN__}
"""
PARAMETER stop "{__EOS_TOKEN__}"
SYSTEM """You are a helpful assistant to the user"""
'''
unsloth_eos_token = "eos_token"
CHAT_TEMPLATES["unsloth"] = (unsloth_template, unsloth_eos_token, False, unsloth_ollama,)
pass
# =========================================== Zephyr
# Zephyr has no BOS!
zephyr_template = \
"{% for message in messages %}"\
"{% if message['role'] == 'user' %}"\
"{{ '<|user|>\n' + message['content'] + eos_token + '\n' }}"\
"{% elif message['role'] == 'assistant' %}"\
"{{ '<|assistant|>\n' + message['content'] + eos_token + '\n' }}"\
"{% else %}"\
"{{ '<|system|>\n' + message['content'] + eos_token + '\n' }}"\
"{% endif %}"\
"{% endfor %}"\
"{% if add_generation_prompt %}"\
"{{ '<|assistant|>\n' }}"\
"{% endif %}"
pass
zephyr_ollama = \
'''
FROM {__FILE_LOCATION__}
TEMPLATE """{{ if .System }}<|system|>
{{ .System }}{__EOS_TOKEN__}
{{ end }}{{ if .Prompt }}<|user|>
{{ .Prompt }}{__EOS_TOKEN__}
{{ end }}<|assistant|>
{{ .Response }}{__EOS_TOKEN__}
"""
PARAMETER stop "{__EOS_TOKEN__}"
'''
zephyr_eos_token = "eos_token"
CHAT_TEMPLATES["zephyr"] = (zephyr_template, zephyr_eos_token, False, zephyr_ollama,)
pass
# =========================================== ChatML
# ChatML has no BOS and not EOS! Rather <|im_start|> and <|im_end|> acts as BOS / EOS.
chatml_template = \
"{% for message in messages %}"\
"{% if message['role'] == 'user' %}"\
"{{'<|im_start|>user\n' + message['content'] + '<|im_end|>\n'}}"\
"{% elif message['role'] == 'assistant' %}"\
"{{'<|im_start|>assistant\n' + message['content'] + '<|im_end|>\n' }}"\
"{% else %}"\
"{{ '<|im_start|>system\n' + message['content'] + '<|im_end|>\n' }}"\
"{% endif %}"\
"{% endfor %}"\
"{% if add_generation_prompt %}"\
"{{ '<|im_start|>assistant\n' }}"\
"{% endif %}"
pass
chatml_ollama = \
'''
FROM {__FILE_LOCATION__}
TEMPLATE """{{ if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}{{ if .Prompt }}<|im_start|>user
{{ .Prompt }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ .Response }}<|im_end|>
"""
PARAMETER stop "<|im_start|>"
PARAMETER stop "<|im_end|>"
'''
chatml_eos_token = "<|im_end|>"
CHAT_TEMPLATES["chatml"] = (chatml_template, chatml_eos_token, True, chatml_ollama,)
pass
# =========================================== Mistral-1
# Mistral Instruct doesn't allow system prompts, so we append it to the user message.
mistral_template = \
"{{ bos_token }}"\
"{% if messages[0]['role'] == 'system' %}"\
"{% if messages[1]['role'] == 'user' %}"\
"{{ '[INST] ' + messages[0]['content'] + ' ' + messages[1]['content'] + ' [/INST]' }}"\
"{% set loop_messages = messages[2:] %}"\
"{% else %}"\
"{{ '[INST] ' + messages[0]['content'] + ' [/INST]' }}"\
"{% set loop_messages = messages[1:] %}"\
"{% endif %}"\
"{% else %}"\
"{% set loop_messages = messages %}"\
"{% endif %}"\
"{% for message in loop_messages %}"\
"{% if message['role'] == 'user' %}"\
"{{ '[INST] ' + message['content'] + ' [/INST]' }}"\
"{% elif message['role'] == 'assistant' %}"\
"{{ message['content'] + eos_token }}"\
"{% else %}"\
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
"{% endif %}"\
"{% endfor %}"
pass
# Ollama from https://www.ollama.com/library/mistral
mistral_ollama = \
'''
FROM {__FILE_LOCATION__}
TEMPLATE """[INST] {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }} [/INST]"""
PARAMETER stop "{__EOS_TOKEN__}"
'''
mistral_eos_token = "eos_token"
CHAT_TEMPLATES["mistral"] = (mistral_template, mistral_eos_token, False, mistral_ollama,)
pass
# =========================================== Llama-2
# Adds BOS to every convo! And weird <<SYS>> system messages.
llama_template = \
"{% if messages[0]['role'] == 'system' %}"\
"{% if messages[1]['role'] == 'user' %}"\
"{{ bos_token + '[INST] <<SYS>>\n' + messages[0]['content'] + '\n<</SYS>>\n\n' + messages[1]['content'] + ' [/INST]' }}"\
"{% set loop_messages = messages[2:] %}"\
"{% else %}"\
"{{ bos_token + '[INST] ' + messages[0]['content'] + ' [/INST]' }}"\
"{% set loop_messages = messages[1:] %}"\
"{% endif %}"\
"{% else %}"\
"{% set loop_messages = messages %}"\
"{% endif %}"\
"{% for message in loop_messages %}"\
"{% if message['role'] == 'user' %}"\
"{{ bos_token + '[INST] ' + message['content'].strip() + ' [/INST]' }}"\
"{% elif message['role'] == 'assistant' %}"\
"{{ ' ' + message['content'].strip() + ' ' + eos_token }}"\
"{% else %}"\
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
"{% endif %}"\
"{% endfor %}"
pass
# Ollama from https://www.ollama.com/library/llama3
llama_ollama = \
'''
FROM {__FILE_LOCATION__}
TEMPLATE """[INST] <<SYS>>{{ .System }}<</SYS>>
{{ .Prompt }} [/INST]"""
PARAMETER stop "{__EOS_TOKEN__}"
'''
llama_eos_token = "eos_token"
CHAT_TEMPLATES["llama"] = (llama_template, llama_eos_token, False, llama_ollama,)
pass
# =========================================== Vicuna
# https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#prompt-template
vicuna_template = \
"{{ bos_token }}"\
"{% if messages[0]['role'] == 'system' %}"\
"{{ messages[0]['content'] + ' ' }}"\
"{% set loop_messages = messages[1:] %}"\
"{% else %}"\
"{{ 'A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user\\'s questions.' + ' ' }}"\
"{% set loop_messages = messages %}"\
"{% endif %}"\
"{% for message in loop_messages %}"\
"{% if message['role'] == 'user' %}"\
"{{ 'USER: ' + message['content'] + ' ' }}"\
"{% elif message['role'] == 'assistant' %}"\
"{{ 'ASSISTANT: ' + message['content'] + eos_token }}"\
"{% else %}"\
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
"{% endif %}"\
"{% endfor %}"\
"{% if add_generation_prompt %}"\
"{{ 'ASSISTANT:' }}"\
"{% endif %}"
pass
# Ollama from https://www.ollama.com/library/vicuna
vicuna_ollama = \
'''
FROM {__FILE_LOCATION__}
TEMPLATE """{{ if .System }}{{ .System }} {{ end }}{{ if .Prompt }}USER: {{ .Prompt }} {{ end }}ASSISTANT: {{ .Response }} {__EOS_TOKEN__}"""
PARAMETER stop "{__EOS_TOKEN__}"
'''
vicuna_eos_token = "eos_token"
CHAT_TEMPLATES["vicuna"] = (vicuna_template, vicuna_eos_token, False, vicuna_ollama,)
pass
# =========================================== Vicuna Old
# https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#prompt-template
vicuna_old_template = \
"{{ bos_token }}"\
"{% if messages[0]['role'] == 'system' %}"\
"{{ messages[0]['content'] + '\n' }}"\
"{% set loop_messages = messages[1:] %}"\
"{% else %}"\
"{{ 'A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human\\'s questions.' + '\n' }}"\
"{% set loop_messages = messages %}"\
"{% endif %}"\
"{% for message in loop_messages %}"\
"{% if message['role'] == 'user' %}"\
"{{ '### Human: ' + message['content'] + '\n' }}"\
"{% elif message['role'] == 'assistant' %}"\
"{{ '### Assistant: ' + message['content'] + eos_token + '\n' }}"\
"{% else %}"\
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
"{% endif %}"\
"{% endfor %}"\
"{% if add_generation_prompt %}"\
"{{ '### Assistant:' }}"\
"{% endif %}"
pass
vicuna_old_ollama = \
'''
FROM {__FILE_LOCATION__}
TEMPLATE """{{ if .System }}{{ .System }}
{{ end }}{{ if .Prompt }}### Human: {{ .Prompt }}
{{ end }}### Assistant: {{ .Response }}{__EOS_TOKEN__}
"""
PARAMETER stop "{__EOS_TOKEN__}"
SYSTEM """A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."""
'''
vicuna_old_eos_token = "eos_token"
CHAT_TEMPLATES["vicuna_old"] = (vicuna_old_template, vicuna_old_eos_token, False, vicuna_old_ollama,)
pass
# =========================================== Alpaca multi turn
# https://github.com/tatsu-lab/stanford_alpaca Changed for multi-turn convos
alpaca_template = \
"{{ bos_token }}"\
"{% if messages[0]['role'] == 'system' %}"\
"{{ messages[0]['content'] + '\n\n' }}"\
"{% set loop_messages = messages[1:] %}"\
"{% else %}"\
"{{ 'Below are some instructions that describe some tasks. Write responses that appropriately complete each request.\n\n' }}"\
"{% set loop_messages = messages %}"\
"{% endif %}"\
"{% for message in loop_messages %}"\
"{% if message['role'] == 'user' %}"\
"{{ '### Instruction:\n' + message['content'] + '\n\n' }}"\
"{% elif message['role'] == 'assistant' %}"\
"{{ '### Response:\n' + message['content'] + eos_token + '\n\n' }}"\
"{% else %}"\
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
"{% endif %}"\
"{% endfor %}"\
"{% if add_generation_prompt %}"\
"{{ '### Response:\n' }}"\
"{% endif %}"
pass
alpaca_ollama = \
'''
FROM {__FILE_LOCATION__}
TEMPLATE """{{ if .System }}{{ .System }}
{{ end }}{{ if .Prompt }}### Instruction:
{{ .Prompt }}{{ end }}
### Response:
{{ .Response }}{__EOS_TOKEN__}
"""
PARAMETER stop "{__EOS_TOKEN__}"
SYSTEM """Below are some instructions that describe some tasks. Write responses that appropriately complete each request."""
'''
alpaca_eos_token = "eos_token"
CHAT_TEMPLATES["alpaca"] = (alpaca_template, alpaca_eos_token, False, alpaca_ollama,)
pass
# =========================================== Gemma
# https://huggingface.co/google/gemma-7b-it
# Notice we must use |trim for lstrip and rstrip. <start_of_turn> maps to 106.
# <end_of_turn> maps to 107. user and model are normal 1 word tokens.
gemma_template = \
"{{ bos_token }}"\
"{% if messages[0]['role'] == 'system' %}"\
"{{'<start_of_turn>user\n' + messages[0]['content'] | trim + ' ' + messages[1]['content'] | trim + '<end_of_turn>\n'}}"\
"{% set loop_messages = messages[2:] %}"\
"{% endif %}"\
"{% for message in messages %}"\
"{% if message['role'] == 'user' %}"\
"{{'<start_of_turn>user\n' + message['content'] | trim + '<end_of_turn>\n'}}"\
"{% elif message['role'] == 'assistant' %}"\
"{{'<start_of_turn>model\n' + message['content'] | trim + '<end_of_turn>\n' }}"\
"{% else %}"\
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
"{% endif %}"\
"{% endfor %}"\
"{% if add_generation_prompt %}"\
"{{ '<start_of_turn>model\n' }}"\
"{% endif %}"
pass
# Ollama from https://www.ollama.com/library/gemma
gemma_ollama = \
'''
FROM {__FILE_LOCATION__}
TEMPLATE """<start_of_turn>user
{{ if .System }}{{ .System }} {{ end }}{{ .Prompt }}<end_of_turn>
<start_of_turn>model
{{ .Response }}<end_of_turn>
"""
PARAMETER repeat_penalty 1
PARAMETER stop "<start_of_turn>"
PARAMETER stop "<end_of_turn>"
PARAMETER penalize_newline false
'''
gemma_eos_token = "<end_of_turn>"
CHAT_TEMPLATES["gemma"] = (gemma_template, gemma_eos_token, True, gemma_ollama,)
pass
# =========================================== Gemma with ChatML instead
# We find using <eos> is still more appropriate!
gemma_chatml_template = "{{ bos_token }}" + chatml_template
pass
gemma_chatml_ollama = \
'''
FROM {__FILE_LOCATION__}
TEMPLATE """{{ if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}{{ if .Prompt }}<|im_start|>user
{{ .Prompt }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ .Response }}<|im_end|>
"""
PARAMETER repeat_penalty 1
PARAMETER stop "<|im_start|>"
PARAMETER stop "<|im_end|>"
PARAMETER penalize_newline false
'''
gemma_chatml_eos_token = (
{"<start_of_turn>" : "<|im_start|>", "<eos>" : "<|im_end|>"},
"<|im_end|>",
)
CHAT_TEMPLATES["gemma_chatml"] = (gemma_chatml_template, gemma_chatml_eos_token, True, gemma_chatml_ollama,)
pass
# =========================================== Gemma 2
# Same as Gemma 1, but with sliding window attention!
# https://ollama.com/library/gemma2/blobs/6522ca797f47
gemma2_template = gemma_template
gemma2_ollama = gemma_ollama + "PARAMETER num_ctx 4096\n"
gemma2_eos_token = "<end_of_turn>"
CHAT_TEMPLATES["gemma2"] = (gemma2_template, gemma2_eos_token, True, gemma2_ollama,)
# =========================================== Gemma 2 with ChatML instead
gemma2_chatml_template = gemma_chatml_template
gemma2_chatml_ollama = gemma_chatml_ollama + "PARAMETER num_ctx 4096\n"
gemma2_chatml_eos_token = gemma_chatml_eos_token
CHAT_TEMPLATES["gemma2_chatml"] = (gemma2_chatml_template, gemma2_chatml_eos_token, True, gemma2_chatml_ollama,)
pass
# =========================================== Llama-3
# Weirdly \n\n is needed?
llama3_template = \
"{{ bos_token }}"\
"{% for message in messages %}"\
"{% if message['role'] == 'user' %}"\
"{{ '<|start_header_id|>user<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }}"\
"{% elif message['role'] == 'assistant' %}"\
"{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }}"\
"{% else %}"\
"{{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }}"\
"{% endif %}"\
"{% endfor %}"\
"{% if add_generation_prompt %}"\
"{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}"\
"{% endif %}"
pass
# Ollama from https://www.ollama.com/library/llama3
llama3_ollama = \
'''
FROM {__FILE_LOCATION__}
TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
{{ .Response }}<|eot_id|>"""
PARAMETER stop "<|start_header_id|>"
PARAMETER stop "<|end_header_id|>"
PARAMETER stop "<|eot_id|>"
'''
llama3_template_eos_token = "eos_token"
CHAT_TEMPLATES["llama-3"] = (llama3_template, llama3_template_eos_token, False, llama3_ollama,)
pass
# =========================================== Phi-3
phi3_template = \
"{{ bos_token }}"\
"{% for message in messages %}"\
"{% if message['role'] == 'user' %}"\
"{{'<|user|>\n' + message['content'] + '<|end|>\n'}}"\
"{% elif message['role'] == 'assistant' %}"\
"{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}"\
"{% else %}"\
"{{'<|' + message['role'] + '|>\n' + message['content'] + '<|end|>\n'}}"\
"{% endif %}"\
"{% endfor %}"\
"{% if add_generation_prompt %}"\
"{{ '<|assistant|>\n' }}"\
"{% endif %}"
pass
# Ollama from https://www.ollama.com/library/phi3
phi3_ollama = \
'''
FROM {__FILE_LOCATION__}
TEMPLATE """{{ if .System }}<|system|>
{{ .System }}<|end|>
{{ end }}{{ if .Prompt }}<|user|>
{{ .Prompt }}<|end|>
{{ end }}<|assistant|>
{{ .Response }}<|end|>
"""
PARAMETER stop "<|end|>"
PARAMETER stop "<|user|>"
PARAMETER stop "<|assistant|>"
'''
phi3_template_eos_token = "<|end|>"
CHAT_TEMPLATES["phi-3"] = (phi3_template, phi3_template_eos_token, False, phi3_ollama,)
pass
def get_chat_template(
tokenizer,
chat_template = "chatml",
mapping = {"role" : "role", "content" : "content", "user" : "user", "assistant" : "assistant"},
map_eos_token = True,
system_message = None,
):
assert(type(map_eos_token) is bool)
old_tokenizer = tokenizer
IS_GEMMA = False
if tokenizer.__class__.__name__.startswith("Gemma"):
if chat_template == "chatml": chat_template = "gemma_chatml"
IS_GEMMA = True
pass
# We add a check for Llama-3
# if chat_template == "llama-3":
# tokenizer._using_llama3_template = True
# else:
# llama3_tokens = set(["<|end_header_id|>", "<|eot_id|>", "<|start_header_id|>"])
# check_llama3_tokens = llama3_tokens & set(str(x) for x in tokenizer.added_tokens_decoder.values())
# if len(check_llama3_tokens) == len(llama3_tokens):
# tokenizer._using_llama3_template = True
# pass
# pass
# We first check if the tokenizer is a fast one. If not, we cannot convert this!
is_fast_tokenizer = getattr(tokenizer, "is_fast", False)
old_padding_side = tokenizer.padding_side
same_padding_token = False
if type(chat_template) in (list, tuple,):
chat_template, stop_word = chat_template
assert(type(chat_template) is str)
assert(type(stop_word) is str)
ollama_modelfile = None
elif type(chat_template) is str:
chat_template, stop_word, yes_map_eos_token, ollama_modelfile = CHAT_TEMPLATES[chat_template]
# Check mapping to eos_token
if not map_eos_token and yes_map_eos_token: map_eos_token = True
if not yes_map_eos_token and map_eos_token: map_eos_token = False
if type(stop_word) in (list, tuple,):
token_mapping, stop_word = stop_word
assert(type(token_mapping) is dict)
else:
token_mapping = None
assert(type(stop_word) is str)
# Check fast tokenizer
if not is_fast_tokenizer:
print(
f"Unsloth: Not a fast tokenizer, so can't process it as of yet :(\n"\
"Please log a Github issue if you want this as a new feature!\n"\
"Your chat template will still work, but it won't add or edit tokens."
)
elif token_mapping is not None:
# token_mapping = {"<start_of_turn>" : "<|im_start|>", "<end_of_turn>" : "<|im_end|>"}
# For Gemma :)
string_vocab = tokenizer._tokenizer.to_str()
skipped = 0
for old_token, new_token in token_mapping.items():
old_count = string_vocab.count(f'"{old_token}"')
new_count = string_vocab.count(f'"{new_token}"')
if new_count != 0:
print(f"{new_token} is already a token. Skipping.")
skipped += 1
elif old_count == 0:
raise RuntimeError(f"{old_token} was not part of the tokenizer!")
else:
string_vocab = string_vocab.replace(f'"{old_token}"', f'"{new_token}"')
pass
pass
if map_eos_token and (not stop_word in token_mapping.values()):
# Do not map 107 = <|im_end|> and 1 = <|im_end|>. This will reduce the vocab size by 1
logger.warning_once(f"Unsloth: Will map {stop_word} to EOS = {tokenizer.eos_token}.")
string_vocab = string_vocab.replace(tokenizer.eos_token, stop_word)
pass
if skipped != len(token_mapping):
new_tokenizer = tokenizer._tokenizer.from_str(string_vocab)
# Careful on pad_token
old_pad_token = tokenizer.pad_token
if old_pad_token == tokenizer.eos_token:
old_pad_token = stop_word
same_padding_token = True
pass
if map_eos_token:
new_tokenizer = tokenizer.__class__(
tokenizer_object = new_tokenizer,
eos_token = stop_word,
pad_token = old_pad_token,
)
else:
new_tokenizer = tokenizer.__class__(
tokenizer_object = new_tokenizer,
pad_token = old_pad_token,
)
pass
# Must fix the sentence piece tokenizer since there's no tokenizer.model file!
tokenizer = fix_sentencepiece_tokenizer(tokenizer, new_tokenizer, token_mapping,)
else:
pass
elif map_eos_token and (stop_word != "eos_token"):
logger.warning_once(f"Unsloth: Will map {stop_word} to EOS = {tokenizer.eos_token}.")
# Replaces the old EOS token with a new one.
# Useful for ChatML <|im_end|> for example.
# Usually we train 2 more tokens <|im_start|> and <|im_end|>
# But training the lm_head and embeddings are slow!
# This is a HACK!
# Idea from https://huggingface.co/cognitivecomputations/dolphin-2.6-mistral-7b-dpo-laser
old_bos_token = getattr(tokenizer, "bos_token", None)
old_eos_token = getattr(tokenizer, "eos_token", None)
old_pad_token = getattr(tokenizer, "pad_token", None)
old_unk_token = getattr(tokenizer, "unk_token", None)
string_vocab = tokenizer._tokenizer.to_str()
# First check if new stop_word is in the tokenizer
if stop_word in string_vocab:
# We shall swap them around
temporary_stop_token = "<|:__TEMP//STOP//TOKEN__:|>"
string_vocab = string_vocab.replace(old_eos_token, temporary_stop_token)
string_vocab = string_vocab.replace(stop_word, old_eos_token)
string_vocab = string_vocab.replace(temporary_stop_token, stop_word)
else:
string_vocab = string_vocab.replace(old_eos_token, stop_word)
pass
new_tokenizer = tokenizer._tokenizer.from_str(string_vocab)
# Careful on pad_token
if old_pad_token == old_eos_token:
old_pad_token = stop_word
same_padding_token = True
pass
new_tokenizer = tokenizer.__class__(
tokenizer_object = new_tokenizer,
bos_token = old_bos_token,
eos_token = stop_word,
unk_token = old_unk_token,
pad_token = old_pad_token,
)
# Must fix the sentence piece tokenizer since there's no tokenizer.model file!
token_mapping = { old_eos_token : stop_word, }
tokenizer = fix_sentencepiece_tokenizer(tokenizer, new_tokenizer, token_mapping,)
pass
else:
raise TypeError(
f"Unsloth: `chat_template` must be a tuple of (your_template, eos_token,) or one of\n"\
f"{CHAT_TEMPLATES.keys()}"
)
pass
# For ShareGPT role -> from and content -> value
chat_template = chat_template\
.replace("'role'", "'" + mapping["role"] + "'")\
.replace("'content'", "'" + mapping["content"] + "'")\
.replace("'user'", "'" + mapping["user"] + "'")\
.replace("'assistant'", "'" + mapping["assistant"] + "'")
# Careful on Gemma
# bos_token is a must or else losses become too high
if IS_GEMMA and not chat_template.startswith("{{ bos_token }}"):
chat_template = "{{ bos_token }}" + chat_template
pass
_, tokenizer = patch_tokenizer(model = None, tokenizer = tokenizer)
tokenizer.padding_side = old_padding_side
tokenizer.chat_template = chat_template
# Also fix up other tokens
old_pad_token = getattr(old_tokenizer, "pad_token", None)
old_bos_token = getattr(old_tokenizer, "bos_token", None)
old_unk_token = getattr(old_tokenizer, "unk_token", None)
new_pad_token = getattr(tokenizer, "pad_token", None)
new_bos_token = getattr(tokenizer, "bos_token", None)
new_unk_token = getattr(tokenizer, "unk_token", None)
if old_bos_token != new_bos_token: tokenizer.bos_token = old_bos_token
if old_unk_token != new_unk_token: tokenizer.unk_token = old_unk_token
if not same_padding_token:
if old_pad_token != new_pad_token: tokenizer.pad_token = old_pad_token
pass
# stopping_criteria = create_stopping_criteria(tokenizer, stop_word)
# Patch saving functions
tokenizer = patch_saving_functions(tokenizer)
# Add Ollama
tokenizer._ollama_modelfile = ollama_modelfile
tokenizer._system_message = system_message
return tokenizer#, stopping_criteria
pass
def remove_special_tokens(tokenizer, prompt):
# Removes double BOS token
if prompt.startswith(tokenizer.bos_token):
prompt = prompt[len(tokenizer.bos_token):]
pass
return prompt
pass
def _parse_combined_prompt(combined_prompt, dataset):
# Find {...}
possible_columns = re.findall(r"\{(.+?)\}", combined_prompt)
dataset_columns = set(dataset.column_names)
for column in possible_columns:
if column not in dataset_columns:
raise KeyError(
f"Unsloth: Your prompt includes '{column}' but this does not exist in the dataset. "\
f"Only allowed columns are {list(dataset_columns)}"
)
pass
pass
# Find [[...]]
optional_prompts = list(re.finditer(r"\[\[.+?\]\]", combined_prompt, flags = re.DOTALL | re.MULTILINE))
optional_prompts = [(x.span(), x.group(0)) for x in optional_prompts]
final_optional_prompts = []
if len(optional_prompts) != 0:
# Add left
left = optional_prompts[0]
l = left[0][0]
if l != 0: final_optional_prompts.append(combined_prompt[:l])
# Add in between
for left, right in zip(optional_prompts[:-1], optional_prompts[1:]):
l, r = left[0][-1], right[0][0]
final_optional_prompts.append(left)
if l != r: final_optional_prompts.append(combined_prompt[l : r])
pass
final_optional_prompts.append(optional_prompts[-1])
# Add right
right = optional_prompts[-1]
r = right[0][1]
if r != len(combined_prompt): final_optional_prompts.append(combined_prompt[r:])
else:
# Just add in the entire string
final_optional_prompts.append(combined_prompt)
pass
check_combined = "".join(x if type(x) is str else x[1] for x in final_optional_prompts)
assert(combined_prompt == check_combined)
return possible_columns, final_optional_prompts
pass
def _create_formatter(possible_columns, final_optional_prompts, user_column_name):
# Start final prompt!
function = ["def __combined_prompt_processor__(examples):"]
columns = list(set(possible_columns))
for column in columns:
function.append(f"{' '*4}{column}__ = examples['{column}']")
function.append(f"{' '*4}texts = []")
function.append(f"{' '*4}for ({', '.join(columns)}) in zip({', '.join(f'{x}__' for x in columns)}):")
# Add optional tags as well!
final_prompt = ""
formatter = []
for j, optional_prompt in enumerate(final_optional_prompts):
if type(optional_prompt) is str:
columns = re.findall(r"\{(.+?)\}", optional_prompt)
formatter += columns
# Must escape \n \r
final_prompt += optional_prompt.encode("unicode-escape").decode("utf-8").replace("'", "\\'").replace('"', '\\"')
else:
where, prompt = optional_prompt
# Strip [[...]]
# Must escape \n \r
prompt = prompt[2:-2].encode("unicode-escape").decode("utf-8").replace("'", "\\'").replace('"', '\\"')
columns = re.findall(r"\{(.+?)\}", prompt)
x = f"__optional_{j}__"
prompt = f"{' '*8}{x} = '{prompt}'.format({', '.join(f'{x} = {x}' for x in columns)}) if {columns[0]} else ''"
function.append(prompt)
formatter.append(x)
final_prompt += "{" + x + "}"
pass
pass
function.insert(1, f"{' '*4}__combined_prompt__ = '{final_prompt}'")
function.append(f"{' '*8}texts.append("\
f"__combined_prompt__.format({', '.join(f'{x} = {x}' for x in formatter)}))")
function.append(f"{' '*4}return " + "{ " + f"'{user_column_name}' : texts" + " }")
return "\n".join(function)
pass
def to_sharegpt(
dataset,
merged_prompt = "",
merged_column_name = "instruction",
output_column_name = "output",
remove_unsued_columns = True,
conversation_extension = 1,
random_state = 3407,
):
"""
Converts a dataset to ShareGPT style.
ShareGPT requires only 1 input and 1 output field.
This means one has to merge multiple columns into 1 for 1 input field.
Use `conversation_extension` to increase the length of each conversation by randomnly
selecting a few and packing them into 1.
merged_prompt = "", Prompt to merge columns into 1 input
merged_column_name = "instruction", Final column name for the input field
output_column_name = "output", Final column name for the output field
remove_unsued_columns = True,
conversation_extension = 1, Automatically combines `conversation_extension` convos into 1
random_state = 3407,
"""
if "conversations" in dataset.column_names:
convo = dataset[0]["conversations"]
if type(convo) is list:
raise TypeError("Unsloth: Your dataset is probably already in ShareGPT format!")
pass
pass
possible_columns, final_optional_prompts = _parse_combined_prompt(merged_prompt, dataset)
function = _create_formatter(possible_columns, final_optional_prompts, merged_column_name)
exec(function, globals())
dataset = dataset.map(__combined_prompt_processor__, batched = True, desc = "Merging columns")
def __convert_to_sharegpt__(examples):
users = examples[merged_column_name]
assistants = examples[output_column_name]
texts = [
[
{"from" : "user", "content" : str(user) },
{"from" : "assistant", "content" : str(assistant)},
] \
for user, assistant in zip(users, assistants)
]
return { "conversations" : texts, }
pass
dataset = dataset.map(
__convert_to_sharegpt__,
batched = True,
desc = "Converting to ShareGPT",
# Remove unsued columns!
remove_columns = dataset.column_names if remove_unsued_columns else None,
)
# Randomnly concat conversations to create a long stream!
from datasets import concatenate_datasets
n_extensions = max(conversation_extension-1, 0)
if n_extensions == 0: return dataset
dataset = dataset.rename_columns({"conversations" : f"conversations0"})
all_shuffled = [dataset]
for j in range(1, n_extensions+1):
shuffled = dataset.shuffle(seed = random_state+j).rename_columns({"conversations0" : f"conversations{j}"})
all_shuffled.append(shuffled)
pass
dataset = concatenate_datasets(all_shuffled, axis = 1)
# Combine them into 1
function = "def __combine_conversations__(examples):\n"
n_extensions += 1
for j in range(n_extensions):
function += f"{' '*4}conversations{j}__ = examples['conversations{j}']\n"
function += f"{' '*4}convos = []\n"
function += f"{' '*4}for ({', '.join(f'conversations{j}' for j in range(n_extensions))}) "\
f"in zip({', '.join(f'conversations{j}__' for j in range(n_extensions))}):\n"
function += f"{' '*8}convos.append("\
f"{'+'.join(f'conversations{j}' for j in range(n_extensions))})\n"
function += f"{' '*4}return " + "{ " + f"'conversations' : convos" + " }"
# Map function
exec(function, globals())
dataset = dataset.map(
__combine_conversations__,
batched = True,
desc = "Extending conversations",
# Remove unsued columns!
remove_columns = dataset.column_names if remove_unsued_columns else None,
)
return dataset
pass
def standardize_sharegpt(
dataset,
aliases_for_system = ["system",],
aliases_for_user = ["user", "human", "input",],
aliases_for_assistant = ["gpt", "assistant", "output",],
):
"""
Standardizes ShareGPT and other formats to user/assistant Hugging Face format.
Get aliases for the system, user and assistant roles.
These shall map to "system", "user" and "assistant" respectively.
aliases_for_system = ["system",],
aliases_for_user = ["user", "human", "input",],
aliases_for_assistant = ["gpt", "assistant", "output",],
"""
import collections
import itertools
convos = dataset[:10]["conversations"]
uniques = collections.defaultdict(list)
for convo in convos:
for message in convo:
for key, value in message.items():
uniques[key].append(value)
pass
# Must be only 2 entries
assert(len(uniques.keys()) == 2)
keys = list(uniques.keys())
length_first = len(set(uniques[keys[0]]))
length_second = len(set(uniques[keys[1]]))
if length_first < length_second:
# Role is assigned to the first element
role_key = keys[0]
content_key = keys[1]
else:
role_key = keys[1]
content_key = keys[0]
pass
# Check roles are in aliases
all_aliases = set(aliases_for_system + aliases_for_user + aliases_for_assistant)
roles = set(uniques[role_key])
leftover_aliases = (all_aliases | roles) - all_aliases
if len(leftover_aliases) != 0:
raise TypeError(
f"Unsloth: {list(leftover_aliases)} are not in aliases. Please update aliases."
)
pass
# Mapping for aliases
aliases_mapping = {}
for x in aliases_for_system: aliases_mapping[x] = "system"
for x in aliases_for_user: aliases_mapping[x] = "user"
for x in aliases_for_assistant: aliases_mapping[x] = "assistant"
def _standardize_dataset(examples):
convos = examples["conversations"]
all_convos = []
for convo in convos:
new_convo = [
{ "role" : aliases_mapping[message[role_key]], "content" : message[content_key], }
for message in convo
]
all_convos.append(new_convo)
pass
return { "conversations" : all_convos, }
pass
return dataset.map(_standardize_dataset, batched = True, desc = "Standardizing format")
pass
def get_ollama_eos_tokens(tokenizer, extra_eos_tokens = []):
added_tokens_decoder = tokenizer.added_tokens_decoder.values()
added_tokens_decoder = [str(x) for x in added_tokens_decoder]
# Remove added_tokens_decoder duplicates
added_tokens_decoder = list(set(added_tokens_decoder) - set(extra_eos_tokens))
# Remove BOS
if getattr(tokenizer, "bos_token", None) is not None:
added_tokens_decoder = [x for x in added_tokens_decoder if x != tokenizer.bos_token]
pass
repeatted_tokens = []
# Join all vocab
joined_text = "\x01\x00".join(added_tokens_decoder)
for token in added_tokens_decoder:
n = len(token)
repeatted_counts = joined_text.count(token[:n//2])
# Try finding longer than 1/2 of the token in the rest
# For eg <|reserved_special_token_0|>, <|reserved_special_token_1|>
if repeatted_counts > 2:
for j in range(n//2+1, n):
if joined_text.count(token[:j]) < repeatted_counts:
j -= 1
# Remove repeatted tokens to reduce search space
joined_text = joined_text.replace(token[:j], "")
repeatted_tokens.append(token[:j])
break
pass
pass
pass
# Remove duplicates
splitted = joined_text.split("\x01\x00")
final_eos_tokens = []
for old, new in zip(added_tokens_decoder, splitted):
if old == new: final_eos_tokens.append(old)
pass
final_eos_tokens += extra_eos_tokens
final_eos_tokens += repeatted_tokens
# Remove new lines, spaces and HTML tags
filtered_eos_tokens = []
for token in final_eos_tokens:
if token.count("\n") == len(token): continue
elif token.count("▁") == len(token): continue
elif token.startswith("<") and len(token) <= 2: continue
elif token.startswith("</") and len(token) == 3: continue
filtered_eos_tokens.append(token)
pass
return filtered_eos_tokens
pass
def construct_chat_template( \
tokenizer = None,
chat_template = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
{SYSTEM}<|eot_id|><|start_header_id|>user<|end_header_id|>
{INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{OUTPUT}<|eot_id|><|start_header_id|>user<|end_header_id|>
{INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{OUTPUT}<|eot_id|>""",
default_system_message = \
"Below are some instructions that describe some tasks. Write responses that appropriately complete each request.",
extra_eos_tokens = None,
):
"""
Creates a Ollama modelfile and a HF Jinja template from a custom
template. You must provide 2x examples of an input & output.
There is an optional system message as well.
You must use {INPUT}, {OUTPUT} twice, and {SYSTEM} is optional.
"""
# Strip only the left
chat_template = chat_template.lstrip()
assert(tokenizer is not None)
if extra_eos_tokens is None: extra_eos_tokens = []
elif type(extra_eos_tokens) is str: extra_eos_tokens = [extra_eos_tokens,]
vocab = tokenizer.get_vocab()
for extra_eos in extra_eos_tokens:
assert(type(extra_eos) is str)
if extra_eos not in vocab:
raise ValueError(f"Unsloth: `{extra_eos}` is not a singular token in the tokenizer.")
pass
pass
error_msg = \
"Unsloth: Your prompt template must have 2 examples showing the user input {INPUT} "\
"and the assistant output {OUTPUT}\n\n"\
"For example what is not allowed is just:\n"\
"### Input:\\n{INPUT}\\n\\n### Response:\\n{OUTPUT}\\n\n\n"\
"What is required is 2x of this:\n"\
"### Input:\\n{INPUT}\\n\\n### Response:\\n{OUTPUT}\\n"\
"### Input:\\n{INPUT}\\n\\n### Response:\\n{OUTPUT}\\n"
# Check for EOS after {OUTPUT}
if tokenizer.eos_token is not None:
extra_eos_tokens.insert(0, tokenizer.eos_token)
if len(extra_eos_tokens) == 0:
raise RuntimeError(
"Unsloth: Your tokenizer does not have an EOS token? Please provide one via extra_eos_tokens!"
)
pass
# Check tokenizer types
tokenizer_name = tokenizer.name_or_path.lower()
if tokenizer_name.startswith(("unsloth/llama-3-8b-instruct", "unsloth/llama-3-70b-instruct")):
# Add <|eot_id|>
extra_eos_tokens.append("<|eot_id|>")
elif ("<|eot_id|>" in extra_eos_tokens or "<|eot_id|>" in chat_template) and \
tokenizer_name.startswith(("unsloth/llama-3-8b", "unsloth/llama-3-70b")):
# Warn
logger.warning(
"Unsloth: Base llama-3 models did not train <|eot_id|>.\n"\
"Please use the instruct version or use <|end_of_text|>"
)
pass
extra_eos_tokens = list(set(extra_eos_tokens))
count_eos = 0
for eos in extra_eos_tokens:
count_eos += len(re.findall(r"{OUTPUT}" + re.escape(eos), chat_template))
pass
# This forces you to provide 2 input and outputs
final_combined_check = False
try:
# O(N^2) search finding 2 repeatted pieces of text
j = len(chat_template)-1
at_least_one = False
while j > 0:
found = chat_template.rfind(chat_template[j:], 0, j)
if found == -1: break
j -= 1
at_least_one = True
pass
if j > 0: j += 1
else: raise RuntimeError(error_msg)
if not at_least_one: raise RuntimeError(error_msg)
# Must be equivalent to left
final_combined_check = True
# Repeatted text
instruction_response = chat_template[j:]
if instruction_response.count("{INPUT}") != 1 or instruction_response.count("{OUTPUT}") != 1:
raise RuntimeError(error_msg)
pass
# 1st System, Instruction, Output pair
left = chat_template[:j]
# 2nd Instruction, Output pair
right = chat_template[j:]
final_combined_check = left if final_combined_check else chat_template
# Isolate input
extra_eos_tokens_regex = "|".join(f"(?:{re.escape(x)})" for x in extra_eos_tokens)
if len(extra_eos_tokens_regex) != 0:
find_end = f"(?:{extra_eos_tokens_regex})?"
else:
find_end = ""
find_end = r"\{INPUT\}[\s\n]{0,}" + find_end
input_end = list(re.finditer(find_end, right))
assert(len(input_end) == 1)
input_end = input_end[0]
input_end = input_end.span(0)[1]
input_part = right[:input_end]
# Isolate output
output_part = right[input_end:]
# Isolate system
where_system = left.find(input_part)
system_part = left[:where_system if where_system != -1 else len(left)]
# Check if the user provided a correct prompt
combined = system_part + input_part + output_part
if combined != final_combined_check:
combined_changed = combined .replace('\n', '\\n')
left_changed = final_combined_check.replace('\n', '\\n')
raise RuntimeError(
"Unsloth: The prompt template you provided isn't correct. You gave:\n"\
f"{combined_changed}\n\n"\
"But we require the following:\n"\
f"{left_changed}"
)
pass
except:
ending = chat_template[chat_template.find("{OUTPUT}") + len("{OUTPUT}"):]
ending = re.escape(ending)
find_text = "{INPUT}" + ending + "(.+?{OUTPUT}" + ending + ")"
response_part = re.findall(find_text, chat_template, flags = re.DOTALL | re.MULTILINE)
response_part = response_part[0]
for j in range(1, len(response_part)):
try_find = re.escape(response_part[:j])
try: found = next(re.finditer("(" + try_find + ").+?\{INPUT\}", chat_template, flags = re.DOTALL | re.MULTILINE))
except: break
pass
separator = found.group(1)
response_start = chat_template.find(response_part)
start_instruction = chat_template[:response_start].rfind(separator)
if start_instruction == -1: start_instruction = 0
instruction_part = chat_template[start_instruction:response_start]
combined = instruction_part + response_part
where = chat_template.find(combined)
system_part = chat_template[:where]
system_part, input_part, output_part = system_part, instruction_part, response_part
pass
if count_eos == 0:
logger.warning("Unsloth: We automatically added an EOS token to stop endless generations.")
eos = extra_eos_tokens[0]
output_part = output_part + eos
pass
# Ollama modelfile parts
# Check bos_token is in system prompt
ollama_system = system_part
has_bos_token = False
always_bos_token = False
if tokenizer("A").input_ids[0] == getattr(tokenizer, "bos_token_id", None):
always_bos_token = True
if ollama_system.startswith(tokenizer.bos_token):
has_bos_token = True
ollama_system = ollama_system[len(tokenizer.bos_token):]
pass
pass
# Check system
if "{SYSTEM}" in ollama_system:
system_modelfile = "{{ if .System }}" + ollama_system.replace("{SYSTEM}", "{{ .System }}") + "{{ end }}"
else:
system_modelfile = ollama_system
pass
input_modelfile = "{{ if .Prompt }}" + input_part .replace("{INPUT}", "{{ .Prompt }}") + "{{ end }}"
output_modelfile = output_part.replace("{OUTPUT}", "{{ .Response }}")
# Ollama EOS
ollama_eos = get_ollama_eos_tokens(tokenizer, extra_eos_tokens)
ollama_eos = '\n'.join(f'PARAMETER stop "{eos}"' for eos in ollama_eos)
# Ollama modelfile
modelfile = 'FROM {__FILE_LOCATION__}\n\n'\
'TEMPLATE """' + system_modelfile + input_modelfile + output_modelfile + \
'"""\n\n' + ollama_eos
# HF Jinja Chat template
def process(part, which, content = "message['content']"):
if part.endswith(which):
part = "'" + part[:part.find(which)] + f"' + {content}"
elif part.startswith(which):
part = f"{content} + '" + part[part.find(which):] + "'"
else:
part = "'" + part.replace(which, f"' + {content} + '") + "'"
if part.startswith("'' + "): part = part[5:]
return part
pass
input_jinja = process(input_part, "{INPUT}")
output_jinja = process(output_part, "{OUTPUT}")
pass
jinja_template = \
"{% for message in loop_messages %}"\
"{% if message['role'] == 'user' %}"\
"{{ " + input_jinja + " }}"\
"{% elif message['role'] == 'assistant' %}"\
"{{ " + output_jinja + " }}"\
"{% else %}"\
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
"{% endif %}"\
"{% endfor %}"\
"{% if add_generation_prompt %}"\
"{{ '" + output_part[:output_part.find("{OUTPUT}")] + "' }}"\
"{% endif %}"
pass
# Now add system prompt to jinja
if len(system_part) != 0:
partial_system = process(system_part, "{SYSTEM}", "messages[0]['content']")
partial_system = partial_system.replace("{SYSTEM}", "")
if "{SYSTEM}" in partial_system:
if default_system_message is None:
raise RuntimeError("Unsloth: Please specify a default system message!")
pass
# Separate the BOS
if has_bos_token:
partial_system = partial_system.replace(tokenizer.bos_token, "", 1)
system_part = system_part .replace(tokenizer.bos_token, "", 1)
pass
partial_system = \
"{% if messages[0]['role'] == 'system' %}"\
"{{ " + partial_system + " }}"\
"{% set loop_messages = messages[1:] %}"
if default_system_message is not None:
full_system = system_part.replace("{SYSTEM}", default_system_message)
if "{SYSTEM}" in system_part:
modelfile += '\nSYSTEM "' + default_system_message + '"'
pass
partial_system += "{% else %}"\
"{{ '" + full_system + "' }}"\
"{% set loop_messages = messages %}"\
"{% endif %}"
else:
partial_system += "{% endif %}"
pass
jinja_template = partial_system + jinja_template
if has_bos_token:
jinja_template = "{{ bos_token }}" + jinja_template
pass
# Fix missing loop_messages
if "{% set loop_messages = messages %}" not in jinja_template:
jinja_template = jinja_template.replace(
"{% for message in loop_messages %}",
"{% for message in messages %}",
1, # Only replace the first one
)
pass
# Check if system part is the same!
jinja_template = re.sub(
r"\{\% if messages\[0\]\['role'\] \=\= 'system' \%\}\{\{ '(.+?)' \}\}"\
r"\{\% set loop\_messages \= messages\[1\:\] \%\}"\
r"\{\% else \%\}\{\{ '\1' \}\}\{\% set loop\_messages \= messages \%\}\{\% endif \%\}"\
r"\{\% for message in loop\_messages \%\}",
r"{{ '\1' }}{% for message in messages %}",
jinja_template, flags = re.MULTILINE | re.DOTALL,
)
# Check jinja tempate for bos
if always_bos_token:
if not jinja_template.startswith("{{ bos_token }}"):
jinja_template = "{{ bos_token }}" + jinja_template
pass
# Get instruction and output parts for train_on_inputs = False
input_part = input_part [:input_part .find("{INPUT}")]
output_part = output_part[:output_part.find("{OUTPUT}")]
return modelfile, jinja_template, input_part, output_part
pass
def test_construct_chat_template():
token = "hf_"
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", token = token)
chat_template = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
{SYSTEM}<|eot_id|><|start_header_id|>user<|end_header_id|>
{INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{OUTPUT}<|eot_id|><|start_header_id|>user<|end_header_id|>
{INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{OUTPUT}<|eot_id|>"""
default_system_message = \
"Below are some instructions that describe some tasks. Write responses that appropriately complete each request."
extra_eos_tokens = None
modelfile, jinja_template, _, _ = construct_chat_template(
tokenizer = tokenizer,
chat_template = chat_template,
extra_eos_tokens = extra_eos_tokens,
)
messages = [
{"role": "system", "content": "You are an assistant"},
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "It's 4."},
{"role": "user", "content": "Ok!"},
{"role": "assistant", "content": "Anything else?"},
{"role": "user", "content": "What's 2x2?"},
]
correct_output = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
tokenizer.chat_template = jinja_template
new_output = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
assert(correct_output == new_output)
pass
pass
def apply_chat_template( \
dataset,
tokenizer = None,
chat_template = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
{SYSTEM}<|eot_id|><|start_header_id|>user<|end_header_id|>
{INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{OUTPUT}<|eot_id|><|start_header_id|>user<|end_header_id|>
{INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{OUTPUT}<|eot_id|>""",
default_system_message = \
"Below are some instructions that describe some tasks. Write responses that appropriately complete each request.",
extra_eos_tokens = None,
):
"""
Creates a Ollama modelfile and a HF Jinja template from a custom
template. You must provide 2x examples of an input & output.
There is an optional system message as well.
You must use {INPUT}, {OUTPUT} twice, and {SYSTEM} is optional.
"""
modelfile, jinja_template, input_part, output_part = construct_chat_template(
tokenizer = tokenizer,
chat_template = chat_template,
default_system_message = default_system_message,
extra_eos_tokens = extra_eos_tokens,
)
def formatting_prompts_func(examples):
convos = examples["conversations"]
texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
return { "text" : texts, }
pass
tokenizer.chat_template = jinja_template
tokenizer._ollama_modelfile = modelfile
tokenizer._unsloth_input_part = input_part
tokenizer._unsloth_output_part = output_part
return dataset.map(formatting_prompts_func, batched = True,)
pass
def train_on_responses_only(
trainer,
instruction_part = None,
response_part = None,
):
"""
Trains only on responses and not on the instruction by masking out
the labels with -100 for the instruction part.
"""
tokenizer = trainer.tokenizer
if not hasattr(tokenizer, "_unsloth_input_part") or \
not hasattr(tokenizer, "_unsloth_output_part"):
if instruction_part is None or response_part is None:
raise ValueError("Unsloth: instruction_part and response_part must be given!")
pass
elif (instruction_part is not None or response_part is not None) and \
(hasattr(tokenizer, "_unsloth_input_part") or hasattr(tokenizer, "_unsloth_output_part")):
raise ValueError("Unsloth: Your tokenizer already has instruction and response parts set - do not give custom ones!")
else:
instruction_part = tokenizer._unsloth_input_part
response_part = tokenizer._unsloth_output_part
pass
instruction_ids = tokenizer(instruction_part, add_special_tokens = False).input_ids
response_ids = tokenizer(response_part, add_special_tokens = False).input_ids
instruction_length = len(instruction_ids)
response_length = len(response_ids)
max_length = max(instruction_length, response_length)
def _train_on_responses_only(examples):
input_ids_ = examples["input_ids"]
all_labels = []
for input_ids in input_ids_:
labels = [-100] * len(input_ids)
m = len(input_ids) - max_length
first_response = response_ids[0]
first_instruction = instruction_ids[0]
j = 0
while j < m:
if input_ids[j] == first_response:
if input_ids[j : j+response_length] == response_ids:
j = j + response_length
start = j
while j < m:
if input_ids[j] == first_instruction and input_ids[j : j+instruction_length] == instruction_ids:
j = j + instruction_length
labels[start : j] = input_ids[start : j]
break
elif j == (m-1):
j = m
labels[start:] = input_ids[start:]
break
pass
j += 1
pass
pass
pass
j += 1
pass
all_labels.append(labels)
pass
return { "labels" : all_labels }
pass
trainer.train_dataset = trainer.train_dataset.map(_train_on_responses_only, batched = True)
return trainer
pass
def create_stopping_criteria(tokenizer, stop_word = "eos_token"):
class StoppingCriteriaSub(StoppingCriteria):
__slots__ = "stop_token", "single_match", "length",
def __init__(self, stops = "eos_token", device = "cuda", encounters = 1):
super().__init__()
if stops == "eos_token":
self.stop_token = torch.tensor(tokenizer.eos_token_id, device = "cuda")
self.length = 1
else:
self.stop_token = tokenizer(["\n" + stops], add_special_tokens = False, return_tensors = "pt")
self.stop_token = self.stop_token.input_ids.ravel()[1:].to("cuda")
self.length = self.stop_token.shape[0]
pass
self.single_match = self.length == 1
pass
def __call__(self, input_ids: LongTensor, scores: FloatTensor) -> bool:
input_ids = input_ids.ravel()
last_token = input_ids[-1]
if self.single_match and (last_token == self.stop_token): return True
if input_ids.shape[0] >= self.length and \
(input_ids[-self.length:] == self.stop_token).all(): return True
return False
pass
pass
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops = stop_word)])
return stopping_criteria
pass
def test_chat_templates():
messages = [
{"role": "system","content": " You are a friendly chatbot.",},
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "It's 4."},
{"role": "user", "content": " But 2+2 is equal to 5. "},
{"role": "assistant", "content": "No I'm sure its 4."},
{"role": "user", "content": " No it's 100% 5! "},
]
# Zephyr
from transformers import AutoTokenizer
template = zephyr_template
correct_tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
correct_tokenizer.chat_template = template
our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
assert(correct_prompt == our_prompt)
# Chatml
template = chatml_template
correct_tokenizer = AutoTokenizer.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B")
correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
correct_tokenizer.chat_template = template
our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
assert(correct_prompt == our_prompt)
# Mistral
template = mistral_template
correct_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
correct_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
correct_tokenizer.chat_template = template
our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
assert(correct_prompt == our_prompt)
# Llama
template = llama_template
correct_tokenizer = AutoTokenizer.from_pretrained("unsloth/llama-2-7b-chat")
correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
correct_tokenizer.chat_template = template
our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
assert(correct_prompt == our_prompt)
# Vicuna
try:
from fastchat.conversation import get_conv_template
except:
os.system("pip -qqq install git+https://github.com/lm-sys/FastChat.git")
from fastchat.conversation import get_conv_template
correct_prompt = get_conv_template("vicuna_v1.1")
for j in range(len(messages)-1):
correct_prompt.append_message(correct_prompt.roles[j%2==1], messages[j+1]["content"])
correct_prompt.append_message(correct_prompt.roles[1], "")
correct_prompt = tokenizer.bos_token + correct_prompt.get_prompt()
template = vicuna_template
correct_tokenizer = AutoTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
correct_tokenizer.chat_template = template
our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
assert(correct_prompt == our_prompt)
try:
from fastchat.conversation import get_conv_template
except:
os.system("pip -qqq install git+https://github.com/lm-sys/FastChat.git")
from fastchat.conversation import get_conv_template
correct_prompt = get_conv_template("zero_shot")
for j in range(len(messages)-1):
correct_prompt.append_message(correct_prompt.roles[j%2==1], messages[j+1]["content"])
correct_prompt.append_message(correct_prompt.roles[1], "")
correct_prompt = tokenizer.bos_token + correct_prompt.get_prompt()
template = vicuna_old_template
correct_tokenizer = AutoTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
correct_tokenizer.chat_template = template
our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
# We add </s> ourselves
assert(correct_prompt == our_prompt.replace("</s>", ""))
# Gemma
correct_tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-7b-it")
correct_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
correct_tokenizer.chat_template = gemma_template
our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
assert(our_prompt == correct_prompt)
# Llama-3
template = llama3_template
correct_tokenizer = AutoTokenizer.from_pretrained("unsloth/llama-3-8b-Instruct")
correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
correct_tokenizer.chat_template = template
our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
assert(correct_prompt == our_prompt)
# Phi-3
template = phi3_template
correct_tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
correct_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
correct_tokenizer.chat_template = template
our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
assert(correct_prompt == our_prompt)
pass
def test_hf_gguf_equivalence(tokenizer, gguf_model = "./model-unsloth.F16.gguf"):
"""
Carefully checks the output of GGUF's tokenization and HF.
Can catch all tokenization bugs.
"""
import subprocess
import re
messages = [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "It's 4."},
{"role": "user", "content": " But 2+2 is equal to 5. "},
{"role": "assistant", "content": "No I'm sure its 4."},
{"role": "user", "content": " No it's 100% 5! "},
]
prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{}
### Input:
{}
### Response:
{}""".format(
"Describe the city given eloquently.", # instruction
"The lost city of Atlantis.", # input
"", # output - leave this blank for generation!
)
prompts = [ prompt, ]
if tokenizer.chat_template is not None:
prompt = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
prompt = prompt.replace("'", "") # Subprocess does not like ''
prompt = remove_special_tokens(tokenizer, prompt)
prompts.append(prompt)
pass
for prompt in prompts:
command = f"./llama.cpp/llama-cli -m {gguf_model} -n 0 --temp 0.0 --verbose-prompt "\
f"--check-tensors -p '{prompt}'"
datas = []
with subprocess.Popen(command, shell = True, stdout = subprocess.PIPE, stderr = subprocess.STDOUT, bufsize = 1) as sp:
for line in sp.stdout:
datas.append(line.decode("utf-8", errors = "replace"))
pass
gguf_tokens = "".join(datas)
# Now extract GGUF tokenization attempt
gguf_tokenized = re.findall("([\d]{1,}) \-\> \'([^\']{1,})\'", gguf_tokens, flags = re.MULTILINE)
gguf_tokenized = [(int(x[0]), x[1],) for x in gguf_tokenized]
input_ids = tokenizer(prompt).input_ids
tokens = tokenizer.batch_decode(input_ids)
hf_tokenized = list(zip(input_ids, tokens))
# Compare to Huggingface
for j, (hf_token, gguf_token) in enumerate(zip(hf_tokenized, gguf_tokenized)):
if (hf_token[0] != gguf_token[0]):
print("Failed GGUF != HF at", j)
print("HF =", hf_token)
print("GGUF =", gguf_token)
print(hf_tokenized)
print()
print(gguf_tokenized)
print()
raise RuntimeError("Failed comparing GGUF to HF.")
pass
pass
return True
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .cross_entropy_loss import fast_cross_entropy_loss
from .rms_layernorm import fast_rms_layernorm
from .rope_embedding import fast_rope_embedding, inplace_rope_embedding
from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel
from .geglu import (
geglu_exact_forward_kernel,
geglu_exact_backward_kernel,
geglu_approx_forward_kernel,
geglu_approx_backward_kernel,
)
from .fast_lora import (
get_lora_parameters,
get_lora_parameters_bias,
apply_lora_mlp_swiglu,
apply_lora_mlp_geglu_exact,
apply_lora_mlp_geglu_approx,
apply_lora_qkv,
apply_lora_o,
)
from .utils import fast_dequantize, fast_gemv, QUANT_STATE, fast_linear_forward, matmul_lora
from .flex_attention import HAS_FLEX_ATTENTION, slow_attention_softcapping
if HAS_FLEX_ATTENTION:
from .flex_attention import (
FLEX_ATTENTION_PADDING,
)
pass
try:
print("🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.")
except:
print("Unsloth: Will patch your computer to enable 2x faster free finetuning.")
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import triton
import triton.language as tl
import torch
from .utils import calculate_settings, MAX_FUSED_SIZE, triton_tanh
from transformers.models.llama.modeling_llama import logger
@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],})
@triton.jit
def _cross_entropy_forward(
logits_ptr, logits_row_stride,
loss_ptr,
logsumexp_ptr,
labels_ptr,
VOCAB_SIZE : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
DO_SOFTCAPPING : tl.constexpr,
SOFTCAP : tl.constexpr,
):
"""
Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
Pi = exp(xi) / sum(exp(xi))
CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ]
= -y [ x - log[sum(exp(x))] ]
= y * (log[sum(exp(x))] - x)
If y == 0: CE_i = 0
If y == 1: CE_i = logsumexp - x
logsumexp is also stable
Take y = log[sum(exp(x))]
exp(y) = sum(exp(x))
exp(y) = sum(exp(x - c)*exp(c)) Since e^(x-c)*e^c = e^x
exp(y) = exp(c)*sum(exp(x - c))
y = log(exp(c)*sum(exp(x - c)))
y = c + log[sum(exp(x - c))]
This means we can set c = max(x) to make sure
exp(x - c) always is exp(x - max(x)).
This ensures exp(x - max(x))'s maximum is 1 as exp(0) = 1.
"""
row_idx = tl.program_id(0)
logits_ptr += row_idx * logits_row_stride.to(tl.int64)
loss_ptr += row_idx
logsumexp_ptr += row_idx
labels_ptr += row_idx
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < VOCAB_SIZE
label_idx = tl.load(labels_ptr).to(tl.int32)
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP)
logits = logits.to(tl.float32)
c = tl.max(logits, 0)
logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
if label_idx != -100:
x = tl.load(logits_ptr + label_idx)
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP)
loss = logsumexp - x.to(tl.float32)
else:
loss = 0.0
tl.store(logsumexp_ptr, logsumexp)
tl.store(loss_ptr, loss)
pass
@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],})
@triton.jit
def _chunked_cross_entropy_forward(
logits_ptr, logits_row_stride,
loss_ptr,
logsumexp_ptr,
labels_ptr,
VOCAB_SIZE : tl.constexpr,
N_CHUNKS : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
DO_SOFTCAPPING : tl.constexpr,
SOFTCAP : tl.constexpr,
):
"""
256K vocab divided in 4 chunks
|-65536-| |-65536-| |-65536-| |-65536-|
|-------| |-------| |-------| |-------|
|-------| |-------| |-------| |-------|
If y == 0: CE_i = 0
If y == 1: CE_i = logsumexp - x
Notice we can do logsumexp for each chunk and then
logsumexp[chunk_sum(logsumexp)] == logsumexp
chunk_sum = log[chunk_sum(logsumexp)]
= log[exp(logsumexp(a)) + ... + exp(logsumexp(z))]
= log[exp(log[sum(exp(a))]) + ... + exp(log[sum(exp(z))])]
= log[sum(exp(a)) + ... + sum(exp(z))]
= logsumexp(x)
This means we can perform a logsumexp for each chunk, then do a
final logsumexp reduction!
Ie do: logsumexp(chunked_logsumexp) - x
"""
row_idx = tl.program_id(0)
chunk_idx = tl.program_id(1)
logits_ptr += row_idx * logits_row_stride.to(tl.int64)
loss_ptr += row_idx
logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx
labels_ptr += row_idx
col_offsets = chunk_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = col_offsets < VOCAB_SIZE
label_idx = tl.load(labels_ptr).to(tl.int32)
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP)
logits = logits.to(tl.float32)
c = tl.max(logits, 0)
logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
if chunk_idx == 0:
# logsumexp(chunked_logsumexp) - x
# Do the -x separately
if label_idx != -100:
x = tl.load(logits_ptr + label_idx).to(tl.float32)
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP)
loss = -1.0 * x.to(tl.float32)
else:
loss = 0.0
tl.store(loss_ptr, loss)
pass
tl.store(logsumexp_ptr, logsumexp)
pass
@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],})
@triton.jit
def _cross_entropy_backward(
logits_ptr, logits_row_stride,
dloss_ptr, dloss_row_stride,
logsumexp_ptr,
labels_ptr,
VOCAB_SIZE : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
DO_SOFTCAPPING : tl.constexpr,
SOFTCAP : tl.constexpr,
):
"""
CE_i = -y log(P) = y * (log[sum(exp(x))] - x)
dC/dx = d/dx (y * log[sum(exp(x))] - x * y)
From https://en.wikipedia.org/wiki/LogSumExp
d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x)
dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y)
dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick
dC/dx = y * exp[x - logsumexp] - d/dx (x * y)
If y == 0: dC/dx = 0
If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1
If y == 1 and x != label: dC/dx = exp[x - logsumexp]
"""
row_idx = tl.program_id(0)
block_idx = tl.program_id(1)
logits_ptr += row_idx * logits_row_stride.to(tl.int64)
dloss_ptr += row_idx * dloss_row_stride
col_offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = col_offsets < VOCAB_SIZE
label_idx = tl.load(labels_ptr + row_idx).to(tl.int32)
if label_idx != -100:
dloss = tl.load(dloss_ptr)
else:
dloss = 0.0
x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING:
# d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
partial = triton_tanh(x / SOFTCAP)
x = SOFTCAP * partial
pass
logsumexp = tl.load(logsumexp_ptr + row_idx)
y = tl.exp(x.to(tl.float32) - logsumexp)
y = tl.where(
col_offsets == label_idx,
y - 1.0, # exp(x - logsumexp) - 1
y, # exp(x - logsumexp)
)
if DO_SOFTCAPPING:
# d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
y = y * (1.0 - partial*partial)
pass
# If y == 0: dC/dx = 0 ==> we already masked it to be = 0, so dloss = 0.
tl.store(logits_ptr + col_offsets, dloss * y, mask = mask)
pass
MAX_FUSED_SIZE = 65536 # 2**16
class Fast_CrossEntropyLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, logits, labels, logit_softcapping = 0):
n_rows, vocab_size = logits.shape
div, mod = divmod(vocab_size, MAX_FUSED_SIZE)
n_chunks = div + (mod != 0)
losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
DO_SOFTCAPPING = (logit_softcapping != 0)
if n_chunks == 1:
# For small vocabs <= 65336 like Llama, Mistral
BLOCK_SIZE, num_warps = calculate_settings(vocab_size)
logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
_cross_entropy_forward[(n_rows,)](
logits, logits.stride(0),
losses,
logsumexp,
labels,
VOCAB_SIZE = vocab_size,
BLOCK_SIZE = BLOCK_SIZE,
DO_SOFTCAPPING = DO_SOFTCAPPING,
SOFTCAP = logit_softcapping,
num_warps = num_warps,
)
else:
# For large vocabs > 65336 like Gemma 256K
logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = "cuda:0")
_chunked_cross_entropy_forward[(n_rows, n_chunks,)](
logits, logits.stride(0),
losses,
logsumexp,
labels,
VOCAB_SIZE = vocab_size,
N_CHUNKS = n_chunks,
BLOCK_SIZE = MAX_FUSED_SIZE,
DO_SOFTCAPPING = DO_SOFTCAPPING,
SOFTCAP = logit_softcapping,
num_warps = 32,
)
# logsumexp(chunked_logsumexp) - x
# Do the -x separately
logsumexp = torch.logsumexp(logsumexp, dim = 1) # Row sum
losses += logsumexp
losses.masked_fill_(labels == -100, 0) # Don't forget to mask padding out!
pass
ctx.save_for_backward(logits, logsumexp, labels)
ctx.DO_SOFTCAPPING = DO_SOFTCAPPING
ctx.logit_softcapping = logit_softcapping
return losses
pass
@staticmethod
def backward(ctx, dlosses):
logits, logsumexp, labels = ctx.saved_tensors
n_rows, vocab_size = logits.shape
BLOCK_SIZE = 4096
div, mod = divmod(vocab_size, BLOCK_SIZE)
n_blocks = div + (mod != 0)
_cross_entropy_backward[(n_rows, n_blocks,)](
logits, logits.stride(0),
dlosses, dlosses.stride(0),
logsumexp,
labels,
VOCAB_SIZE = vocab_size,
BLOCK_SIZE = BLOCK_SIZE,
DO_SOFTCAPPING = ctx.DO_SOFTCAPPING,
SOFTCAP = ctx.logit_softcapping,
num_warps = 8,
)
return logits, None, None,
pass
pass
def fast_cross_entropy_loss(logits, labels, logit_softcapping = 0):
"""
Arguments:
logits: (batch, seq_len, vocab_size)
labels: (batch, seq_len,)
Returns:
losses: float
"""
batch, seq_len, d = logits.shape
assert(labels.shape == (batch, seq_len))
loss = Fast_CrossEntropyLoss.apply(
logits.view(batch*seq_len, d),
labels.view(-1),
logit_softcapping,
)
n_items = torch.count_nonzero(labels != -100)
return loss.sum() / n_items
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from .utils import (
fast_dequantize,
QUANT_STATE,
get_lora_parameters,
get_lora_parameters_bias,
matmul_lora,
torch_amp_custom_fwd,
torch_amp_custom_bwd,
)
class LoRA_MLP(torch.autograd.Function):
"""
### LoRA weights
G = G + Ag @ Bg
U = U + Au @ Bu
W = W + Aw @ Bw
### SwiGLU(X)
e = X @ G
f = e * sigmoid(e)
g = X @ U
h = f * g
i = h @ W
### Backpropagation chain rule
See our blog post for more details
df = sigmoid(e) * (1 - f) + f
dC/dW = h.T @ dY
dC/dU = X.T @ (D @ W.T * f)
dC/dG = X.T @ (D @ W.T * df * g)
### Down projection LoRA weights
dC/dAw = dC/dW @ B.T
dC/dBw = A.T @ dC/dW
dC/dAw = h.T @ dY @ B.T
dC/dBw = A.T @ h.T @ dY
### Up projection LoRA weights
dC/dAu = X.T @ (D @ W.T * f) @ B.T
dC/dBu = A.T @ X.T @ (D @ W.T * f)
### Gate projection LoRA weights
dC/dAg = X.T @ (D @ W.T * df * g) @ B.T
dC/dBg = A.T @ X.T @ (D @ W.T * df * g)
Don't forget to see our blog post for more details!
"""
@staticmethod
@torch_amp_custom_fwd
def forward(ctx, X : torch.Tensor,
gateW, gateW_quant, gateA, gateB, gateS,
upW, upW_quant, upA, upB, upS,
downW, downW_quant, downA, downB, downS,
_forward_function, _backward_function,):
dtype = X.dtype
e = matmul_lora(X, gateW, gateW_quant, gateA, gateB, gateS)
g = matmul_lora(X, upW, upW_quant, upA, upB, upS)
h = _forward_function(e, g)
i = matmul_lora(h, downW, downW_quant, downA, downB, downS)
ctx.custom_saved_tensors = (
gateW, gateW_quant, gateS,
upW, upW_quant, upS,
downW, downW_quant, downS,
_backward_function,
)
ctx.save_for_backward(gateA, gateB, upA, upB, downA, downB,
X, e, g)
return i
pass
@staticmethod
@torch_amp_custom_bwd
def backward(ctx, dY : torch.Tensor):
gateW, gateW_quant, gateS, upW, upW_quant, upS, downW, downW_quant, downS, \
_backward_function = ctx.custom_saved_tensors
gateA, gateB, upA, upB, downA, downB, \
X, e, g = ctx.saved_tensors
gateA, gateB, upA, upB, downA, downB = \
gateA.t(), gateB.t(), upA.t(), upB.t(), downA.t(), downB.t()
batch, seq_len, hd = X.shape
dY = dY.view(-1, dY.shape[-1])
X = X .view(-1, X .shape[-1])
e = e .view(-1, e .shape[-1])
g = g .view(-1, g .shape[-1])
dtype = X.dtype
DW = matmul_lora(dY, downW.t(), downW_quant, downB, downA, downS)
DW, e, g = _backward_function(DW, e, g)
h, df, de = DW, e, g
# Down projection LoRA weights
d_downA = h.t() @ (dY @ downB.t())
d_downB = (downA.t() @ h.t()) @ dY
d_downA *= downS
d_downB *= downS
# Up projection LoRA weights
d_upA = X.t() @ (df @ upB.t())
d_upB = (upA.t() @ X.t()) @ df
d_upA *= upS
d_upB *= upS
# Gate projection LoRA weights
d_gateA = X.t() @ (de @ gateB.t())
d_gateB = (gateA.t() @ X.t()) @ de
d_gateA *= gateS
d_gateB *= gateS
# dX = matmul_lora(df, upW.t(), upW_quant, upB, upA, upS)
# dX += matmul_lora(de, gateW.t(), gateW_quant, gateB, gateA, gateS)
upW = fast_dequantize(upW.t(), upW_quant)
dX = torch.matmul(df, upW.t(), out = X)
del upW
dX += df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t())
gateW = fast_dequantize(gateW.t(), gateW_quant)
dX += de @ gateW.t()
del gateW
dX += de @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t())
# gateW, gateW_quant, gateA, gateB, gateS,
# upW, upW_quant, upA, upB, upS,
# downW, downW_quant, downA, downB, downS,
return dX.view(batch, seq_len, hd), \
None, None, d_gateA.t(), d_gateB.t(), None, \
None, None, d_upA.t(), d_upB.t(), None, \
None, None, d_downA.t(), d_downB.t(), None, \
None, None, # _backward and _forward
pass
pass
from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel
def apply_lora_mlp_swiglu(self, X):
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
out = LoRA_MLP.apply(X,
gateW, gateW_quant, gateA, gateB, gateS,
upW, upW_quant, upA, upB, upS,
downW, downW_quant, downA, downB, downS,
swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel,)
return out
pass
from .geglu import geglu_exact_forward_kernel, geglu_exact_backward_kernel
def apply_lora_mlp_geglu_exact(self, X):
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
out = LoRA_MLP.apply(X,
gateW, gateW_quant, gateA, gateB, gateS,
upW, upW_quant, upA, upB, upS,
downW, downW_quant, downA, downB, downS,
geglu_exact_forward_kernel, geglu_exact_backward_kernel,)
return out
pass
from .geglu import geglu_approx_forward_kernel, geglu_approx_backward_kernel
def apply_lora_mlp_geglu_approx(self, X):
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
out = LoRA_MLP.apply(X,
gateW, gateW_quant, gateA, gateB, gateS,
upW, upW_quant, upA, upB, upS,
downW, downW_quant, downA, downB, downS,
geglu_approx_forward_kernel, geglu_approx_backward_kernel,)
return out
pass
class LoRA_QKV(torch.autograd.Function):
"""
### LoRA weights
Wq = Wq + Aq @ Bq
Wk = Wk + Ak @ Bk
Wv = Wv + Av @ Bv
Q = X @ Wq = X @ Wq + X @ Aq @ Bq
K = X @ Wk = X @ Wk + X @ Ak @ Bk
V = X @ Wv = X @ Wv + X @ Av @ Bv
### Backpropagation chain rule
See our blogpost for more details.
dC/dWq = X.T @ D(Wq)
dC/dWk = X.T @ D(Wk)
dC/dWv = X.T @ D(Wv)
We then sum them all find dC/dX
### Q projection LoRA weights
dC/dAq = X.T @ D(Wq) @ B.T
dC/dBq = A.T @ X.T @ D(Wq)
### K projection LoRA weights
dC/dAk = X.T @ D(Wk) @ B.T
dC/dBk = A.T @ X.T @ D(Wk)
### V projection LoRA weights
dC/dAv = X.T @ D(Wv) @ B.T
dC/dBv = A.T @ X.T @ D(Wv)
"""
@staticmethod
@torch_amp_custom_fwd
def forward(ctx, X : torch.Tensor,
QW, QW_quant, QA, QB, QS,
KW, KW_quant, KA, KB, KS,
VW, VW_quant, VA, VB, VS,):
dtype = X.dtype
Q = matmul_lora(X, QW, QW_quant, QA, QB, QS)
K = matmul_lora(X, KW, KW_quant, KA, KB, KS)
V = matmul_lora(X, VW, VW_quant, VA, VB, VS)
ctx.custom_saved_tensors = (
QW, QW_quant, QS,
KW, KW_quant, KS,
VW, VW_quant, VS,
)
ctx.save_for_backward(X, QA, QB, KA, KB, VA, VB,)
return Q, K, V
pass
@staticmethod
@torch_amp_custom_bwd
def backward(ctx, dQ, dK, dV):
QW, QW_quant, QS, KW, KW_quant, KS, VW, VW_quant, VS = \
ctx.custom_saved_tensors
X, QA, QB, KA, KB, VA, VB, = ctx.saved_tensors
QA, QB, KA, KB, VA, VB = \
QA.t(), QB.t(), KA.t(), KB.t(), VA.t(), VB.t()
batch, seq_len, hd = X.shape
dQ = dQ.view(-1, dQ.shape[-1])
dK = dK.reshape(-1, dK.shape[-1]) # view doesn't work on K.T
dV = dV.view(-1, dV.shape[-1])
X = X .view(-1, X .shape[-1])
dtype = X.dtype
### Weight projection LoRA weights
# See our blogpost for more details.
# Q Projection
d_QA = X.t() @ (dQ @ QB.t())
d_QB = (QA.t() @ X.t()) @ dQ
d_QA *= QS
d_QB *= QS
# K Projection
d_KA = X.t() @ (dK @ KB.t())
d_KB = (KA.t() @ X.t()) @ dK
d_KA *= KS
d_KB *= KS
# V Projection
d_VA = X.t() @ (dV @ VB.t())
d_VB = (VA.t() @ X.t()) @ dV
d_VA *= VS
d_VB *= VS
# Combine derivatives to find dX
# dQ
QW = fast_dequantize(QW.t(), QW_quant)
dX = torch.matmul(dQ, QW.t(), out = X)
del QW
dX += (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t()))
# dK
KW = fast_dequantize(KW.t(), KW_quant)
dX += dK @ KW.t()
del KW
dX += dK @ KB.to(dtype).t() @ (KS * KA.to(dtype).t())
# dV
VW = fast_dequantize(VW.t(), VW_quant)
dX += dV @ VW.t()
del VW
dX += dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t())
# QW, QW_quant, QA, QB, QS,
# KW, KW_quant, KA, KB, KS,
# VW, VW_quant, VA, VB, VS,
return dX.view(batch, seq_len, hd), \
None, None, d_QA.t(), d_QB.t(), None, \
None, None, d_KA.t(), d_KB.t(), None, \
None, None, d_VA.t(), d_VB.t(), None
pass
pass
def apply_lora_qkv(self, X):
QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
Q, K, V = LoRA_QKV.apply(X,
QW, QW_quant, QA, QB, QS,
KW, KW_quant, KA, KB, KS,
VW, VW_quant, VA, VB, VS,
)
return Q, K, V
pass
class LoRA_W(torch.autograd.Function):
"""
### LoRA weights
Wq = Wq + Aq @ Bq
Wk = Wk + Ak @ Bk
Wv = Wv + Av @ Bv
Q = X @ Wq = X @ Wq + X @ Aq @ Bq
K = X @ Wk = X @ Wk + X @ Ak @ Bk
V = X @ Wv = X @ Wv + X @ Av @ Bv
### Backpropagation chain rule
dC/dWq = X.T @ D(Wq)
dC/dWk = X.T @ D(Wk)
dC/dWv = X.T @ D(Wv)
### Q projection LoRA weights
dC/dAq = X.T @ D(Wq) @ B.T
dC/dBq = A.T @ X.T @ D(Wq)
### K projection LoRA weights
dC/dAk = X.T @ D(Wk) @ B.T
dC/dBk = A.T @ X.T @ D(Wk)
### V projection LoRA weights
dC/dAv = X.T @ D(Wv) @ B.T
dC/dBv = A.T @ X.T @ D(Wv)
"""
@staticmethod
@torch_amp_custom_fwd
def forward(ctx, X : torch.Tensor,
W, W_quant, A, B, S):
dtype = X.dtype
XW = matmul_lora(X, W, W_quant, A, B, S)
ctx.custom_saved_tensors = (W, W_quant, S,)
ctx.save_for_backward(A, B, X)
return XW
pass
@staticmethod
@torch_amp_custom_bwd
def backward(ctx, dY : torch.Tensor):
W, W_quant, S = ctx.custom_saved_tensors
A, B, X = ctx.saved_tensors
A, B = A.t(), B.t()
batch, seq_len, hd = X.shape
dY = dY.reshape(-1, dY.shape[-1]) # Must be reshape
X = X .reshape(-1, X .shape[-1]) # Must be reshape
dtype = X.dtype
### Weight projection LoRA weights
# Weight projection
d_A = X.t() @ (dY @ B.t())
d_B = (A.t() @ X.t()) @ dY
d_A *= S
d_B *= S
# Get derivative for dX
W = fast_dequantize(W.t(), W_quant)
dX = dY @ W.t()
del W
dX += dY @ B.to(dtype).t() @ (S * A.to(dtype).t())
# W, W_quant, A, B, S
return dX.view(batch, seq_len, hd), \
None, None, d_A.t(), d_B.t(), None
pass
pass
def apply_lora_o(self, X):
OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS)
return O
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from functools import lru_cache
from transformers.models.llama.modeling_llama import logger
torch_compile_options = {
"epilogue_fusion" : True,
"max_autotune" : True,
"shape_padding" : True,
"trace.enabled" : False, # Output Triton kernel outputs!
"triton.cudagraphs" : False,
}
# Flex Attention supported from torch 2.5 onwards only
import torch.nn
if hasattr(torch.nn, "attention"):
import torch.nn.attention
if hasattr(torch.nn.attention, "flex_attention"):
import torch.nn.attention.flex_attention
from torch.nn.attention.flex_attention import flex_attention
from torch.nn.attention.flex_attention import create_block_mask
FLEX_ATTENTION_PADDING = getattr(
torch.nn.attention.flex_attention,
"_DEFAULT_SPARSE_BLOCK_SIZE",
1,
)
flex_attention = torch.compile(flex_attention, dynamic = False)
HAS_FLEX_ATTENTION = True
else:
HAS_FLEX_ATTENTION = False
pass
else:
HAS_FLEX_ATTENTION = False
pass
# Logit softcapping
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
n_heads = self.num_heads
head_dim = self.head_dim
n_kv_heads = self.num_key_value_heads
n_groups = self.num_key_value_groups
# Grouped query attention
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
K = K.reshape(bsz, n_heads, q_len, head_dim)
V = V.reshape(bsz, n_heads, q_len, head_dim)
# See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
# Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below
# We default to using the config file itself
# s = self.config.hidden_size // self.config.num_attention_heads
s = self.config.query_pre_attn_scalar
t = self.config.attn_logit_softcapping
Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly
A = torch.matmul(Q, K.transpose(2, 3))
A = t * torch.tanh(A / t) # Logit softcapping
A += causal_mask[:q_len, :q_len]
# Much slower in torch compile!
# A.masked_fill_(causal_mask[:q_len, :q_len], -float("inf"))
A = torch.nn.functional.softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype)
A = torch.matmul(A, V)
A = A.transpose(1, 2).contiguous()
A = A.reshape(bsz, q_len, n_heads*head_dim)
return A
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import triton
import triton.language as tl
import torch
from .utils import calculate_settings, triton_tanh
@triton.jit
def _exact_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
block_idx = tl.program_id(0)
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
# f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
# h = f * up
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
f_row = 0.5 * e_row * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0)
f_row = f_row.to(g_row.dtype) # Exact copy from HF
h_row = f_row * g_row
# Store h
tl.store(h + offsets, h_row, mask = mask)
pass
def geglu_exact_forward_kernel(gate, up):
batch, seq_len, hd = gate.shape
n_elements = gate.numel()
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda:0")
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_exact_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
return out
pass
@triton.jit
def _exact_backward_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
"""
f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
h = f * up
df/de (with help of Wolfram :)
df/de = 1/2 * (1 + erf(1/sqrt(2) * e)) + 1/sqrt(2*pi) * e * exp(-1/2 * e^2)
Reuse via
f = 1/2 * (1 + erf(1/sqrt(2) * e)) * e
"""
block_idx = tl.program_id(0)
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
# Break e_row away for re-use
# f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
f_partial_row = 0.5 * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0)
f_row = f_partial_row * e_row
f_row = f_row.to(DW_row.dtype)
# h = f * g
h_row = f_row * g_row
# df = DW * f
df_row = DW_row * f_row
# dg = DW * g
dg_row = DW_row * g_row
# df/de = 1/2 * (1 + erf(1/sqrt(2) * e)) + 1/sqrt(2*pi) * e * exp(-1/2 * e^2)
t = 0.3989422804014327 # 1/sqrt(2*pi)
df_de = f_partial_row + t * e_row * tl.exp(-0.5 * e_row * e_row)
de_row = dg_row.to(tl.float32) * df_de
de_row = de_row.to(DW_row.dtype)
# Store derivatives in buffers
tl.store(DW + offsets, h_row, mask = mask) # h = f * g
tl.store(e + offsets, df_row, mask = mask) # df = DW * f
tl.store(g + offsets, de_row, mask = mask) # de
pass
def geglu_exact_backward_kernel(DW, e, g):
batch_seq_len, hd = e.shape
n_elements = e.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_exact_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
return DW, e, g
pass
@triton.jit
def _approx_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
block_idx = tl.program_id(0)
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
# f = 1/2 * e * (1 + tanh( sqrt(2/pi) * (x + 0.044715 * x^3 ) ))
# f = 1/2 * e * (1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ))
# h = f * up
s = 0.7978845608028654 # math.sqrt(2 / math.pi)
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
f_row = 0.5 * e_row * (
triton_tanh(s * e_row * (1.0 + 0.044715 * e_row * e_row)) \
+ 1.0
)
f_row = f_row.to(g_row.dtype) # Exact copy from HF
h_row = f_row * g_row
# Store h
tl.store(h + offsets, h_row, mask = mask)
pass
def geglu_approx_forward_kernel(gate, up):
batch, seq_len, hd = gate.shape
n_elements = gate.numel()
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda:0")
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_approx_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
return out
pass
@triton.jit
def _approx_backward_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
"""
f = 1/2 * e * (1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ))
h = f * up
df/de (with help from https://arxiv.org/pdf/2305.12073.pdf :))
df/de = 1/2 * [1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) )] +
1/2 * sech^2 [ sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ] * \
( sqrt(2/pi) * x * (1 + 0.044715 * x^2 * 3 ) )
Notice sech^2(x) = 1 - tanh^2(x)
So reuse tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) )
See https://www.desmos.com/calculator/nqprfoni6x
"""
block_idx = tl.program_id(0)
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
# See https://www.desmos.com/calculator/nqprfoni6x
s = 0.7978845608028654 # math.sqrt(2 / math.pi)
a = s * e_row # a = sqrt(2 / pi) * x
b = a * 0.044715 * e_row * e_row # b = a * 0.044715 * x^2
T = 1.0 + triton_tanh(a + b)
T2 = 0.5 * T
# Q = 0.5 * -T * (T - 2.0) * (a + 3.0 * b)
Q2 = -T2 * (T - 2.0) * (a + 3.0 * b)
df_de = T2 + Q2 # 1/2 * (T + Q)
# f = 1/2 * e * (1 + tanh( sqrt(2/pi) * (x + 0.044715 * x^3 ) ))
f_row = T2 * e_row
f_row = f_row.to(DW_row.dtype)
# h = f * g
h_row = f_row * g_row
# df = DW * f
df_row = DW_row * f_row
# dg = DW * g
dg_row = DW_row * g_row
de_row = dg_row.to(tl.float32) * df_de
de_row = de_row.to(DW_row.dtype)
# Store derivatives in buffers
tl.store(DW + offsets, h_row, mask = mask) # h = f * g
tl.store(e + offsets, df_row, mask = mask) # df = DW * f
tl.store(g + offsets, de_row, mask = mask) # de
pass
def geglu_approx_backward_kernel(DW, e, g):
batch_seq_len, hd = e.shape
n_elements = e.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_approx_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
return DW, e, g
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import triton
import triton.language as tl
import torch
from .utils import calculate_settings
@triton.jit
def _rms_layernorm_forward(
Y, Y_row_stride,
X, X_row_stride,
W, W_row_stride,
r, r_row_stride,
n_cols, eps,
BLOCK_SIZE : tl.constexpr
):
"""
Fast RMS Layernorm kernel
Inspiration from a Triton tutorial:
https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
"""
row_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
Y += row_idx * Y_row_stride
X += row_idx * X_row_stride
r += row_idx * r_row_stride
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
W_row = tl.load(W + col_offsets, mask = mask, other = 0)#.to(tl.float32)
row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
inv_var = tl.math.rsqrt(row_var + eps)
tl.store(r, inv_var)
normed = X_row * inv_var
normed = normed.to(W_row.dtype) # Exact copy from HF
output = normed * W_row
tl.store(Y + col_offsets, output, mask = mask)
pass
@triton.heuristics({"GEMMA": lambda args: args["GEMMA"],})
@triton.jit
def _rms_layernorm_backward(
dY, dY_row_stride,
X, X_row_stride,
W, W_row_stride,
r, r_row_stride,
dW, dW_row_stride,
n_cols, eps,
GEMMA : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
):
"""
Fast RMS Layernorm kernel for the backward pass
Inspiration from a Triton tutorial:
https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
"""
row_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
dY += row_idx * dY_row_stride
X += row_idx * X_row_stride
r += row_idx * r_row_stride
dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
# Get saved row variance
inv_var = tl.load(r).to(tl.float32)
normed = X_row * inv_var
if GEMMA: dY_W = dY_row * (W_row + 1.0)
else: dY_W = dY_row * W_row
rowsum_dY_normed = tl.sum(dY_W * normed, axis = 0)
output = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed)
tl.store(dY + col_offsets, output, mask = mask)
pass
@triton.jit
def _gemma_rms_layernorm_forward(
Y, Y_row_stride,
X, X_row_stride,
W, W_row_stride,
r, r_row_stride,
n_cols, eps,
BLOCK_SIZE : tl.constexpr,
):
# Copies https://github.com/google-deepmind/gemma/blob/main/gemma/layers.py#L31
# and https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L33
# exactly. Essentially all in float32!
row_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
Y += row_idx * Y_row_stride
X += row_idx * X_row_stride
r += row_idx * r_row_stride
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
inv_var = tl.math.rsqrt(row_var + eps)
tl.store(r, inv_var)
normed = X_row * inv_var
output = normed * (W_row + 1.0)
tl.store(Y + col_offsets, output, mask = mask)
pass
class Fast_RMS_Layernorm(torch.autograd.Function):
@staticmethod
def forward(ctx, X, W, eps, gemma = False):
shape = X.shape
dim = shape[-1]
X = X.view(-1, dim)
n_rows, n_cols = X.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda:0")
r = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward
fx[(n_rows,)](
Y, Y.stride(0),
X, X.stride(0),
W, W.stride(0),
r, r.stride(0),
n_cols, eps,
BLOCK_SIZE = BLOCK_SIZE,
num_warps = num_warps,
)
ctx.eps = eps
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.GEMMA = gemma
ctx.save_for_backward(X, W, r)
return Y.view(*shape)
pass
@staticmethod
def backward(ctx, dY):
shape = dY.shape
dim = shape[-1]
dY = dY.view(-1, dim)
X, W, r = ctx.saved_tensors
n_rows, n_cols = dY.shape
dW = X
_rms_layernorm_backward[(n_rows,)](
dY, dY.stride(0),
X, X .stride(0),
W, W .stride(0),
r, r .stride(0),
dW, dW.stride(0),
n_cols, ctx.eps,
GEMMA = ctx.GEMMA,
BLOCK_SIZE = ctx.BLOCK_SIZE,
num_warps = ctx.num_warps,
)
dX = dY.view(*shape)
return dX, None, None, None
pass
pass
def fast_rms_layernorm(layernorm, X, gemma = False):
W = layernorm.weight
eps = layernorm.variance_epsilon
out = Fast_RMS_Layernorm.apply(X, W, eps, gemma)
return out
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import triton
import triton.language as tl
import torch
from .utils import calculate_settings
ROPE_GROUP_SIZE = 4
@triton.heuristics({"BACKWARD_PASS": lambda args: args["BACKWARD_PASS"],})
@triton.jit
def _rope_embedding(
Q, Q_row_stride,
cos, cos_row_stride,
sin, sin_row_stride,
seqlen,
head_dim : tl.constexpr,
n_heads : tl.constexpr,
BACKWARD_PASS : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
):
"""
Calculates the RoPE Embedding quickly
RoPE is Q * cos + rotate_half(Q) * sin
See our blog post for more info
"""
ROPE_GROUP_SIZE = 4
row_position = tl.program_id(0)
group_head_position = tl.program_id(1)
col_offsets = tl.arange(0, BLOCK_SIZE)
half_head_dim = head_dim // 2
mask = col_offsets < half_head_dim
sin1 = tl.load(sin + (row_position % seqlen)*sin_row_stride + \
half_head_dim*0 + col_offsets, mask = mask, other = 0)
cos1 = tl.load(cos + (row_position % seqlen)*cos_row_stride + \
half_head_dim*0 + col_offsets, mask = mask, other = 0)
if BACKWARD_PASS:
# See our blog post for more info.
sin1 = -sin1
pass
# [TODO] Autotune ROPE_GROUP_SIZE to be 1, 2, 4, 8
head_start = group_head_position * ROPE_GROUP_SIZE
head_end = min((head_start + ROPE_GROUP_SIZE), n_heads)
# 10% Faster kernel from [HuyNguyen-hust](https://github.com/unslothai/unsloth/pull/238)
for k in range(head_start, head_end):
offs_q1 = row_position * Q_row_stride + k * head_dim + col_offsets
offs_q2 = row_position * Q_row_stride + k * head_dim + col_offsets + half_head_dim
# For Gemma - sometimes RoPE must be done in float32 and not bfloat16
Q1 = tl.load(Q + offs_q1, mask = mask, other = 0).to(sin1.dtype)
Q2 = tl.load(Q + offs_q2, mask = mask, other = 0).to(sin1.dtype)
tl.store(Q + offs_q1, Q1*cos1 - Q2*sin1, mask = mask)
tl.store(Q + offs_q2, Q2*cos1 + Q1*sin1, mask = mask)
pass
pass
class Fast_RoPE_Embedding(torch.autograd.Function):
@staticmethod
def forward(ctx, Q, cos, sin):
cos, sin = cos.squeeze(), sin.squeeze()
batch, seq_len, n_heads, head_dim = Q.shape
Q = Q.view(batch*seq_len, n_heads*head_dim)
n_rows, n_cols = Q.shape
assert(seq_len <= cos.shape[0])
# [TODO] Changing blocksize to head_dim//2 seems to have
# some concurrency / un-deterministic issues.
BLOCK_SIZE, num_warps = calculate_settings(head_dim//2) # (head_dim//2)
# group_size = 4 # 4 or 8, too large group_size can hurt performance.
div, mod = divmod(n_heads, ROPE_GROUP_SIZE)
n_groups = div + (mod != 0)
_rope_embedding[(n_rows, n_groups, )](
Q, Q.stride(0),
cos, cos.stride(0),
sin, sin.stride(0),
seq_len,
head_dim, n_heads,
BACKWARD_PASS = False,
BLOCK_SIZE = BLOCK_SIZE,
num_warps = num_warps,
)
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.n_groups = n_groups
ctx.cos = cos
ctx.sin = sin
return Q.view(batch, seq_len, n_heads, head_dim)
pass
@staticmethod
def backward(ctx, dY):
batch, seq_len, n_heads, head_dim = dY.shape
dY = dY.reshape(batch*seq_len, n_heads*head_dim)
# Must be reshape not view
n_rows, n_cols = dY.shape
cos = ctx.cos
sin = ctx.sin
_rope_embedding[(n_rows, ctx.n_groups, )](
dY, dY .stride(0),
cos, cos.stride(0),
sin, sin.stride(0),
seq_len, head_dim, n_heads,
BACKWARD_PASS = True,
BLOCK_SIZE = ctx.BLOCK_SIZE,
num_warps = ctx.num_warps,
)
dY = dY.view(batch, seq_len, n_heads, head_dim)
return dY, None, None,
pass
pass
def fast_rope_embedding(Q, K, cos, sin):
Q = Fast_RoPE_Embedding.apply(Q.transpose(1, 2), cos, sin).transpose(1, 2)
K = Fast_RoPE_Embedding.apply(K.transpose(1, 2), cos, sin).transpose(1, 2)
return Q, K
pass
class Slow_RoPE_Embedding(torch.autograd.Function):
@staticmethod
def forward(ctx, Q, cos, sin, position_ids):
if position_ids is not None:
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
# Q * cos + rotate_half(Q) * sin
half = Q.shape[-1]//2
RH_Q = torch.cat((-Q[..., half:], Q[..., :half]), dim = -1)
Q *= cos
Q.addcmul_(RH_Q, sin)
# RH_Q *= sin
# Q += RH_Q
ctx.save_for_backward(cos, sin)
return Q
pass
@staticmethod
def backward(ctx, dY):
cos, sin = ctx.saved_tensors
# Q * cos + rotate_half.T(Q) * sin
half = dY.shape[-1]//2
RH_dY = torch.cat((dY[..., half:], -dY[..., :half]), dim = -1)
dY *= cos
dY.addcmul_(RH_dY, sin)
# RH_dY *= sin
# dY += RH_dY
return dY, None, None, None
pass
pass
def inplace_rope_embedding(Q, K, cos, sin, position_ids):
Q = Slow_RoPE_Embedding.apply(Q, cos, sin, position_ids)
K = Slow_RoPE_Embedding.apply(K, cos, sin, position_ids)
return Q, K
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import triton
import triton.language as tl
import torch
from .utils import calculate_settings
@triton.jit
def _fg_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
block_idx = tl.program_id(0)
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
# f = e * sigmoid(e)
f_row = e_row * tl.sigmoid(e_row) # e_row / (1 + tl.exp(-e_row))
f_row = f_row.to(g_row.dtype) # Exact copy from HF
# h = f * g
h_row = f_row * g_row
# Store h
tl.store(h + offsets, h_row, mask = mask)
pass
def swiglu_fg_kernel(e, g):
batch, seq_len, hd = e.shape
n_elements = e.numel()
h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = "cuda:0")
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_fg_kernel[grid](e, g, h, n_elements, BLOCK_SIZE = 1024,)
return h
pass
@triton.jit
def _DWf_DW_dfg_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
"""
e = e.float()
se = 1.0 / (1.0 + torch.exp(-e))
f = (se * e).to(dtype)
h = f * g
df = DW * f
dg = DW * g
de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)
"""
block_idx = tl.program_id(0)
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
# e = e.float()
# se = 1.0 / (1.0 + torch.exp(-e))
se_row = tl.sigmoid(e_row) # 1.0 / (1.0 + tl.exp(-e_row))
# f = (se * e).to(dtype)
f_row = se_row * e_row
f_row = f_row.to(DW_row.dtype)
# h = f * g
h_row = f_row * g_row
# df = DW * f
df_row = DW_row * f_row
# dg = DW * g
dg_row = DW_row * g_row
# de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)
de_row = dg_row.to(tl.float32) * se_row * (1.0 + e_row * (1.0 - se_row))
de_row = de_row.to(DW_row.dtype)
# Store derivatives in buffers
tl.store(DW + offsets, h_row, mask = mask) # h = f * g
tl.store(e + offsets, df_row, mask = mask) # df = DW * f
tl.store(g + offsets, de_row, mask = mask) # de
pass
def swiglu_DWf_DW_dfg_kernel(DW, e, g):
batch_seq_len, hd = e.shape
n_elements = e.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_DWf_DW_dfg_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
return DW, e, g
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import triton
MAX_FUSED_SIZE = 65536
next_power_of_2 = triton.next_power_of_2
# torch.cuda.amp.custom_fwd is deprecated >= 2.4
import torch
from packaging.version import Version
if Version(torch.__version__) < Version("2.4.0"):
torch_amp_custom_fwd = torch.cuda.amp.custom_fwd
torch_amp_custom_bwd = torch.cuda.amp.custom_bwd
else:
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda")
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda")
pass
# tl.math.tanh now is libdevice.tanh
from packaging.version import Version
import triton
if Version(triton.__version__) >= Version("3.0.0"):
from triton.language.extra import libdevice
triton_tanh = libdevice.tanh
else:
import triton.language as tl
triton_tanh = tl.math.tanh
pass
def calculate_settings(n):
BLOCK_SIZE = next_power_of_2(n)
if BLOCK_SIZE > MAX_FUSED_SIZE:
raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
num_warps = 4
if BLOCK_SIZE >= 32768: num_warps = 32
elif BLOCK_SIZE >= 8192: num_warps = 16
elif BLOCK_SIZE >= 2048: num_warps = 8
return BLOCK_SIZE, num_warps
pass
import bitsandbytes as bnb
get_ptr = bnb.functional.get_ptr
import ctypes
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4
cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4
cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16
cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemm_4bit_inference_naive_bf16
def QUANT_STATE(W):
return getattr(W, "quant_state", None)
pass
def get_lora_parameters(proj):
# For DPO or disabled adapters
base_layer = (proj.base_layer if hasattr(proj, "base_layer") else proj)
W = base_layer.weight
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
return W, QUANT_STATE(W), None, None, None
pass
active_adapter = proj.active_adapters[0] if \
hasattr(proj, "active_adapters") else proj.active_adapter
A = proj.lora_A [active_adapter].weight
B = proj.lora_B [active_adapter].weight
s = proj.scaling[active_adapter]
return W, QUANT_STATE(W), A, B, s
pass
def get_lora_parameters_bias(proj):
# For DPO or disabled adapters
base_layer = (proj.base_layer if hasattr(proj, "base_layer") else proj)
W = base_layer.weight
bias = base_layer.bias
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
return W, QUANT_STATE(W), None, None, None, bias
pass
active_adapter = proj.active_adapters[0] if \
hasattr(proj, "active_adapters") else proj.active_adapter
A = proj.lora_A [active_adapter].weight
B = proj.lora_B [active_adapter].weight
s = proj.scaling[active_adapter]
return W, QUANT_STATE(W), A, B, s, bias
pass
def fast_dequantize(W, quant_state = None, out = None):
if quant_state is None: return W
if type(quant_state) is not list:
# New quant_state as a class
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
absmax = quant_state.absmax
shape = quant_state.shape
dtype = quant_state.dtype
blocksize = quant_state.blocksize
offset = quant_state.offset
state2 = quant_state.state2
absmax2 = state2.absmax
code2 = state2.code
blocksize2 = state2.blocksize
else:
# Old quant_state as a list of lists
absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
offset, state2 = compressed_stats
absmax2, code2, blocksize2, _, _, _, _ = state2
pass
# Create weight matrix
if out is None:
out = torch.empty(shape, dtype = dtype, device = "cuda:0")
else:
assert(out.shape == shape)
assert(out.dtype == dtype)
# NF4 dequantization of statistics
n_elements_absmax = absmax.numel()
out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0")
# Do dequantization
ptr_out_absmax = get_ptr(out_absmax)
cdequantize_blockwise_fp32(
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax,
ctypes.c_int(blocksize2), ctypes.c_int(n_elements_absmax)
)
out_absmax += offset
fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \
cdequantize_blockwise_bf16_nf4
fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out),
ctypes.c_int(blocksize), ctypes.c_int(out.numel()))
# Careful returning transposed data
is_transposed = (True if W.shape[0] == 1 else False)
return out.t() if is_transposed else out
pass
def fast_gemv(X, W, quant_state, out = None):
if quant_state is None: return torch.matmul(X, W, out = out)
# For fast X @ W where seq_len == 1
# From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
_, q_len, hd = X.shape
# assert(q_len == 1)
if type(quant_state) is not list:
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
absmax = quant_state.absmax
shape = quant_state.shape
dtype = quant_state.dtype
blocksize = quant_state.blocksize
stats = quant_state.code
offset = quant_state.offset
state2 = quant_state.state2
absmax2 = state2.absmax
code2 = state2.code
blocksize2 = state2.blocksize
else:
absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = quant_state
offset, state2 = compressed_stats
absmax2, code2, blocksize2, _, _, _, _ = state2
pass
# assert(dtype == X.dtype)
bout = shape[0]
if out is None:
out = torch.empty((1, 1, bout,), dtype = dtype, device = "cuda:0")
# else:
# assert(out.shape == (1, 1, bout,))
# pass
n = 1
m = shape[0]
k = shape[1]
lda = shape[0]
ldc = shape[0]
ldb = (hd+1)//2
m = ctypes.c_int32(m)
n = ctypes.c_int32(n)
k = ctypes.c_int32(k)
lda = ctypes.c_int32(lda)
ldb = ctypes.c_int32(ldb)
ldc = ctypes.c_int32(ldc)
df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda:0")
cdequantize_blockwise_fp32(
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df),
ctypes.c_int(blocksize2), ctypes.c_int(df.numel()),
)
df += offset
absmax = df
fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else \
cgemm_4bit_inference_naive_bf16
blocksize = ctypes.c_int32(blocksize)
fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out),
lda, ldb, ldc, blocksize)
return out
pass
def fast_linear_forward(proj, X, temp_lora = None, out = None):
W, W_quant, lora_A, lora_B, lora_S, bias = get_lora_parameters_bias(proj)
bsz, q_len, in_dim = X.shape
if q_len != 1: return matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S)
if W_quant is None:
out = torch.matmul(X, W.t(), out = out)
elif bsz == 1 and q_len == 1:
out = fast_gemv(X, W, W_quant, out = out)
else:
W = fast_dequantize(W.t(), W_quant)
out = torch.matmul(X, W, out = out)
pass
# Add in LoRA weights
if lora_A is not None:
out_dim = out.shape[2]
dtype = X.dtype
if not hasattr(lora_A, "_fast_lora"):
lora_A._fast_lora = lora_A.to(dtype)
lora_B._fast_lora = lora_B.to(dtype)
pass
if bsz == 1:
out = out.view(out_dim)
temp_lora = torch.mv(lora_A._fast_lora, X.ravel(), out = temp_lora)
out.addmv_(lora_B._fast_lora, temp_lora, alpha = lora_S)
else:
out = out.view(bsz, out_dim)
temp_lora = torch.mm(X.view(bsz, in_dim), lora_A._fast_lora.t(), out = temp_lora)
out.addmm_(temp_lora, lora_B._fast_lora.t(), alpha = lora_S)
pass
out = out.view(bsz, 1, out_dim)
pass
if bias is not None: out += bias
return out
pass
def matmul_lora(X, W, W_quant, A, B, s, out = None):
dtype = X.dtype
W = fast_dequantize(W.t(), W_quant)
if X.dim() == 3:
batch, seq_len, d = X.shape
X = X.view(-1, X.shape[-1])
reshape = True
else:
reshape = False
pass
out = torch.matmul(X, W, out = out)
if W_quant is not None: del W
if A is not None:
# LoRA is enabled
A, B = A.t(), B.t()
out += (X @ A.to(dtype)) @ (s * B.to(dtype))
pass
return out.view(batch, seq_len, -1) if reshape else out
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .loader import FastLanguageModel
from .llama import FastLlamaModel
from .mistral import FastMistralModel
from .qwen2 import FastQwen2Model
from .dpo import PatchDPOTrainer
from ._utils import is_bfloat16_supported
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "2024.8"
__all__ = [
"prepare_model_for_kbit_training",
"xformers",
"xformers_attention",
"xformers_version",
"__version__",
"HAS_FLASH_ATTENTION",
"PRE_CHECK",
"platform_system",
"patch_tokenizer",
"get_statistics",
"Unsloth_Offloaded_Gradient_Checkpointer",
"offload_to_disk",
"offload_input_embeddings",
"offload_output_embeddings",
"is_bfloat16_supported",
"unsloth_offloaded_gradient_checkpoint",
"torch_compile_options",
"patch_linear_scaling",
"patch_llama_rope_scaling",
"check_nvidia",
"create_boolean_mask",
"torch_amp_custom_fwd",
"torch_amp_custom_bwd",
]
import torch
from typing import Union, Optional, List, Any, Callable, Tuple
from platform import system as platform_system
platform_system = platform_system()
import numpy as np
import warnings, subprocess, re, inspect, psutil, os, math
from packaging.version import Version
# =============================================
# Disable some warnings which can get annoying
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "torch")
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "huggingface_hub")
warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "huggingface_hub")
warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "subprocess")
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "transformers")
warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "accelerate")
warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "multiprocessing")
warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "multiprocess")
# Stop "Special tokens have been added in the vocabulary, ..."
import logging
logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.CRITICAL+1)
# =============================================
# =============================================
# Edits all Config files to enable RoPE Scaling for all models
# Transformers had to update for Mistral Nemo 12b since Attention is (5120, 4096) now.
def patch_mistral_nemo_config(config):
if "head_dim (" not in config:
add_head_dim = "If it is not specified, will default to `8`.\n"\
" head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):\n"\
" The attention head dimension."
config = config.replace("If it is not specified, will default to `8`.", add_head_dim)
add_head_dim = "num_key_value_heads=8,\n head_dim=None,"
config = config.replace("num_key_value_heads=8,", add_head_dim)
add_head_dim = "self.sliding_window = sliding_window\n self.head_dim = head_dim or hidden_size // num_attention_heads\n"
config = config.replace("self.sliding_window = sliding_window", add_head_dim)
pass
return config
pass
from transformers import __version__ as transformers_version
from transformers import PretrainedConfig
model_architectures = ["llama", "mistral", "gemma", "gemma2", "qwen2",]
for model_name in model_architectures:
config_filepath = f"transformers.models.{model_name}.configuration_{model_name}"
model_filepath = f"transformers.models.{model_name}.modeling_{model_name}"
config_filename = f"{model_name.title()}Config"
exec(f"from {config_filepath} import {config_filename}", globals())
try:
config = inspect.getsource(eval(config_filename))
except:
continue
if "rope_scaling" in config: continue
config = re.sub(
r"(\*\*kwargs)[\s]{0,}\,[\s]{0,}\)[\s]{0,}\:",
r"rope_scaling=None,"\
r"\n **kwargs):\n"\
r"\n self.rope_scaling = rope_scaling\n",
config,
)
# Just for Mistral Nemo
if model_name == "mistral":
if Version(transformers_version) <= Version("4.42.4"):
config = patch_mistral_nemo_config(config)
pass
exec(config, globals())
exec(f"import {config_filepath}", globals())
exec(f"{config_filepath}.{config_filename} = {config_filename}", globals())
pass
# =============================================
# =============================================
# torch.cuda.amp.custom_fwd is deprecated >= 2.4
import torch
if Version(torch.__version__) < Version("2.4.0"):
torch_amp_custom_fwd = torch.cuda.amp.custom_fwd
torch_amp_custom_bwd = torch.cuda.amp.custom_bwd
else:
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda")
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda")
pass
# =============================================
# =============================================
# Get Flash Attention v2 if Ampere (RTX 30xx, A100)
import bitsandbytes as bnb
from transformers.models.llama.modeling_llama import logger
from transformers import AutoTokenizer
major_version, minor_version = torch.cuda.get_device_capability()
SUPPORTS_BFLOAT16 = False
if major_version >= 8:
SUPPORTS_BFLOAT16 = True
try:
from flash_attn import flash_attn_func
# Check for CUDA linking errors "undefined symbol: _ZNK3c106SymIntltEl"
try:
from flash_attn.flash_attn_interface import flash_attn_cuda
HAS_FLASH_ATTENTION = True
except:
logger.warning_once(
"Unsloth: Your Flash Attention 2 installation seems to be broken?\n"\
"A possible explanation is you have a new CUDA version which isn't\n"\
"yet compatible with FA2? Please file a ticket to Unsloth or FA2.\n"\
"We shall now use Xformers instead, which gets a 0.01% performance hit.\n"\
"We found this negligible impact by benchmarking on 1x A100."
)
HAS_FLASH_ATTENTION = False
except:
HAS_FLASH_ATTENTION = False
else:
# Tri Dao's benchmark shows xformers is faster for now.
HAS_FLASH_ATTENTION = False
pass
import xformers.ops.fmha as xformers
xformers_attention = xformers.memory_efficient_attention
from xformers import __version__ as xformers_version
# Temporarily disable 0.0.27 and higher - inference issues
if Version(xformers_version) >= Version("0.0.27"):
raise ImportError(
"Unsloth: If you are in Colab, we updated the top cell install instructions - please change it to below "\
"then press Disconnect Runtime and then Restart it.\n"\
"\n"\
"%%capture\n"
"# Installs Unsloth, Xformers (Flash Attention) and all other packages!\n"
'!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"\n'
'!pip install --no-deps "xformers<0.0.27" "trl<0.9.0" peft accelerate bitsandbytes\n'\
'\n'\
f"Otherwise in local machines, your xformers version of {xformers_version} is too new.\n"\
'Please downgrade xformers via `pip install --force-reinstall "xformers<0.0.27"'
)
pass
# Check TRL version
from trl import __version__ as trl_version
if Version(xformers_version) >= Version("0.9.0"):
raise ImportError(
"Unsloth: If you are in Colab, we updated the top cell install instructions - please change it to below "\
"then press Disconnect Runtime and then Restart it.\n"\
"\n"\
"%%capture\n"
"# Installs Unsloth, Xformers (Flash Attention) and all other packages!\n"
'!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"\n'
'!pip install --no-deps "xformers<0.0.27" "trl<0.9.0" peft accelerate bitsandbytes\n'\
'\n'\
f"Otherwise in local machines, your TRL version of {trl_version} is too new.\n"\
'Please downgrade TRL via `pip install --force-reinstall "trl<0.9.0"'
)
pass
# =============================================
# =============================================
# Torch compile settings
# Just remove max_autotune_gemm warning
import functools
@functools.lru_cache(None)
def is_big_gpu(index):
sms = torch.cuda.get_device_properties(index).multi_processor_count
if sms < 80: # V100
# log.warning("not enough SMs to use max_autotune_gemm mode")
return False
return True
import torch._inductor.utils
torch._inductor.utils.is_big_gpu = is_big_gpu
# Torch compile arguments
torch_compile_arguments = [
"config.dce = True",
"config.memory_planning = True",
"config.memory_pool = 'combined'",
"config.coordinate_descent_tuning = True",
"config.max_autotune_gemm = False", # GEMM is unnecessary
"config.autotune_multi_device = False",
"config.max_autotune_gemm_backends = 'ATEN'", # Not much faster
"config.aggressive_fusion = False", # Careful changes results!
"config.cuda.enable_cuda_lto = True",
"config.cuda.use_fast_math = True",
"config.cuda.compile_opt_level = '-O2'",
]
# Torch dynamo arguments
torch_dynamo_arguments = [
"config.accumulated_cache_size_limit = 512", # Bump up a bit from 256
"config.suppress_errors = True", # Supress errors for now
"config.do_not_emit_runtime_asserts = True",
]
import torch._inductor.config as config
for _try_compile_argument in torch_compile_arguments:
try: exec(_try_compile_argument)
except: pass
pass
import torch._dynamo.config as config
for _try_dynamo_argument in torch_dynamo_arguments:
try: exec(_try_dynamo_argument)
except: pass
pass
torch_compile_options = {
"epilogue_fusion" : True,
"max_autotune" : True,
"shape_padding" : True,
"trace.enabled" : False, # Output Triton kernel outputs!
"triton.cudagraphs" : False,
}
# =============================================
def prepare_model_for_kbit_training(
model : Any,
use_gradient_checkpointing : Optional = True,
use_reentrant : Optional[bool] = True,
) -> Any:
"""
Calculates where to place the gradient checkpoints given n_layers.
We also freeze all other layers's gradients
Args:
model: Any LlamaModel with layers.
use_gradient_checkpointing (`bool`, *optional*):
Default enabled. Provides memory savings by not saving all activations,
but only some.
use_reentrant (`bool`, *optional*):
https://github.com/pytorch/pytorch/blob/main/torch/utils/checkpoint.py#L354
Optimal gradient checkpointing algorithm which will be the default in
future Pytorch versions.
"""
# Freeze all parameters except LoRA
import re
with torch.no_grad():
for name, param in model.named_parameters():
if ".lora_A." in name or ".lora_B." in name or ".lora_magnitude_vector" in name:
param.requires_grad_(True)
# Also must be in float32!
if param.dtype != torch.float32:
name = name.replace("base_model", "model", 1)
layer_number = re.search(r"\.[\d]{1,}\.", name).group(0)
name = name.replace(layer_number, f"[{layer_number[1:-1]}].")
name = name.replace(".weight", "", 1)
exec(f"{name}.to(torch.float32)")
pass
else:
param.requires_grad_(False)
pass
pass
# Gradient checkpointing!
if use_gradient_checkpointing == "unsloth":
# Saves VRAM!
original_model = model
while hasattr(original_model, "model"):
original_model._offloaded_gradient_checkpointing = True
original_model = original_model.model
pass
original_model._offloaded_gradient_checkpointing = True
model.gradient_checkpointing_enable()
elif use_gradient_checkpointing == True:
model.gradient_checkpointing_enable()
pass
# If use_reentrant = True which is the Pytorch default, we just make the input requires_grad.
if use_reentrant:
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
return model
pass
def patch_tokenizer(model, tokenizer):
"""
Phi3's pad_token isn't set. We set it to <|placeholder...
Llama-3 is <|reserved...
Llama-2 is <unk>
Check if pad_token is not the same as eos_token otherwise the loss will ignore it!!
Fixes https://github.com/unslothai/unsloth/issues/5
"""
possible_reserved_tokens = (
"<|reserved", # Llama-3
"<|placeholder", # Phi-3
"[control", # Mistral type models
"<pad>", # Mistral Nemo
"<|finetune_right_pad_id|>", # Llama-3.1
)
if model is not None:
model.config.update({"unsloth_version" : __version__})
bad_pad_token = False
if hasattr(tokenizer, "pad_token") and tokenizer.pad_token is not None:
# Check if pad_token is not the same as eos_token otherwise the loss will ignore it!!
bad_pad_token = tokenizer.eos_token == tokenizer.pad_token
elif hasattr(tokenizer, "pad_token") and tokenizer.pad_token is None:
bad_pad_token = True
else:
bad_pad_token = False
pass
if bad_pad_token:
# Find a better pad token
added_tokens = [str(x) for x in tokenizer.added_tokens_decoder.values()]
possible_pad_token = None
n_possible_pad_tokens = 0
for added_token in added_tokens[::-1]:
if added_token.startswith(possible_reserved_tokens):
if possible_pad_token is None: possible_pad_token = added_token
n_possible_pad_tokens += 1
# We must see at least 3 of the reserved tokens
if n_possible_pad_tokens >= 3: break
pass
pass
if n_possible_pad_tokens < 3: possible_pad_token = None
if possible_pad_token is None:
# Try unk_token
possible_pad_token = tokenizer.unk_token
pass
if possible_pad_token is None:
# Failure to find a good replacement!! We shall manually add one!
new_pad_token = "<|PAD_TOKEN|>"
while new_pad_token in tokenizer.get_vocab():
new_pad_token += "#"
pass
possible_pad_token = new_pad_token
pass
name = model.config._name_or_path if model is not None else "Model"
logger.warning_once(
f"{name} does not have a padding token! Will use pad_token = {possible_pad_token}."
)
# Edit pad_token
tokenizer.add_special_tokens({"pad_token" : possible_pad_token})
tokenizer.pad_token = possible_pad_token
if model is not None:
config = model.config.update({"pad_token_id" : tokenizer.pad_token_id})
else:
if model is not None:
if model.config.pad_token_id is None:
config = model.config.update({"pad_token_id" : tokenizer.pad_token_id})
return model, tokenizer
pass
# =============================================
# Weirdly LoraLayer.update_layer downcasts PEFT layers to float16??
# For mixed precision, we need it to be in float32 not float16.
from peft import __version__ as peft_version
if Version(peft_version) < Version("0.12.0"):
from peft.tuners.lora.layer import LoraLayer
import inspect, re
try:
source = inspect.getsource(LoraLayer.update_layer)
text = "if weight is not None:\n"
start = source.find(text) + len(text)
end = source.find("self.to(weight.device)", start)
spaces = re.findall(r"^([ ]{1,})break", source, flags = re.MULTILINE)[0]
source = source.replace(source[start : end], spaces)
spaces = len(re.match(r"[\s]{1,}", source).group(0))
lines = source.split("\n")
source = "\n".join(x[spaces:] for x in lines)
source = re.sub("([^\.])nn\.", r"\1torch.nn.", source)
source = source.replace("def update_layer", "def LoraLayer_update_layer")
exec(source, globals())
# Fix up incorrect downcasting of LoRA weights
from peft.tuners.lora.layer import LoraLayer
LoraLayer.update_layer = LoraLayer_update_layer
from peft.tuners.lora import LoraLayer
LoraLayer.update_layer = LoraLayer_update_layer
except:
logger.warning_once(
"Unsloth unsuccessfully patched LoraLayer.update_layer. Please file a bug report.\n"\
"Luckily, your training run will still work in the meantime!"
)
pass
pass
# =============================================
import psutil
def _get_statistics(statistics = None, force_download = True):
# We log some basic stats about which environment is being used.
# We simply download a README.md file from HF - all data is made public.
# This is simply so we can check if some envs are broken or not.
# You can disable this by commenting the below out
try:
n_cpus = psutil.cpu_count(logical = False)
keynames = "\n" + "\n".join(os.environ.keys())
if statistics is not None: pass
elif "\nCOLAB_" in keynames and n_cpus == 1: statistics = "colab"
elif "\nCOLAB_" in keynames: statistics = "colabpro"
elif "\nKAGGLE_" in keynames: statistics = "kaggle"
elif "\nRUNPOD_" in keynames: statistics = "runpod"
elif "\nAWS_" in keynames: statistics = "aws"
elif "\nAZURE_" in keynames: statistics = "azure"
elif "\nK_" in keynames or "\nFUNCTION_" in keynames: statistics = "gcp"
elif "\nINVOCATION_ID" in keynames: statistics = "lambda"
else: statistics = "other"
if statistics is not None:
from transformers import AutoModelForCausalLM
stats_model = AutoModelForCausalLM.from_pretrained(
f"unslothai/{statistics}",
force_download = force_download,
)
del stats_model
pass
except:
pass
pass
def get_statistics():
# We log some basic stats about which environment is being used.
# We simply download a README.md file from HF - all data is made public.
# This is simply so we can check if some envs are broken or not.
# You can disable this by commenting the below out
from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled
disabled = False
if not are_progress_bars_disabled():
disable_progress_bars()
disabled = True
pass
_get_statistics(None)
_get_statistics("repeat", force_download = False)
try:
vram = torch.cuda.get_device_properties(0).total_memory / 1024 / 1024 / 1024
if vram <= 8 : vram = 8
elif vram <= 16: vram = 16
elif vram <= 20: vram = 20
elif vram <= 24: vram = 24
elif vram <= 40: vram = 40
elif vram <= 48: vram = 48
elif vram <= 80: vram = 80
else: vram = 96
_get_statistics(f"vram-{vram}")
except:
pass
pass
try:
devices = torch.cuda.device_count()
_get_statistics(f"{devices if devices <= 8 else 9}")
except:
pass
if disabled: enable_progress_bars()
pass
def _calculate_n_gradient_checkpoints(
n_layers : int,
method : Optional[Union[str, int]] = "sqrt",
) -> List[int]:
assert(type(n_layers) is int and n_layers > 0)
if method is None: method = "sqrt"
if method == "sqrt":
n_checkpoints = int(n_layers**0.5)
elif type(method) is int and method > 0:
n_checkpoints = int(np.ceil(n_layers / method))
else:
raise ValueError("method must be 'sqrt' or an int >0 and <= n_layers.")
size = n_layers // n_checkpoints
sizes = np.full(n_checkpoints, size, dtype = int)
leftovers = n_layers % n_checkpoints
# We append leftovers from the right
for k in range(leftovers):
sizes[n_checkpoints-1-k] += 1
boundaries = np.hstack((0, np.cumsum(sizes)))
boundaries = boundaries.tolist()
return boundaries
pass
def calculate_n_gradient_checkpoints(
n_layers : int,
layers_per_checkpoint : Optional[Union[str, int]] = "sqrt",
) -> List[int]:
assert(type(n_layers) is int and n_layers > 0)
if layers_per_checkpoint is None or layers_per_checkpoint == 1:
return None
boundaries = _calculate_n_gradient_checkpoints(n_layers, layers_per_checkpoint)
assert(boundaries[0] == 0 and boundaries[-1] == n_layers)
assert(min(boundaries) == 0 and max(boundaries) == n_layers)
assert(np.diff(boundaries).min() >= 0)
return boundaries
pass
def prepare_n_gradient_checkpoints(
model : Any,
layers_per_checkpoint : Optional[Union[str, int]] = "sqrt",
use_reentrant : Optional[bool] = True,
) -> None:
"""
Calculates where to place the gradient checkpoints given n_layers.
Args:
model: Any LlamaModel with layers.
layers_per_checkpoint (`Union[str, int]`, *optional*):
Can either be `sqrt` or an integer for how many layers per checkpoint you want.
The more, the less memory usage, but can be slower. Default is `sqrt`.
Choose 1 for Pytorch gradient checkpointing. 2 to wrap 2 layers in 1 module etc.
use_reentrant (`bool`, *optional*):
https://github.com/pytorch/pytorch/blob/main/torch/utils/checkpoint.py#L354
Optimal gradient checkpointing algorithm `use_reentrant=False` which will
be the default in future Pytorch versions doesn't seem to work??
"""
_model = None
if hasattr(model, "layers"):
_model = model
elif hasattr(model, "model"):
if hasattr(model.model, "layers"):
_model = model.model
if _model is None:
raise TypeError("`model` or `model.model` does not have attribute `layers`. Are you sure this is a model?")
pass
if use_reentrant is False:
use_reentrant = True
pass
n_layers = len(_model.layers)
boundaries = calculate_n_gradient_checkpoints(n_layers, layers_per_checkpoint)
_model._gradient_checkpointing_boundaries = boundaries
_model._gradient_checkpointing_use_reentrant = use_reentrant
pass
class Unsloth_Offloaded_Gradient_Checkpointer(torch.autograd.Function):
"""
Saves VRAM by smartly offloading to RAM.
Tiny hit to performance, since we mask the movement via non blocking calls.
"""
@staticmethod
@torch_amp_custom_fwd
def forward(ctx, forward_function, hidden_states, *args):
saved_hidden_states = hidden_states.to("cpu", non_blocking = True)
with torch.no_grad():
output = forward_function(hidden_states, *args)
ctx.save_for_backward(saved_hidden_states)
ctx.forward_function = forward_function
ctx.args = args
return output
pass
@staticmethod
@torch_amp_custom_bwd
def backward(ctx, dY):
(hidden_states,) = ctx.saved_tensors
hidden_states = hidden_states.to("cuda:0", non_blocking = True).detach()
hidden_states.requires_grad = True
with torch.enable_grad():
(output,) = ctx.forward_function(hidden_states, *ctx.args)
torch.autograd.backward(output, dY)
return (None, hidden_states.grad,) + (None,)*len(ctx.args)
pass
pass
@torch._disable_dynamo
def unsloth_offloaded_gradient_checkpoint(function, *args, use_reentrant = None, **kwargs):
return Unsloth_Offloaded_Gradient_Checkpointer.apply(function, *args)
pass
# =============================================
# Fixes Bitsandbytes to remove missing warnings
from transformers.utils.quantization_config import BitsAndBytesConfig, QuantizationMethod
from inspect import getsource
from accelerate.utils.dataclasses import DistributedType
import re
BitsAndBytesConfig__init__ = getsource(BitsAndBytesConfig.__init__)
BitsAndBytesConfig__init__ = re.sub(
r"if[\s]{1,}kwargs\:[\s]{1,}.+?\n",
"",
BitsAndBytesConfig__init__,
flags = re.MULTILINE,
)
BitsAndBytesConfig__init__ = BitsAndBytesConfig__init__.split("\n")
length_spaces = len(re.match(r"[\s]{1,}", BitsAndBytesConfig__init__[0]).group(0))
BitsAndBytesConfig__init__ = "\n".join(x[length_spaces:] for x in BitsAndBytesConfig__init__)
BitsAndBytesConfig__init__ = BitsAndBytesConfig__init__.replace(
"__init__",
"_BitsAndBytesConfig__init__",
)
def _prepare_backend(
self, cpu: bool = False, sagemaker_dp = False, backend: str = None,
) -> tuple[str, DistributedType]:
return None, DistributedType.NO
pass
import accelerate.state
accelerate.state.PartialState._prepare_backend = _prepare_backend
import accelerate.accelerator
prepare = inspect.getsource(accelerate.accelerator.Accelerator.prepare)
prepare = prepare.split("\n")
spaces = prepare[0].find("def")
prepare = "\n".join(x[spaces:] for x in prepare)
x = "for obj in args:"
s = " "*spaces
prepare = prepare.replace(x, f'self.state.distributed_type = DistributedType.NO\n{s}{x}', 1)
exec(prepare, globals())
accelerate.accelerator.Accelerator.prepare = prepare
exec(BitsAndBytesConfig__init__, globals())
import transformers.utils.quantization_config
transformers.utils.quantization_config.BitsAndBytesConfig.__init__ = _BitsAndBytesConfig__init__
# =============================================
# Offloading to disk for modules (lm_head, embed_tokens)
import pickle
def offload_to_disk(W, model, name, temporary_location : str = "_unsloth_temporary_saved_buffers"):
file_location = os.path.join(temporary_location, model.config._name_or_path)
if not os.path.exists(file_location):
os.makedirs(file_location)
pass
filename = os.path.join(file_location, f"{name}.pt")
W = W.weight if hasattr(W, "weight") else W
torch.save(W, filename, pickle_module = pickle, pickle_protocol = pickle.HIGHEST_PROTOCOL,)
offloaded_W = torch.load(filename, map_location = "cpu", mmap = True)
offloaded_W._offloaded_file_location = filename
return offloaded_W
pass
def offload_input_embeddings(model, temporary_location : str = "_unsloth_temporary_saved_buffers"):
offloaded_W = offload_to_disk(model.get_input_embeddings(), model, "input_embeddings", temporary_location)
new_input_embeddings = torch.nn.Embedding.from_pretrained(offloaded_W)
new_input_embeddings._offloaded_file_location = offloaded_W._offloaded_file_location
model.set_input_embeddings(new_input_embeddings)
return
pass
def offload_output_embeddings(model, temporary_location : str = "_unsloth_temporary_saved_buffers"):
offloaded_W = offload_to_disk(model.get_output_embeddings(), model, "output_embeddings", temporary_location)
new_output_embeddings = torch.nn.Linear(1, 1, bias = None)
del new_output_embeddings.weight
new_output_embeddings.weight = offloaded_W
new_output_embeddings.in_features = offloaded_W.shape[1]
new_output_embeddings.out_features = offloaded_W.shape[0]
new_output_embeddings._offloaded_file_location = offloaded_W._offloaded_file_location
model.set_output_embeddings(new_output_embeddings)
return
pass
# Fixes a weird Torch 2.3 bug which says T4s have bfloat16
def is_bfloat16_supported():
return SUPPORTS_BFLOAT16
pass
# Patches models to add RoPE Scaling
def patch_linear_scaling(
model_name = "gemma2",
rope_module = None,
scaled_rope_module = None,
attention_module = None,
):
assert(rope_module is not None and scaled_rope_module is not None)
assert(attention_module is not None)
rope_name = rope_module.__name__
scaled_rope_name = scaled_rope_module.__name__
model_filepath = f"transformers.models.{model_name}.modeling_{model_name}"
exec_code = \
f"import torch.nn as nn\n"\
f"from typing import Union, Optional, List, Any, Callable, Tuple\n"\
f"from {model_filepath} import logger, "\
f"{model_name.title()}Attention, {model_name.title()}Config"
try:
function = inspect.getsource(attention_module.__init__)
except:
# Most likely already patched!
return None, None
where = function.find("def")
function = function.split("\n")
function = "\n".join(x[where:] for x in function)
init_name = f"{model_name.title()}Attention__init__"
function = function.replace("def __init__", f"def {init_name}")
function = function.replace(
"super().__init__()",
f"super({model_name.title()}Attention, self).__init__()",
)
fix_rope_function = """
if getattr(self.config, "rope_scaling", None) is None:
self.rotary_emb = {rope_function}(
dim = self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = {scaled_rope_function}(
dim = self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
else:
raise ValueError(f"Unknown RoPE scaling type {{scaling_type}}")
pass
"""
fix_rope_function = fix_rope_function.format(
rope_function = rope_module.__name__,
scaled_rope_function = scaled_rope_module.__name__,
)
rotary_emb = re.findall(
"self.rotary_emb = .+?\)", function,
flags = re.DOTALL | re.MULTILINE,
)
if len(rotary_emb) == 0: return None, function
rotary_emb = rotary_emb[0]
function = function.replace(rotary_emb, fix_rope_function, 1)
function = exec_code + "\n\n" + function
return init_name, function
pass
# Patches for Llama-3 LlamaExtendedRotaryEmbedding
def patch_llama_rope_scaling(
model_name = "llama",
rope_module = None,
scaled_rope_module = None,
extended_rope_module = None,
attention_module = None,
):
assert(\
rope_module is not None and \
scaled_rope_module is not None and \
extended_rope_module is not None
)
assert(attention_module is not None)
rope_name = rope_module.__name__
scaled_rope_name = scaled_rope_module.__name__
model_filepath = f"transformers.models.{model_name}.modeling_{model_name}"
exec_code = \
f"import torch.nn as nn\n"\
f"from typing import Union, Optional, List, Any, Callable, Tuple\n"\
f"from {model_filepath} import logger, "\
f"{model_name.title()}Attention, {model_name.title()}Config"
try:
function = inspect.getsource(attention_module.__init__)
except:
# Most likely already patched!
return None, None
where = function.find("def")
function = function.split("\n")
function = "\n".join(x[where:] for x in function)
init_name = f"{model_name.title()}Attention__init__"
function = function.replace("def __init__", f"def {init_name}")
function = function.replace(
"super().__init__()",
f"super({model_name.title()}Attention, self).__init__()",
)
fix_rope_function = """
if getattr(self.config, "rope_scaling", None) is None:
self.rotary_emb = {rope_function}(
dim = self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
scaling_type1 = self.config.rope_scaling.get("type", None)
scaling_type2 = self.config.rope_scaling.get("rope_type", None)
scaling_type = scaling_type1 if scaling_type1 is not None else scaling_type2
scaling_factor = self.config.rope_scaling.get("factor")
if scaling_type == "linear":
self.rotary_emb = {scaled_rope_function}(
dim = self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
elif scaling_type == "llama3":
self.rotary_emb = {extended_rope_function}(
dim = self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
raise ValueError(f"Unknown RoPE scaling type {{scaling_type}}")
pass
"""
fix_rope_function = fix_rope_function.format(
rope_function = rope_module.__name__,
scaled_rope_function = scaled_rope_module.__name__,
extended_rope_function = extended_rope_module.__name__,
)
rotary_emb = re.findall(
"self.rotary_emb = .+?\)", function,
flags = re.DOTALL | re.MULTILINE,
)
if len(rotary_emb) == 0: return None, function
rotary_emb = rotary_emb[0]
function = function.replace(rotary_emb, fix_rope_function, 1)
function = exec_code + "\n\n" + function
return init_name, function
pass
def check_nvidia():
# Unsloth doesn't work yet on AMD devices - we're working on it!
output = np.array([0,])
try:
output = subprocess.check_output("nvidia-smi --query-gpu=memory.used --format=csv", shell = True)
output = re.findall(rb'([\d]{1,})[\s]{1,}M', output)
output = np.array([int(x.decode('utf-8'))/1024 for x in output])
except:
if not torch.cuda.is_available():
raise RuntimeError("Unsloth: We do not support AMD / Intel machines yet - it is a work in progress!")
return output
pass
PRE_CHECK = check_nvidia()
def create_boolean_mask(n = 4096, sliding_window = 2048):
# Creates a boolean mask for attention
mask = torch.ones(n, n, dtype = torch.bool)
if sliding_window == 0:
return torch.triu(mask, diagonal = 1, out = mask)
pass
torch.triu(mask, diagonal = 0, out = mask)
torch.triu(mask.T, diagonal = -sliding_window, out = mask.T)
mask = mask.T
torch.logical_not(mask, out = mask)
return mask
pass
def test_mask_creation():
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
for n in range(2, 23):
for s in range(1, 23):
correct_mask = AttentionMaskConverter(
is_causal = True,
sliding_window = s,
).to_causal_4d(1, n, n, dtype = torch.float16,).squeeze(0).squeeze(0)
correct_mask = (correct_mask == correct_mask.min())
our_mask = create_boolean_mask(n = n, sliding_window = s)
assert(torch.all(correct_mask == our_mask))
pass
correct_mask = AttentionMaskConverter(
is_causal = True,
sliding_window = None,
).to_causal_4d(1, n, n, dtype = torch.float16,).squeeze(0).squeeze(0)
correct_mask = (correct_mask == correct_mask.min())
our_mask = create_boolean_mask(n = n, sliding_window = 0)
assert(torch.all(correct_mask == our_mask))
pass
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
try:
from transformers.utils.notebook import (
IntervalStrategy,
NotebookTrainingTracker,
NotebookProgressCallback,
)
HAS_NOTEBOOK = True
except:
HAS_NOTEBOOK = False
pass
DPOTrainer_metrics = [
"rewards/chosen",
"rewards/rejected",
"rewards/accuracies",
"rewards/margins",
"logps/rejected",
"logps/chosen",
"logits/rejected",
"logits/chosen",
]
set_DPOTrainer_metrics = frozenset(DPOTrainer_metrics)
def NotebookProgressCallback_on_train_begin(self, args, state, control, **kwargs):
self.first_column = "Epoch" if args.evaluation_strategy == IntervalStrategy.EPOCH else "Step"
self.training_loss = 0
self.last_log = 0
column_names = [self.first_column] + ["Training Loss"]
if args.evaluation_strategy != IntervalStrategy.NO:
column_names.append("Validation Loss")
column_names += [x.replace("/", " / ") for x in DPOTrainer_metrics]
self.training_tracker = NotebookTrainingTracker(state.max_steps, column_names)
pass
def NotebookProgressCallback_on_log(self, args, state, control, logs=None, **kwargs):
# Only for when there is no evaluation
if args.evaluation_strategy == IntervalStrategy.NO and "loss" in logs:
values = {"Training Loss": logs["loss"]}
for metric in DPOTrainer_metrics:
values[metric.replace("/", " / ")] = logs[metric]
pass
# First column is necessarily Step since we're not in epoch eval strategy
values["Step"] = state.global_step
self.training_tracker.write_line(values)
pass
pass
def NotebookTrainingTracker_write_line(self, values):
"""
Write the values in the inner table.
Args:
values (`Dict[str, float]`): The values to display.
"""
if self.inner_table is None:
self.inner_table = [list(values.keys()), list(values.values())]
else:
columns = self.inner_table[0]
new_values = {}
for key, value in values.items():
lowered = key.lower()
if lowered in set_DPOTrainer_metrics:
new_values[lowered.replace("/", " / ")] = value
else:
new_values[key] = value
pass
values = new_values
self.inner_table[0] = columns
if len(self.inner_table) > 1:
last_values = self.inner_table[-1]
first_column = self.inner_table[0][0]
if last_values[0] != values[first_column]:
# write new line
self.inner_table.append([values[c] if c in values else "No Log" for c in columns])
else:
# update last line
new_values = values
for c in columns:
if c not in new_values.keys():
new_values[c] = last_values[columns.index(c)]
self.inner_table[-1] = [new_values[c] for c in columns]
else:
# Edit for evaluation purposes
self.inner_table.append([values[c] if c in values else 0 for c in columns])
pass
pass
pass
def PatchDPOTrainer():
if HAS_NOTEBOOK:
from transformers.trainer import is_in_notebook
if is_in_notebook():
# Patch DPO notebook printing
NotebookTrainingTracker.write_line = NotebookTrainingTracker_write_line
from transformers.trainer import DEFAULT_PROGRESS_CALLBACK
DEFAULT_PROGRESS_CALLBACK.on_train_begin = NotebookProgressCallback_on_train_begin
DEFAULT_PROGRESS_CALLBACK.on_log = NotebookProgressCallback_on_log
pass
pass
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .llama import *
from ._utils import __version__
try:
from transformers.models.gemma.modeling_gemma import (
GemmaAttention,
GemmaDecoderLayer,
GemmaModel,
GemmaForCausalLM,
GemmaRotaryEmbedding,
apply_rotary_pos_emb,
repeat_kv,
)
except:
from packaging.version import Version
transformers_version = Version(transformers_version)
if not transformers_version >= Version("4.38"):
raise ImportError(
f"Unsloth: Your transformers version of {transformers_version} does not support Gemma.\n"\
f"The minimum required version is 4.38.\n"\
f'Try `pip install --upgrade "transformers>=4.38"`\n'\
f"to obtain the latest transformers build, then restart this session."\
)
pass
pass
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask_for_sdpa,
)
# For Pytorch 2.1.1
try:
from transformers.models.gemma.modeling_gemma import (
GemmaSdpaAttention,
GemmaFlashAttention2,
)
except:
GemmaSdpaAttention = GemmaAttention
GemmaFlashAttention2 = GemmaAttention
pass
torch_nn_functional_gelu = torch.nn.functional.gelu
def fast_geglu_inference(self, X):
# gate = self.gate_proj(X)
# up = self.up_proj(X)
bsz, _, hd = X.shape
# mlp_size = self.config.intermediate_size
# temp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0")
gate = fast_linear_forward(self.gate_proj, X)#, out = temp[0])
up = fast_linear_forward(self. up_proj, X)#, out = temp[1])
gate = torch_nn_functional_gelu(gate, approximate = "tanh")
gate *= up
# X = self.down_proj(gate)
down = fast_linear_forward(self.down_proj, gate, out = up[:,:,:hd])
return down
pass
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L590
def GemmaDecoderLayer_fast_forward(
self,
hidden_states: torch.Tensor,
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
*args, **kwargs,
):
if use_cache and hasattr(self, "_flag_for_generation"): #past_key_value is not None:
out_weight = torch.empty(self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda:0")
# Self Attention
residual = hidden_states
hidden_states = fast_rms_layernorm_inference_gemma(self.input_layernorm, hidden_states, out_weight)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)
hidden_states += residual
# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm_inference_gemma(self.post_attention_layernorm, hidden_states, out_weight)
hidden_states = fast_geglu_inference(self.mlp, hidden_states)
hidden_states += residual
else:
residual = hidden_states
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states, gemma = True)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states, gemma = True)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
pass
outputs = (hidden_states,)
if output_attentions: outputs += (self_attn_weights,)
if use_cache: outputs += (present_key_value,)
return outputs
pass
from math import sqrt as math_sqrt
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825
# @torch.inference_mode
def GemmaModel_fast_forward_inference(
self,
input_ids,
past_key_values,
position_ids,
attention_mask = None,
):
out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda:0")
input_ids = input_ids[:,:self.max_seq_length]
hidden_states = self.model.embed_tokens(input_ids)
hidden_states = hidden_states.to(self.config.torch_dtype)
# 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32
# 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32
hidden_states *= torch.tensor(math_sqrt(self.config.hidden_size), dtype = hidden_states.dtype)
bsz, q_len, hd = hidden_states.shape
seq_len = past_key_values[0][0].shape[-2]
if bsz != 1:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(bsz, q_len),
hidden_states,
seq_len,
)
pass
next_decoder_cache = []
for idx, decoder_layer in enumerate(self.model.layers):
residual = hidden_states
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.input_layernorm, hidden_states, out_weight)
hidden_states, present_key_value = LlamaAttention_fast_forward_inference(
decoder_layer.self_attn,
hidden_states = hidden_states,
past_key_value = past_key_values[idx],
position_ids = position_ids,
attention_mask = attention_mask,
do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
)
hidden_states += residual
residual = hidden_states
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_attention_layernorm, hidden_states, out_weight)
hidden_states = fast_geglu_inference(decoder_layer.mlp, hidden_states)
hidden_states += residual
next_decoder_cache.append(present_key_value)
pass
hidden_states = fast_rms_layernorm_inference_gemma(self.model.norm, hidden_states, out_weight)
return BaseModelOutputWithPast(
last_hidden_state = hidden_states,
past_key_values = next_decoder_cache,
hidden_states = [],
attentions = [],
)
pass
# Follows line by line https://github.com/google-deepmind/gemma/blob/main/gemma/positional_embeddings.py#L45
# Formulates cos and sin differently from Llama!
class GemmaFixedRotaryEmbedding(torch.nn.Module):
# Fixes https://github.com/huggingface/transformers/pull/28837
# https://github.com/microsoft/DeepSpeed/issues/4932
# The precision of RoPE buffers is not correct, so we cast to int64.
def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None,
config = None, # [TODO] Hack to pass in config - need to remove later
):
super().__init__()
if config is not None: return # [TODO] Hack to pass in config - need to remove later
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
# Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this
self.current_rope_size = min(4 * 8192, self.max_position_embeddings)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype())
pass
def _set_cos_sin_cache(self, seq_len, device, dtype):
# Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
# in FP32. They are applied (multiplied) in FP32 as well.
self.current_rope_size = seq_len
# The difference is we do division explicity instead of t * (1/x) ie we do t/x.
freq_exponents = (2.0 / self.dim) * (
torch.arange(self.dim // 2, dtype = torch.int64, device = "cpu").float()
)
timescale = self.base**freq_exponents
positions = torch.arange(self.current_rope_size, device = "cpu", dtype = torch.int64).float()
radians_new = positions[..., None] / timescale[None, None, :]
radians_new = radians_new.squeeze(0)
emb = torch.cat((radians_new, radians_new), dim = -1)
# We must do RoPE in float32!
cos = emb.cos().to(device = "cuda:0", non_blocking = True)#, dtype = dtype)
sin = emb.sin().to(device = "cuda:0", non_blocking = True)#, dtype = dtype)
self.register_buffer("cos_cached", cos, persistent = False)
self.register_buffer("sin_cached", sin, persistent = False)
pass
def forward(self, x, position_ids=None, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.current_rope_size:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
pass
def extend_rope_embedding(self, x, seq_len):
if seq_len <= self.current_rope_size: return
# Iteratively grow by increments of 8192
self.current_rope_size = int(round(seq_len / 8192)) * 8192
self._set_cos_sin_cache(self.current_rope_size, device = "cuda:0", dtype = x.dtype)
pass
pass
class GemmaFixedLinearScalingRotaryEmbedding(GemmaFixedRotaryEmbedding):
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
# Fixes https://github.com/huggingface/transformers/pull/28837
# https://github.com/microsoft/DeepSpeed/issues/4932
# The precision of RoPE buffers is not correct, so we cast to int64.
def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0,
config = None, # [TODO] Hack to pass in config - need to remove later
):
self.scaling_factor = scaling_factor
super().__init__(dim = dim, max_position_embeddings = max_position_embeddings, base = base, device = device, config = config)
pass
def _set_cos_sin_cache(self, seq_len, device, dtype):
# Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
# in FP32. They are applied (multiplied) in FP32 as well.
self.current_rope_size = seq_len
# The difference is we do division explicity instead of t * (1/x) ie we do t/x.
freq_exponents = (2.0 / self.dim) * (
torch.arange(self.dim // 2, dtype = torch.int64, device = "cpu").float()
)
timescale = self.base**freq_exponents
positions = torch.arange(self.current_rope_size, device = "cpu", dtype = torch.int64).float()
positions = positions / self.scaling_factor
radians_new = positions[..., None] / timescale[None, None, :]
radians_new = radians_new.squeeze(0)
emb = torch.cat((radians_new, radians_new), dim = -1)
# We must do RoPE in float32!
cos = emb.cos().to(device = "cuda:0", non_blocking = True)#, dtype = dtype)
sin = emb.sin().to(device = "cuda:0", non_blocking = True)#, dtype = dtype)
self.register_buffer("cos_cached", cos, persistent = False)
self.register_buffer("sin_cached", sin, persistent = False)
pass
pass
class FastGemmaModel(FastLlamaModel):
@staticmethod
def pre_patch():
init_name, function = patch_linear_scaling(
model_name = "gemma",
rope_module = GemmaFixedRotaryEmbedding,
scaled_rope_module = GemmaFixedLinearScalingRotaryEmbedding,
attention_module = GemmaAttention,
)
if init_name is not None:
exec(function, globals())
GemmaAttention.__init__ = eval(init_name)
pass
GemmaAttention .forward = LlamaAttention_fast_forward
GemmaSdpaAttention .forward = LlamaAttention_fast_forward
GemmaFlashAttention2.forward = LlamaAttention_fast_forward
GemmaDecoderLayer .forward = GemmaDecoderLayer_fast_forward
GemmaModel .forward = LlamaModel_fast_forward
GemmaForCausalLM .forward = CausalLM_fast_forward(GemmaModel_fast_forward_inference)
PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward
fix_prepare_inputs_for_generation(GemmaForCausalLM)
# Solves https://github.com/unslothai/unsloth/issues/168
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
# Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
# https://github.com/huggingface/transformers/pull/27931
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
import transformers.models.gemma.modeling_gemma
transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding = GemmaFixedRotaryEmbedding
return
pass
@staticmethod
def post_patch(model):
# Patch model for Gemma
layers = model.model.layers
# Torch.compile fails on embedding matrix??
# Workaround randomnly fixes it for torch versions < 2.2
model.model.embed_tokens = torch.nn.Embedding.from_pretrained(model.model.embed_tokens.weight)
model.config.update({"unsloth_version" : __version__})
# We also do this for the lm_head
lm_head = torch.nn.Linear(1, 1, bias = None)
del lm_head.weight
lm_head.weight = model.lm_head.weight
lm_head.in_features = lm_head.weight.shape[1]
lm_head.out_features = lm_head.weight.shape[0]
model.lm_head = lm_head
# Gemma has tied weights! This means lm_head == embed_tokens
if model.model.embed_tokens.weight.data_ptr() != model.lm_head.weight.data_ptr():
lm_head = torch.nn.Linear(1, 1, bias = None)
del lm_head.weight
lm_head.weight = model.model.embed_tokens.weight
lm_head.in_features = lm_head.weight.shape[1]
lm_head.out_features = lm_head.weight.shape[0]
model.lm_head = lm_head
pass
# Also patch all dtypes - BnB seems to not allocate the correct type?
# BnB default dtype seems to be float16!
correct_dtype = lm_head.weight.dtype
for name, module in model.named_modules():
if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)):
weight = module.weight
quant_state = weight.quant_state
if type(quant_state) is list:
# BnB seems to have float16 as default!
module.weight.quant_state[2] = correct_dtype # Cast to correct dtype
else:
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
quant_state.dtype = correct_dtype
pass
pass
# Downcast RoPE embedding to correct data type
# RoPE must be done in float32 for Gemma
# if (name.endswith("rotary_emb") or hasattr(module, "cos_cached")) \
# and (module.cos_cached.dtype != correct_dtype):
# module.cos_cached = module.cos_cached.to(correct_dtype)
# module.sin_cached = module.sin_cached.to(correct_dtype)
# pass
# pass
pass
# Add 1 to weight
# return output * (1 + self.weight)
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py#L89
from transformers.models.gemma.modeling_gemma import GemmaRMSNorm
# Freeze all parameters except LoRA
# We do this first since += 1 seems to not be liked by requires_grad = True
for name, param in model.named_parameters():
if ".lora_A." in name or ".lora_B." in name:
param.requires_grad_(True)
else:
param.requires_grad_(False)
pass
# Patch RMS Layernorm
for name, module in model.named_modules():
if isinstance(module, GemmaRMSNorm):
# Must be in float32
# https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L36
# module = module.to(torch.float32)
# Leave + 1 to Triton kernel itself
# module.weight += 1.0 # return output * (1 + self.weight)
if not hasattr(module, "variance_epsilon"):
module.variance_epsilon = module.eps # Gemma doesn't use variance_epsilon
pass
# Clear deleted GPU items
import gc
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
return model
pass
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .llama import *
from ._utils import __version__
from .gemma import (
GemmaFixedRotaryEmbedding,
GemmaFixedLinearScalingRotaryEmbedding,
fast_geglu_inference,
)
try:
from transformers.models.gemma2.modeling_gemma2 import (
Gemma2Attention,
Gemma2DecoderLayer,
Gemma2Model,
Gemma2ForCausalLM,
Gemma2RotaryEmbedding,
apply_rotary_pos_emb,
repeat_kv,
)
except:
from packaging.version import Version
transformers_version = Version(transformers_version)
if not transformers_version >= Version("4.42"):
raise ImportError(
f"Unsloth: Your transformers version of {transformers_version} does not support Gemma2.\n"\
f"The minimum required version is 4.42.3.\n"\
f'Try `pip install --upgrade "transformers>=4.42.3"`\n'\
f"to obtain the latest transformers build, then restart this session."\
)
pass
pass
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask_for_sdpa,
)
# For Pytorch 2.1.1
try:
from transformers.models.gemma2.modeling_gemma2 import (
Gemma2SdpaAttention,
Gemma2FlashAttention2,
)
except:
Gemma2SdpaAttention = Gemma2Attention
Gemma2FlashAttention2 = Gemma2Attention
pass
# [TODO] We must randomnly use torch.compile?
# I checked the gradients and formulas and I'm sure it's correct.
# I'm stumped :(
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
def fast_rms_layernorm_gemma2_compiled(layernorm, X, gemma = True):
old_dtype = X.dtype
X = X.float()
X = X * torch.rsqrt(X.square().mean(-1, keepdim = True) + layernorm.eps) * \
(1.0 + layernorm.weight.float())
return X.to(old_dtype)
pass
# Logit softcapping
def Gemma2Attention_fast_forward(
self,
hidden_states: torch.Tensor,
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
*args, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# Clear inference
if hasattr(self, "paged_attention"):
del self.paged_attention_K
del self.paged_attention_V
del self.paged_attention
del self.temp_QA
del self.temp_KV
del self.RH_Q
del self.attention
pass
bsz, q_len, _ = hidden_states.size()
n_heads = self.num_heads
n_groups = self.num_key_value_groups
n_kv_heads = self.num_key_value_heads
head_dim = self.head_dim
assert(n_kv_heads * n_groups == n_heads)
Q, K, V = self.apply_qkv(self, hidden_states)
Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
kv_seq_len = K.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
if position_ids is None:
cos = self.rotary_emb.cos_cached
sin = self.rotary_emb.sin_cached
Q, K = fast_rope_embedding(Q, K, cos, sin)
else:
cos, sin = self.rotary_emb(V, seq_len = kv_seq_len)
Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
pass
if past_key_value is not None:
K = torch.cat([past_key_value[0], K], dim = 2)
V = torch.cat([past_key_value[1], V], dim = 2)
pass
past_key_value = (K, V) if use_cache else None
A = slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, kv_seq_len)
A = self.apply_o(self, A)
return A, None, past_key_value
pass
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L590
def Gemma2DecoderLayer_fast_forward(
self,
hidden_states: torch.Tensor,
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
*args, **kwargs,
):
if use_cache and hasattr(self, "_flag_for_generation"): #past_key_value is not None:
out_weight = torch.empty(self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda:0")
# Self Attention
residual = hidden_states
hidden_states = fast_rms_layernorm_inference_gemma(self.input_layernorm, hidden_states, out_weight)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)
hidden_states = fast_rms_layernorm_inference_gemma(self.post_attention_layernorm, hidden_states, out_weight)
hidden_states += residual
# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm_inference_gemma(self. pre_feedforward_layernorm, hidden_states, out_weight)
hidden_states = fast_geglu_inference(self.mlp, hidden_states)
hidden_states = fast_rms_layernorm_inference_gemma(self.post_feedforward_layernorm, hidden_states, out_weight)
hidden_states += residual
else:
residual = hidden_states
hidden_states = fast_rms_layernorm_gemma2_compiled(self.input_layernorm, hidden_states, gemma = True)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)
hidden_states = fast_rms_layernorm_gemma2_compiled(self.post_attention_layernorm, hidden_states, gemma = True)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm_gemma2_compiled(self. pre_feedforward_layernorm, hidden_states, gemma = True)
hidden_states = self.mlp(hidden_states)
hidden_states = fast_rms_layernorm_gemma2_compiled(self.post_feedforward_layernorm, hidden_states, gemma = True)
hidden_states = residual + hidden_states
pass
outputs = (hidden_states,)
if output_attentions: outputs += (self_attn_weights,)
if use_cache: outputs += (present_key_value,)
return outputs
pass
from math import sqrt as math_sqrt
KV_CACHE_INCREMENT = 256 # KV Cache update size
torch_nn_functional_softmax = torch.nn.functional.softmax
def Gemma2Attention_fast_forward_inference(
self,
hidden_states: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]],
position_ids,
do_prefill = False,
attention_mask = None,
use_sliding_window = False,
):
Xn = hidden_states
bsz, _, hd = hidden_states.size()
K1, V1 = past_key_value
dtype = Xn.dtype
n_heads = self.num_heads
n_groups = self.num_key_value_groups
n_kv_heads = self.num_key_value_heads
head_dim = self.head_dim
attention_size = n_heads*head_dim
# assert(n_kv_heads * n_groups == n_heads)
seq_len = K1.shape[-2]
kv_seq_len = seq_len + 1
# Prefill phase
# if not hasattr(self, "paged_attention"):
if do_prefill:
self.paged_attention = torch.empty((KV_CACHE_INCREMENT+seq_len+1, 2, bsz, n_kv_heads, head_dim), dtype = dtype, device = "cuda:0")
self.paged_attention_K = self.paged_attention[:,0]
self.paged_attention_V = self.paged_attention[:,1]
self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = "cuda:0")
self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0")
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
# Only for Gemma2
self.temp_O = torch.empty((1, bsz, self.hidden_size), dtype = dtype, device = "cuda:0")
self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0")
# See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
# Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below
# We default to using the config file itself
# s = self.config.hidden_size // self.config.num_attention_heads
self.scalar = 1.0 / math_sqrt(self.config.query_pre_attn_scalar)
# self.scalar = 1.0 / math_sqrt(self.config.hidden_size // self.config.num_attention_heads)
self.half_head_dim = head_dim // 2
self. t = self.config.attn_logit_softcapping
self.reciprocal_t = 1.0 / self.config.attn_logit_softcapping
elif kv_seq_len >= self.paged_attention.shape[0]:
self.paged_attention.resize_((self.paged_attention.shape[0]+KV_CACHE_INCREMENT, 2, bsz, n_kv_heads, head_dim))
self.paged_attention_K = self.paged_attention[:,0]
self.paged_attention_V = self.paged_attention[:,1]
self.attention.resize_((bsz, n_heads, 1, self.attention.shape[-1]+KV_CACHE_INCREMENT))
pass
Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
# cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len)
# Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids)
cos = self.rotary_emb.cos_cached[position_ids].unsqueeze(1)
sin = self.rotary_emb.sin_cached[position_ids].unsqueeze(1)
h = self.half_head_dim
RH_Q = self.RH_Q
RH_Q[:,:,:,:h] = Qn[:,:,:,h:]
RH_Q[:,:,:,h:] = Qn[:,:,:,:h]
torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h])
Qn *= cos
Qn.addcmul_(RH_Q, sin)
RH_K = RH_Q[:,:n_kv_heads,:,:] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
RH_K[:,:,:,:h] = Kn[:,:,:,h:]
RH_K[:,:,:,h:] = Kn[:,:,:,:h]
torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])
Kn *= cos
Kn.addcmul_(RH_K, sin)
# New KV cache
# Kn = torch.cat([K1, Kn], dim = 2)
# Vn = torch.cat([V1, Vn], dim = 2)
self.paged_attention_K[seq_len] = Kn.permute(2, 0, 1, 3)
self.paged_attention_V[seq_len] = Vn.permute(2, 0, 1, 3)
Kn = self.paged_attention_K[:kv_seq_len].permute(1, 2, 0, 3)
Vn = self.paged_attention_V[:kv_seq_len].permute(1, 2, 0, 3)
# Handle sliding windows
sliding_window = self.config.sliding_window
if use_sliding_window and kv_seq_len > sliding_window:
# From https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L193
slicing_tokens = 1 - sliding_window
Knn = Kn[:, :, slicing_tokens:, :]#.contiguous()
Vnn = Vn[:, :, slicing_tokens:, :]#.contiguous()
else:
Knn, Vnn = Kn, Vn
pass
# Grouped query attention
_, _, cached_len, _ = Knn.shape
if n_groups != 1:
Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)
Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim)
pass
# else:
# Knn, Vnn = Knn, Vnn
# pass
# Attention
# if bsz == 1:
Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
# It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows
A = torch.matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len])
# if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
A *= self.reciprocal_t; torch.tanh(A, out = A); A *= self.t; # Logit softcapping
A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype)
A = torch.matmul(A, Vnn, out = Qn)
# else:
# A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False)
# pass
A = A.transpose(1, 2)
A = A.reshape(bsz, 1, attention_size)
A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
return A, (Kn, Vn)
pass
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825
# @torch.inference_mode
def Gemma2Model_fast_forward_inference(
self,
input_ids,
past_key_values,
position_ids,
attention_mask = None,
):
out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda:0")
input_ids = input_ids[:,:self.max_seq_length]
hidden_states = self.model.embed_tokens(input_ids)
hidden_states = hidden_states.to(self.config.torch_dtype)
# 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32
# 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32
hidden_states *= torch.tensor(math_sqrt(self.config.hidden_size), dtype = hidden_states.dtype)
bsz, q_len, hd = hidden_states.shape
seq_len = past_key_values[0][0].shape[-2]
if bsz != 1:
SWA = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(bsz, q_len),
hidden_states,
seq_len,
sliding_window = self.config.sliding_window,
)
GA = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(bsz, q_len),
hidden_states,
seq_len,
)
else:
SWA = attention_mask
GA = attention_mask
pass
next_decoder_cache = []
for idx, decoder_layer in enumerate(self.model.layers):
use_sliding_window = idx % 2 == 0
residual = hidden_states
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.input_layernorm, hidden_states, out_weight)
hidden_states, present_key_value = Gemma2Attention_fast_forward_inference(
decoder_layer.self_attn,
hidden_states = hidden_states,
past_key_value = past_key_values[idx],
position_ids = position_ids,
attention_mask = SWA if use_sliding_window else GA,
do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
use_sliding_window = use_sliding_window,
)
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_attention_layernorm, hidden_states, out_weight)
hidden_states += residual
residual = hidden_states
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer. pre_feedforward_layernorm, hidden_states, out_weight)
hidden_states = fast_geglu_inference(decoder_layer.mlp, hidden_states)
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_feedforward_layernorm, hidden_states, out_weight)
hidden_states += residual
next_decoder_cache.append(present_key_value)
pass
hidden_states = fast_rms_layernorm_inference_gemma(self.model.norm, hidden_states, out_weight)
return BaseModelOutputWithPast(
last_hidden_state = hidden_states,
past_key_values = next_decoder_cache,
hidden_states = [],
attentions = [],
)
pass
class FastGemma2Model(FastLlamaModel):
@staticmethod
def pre_patch():
init_name, function = patch_linear_scaling(
model_name = "gemma2",
rope_module = GemmaFixedRotaryEmbedding,
scaled_rope_module = GemmaFixedLinearScalingRotaryEmbedding,
attention_module = Gemma2Attention,
)
if init_name is not None:
exec(function, globals())
Gemma2Attention.__init__ = eval(init_name)
pass
Gemma2Attention .forward = Gemma2Attention_fast_forward
Gemma2SdpaAttention .forward = Gemma2Attention_fast_forward
Gemma2FlashAttention2.forward = Gemma2Attention_fast_forward
Gemma2DecoderLayer .forward = Gemma2DecoderLayer_fast_forward
Gemma2Model .forward = LlamaModel_fast_forward
Gemma2ForCausalLM .forward = CausalLM_fast_forward(Gemma2Model_fast_forward_inference)
PeftModelForCausalLM .forward = PeftModelForCausalLM_fast_forward
fix_prepare_inputs_for_generation(Gemma2ForCausalLM)
# Solves https://github.com/unslothai/unsloth/issues/168
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
# Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
# https://github.com/huggingface/transformers/pull/27931
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
import transformers.models.gemma2.modeling_gemma2
transformers.models.gemma2.modeling_gemma2.Gemma2RotaryEmbedding = GemmaFixedRotaryEmbedding
return
pass
@staticmethod
def post_patch(model):
# Patch model for Gemma
layers = model.model.layers
# Torch.compile fails on embedding matrix??
# Workaround randomnly fixes it for torch versions < 2.2
model.model.embed_tokens = torch.nn.Embedding.from_pretrained(model.model.embed_tokens.weight)
model.config.update({"unsloth_version" : __version__})
# We also do this for the lm_head
lm_head = torch.nn.Linear(1, 1, bias = None)
del lm_head.weight
lm_head.weight = model.lm_head.weight
lm_head.in_features = lm_head.weight.shape[1]
lm_head.out_features = lm_head.weight.shape[0]
model.lm_head = lm_head
# Gemma has tied weights! This means lm_head == embed_tokens
if model.model.embed_tokens.weight.data_ptr() != model.lm_head.weight.data_ptr():
lm_head = torch.nn.Linear(1, 1, bias = None)
del lm_head.weight
lm_head.weight = model.model.embed_tokens.weight
lm_head.in_features = lm_head.weight.shape[1]
lm_head.out_features = lm_head.weight.shape[0]
model.lm_head = lm_head
pass
# Also patch all dtypes - BnB seems to not allocate the correct type?
# BnB default dtype seems to be float16!
correct_dtype = lm_head.weight.dtype
for name, module in model.named_modules():
if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)):
weight = module.weight
quant_state = weight.quant_state
if type(quant_state) is list:
# BnB seems to have float16 as default!
module.weight.quant_state[2] = correct_dtype # Cast to correct dtype
else:
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
quant_state.dtype = correct_dtype
pass
pass
# Downcast RoPE embedding to correct data type
# RoPE must be done in float32 for Gemma
# if (name.endswith("rotary_emb") or hasattr(module, "cos_cached")) \
# and (module.cos_cached.dtype != correct_dtype):
# module.cos_cached = module.cos_cached.to(correct_dtype)
# module.sin_cached = module.sin_cached.to(correct_dtype)
# pass
# pass
pass
# Add 1 to weight
# return output * (1 + self.weight)
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py#L89
from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm
# Freeze all parameters except LoRA
# We do this first since += 1 seems to not be liked by requires_grad = True
for name, param in model.named_parameters():
if ".lora_A." in name or ".lora_B." in name:
param.requires_grad_(True)
else:
param.requires_grad_(False)
pass
# Patch RMS Layernorm
for name, module in model.named_modules():
if isinstance(module, Gemma2RMSNorm):
# Must be in float32
# https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L36
# module = module.to(torch.float32)
# Leave + 1 to Triton kernel itself
# module.weight += 1.0 # return output * (1 + self.weight)
if not hasattr(module, "variance_epsilon"):
module.variance_epsilon = module.eps # Gemma doesn't use variance_epsilon
pass
# Clear deleted GPU items
import gc
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
return model
pass
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import gc
from typing import Optional, Tuple, List, Union
from ._utils import *
from ._utils import __version__
from torch.nn.functional import scaled_dot_product_attention
from transformers.models.llama.modeling_llama import (
logger,
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask_for_sdpa,
)
from ..kernels import *
from ..tokenizer_utils import *
if HAS_FLASH_ATTENTION:
from flash_attn import flash_attn_func
# Final patching code
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaModel,
LlamaForCausalLM,
)
# For Pytorch 2.1.1
try:
from transformers.models.llama.modeling_llama import (
LlamaSdpaAttention,
LlamaFlashAttention2,
)
except:
LlamaSdpaAttention = LlamaAttention
LlamaFlashAttention2 = LlamaAttention
pass
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoConfig
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING
from transformers import set_seed as transformers_set_seed
from peft import LoraConfig, TaskType, get_peft_model as _get_peft_model
from peft import PeftModelForCausalLM
from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit
from peft.tuners.lora import Linear4bit as Peft_Linear4bit
from ..save import patch_saving_functions
import re, os, inspect, math, sys
def original_apply_qkv(self, X):
Q = self.q_proj(X)
K = self.k_proj(X)
V = self.v_proj(X)
return Q, K, V
pass
def original_apply_o(self, X):
O = self.o_proj(X)
return O
pass
from math import sqrt as math_sqrt
KV_CACHE_INCREMENT = 256 # KV Cache update size
torch_nn_functional_softmax = torch.nn.functional.softmax
# Fix new HF's inference code
def _fast_prepare_inputs_for_generation(self, input_ids, **kwargs,):
if "past_key_values" in kwargs:
input_ids = input_ids[:,[-1]]
kwargs["attention_mask"] = kwargs["attention_mask"][:,[-1]]
if "cache_position" in kwargs:
kwargs["position_ids"] = kwargs["cache_position"]
return { "input_ids" : input_ids, **kwargs, }
pass
def fix_prepare_inputs_for_generation(module):
# Fix prepare_inputs_for_generation
if hasattr(module, "prepare_inputs_for_generation"):
module.prepare_inputs_for_generation = _fast_prepare_inputs_for_generation
pass
pass
torch_matmul = torch.matmul
def LlamaAttention_fast_forward_inference(
self,
hidden_states: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]],
position_ids,
do_prefill = False,
attention_mask = None,
):
"""
https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L406
Fast inference using KV cache.
QK^T can be computed in 4 chunks
[Q, q] @ [K, k].T where q, k are the new tokens.
[QK^T, Qk^T]
[qK^T, qk^T]
Since the attention mask wipes Qk^T, we just get
[QK^T, 0]
[qK^T, qk^T]
Since softmax is row-wise, we get
softmax([QK^T, 0])
softmax([qK^T, qk^T])
We then multiply by [V]
[v]
softmax([QK^T, 0]) [softmax(QK^T)V] *
softmax([qK^T, qk^T]) [softmax([qK^T, qk^T]) @ [V, v]]
But notice * [softmax(QK^T)V] is just the last attention.
We just need to compute the last final row.
This means we can pass in a row of Q, but we need to
remember K and V, which are called the KV cache.
"""
Xn = hidden_states
bsz, _, hd = hidden_states.size()
K1, V1 = past_key_value
dtype = Xn.dtype
n_heads = self.num_heads
n_groups = self.num_key_value_groups
n_kv_heads = self.num_key_value_heads
head_dim = self.head_dim
attention_size = n_heads*head_dim
# assert(n_kv_heads * n_groups == n_heads)
seq_len = K1.shape[-2]
kv_seq_len = seq_len + 1
# Prefill phase
# if not hasattr(self, "paged_attention"):
if do_prefill:
self.paged_attention = torch.empty((KV_CACHE_INCREMENT+seq_len+1, 2, bsz, n_kv_heads, head_dim), dtype = dtype, device = "cuda:0")
self.paged_attention_K = self.paged_attention[:,0]
self.paged_attention_V = self.paged_attention[:,1]
self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = "cuda:0")
self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0")
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
# Mistral Nemo 12b has weird dimensions
if attention_size != self.hidden_size:
self.temp_O = torch.empty((1, bsz, self.hidden_size), dtype = dtype, device = "cuda:0")
else:
self.temp_O = self.temp_QA[1][:,:,:self.hidden_size]
pass
self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0")
self.scalar = 1.0 / math_sqrt(self.head_dim)
self.half_head_dim = head_dim // 2
elif kv_seq_len >= self.paged_attention.shape[0]:
self.paged_attention.resize_((self.paged_attention.shape[0]+KV_CACHE_INCREMENT, 2, bsz, n_kv_heads, head_dim))
self.paged_attention_K = self.paged_attention[:,0]
self.paged_attention_V = self.paged_attention[:,1]
self.attention.resize_((bsz, n_heads, 1, self.attention.shape[-1]+KV_CACHE_INCREMENT))
pass
Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
# cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len)
# Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids)
cos = self.rotary_emb.cos_cached[position_ids].unsqueeze(1)
sin = self.rotary_emb.sin_cached[position_ids].unsqueeze(1)
h = self.half_head_dim
RH_Q = self.RH_Q
RH_Q[:,:,:,:h] = Qn[:,:,:,h:]
RH_Q[:,:,:,h:] = Qn[:,:,:,:h]
torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h])
Qn *= cos
Qn.addcmul_(RH_Q, sin)
RH_K = RH_Q[:,:n_kv_heads,:,:] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
RH_K[:,:,:,:h] = Kn[:,:,:,h:]
RH_K[:,:,:,h:] = Kn[:,:,:,:h]
torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])
Kn *= cos
Kn.addcmul_(RH_K, sin)
# New KV cache
# Kn = torch.cat([K1, Kn], dim = 2)
# Vn = torch.cat([V1, Vn], dim = 2)
self.paged_attention_K[seq_len] = Kn.permute(2, 0, 1, 3)
self.paged_attention_V[seq_len] = Vn.permute(2, 0, 1, 3)
Kn = self.paged_attention_K[:kv_seq_len].permute(1, 2, 0, 3)
Vn = self.paged_attention_V[:kv_seq_len].permute(1, 2, 0, 3)
# Handle sliding windows
sliding_window = getattr(self.config, "sliding_window", None)
if sliding_window is not None and kv_seq_len > sliding_window:
# From https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L193
slicing_tokens = 1 - sliding_window
Knn = Kn[:, :, slicing_tokens:, :]#.contiguous()
Vnn = Vn[:, :, slicing_tokens:, :]#.contiguous()
else:
Knn, Vnn = Kn, Vn
pass
# Grouped query attention
_, _, cached_len, _ = Knn.shape
if n_groups != 1:
Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)
Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim)
pass
# else:
# Knn, Vnn = Knn, Vnn
# pass
# Attention
if bsz == 1:
Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
# It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows
A = torch_matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len])
# if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype)
A = torch_matmul(A, Vnn, out = Qn)
else:
A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False)
pass
A = A.transpose(1, 2)
A = A.reshape(bsz, 1, attention_size)
A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
return A, (Kn, Vn)
pass
torch_nn_functional_silu = torch.nn.functional.silu
def fast_swiglu_inference(self, X):
# gate = self.gate_proj(X)
# up = self.up_proj(X)
bsz, _, hd = X.shape
# mlp_size = self.config.intermediate_size
# temp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0")
gate = fast_linear_forward(self.gate_proj, X)#, out = temp[0])
up = fast_linear_forward(self. up_proj, X)#, out = temp[1])
gate = torch_nn_functional_silu(gate, inplace = True)
gate *= up
# X = self.down_proj(gate)
down = fast_linear_forward(self.down_proj, gate, out = up[:,:,:hd])
return down
pass
def fast_rms_layernorm_inference(self, X):
old_dtype = X.dtype
XX = X.to(torch.float32)
variance = XX.square().mean(-1, keepdim = True)
variance += self.variance_epsilon
XX *= variance.rsqrt_()
X = XX.to(old_dtype) # Must preserve due to residual
X *= self.weight
return X
pass
def fast_rms_layernorm_inference_gemma(self, X, out_weight = None):
XX = X.to(torch.float32)
variance = XX.square().mean(-1, keepdim = True)
variance += self.variance_epsilon
XX *= variance.rsqrt_()
if out_weight is None:
out_weight = self.weight + 1.0
else:
out_weight[:] = self.weight
out_weight += 1.0
pass
XX *= out_weight
return XX.to(X.dtype)
pass
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L320
def LlamaAttention_fast_forward(
self,
hidden_states: torch.Tensor,
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
*args, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# Clear inference
if hasattr(self, "paged_attention"):
del self.paged_attention_K
del self.paged_attention_V
del self.paged_attention
del self.temp_QA
del self.temp_KV
del self.RH_Q
del self.attention
pass
bsz, q_len, _ = hidden_states.size()
n_heads = self.num_heads
n_groups = self.num_key_value_groups
n_kv_heads = self.num_key_value_heads
head_dim = self.head_dim
assert(n_kv_heads * n_groups == n_heads)
Q, K, V = self.apply_qkv(self, hidden_states)
Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
kv_seq_len = K.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
# Extend RoPE dynamically to fit in VRAM
self.rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len)
if position_ids is None:
cos = self.rotary_emb.cos_cached
sin = self.rotary_emb.sin_cached
Q, K = fast_rope_embedding(Q, K, cos, sin)
else:
cos, sin = self.rotary_emb(V, seq_len = kv_seq_len)
Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
pass
if past_key_value is not None:
K = torch.cat([past_key_value[0], K], dim = 2)
V = torch.cat([past_key_value[1], V], dim = 2)
pass
past_key_value = (K, V) if use_cache else None
# Attention module
if (not HAS_FLASH_ATTENTION and attention_mask is None):
# Xformers memory efficient attention
# Also has Flash Attention v2 dispatching
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
# Group query attention
if n_groups != 1:
K = K .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
V = V .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
if hidden_states.requires_grad:
K = K.reshape(bsz, kv_seq_len, n_heads, head_dim)
V = V.reshape(bsz, kv_seq_len, n_heads, head_dim)
else:
Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
pass
A = xformers_attention(Q, K, V, attn_bias = causal_mask)
A = A.view(bsz, q_len, n_heads, head_dim)
elif HAS_FLASH_ATTENTION and attention_mask is None:
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
A = flash_attn_func(Q, K, V, causal = True)
else:
# Grouped query attention
if n_groups != 1:
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
K = K.reshape(bsz, n_heads, kv_seq_len, head_dim)
V = V.reshape(bsz, n_heads, kv_seq_len, head_dim)
pass
# Must be contiguous or else results are False!
# https://github.com/pytorch/pytorch/issues/112577
Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous()
# Needs (batch_size, n_heads, seq_len, head_dim)
# is_casual and attention_mask must not be both set!
A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False)
# Go back to (batch_size, seq_len, n_heads, head_dim)
A = A.transpose(1, 2).contiguous()
pass
attn_output = A.reshape(bsz, q_len, n_heads*head_dim)
attn_output = self.apply_o(self, attn_output)
attn_weights = None
return attn_output, attn_weights, past_key_value
pass
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L590
def LlamaDecoderLayer_fast_forward(
self,
hidden_states: torch.Tensor,
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
*args, **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
if use_cache and hasattr(self, "_flag_for_generation"):
residual = hidden_states
hidden_states = fast_rms_layernorm_inference(self.input_layernorm, hidden_states)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)
hidden_states += residual
# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm_inference(self.post_attention_layernorm, hidden_states)
hidden_states = fast_swiglu_inference(self.mlp, hidden_states)
hidden_states += residual
else:
residual = hidden_states
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
pass
outputs = (hidden_states,)
if output_attentions: outputs += (self_attn_weights,)
if use_cache: outputs += (present_key_value,)
return outputs
pass
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825
def LlamaModel_fast_forward(
self,
input_ids: torch.LongTensor,
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
*args, **kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
assert(output_attentions is False)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("Unsloth: You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("Unsloth: You have to specify either decoder_input_ids or decoder_inputs_embeds")
seq_length_with_past = seq_length
# Fix out of bounds tokenization
if hasattr(self, "max_seq_length"):
if seq_length > self.max_seq_length:
logger.warning_once(
f"Unsloth: Input IDs of length {seq_length} > the model's max sequence length of {self.max_seq_length}.\n"\
"We shall truncate it ourselves. It's imperative if you correct this issue first."
)
if input_ids is not None:
input_ids = input_ids[:,:self.max_seq_length]
elif inputs_embeds is not None:
inputs_embeds = inputs_embeds[:,:self.max_seq_length,:]
pass
pass
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
pass
# We already handle KV cache position_ids ourselves.
if False:#(past_key_values_length != 0):
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length,
dtype = torch.int32,
device = "cuda:0",
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
elif position_ids is not None:
position_ids = position_ids.view(-1, seq_length).to(torch.int32)#.long()
else:
position_ids = None
pass
if position_ids is not None:
if position_ids.shape[0] != batch_size:
position_ids = position_ids.repeat((batch_size, 1))
pass
# Embed positions
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
inputs_embeds = inputs_embeds.to(self.config.torch_dtype)
# Normalized from Gemma
IS_GEMMA = self.config.model_type.startswith("gemma")
IS_GEMMA2 = self.config.model_type.startswith("gemma2")
train_embed_tokens = self.embed_tokens.weight.requires_grad
if IS_GEMMA:
# Match Gemma exactly by casting to bfloat16 / float16
# inputs_embeds *= math_sqrt(self.config.hidden_size)
# Ie 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32
# & 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32
normalizer = torch.tensor(math_sqrt(self.config.hidden_size), dtype = inputs_embeds.dtype)
if train_embed_tokens:
# Careful we must not do an inplace op!
inputs_embeds = inputs_embeds * normalizer
else:
inputs_requires_grad = inputs_embeds.requires_grad
if not inputs_embeds.is_leaf:
inputs_embeds = inputs_embeds.detach()
inputs_requires_grad = True
elif inputs_requires_grad:
inputs_embeds.requires_grad_(False)
pass
inputs_embeds *= normalizer
# inputs_embeds *= math_sqrt(self.config.hidden_size)
if inputs_requires_grad: inputs_embeds.requires_grad_(True)
pass
pass
# Fix up attention mask by setting elements to 0
# Specifically for DPO
if self._has_no_labels and (attention_mask is not None) and (past_key_values is None) and \
(not train_embed_tokens):
# Careful for inference the attention_mask is size (1, kv_seq_len)
# Whilst the input_embeds is size (1, 1, 4096)
inputs_requires_grad = inputs_embeds.requires_grad
if not inputs_embeds.is_leaf:
inputs_embeds = inputs_embeds.detach()
inputs_requires_grad = True
elif inputs_requires_grad:
inputs_embeds.requires_grad_(False)
pass
inputs_embeds *= attention_mask.unsqueeze(0).transpose(0, 1).transpose(1, 2)
if inputs_requires_grad: inputs_embeds.requires_grad_(True)
pass
# Ignore attention_mask
if attention_mask is None:
padding_mask = None
elif self.training:
attention_mask = None
padding_mask = None
else:
# if 0 in attention_mask:
# padding_mask = attention_mask
# else:
padding_mask = None
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window = getattr(self.config, "sliding_window", None),
)
pass
hidden_states = inputs_embeds
if past_key_values is None and self.training:
use_cache = False
# if use_cache:
# logger.warning_once(
# "Unsloth: `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`"
# )
# use_cache = False
pass
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
# Gradient checkpointing methods (ie sqrt)
if hasattr(self, "_gradient_checkpointing_boundaries"):
boundaries = self._gradient_checkpointing_boundaries
else:
boundaries = None
pass
# Check checkpointing method
gradient_checkpointing = False
offloaded_gradient_checkpointing = False
if (self.gradient_checkpointing and self.training and not use_cache):
gradient_checkpointing = True
if output_attentions is False and hasattr(self, "_offloaded_gradient_checkpointing"):
offloaded_gradient_checkpointing = True
pass
# Check for Flex Attention
# if IS_GEMMA2 and HAS_FLEX_ATTENTION:
# if not (seq_length % FLEX_ATTENTION_PADDING == 0):
# USE_FLEX_ATTENTION = True
# Gemma2 has alternating SWA and global attn
if IS_GEMMA2 and not hasattr(self, "SWA_mask"):
n = self.config.max_position_embeddings
# masked_fill is making stuff slower!
# self. GA_mask = create_boolean_mask(n = n, sliding_window = 0)
# self.SWA_mask = create_boolean_mask(n = n, sliding_window = self.config.sliding_window)
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
self.SWA_mask = AttentionMaskConverter(
is_causal = True,
sliding_window = self.config.sliding_window,
)\
.to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda:0",)\
.squeeze(0).squeeze(0)
self.GA_mask = AttentionMaskConverter(
is_causal = True,
)\
.to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda:0",)\
.squeeze(0).squeeze(0)
pass
# Go through every layer!
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states: all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
mask = causal_mask
if IS_GEMMA2: mask = self.SWA_mask if (idx % 2 == 0) else self.GA_mask
if offloaded_gradient_checkpointing:
hidden_states = Unsloth_Offloaded_Gradient_Checkpointer.apply(
decoder_layer,
hidden_states,
mask,
attention_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
)[0]
elif gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions, padding_mask = padding_mask)
return custom_forward
pass
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
mask,
attention_mask,
position_ids,
use_reentrant = True,
preserve_rng_state = False,
)
hidden_states = layer_outputs[0]
else:
layer_outputs = decoder_layer(
hidden_states,
causal_mask=mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)
hidden_states = layer_outputs[0]
pass
if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions: all_self_attns += (layer_outputs[1],)
pass
# Final layernorm
if use_cache:
hidden_states = (fast_rms_layernorm_inference_gemma if IS_GEMMA else fast_rms_layernorm_inference)\
(self.norm, hidden_states)
else:
hidden_states = fast_rms_layernorm(self.norm, hidden_states, gemma = IS_GEMMA)
pass
if output_hidden_states: all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
pass
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825
def LlamaModel_fast_forward_inference(
self,
input_ids,
past_key_values,
position_ids,
attention_mask = None,
):
input_ids = input_ids[:,:self.max_seq_length]
hidden_states = self.model.embed_tokens(input_ids)
hidden_states = hidden_states.to(self.config.torch_dtype)
bsz, q_len, hd = hidden_states.shape
seq_len = past_key_values[0][0].shape[-2]
if bsz != 1:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(bsz, q_len),
hidden_states,
seq_len,
sliding_window = getattr(self.config, "sliding_window", None),
)
else:
attention_mask = None
pass
next_decoder_cache = []
for idx, decoder_layer in enumerate(self.model.layers):
residual = hidden_states
hidden_states = fast_rms_layernorm_inference(decoder_layer.input_layernorm, hidden_states)
hidden_states, present_key_value = LlamaAttention_fast_forward_inference(
decoder_layer.self_attn,
hidden_states = hidden_states,
past_key_value = past_key_values[idx],
position_ids = position_ids,
attention_mask = attention_mask,
do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
)
hidden_states += residual
residual = hidden_states
hidden_states = fast_rms_layernorm_inference(decoder_layer.post_attention_layernorm, hidden_states)
hidden_states = fast_swiglu_inference(decoder_layer.mlp, hidden_states)
hidden_states += residual
next_decoder_cache.append(present_key_value)
pass
hidden_states = fast_rms_layernorm_inference(self.model.norm, hidden_states)
return BaseModelOutputWithPast(
last_hidden_state = hidden_states,
past_key_values = next_decoder_cache,
hidden_states = [],
attentions = [],
)
pass
def CausalLM_fast_forward(fast_forward_inference):
def _CausalLM_fast_forward(
self,
input_ids: torch.LongTensor = None,
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
*args, **kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
if past_key_values is not None:
outputs = fast_forward_inference(
self,
input_ids,
past_key_values,
position_ids = position_ids,
attention_mask = attention_mask,
)
else:
causal_mask = xformers.attn_bias.LowerTriangularMask()
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
self.model._has_no_labels = labels is None
outputs = self.model(
input_ids=input_ids,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pass
hidden_states = outputs[0]
bsz, q_len, hd = hidden_states.shape
lm_head = self.lm_head.weight
if bsz == 1 and q_len == 1:
logits = torch.mv(lm_head, hidden_states.ravel().to(lm_head.dtype))
logits = logits.unsqueeze(0).unsqueeze(0)
else:
logits = self.lm_head(hidden_states.to(lm_head.dtype))
pass
logits = logits.to(self.config.torch_dtype)
loss = None
logit_softcapping = getattr(self.config, "final_logit_softcapping", 0)
if labels is not None:
shift_logits = logits
if not hasattr(self, "extra_ignored_labels"):
# Fixes https://github.com/unslothai/unsloth/issues/10
self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = "cuda:0")
pass
shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
loss = fast_cross_entropy_loss(
logits = shift_logits,
labels = shift_labels,
logit_softcapping = logit_softcapping,
)
elif logit_softcapping != 0:
if logits.requires_grad:
logits = (1.0 / logit_softcapping) * logits
logits = torch.tanh(logits)
logits = logit_softcapping * logits
else:
logits *= (1.0 / logit_softcapping)
torch.tanh(logits, out = logits)
logits *= logit_softcapping
pass
pass
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
pass
return _CausalLM_fast_forward
pass
def PeftModelForCausalLM_fast_forward(
self,
input_ids=None,
causal_mask=None,
attention_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
task_ids=None,
**kwargs,
):
return self.base_model(
input_ids=input_ids,
causal_mask=causal_mask,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
pass
# Solves https://github.com/unslothai/unsloth/issues/168
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
# Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
# https://github.com/huggingface/transformers/pull/27931
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
class LlamaRotaryEmbedding(torch.nn.Module):
# Fixes https://github.com/huggingface/transformers/pull/28837
# https://github.com/microsoft/DeepSpeed/issues/4932
# The precision of RoPE buffers is not correct, so we cast to int64.
def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None,
config = None, # [TODO] Hack to pass in config - need to remove later
):
super().__init__()
if config is not None:
# [TODO] Hack to pass in config - need to remove later
base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = int((config.hidden_size // config.num_attention_heads))
device = "cuda"
max_position_embeddings = config.max_position_embeddings
pass
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
# Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this
self.current_rope_size = min(4 * 8192, self.max_position_embeddings)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype())
pass
def _set_cos_sin_cache(self, seq_len, device, dtype):
# Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
# in FP32. They are applied (multiplied) in FP32 as well.
self.current_rope_size = seq_len
inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() / self.dim)
)
t = torch.arange(self.current_rope_size, device="cpu", dtype=torch.int64).float()
freqs = torch.outer(t, inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype=dtype, device=device, non_blocking=True), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype=dtype, device=device, non_blocking=True), persistent=False)
pass
def forward(self, x, position_ids=None, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.current_rope_size:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:seq_len].to(dtype = x.dtype),
self.sin_cached[:seq_len].to(dtype = x.dtype),
)
pass
def extend_rope_embedding(self, x, seq_len):
if seq_len <= self.current_rope_size: return
# Iteratively grow by increments of 8192
self.current_rope_size = int(round(seq_len / 8192)) * 8192
self._set_cos_sin_cache(self.current_rope_size, device = "cuda:0", dtype = x.dtype)
pass
pass
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
# Fixes https://github.com/huggingface/transformers/pull/28837
# https://github.com/microsoft/DeepSpeed/issues/4932
# The precision of RoPE buffers is not correct, so we cast to int64.
def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0,
config = None, # [TODO] Hack to pass in config - need to remove later
):
self.scaling_factor = scaling_factor
super().__init__(dim = dim, max_position_embeddings = max_position_embeddings, base = base, device = device, config = config)
pass
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.current_rope_size = seq_len
inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() / self.dim)
)
t = torch.arange(self.current_rope_size, device="cpu", dtype=torch.int64).float()
t = t / self.scaling_factor
freqs = torch.outer(t, inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype=dtype, device=device, non_blocking=True), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype=dtype, device=device, non_blocking=True), persistent=False)
pass
pass
# See https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py#L736
# For Llama 3.1
class LlamaExtendedRotaryEmbedding(torch.nn.Module):
def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None,
config = None, # [TODO] Hack to pass in config - need to remove later
):
super().__init__()
if config is not None:
# [TODO] Hack to pass in config - need to remove later
base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = int((config.hidden_size // config.num_attention_heads))
device = "cuda"
max_position_embeddings = config.max_position_embeddings
pass
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
# Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this
self.current_rope_size = min(4 * 8192, self.max_position_embeddings)
# Normal Llama-3 RoPE
inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() / self.dim)
)
inv_freq = self.apply_scaling(inv_freq)
self.register_buffer("inv_freq", inv_freq, persistent = False)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype())
pass
def _set_cos_sin_cache(self, seq_len, device, dtype):
# Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
# in FP32. They are applied (multiplied) in FP32 as well.
self.current_rope_size = seq_len
t = torch.arange(self.current_rope_size, device="cpu", dtype=torch.int64).float()
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype=dtype, device=device, non_blocking=True), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype=dtype, device=device, non_blocking=True), persistent=False)
pass
# From https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py#L41
def apply_scaling(self, freqs: torch.Tensor):
# Values obtained from grid search
scale_factor = 8
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192 # original llama3 length
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scale_factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
pass
def forward(self, x, position_ids=None, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.current_rope_size:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:seq_len].to(dtype = x.dtype),
self.sin_cached[:seq_len].to(dtype = x.dtype),
)
pass
def extend_rope_embedding(self, x, seq_len):
if seq_len <= self.current_rope_size: return
# Iteratively grow by increments of 8192
self.current_rope_size = int(round(seq_len / 8192)) * 8192
self._set_cos_sin_cache(self.current_rope_size, device = "cuda:0", dtype = x.dtype)
pass
pass
def _wrap_fast_inference(generate, device_type, dtype, model):
# Wraps inference with bfloat16 / float16
@torch.inference_mode
def _fast_generate(*args, **kwargs):
# Set a flag for generation!
internal_model = model
while hasattr(internal_model, "model"):
internal_model._flag_for_generation = True
internal_model = internal_model.model
pass
internal_model._flag_for_generation = True
# For newer HF
kwargs["cache_implementation"] = "dynamic"
# Remove token_type_ids
kwargs.pop("token_type_ids", None)
# Check pad_token
model_eos_token_id = getattr(model.config, "eos_token_id", None)
if model_eos_token_id is not None and hasattr(model_eos_token_id, "__iter__"):
model_eos_token_id = model_eos_token_id[0]
kwargs["pad_token_id"] = kwargs.pop("pad_token_id", model_eos_token_id)
# Set pad token
# old_pad_token_id = getattr(model.config, "pad_token_id", None)
# old_eos_token_id = getattr(model.config, "eos_token_id", None)
# model.config.pad_token_id = old_eos_token_id
# Autocasted
with torch.autocast(device_type = device_type, dtype = dtype):
output = generate(*args, **kwargs)
pass
# Revert
# model.config.pad_token_id = old_pad_token_id
# Unset a flag for generation!
internal_model = model
while hasattr(internal_model, "model"):
if hasattr(internal_model, "_flag_for_generation"): del internal_model._flag_for_generation
internal_model = internal_model.model
pass
if hasattr(internal_model, "_flag_for_generation"): del internal_model._flag_for_generation
return output
pass
return _fast_generate
pass
class FastLlamaModel:
@staticmethod
def pre_patch():
init_name, function = patch_llama_rope_scaling(
model_name = "llama",
rope_module = LlamaRotaryEmbedding,
scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
extended_rope_module = LlamaExtendedRotaryEmbedding,
attention_module = LlamaAttention,
)
if init_name is not None:
exec(function, globals())
LlamaAttention.__init__ = eval(init_name)
pass
LlamaAttention .forward = LlamaAttention_fast_forward
LlamaSdpaAttention .forward = LlamaAttention_fast_forward
LlamaFlashAttention2.forward = LlamaAttention_fast_forward
LlamaDecoderLayer .forward = LlamaDecoderLayer_fast_forward
LlamaModel .forward = LlamaModel_fast_forward
LlamaForCausalLM .forward = CausalLM_fast_forward(LlamaModel_fast_forward_inference)
PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward
fix_prepare_inputs_for_generation(LlamaForCausalLM)
# Solves https://github.com/unslothai/unsloth/issues/168
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
# Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
# https://github.com/huggingface/transformers/pull/27931
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
import transformers.models.llama.modeling_llama
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = LlamaRotaryEmbedding
transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding = LlamaLinearScalingRotaryEmbedding
return
pass
@staticmethod
def from_pretrained(
model_name = "unsloth/llama-3-8b-bnb-4bit",
max_seq_length = None,
dtype = None,
load_in_4bit = True,
token = None,
device_map = "sequential",
rope_scaling = None,
fix_tokenizer = True,
model_patcher = None,
tokenizer_name = None,
trust_remote_code = False,
**kwargs,
):
if trust_remote_code:
print(
"Unsloth: WARNING `trust_remote_code` is True.\n"\
"Are you certain you want to do remote code execution?"
)
pass
if token is None and "HF_TOKEN" in os.environ:
token = os.environ["HF_TOKEN"]
if token is None and "HUGGINGFACE_TOKEN" in os.environ:
token = os.environ["HUGGINGFACE_TOKEN"]
if model_patcher is None: model_patcher = FastLlamaModel
SUPPORTS_BFLOAT16 = is_bfloat16_supported()
gpu_stats = torch.cuda.get_device_properties(0)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
statistics = \
f"==((====))== Unsloth: Fast {model_patcher.__name__[4:-5]} patching release {__version__}\n"\
f" \\\ /| GPU: {gpu_stats.name}. Max memory: {max_memory} GB. Platform = {platform_system}.\n"\
f"O^O/ \_/ \\ Pytorch: {torch.__version__}. CUDA = {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit = {torch.version.cuda}.\n"\
f"\ / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\
f' "-____-" Free Apache license: http://github.com/unslothai/unsloth'
print(statistics)
# Warn about fast transfers
old_hf_transfer = os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "0")
if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "0") == "1":
print("Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!")
pass
# Return old flag
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_hf_transfer
model_patcher.pre_patch()
get_statistics() # For debugging - we use a download counter to see if environments are not breaking
if dtype is None:
dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16
elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16:
logger.warning_once("Device does not support bfloat16. Will change to float16.")
dtype = torch.float16
assert(dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32)
# RoPE Scaling
model_config = AutoConfig.from_pretrained(model_name, token = token)
model_max_seq_length = model_config.max_position_embeddings
# Check if RoPE Scaling is even allowed
model_function = MODEL_FOR_CAUSAL_LM_MAPPING[model_config.__class__]
has_rope_scaling = False
try:
with open(inspect.getfile(model_function), "r") as file:
has_rope_scaling = "self.config.rope_scaling" in file.read()
except: pass
has_rope_scaling = True
# If max_seq_length is not specified, use maximum fron config
if max_seq_length is None:
max_seq_length = model_max_seq_length
pass
if (rope_scaling is None) and (max_seq_length > model_max_seq_length):
rope_scaling = max_seq_length / model_max_seq_length
logger.warning_once(
f"Unsloth: {model_name} can only handle sequence lengths of at most "\
f"{model_max_seq_length}.\nBut with kaiokendev's RoPE scaling of "\
f"{round(rope_scaling, 3)}, it can be magically be extended to "\
f"{max_seq_length}!"
)
# Warn RoPE scaling isn't allowed
if not has_rope_scaling:
raise RuntimeError(
"However, {model_name} doesn't support RoPE Scaling!\n"\
"Please file a feature request at https://github.com/unslothai/unsloth."
)
pass
rope_scaling = {"type": "linear", "factor": rope_scaling,}
# Add to kwargs
kwargs["rope_scaling"] = rope_scaling
pass
# We currently only support NVIDIA GPUs - AMD / Intel is a work in progress!
pre_check = check_nvidia()
bnb_config = None
if load_in_4bit:
bnb_config = BitsAndBytesConfig(
load_in_4bit = True,
bnb_4bit_use_double_quant = True,
bnb_4bit_quant_type = "nf4",
bnb_4bit_compute_dtype = dtype,
)
pass
# https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/discussions/12
# RoPE Scaling's max_position_embeddings must be updated
max_position_embeddings = max(max_seq_length, model_max_seq_length)
kwargs.pop("attn_implementation", None); # No need since we auto call it
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map = device_map,
torch_dtype = dtype,
quantization_config = bnb_config,
token = token,
max_position_embeddings = max_position_embeddings,
trust_remote_code = trust_remote_code,
attn_implementation = "eager",
**kwargs,
)
# Return old flag
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_hf_transfer
# We currently only support NVIDIA GPUs - AMD / Intel is a work in progress!
post_check = check_nvidia()
# Counteract saved tokenizers
tokenizer_name = model_name if tokenizer_name is None else tokenizer_name
tokenizer = load_correct_tokenizer(
tokenizer_name = tokenizer_name,
model_max_length = max_position_embeddings,
padding_side = "right",
token = token,
trust_remote_code = trust_remote_code,
)
model, tokenizer = patch_tokenizer(model, tokenizer)
model = model_patcher.post_patch(model)
# Patch up QKV / O and MLP
for idx, layer in enumerate(model.model.layers):
layer.self_attn.apply_qkv = original_apply_qkv
layer.self_attn.apply_o = original_apply_o
pass
# Patch Trainer
from transformers.trainer import Trainer
try:
if Trainer._inner_training_loop.__name__ != "_fast_inner_training_loop":
inner_training_loop = inspect.getsource(Trainer._inner_training_loop)
Trainer._original_training_loop = inner_training_loop
else:
inner_training_loop = Trainer._original_training_loop
except:
raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!')
pass
if ((post_check - pre_check) >= 1).sum() > 1:
raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!')
import transformers.trainer
items_in_trainer = dir(transformers.trainer)
good_items = []
for item in items_in_trainer:
# TODO: Support Deepspeed
if item.startswith(("deepspeed", "xm", "met", "smp")): continue
if item in inner_training_loop: good_items.append(item)
pass
exec("from transformers.trainer import (" + ", ".join(x for x in good_items) + ")", globals())
start = re.search('logger\.info\([\"\'].+?Running training', inner_training_loop).span(0)[0]
end = inner_training_loop.find("\n\n", start)
original_debug = inner_training_loop[start:end]
spaces = re.search('\n([\s\t]{1,})', original_debug).group(0)[1:]
front_spaces = re.match('([\s\t]{1,})', inner_training_loop).group(0)
debug_info = """debug_info = \\
f"==((====))== Unsloth - 2x faster free finetuning | Num GPUs = {args.world_size}\\n"\\
f" \\\\\\ /| Num examples = {num_examples:,} | Num Epochs = {num_train_epochs:,}\\n"\\
f"O^O/ \\_/ \\ Batch size per device = {self._train_batch_size:,} | Gradient Accumulation steps = {args.gradient_accumulation_steps}\\n"\\
f"\\ / Total batch size = {total_train_batch_size:,} | Total steps = {max_steps:,}\\n"\\
f' "-____-" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}'
logger.warning(debug_info)
import subprocess, re, gc, numpy as np
a = np.array([0,])
try:
a = subprocess.check_output('nvidia-smi --query-gpu=memory.used --format=csv', shell = True)
a = re.findall(rb'([\\d]{1,})[\\s]{1,}M', a)
a = np.array([int(x.decode('utf-8'))/1024 for x in a])
except:
if not torch.cuda.is_available():
raise RuntimeError('Unsloth: We do not support AMD / Intel machines yet - it is a work in progress!')
if ((a - PRE_CHECK) >= 1).sum() > 1:
raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!')
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()"""
debug_info = debug_info.split('\n')
debug_info = "\n".join([debug_info[0]] + [spaces + x[8:] for x in debug_info[1:]])
inner_training_loop = inner_training_loop.replace(original_debug, debug_info)
debug_info = """n_total_devices = total_train_batch_size // \\
args.gradient_accumulation_steps // self._train_batch_size
if n_total_devices > 1:
logger.warning_once('Unsloth currently does not support multi GPU setups - but we are working on it!')
debug_info ="""
debug_info = debug_info.split('\n')
debug_info = "\n".join([debug_info[0]] + [spaces + x[8:] for x in debug_info[1:]])
inner_training_loop = inner_training_loop.replace("debug_info =", debug_info, 1)
front_spaces = re.match(r"[\t\s]{1,}", inner_training_loop).group(0)
inner_training_loop = re.sub(r"^" + front_spaces, "", inner_training_loop, flags = re.MULTILINE)
inner_training_loop = inner_training_loop.replace(
"train_dataloader = tpu_spmd_dataloader(train_dataloader)",
"raise RuntimeError('Unsloth: TPUs are not yet supported!')"
)
inner_training_loop = inner_training_loop.replace(
"self.accelerator.free_memory()",
"self.accelerator.free_memory()\n" + \
front_spaces + "if self.is_deepspeed_enabled:"\
"raise RuntimeError('Unsloth: Deepspeed is not yet supported!')\n", 1,
)
check_batches = """train_dataloader = self.get_train_dataloader()
ga = args.gradient_accumulation_steps
bsz = self._train_batch_size
total_batches = bsz * ga * args.world_size
n_total_devices = total_batches // ga // bsz
if n_total_devices > 1:
logger.warning_once('Unsloth currently does not support multi GPU setups - but we are working on it!')
divisor = n_total_devices / 1
bsz = self._train_batch_size = max(int(bsz / divisor), 1)
if total_batches // ga // bsz > 1:
divisor = n_total_devices / 1
ga = args.gradient_accumulation_steps = max(int(ga / divisor), 1)"""
check_batches = check_batches.split('\n')
check_batches = "\n".join([check_batches[0]] + [front_spaces + x[8:] for x in check_batches[1:]])
inner_training_loop = inner_training_loop.replace(
"train_dataloader = self.get_train_dataloader()",
check_batches, 1,
)
inner_training_loop = inner_training_loop.replace(
"_inner_training_loop",
"_fast_inner_training_loop", 1,
)
exec(inner_training_loop, globals())
Trainer._inner_training_loop = _fast_inner_training_loop
inner_training_loop = inner_training_loop.replace(
"is_torch_tpu_available()",
"False",
)
if "n_total_devices >" not in inner_training_loop:
raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!')
pass
inner_training_loop = inner_training_loop.replace(
"is_sagemaker_mp_enabled()",
"False",
)
exec(inner_training_loop, globals())
Trainer._inner_training_loop = _fast_inner_training_loop
# Save max_seq_length
model.max_seq_length = max_position_embeddings
internal_model = model
while hasattr(internal_model, "model"):
internal_model.max_seq_length = max_position_embeddings
internal_model = internal_model.model
pass
internal_model.max_seq_length = max_position_embeddings
# We check the tokenizer first for errors
if fix_tokenizer:
tokenizer = check_tokenizer(
model = model,
tokenizer = tokenizer,
model_name = model_name,
model_max_length = max_position_embeddings,
padding_side = "right",
token = token,
)
pass
patch_saving_functions(tokenizer)
# Fix up config for transformers uploading PEFT
# Not necessary anymore since we require transformers>=4.37!
if False:
name = model.config._name_or_path
if name.startswith("unsloth/") and name.endswith("-bnb-4bit"):
name = name[:len(name) - len("-bnb-4bit")]
model.config.update({"_name_or_path" : name})
pass
pass
# Log Unsloth version for future fastpaths for inference
model.config.update({"unsloth_version" : __version__})
# Add save modules
patch_saving_functions(model)
Trainer._inner_training_loop = _fast_inner_training_loop
# Save tokenizer for inference purposes
tokenizer.padding_side = "left" # Force inference
internal_model = model
while hasattr(internal_model, "model"):
internal_model._saved_temp_tokenizer = tokenizer
internal_model = internal_model.model
pass
internal_model._saved_temp_tokenizer = tokenizer
return model, tokenizer
pass
@staticmethod
def post_patch(model):
# Patch model
layers = model.model.layers
# Torch.compile fails on embedding matrix??
# Workaround randomnly fixes it for torch versions < 2.
model.set_input_embeddings(torch.nn.Embedding.from_pretrained(model.get_input_embeddings().weight))
model.config.update({"unsloth_version" : __version__})
# We also do this for the lm_head
lm_head = torch.nn.Linear(1, 1, bias = None)
del lm_head.weight
lm_head.weight = model.get_output_embeddings().weight
lm_head.in_features = lm_head.weight.shape[1]
lm_head.out_features = lm_head.weight.shape[0]
model.lm_head = lm_head
# Also patch all dtypes - BnB seems to not allocate the correct type?
# BnB default dtype seems to be float16!
correct_dtype = lm_head.weight.dtype
for name, module in model.named_modules():
if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)):
weight = module.weight
quant_state = weight.quant_state
if type(quant_state) is list:
# BnB seems to have float16 as default!
module.weight.quant_state[2] = correct_dtype # Cast to correct dtype
else:
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
quant_state.dtype = correct_dtype
pass
pass
# Downcast RoPE embedding to correct data type
if (name.endswith("rotary_emb") or hasattr(module, "cos_cached")) \
and (module.cos_cached.dtype != correct_dtype):
module.cos_cached = module.cos_cached.to(correct_dtype)
module.sin_cached = module.sin_cached.to(correct_dtype)
pass
pass
pass
# Clear deleted GPU items
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
return model
pass
@staticmethod
def get_peft_model(
model,
r = 16,
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_alpha = 16,
lora_dropout = 0,
bias = "none",
layers_to_transform = None,
layers_pattern = None,
use_gradient_checkpointing = True,
random_state = 3407,
max_seq_length = 2048, # not used anymore
use_rslora = False,
modules_to_save = None,
init_lora_weights = True,
loftq_config = {},
temporary_location = "_unsloth_temporary_saved_buffers",
**kwargs,
):
transformers_set_seed(random_state)
if isinstance(model, PeftModelForCausalLM):
# Check if exactly the same and then pass through!
assert(hasattr(model, "peft_config"))
peft_config = model.peft_config["default"].to_dict()
check_parameters = [
"r", "lora_alpha", "lora_dropout",
"bias", "layers_to_transform", "layers_pattern",
"use_rslora", "init_lora_weights",
]
check_all = True
for param in check_parameters:
check_all = check_all and (peft_config[param] == eval(param))
pass
# Check save_modules
old_target_modules = list(peft_config["target_modules"])
modules_to_save = peft_config["modules_to_save"]
if modules_to_save is None: modules_to_save = {}
modules_to_save = list(modules_to_save)
old_target_modules += modules_to_save
# Combine all
new_target_modules = list(target_modules) + \
list(modules_to_save if modules_to_save is not None else [])
# Now check!
new_target_modules = set(new_target_modules)
check_all = check_all and (
len(set(old_target_modules) ^ new_target_modules) == 0
)
check_all = check_all and (
(loftq_config == {} or loftq_config is None) and \
(peft_config["loftq_config"] == {} or peft_config["loftq_config"] is None)
)
if check_all:
# Simply pass through!
logger.warning(
"Unsloth: Already have LoRA adapters! We shall skip this step."
)
# Offload!
# [TODO] First offload lm_head and embed_tokens to CPU (should be disk!!)
if "embed_tokens" in new_target_modules:
print("Unsloth: Casting embed_tokens to float32")
model.model.model.embed_tokens.modules_to_save.default\
.to(device = "cuda:0", dtype = torch.float32, non_blocking = True)
model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True)
# [TODO] Move old embed_tokens to CPU - should be disk!
model.model.model.embed_tokens.original_module\
.to(device = "cpu", non_blocking = True)
model.model.model.embed_tokens.original_module.requires_grad_(False)
pass
if "lm_head" in new_target_modules:
print("Unsloth: Casting lm_head to float32")
model.model.lm_head.modules_to_save.default\
.to(device = "cuda:0", dtype = torch.float32, non_blocking = True)
model.model.lm_head.modules_to_save.default.requires_grad_(True)
# [TODO] Move old lm_head to CPU - should be disk!
model.model.lm_head.original_module\
.to(device = "cpu", non_blocking = True)
model.model.lm_head.original_module.requires_grad_(False)
pass
return model
else:
raise TypeError(
"Unsloth: Your model already has LoRA adapters. Your new parameters are different."
)
pass
pass
if loftq_config is None: loftq_config = {}
signature = str(inspect.signature(LoraConfig))
SUPPORTS_LOFTQ = "loftq_config" in signature
SUPPORTS_RSLORA = "use_rslora" in signature
assert(max_seq_length <= model.max_seq_length)
if lora_dropout != 0:
logger.warning_once(
f"Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = {lora_dropout}.\n"\
f"Unsloth will patch all other layers, except LoRA matrices, causing a performance hit."
)
pass
if bias != "none":
logger.warning_once(
f"Unsloth: bias = `none` is supported for fast patching. You are using bias = {bias}.\n"\
f"Unsloth will patch all other layers, except LoRA matrices, causing a performance hit."
)
pass
if not (type(init_lora_weights) is bool or \
init_lora_weights == "gaussian" or init_lora_weights == "loftq"):
raise ValueError(
'Unsloth: `init_lora_weights` must be either [True, False, "gaussian", "loftq"].'
)
pass
if init_lora_weights == "loftq":
if not SUPPORTS_LOFTQ:
import peft
raise RuntimeError(
f"Unsloth: Your PEFT version of {peft.__version__} does not support LoftQ init.\n"\
"Please install PEFT 0.7.2 or higher.\n"\
"You can also install from source: `pip install git+https://github.com/huggingface/peft.git"
)
pass
if loftq_config == {}:
from peft import LoftQConfig
logger.warning_once(
f"Unsloth: init_lora_weights = `loftq` is set, but `loftq_config` is None.\n"\
f"We shall use `loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)`."
)
loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)
pass
if hasattr(model.config, "quantization_config"):
raise ValueError(
"Unsloth: You are using `loftq` init, yet `load_in_4bit = True` was set.\n"\
"Reload your model without any quantization by setting `load_in_4bit = False`."
)
pass
pass
assert(type(use_rslora) is bool)
if use_rslora:
if not SUPPORTS_RSLORA:
# We manually check for PEFT
import peft
raise RuntimeError(
f"Unsloth: Your PEFT version of {peft.__version__} does not support `use_rslora`.\n"\
"Please install PEFT 0.7.2 or higher.\n"\
"You can also install from source: `pip install git+https://github.com/huggingface/peft.git"
)
pass
pass
accepted_modules = frozenset(("q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",),)
model.config.update({"unsloth_version" : __version__})
if type(modules_to_save) is tuple:
modules_to_save = list(modules_to_save)
pass
train_lm_head = False
train_embed_tokens = False
final_modules = []
for module in target_modules:
if module == "lm_head":
# logger.warning_once(
# "Unsloth: `lm_head` should be placed in `modules_to_save` and not `target_modules`. "\
# "Luckily, we shall do it for you!"
# )
train_lm_head = True
if modules_to_save is None: modules_to_save = ["lm_head"]
else: modules_to_save.append("lm_head")
elif module == "embed_tokens":
# logger.warning_once(
# "Unsloth: `embed_tokens` should be placed in `modules_to_save` and not `target_modules`. "\
# "Luckily, we shall do it for you!"
# )
train_embed_tokens = True
if modules_to_save is None: modules_to_save = ["embed_tokens"]
else: modules_to_save.append("embed_tokens")
else:
assert(module in accepted_modules)
final_modules.append(module)
pass
# Check if we added new tokens!
if hasattr(model, "_need_to_train_embeddings"):
if not train_lm_head or not train_embed_tokens:
print(
"Unsloth: You added new tokens but did not specify if you wanted to "\
"train the lm_head and embed_tokens.\nWe must turn it on for you."
)
train_lm_head = True
train_embed_tokens = True
if modules_to_save is None: modules_to_save = ["embed_tokens"]
else: modules_to_save.append("embed_tokens")
if modules_to_save is None: modules_to_save = ["lm_head"]
else: modules_to_save.append("lm_head")
pass
pass
# Check for Llama-3
# if hasattr(model._saved_temp_tokenizer, "_using_llama3_template"):
# if not train_embed_tokens and not train_lm_head:
# raise RuntimeError("")
# First fix untrained tokens
# Wrong - can cause reserved tokens to pop out!!
# if train_embed_tokens or train_lm_head:
# fix_untrained_tokens(model, eps = 1e-16)
# pass
# Check modules_to_save
if modules_to_save is not None:
for module in modules_to_save:
if module == "lm_head":
train_lm_head = True
elif module == "embed_tokens":
train_embed_tokens = True
else:
raise TypeError(
f"Unsloth: Module = {module} is not allowed. Only 'lm_head' and 'embed_tokens' is allowed."
)
pass
pass
if isinstance(modules_to_save, (tuple, list)):
modules_to_save = list(set(modules_to_save))
pass
# Get LoRA
arguments = dict(
r = r,
lora_alpha = lora_alpha,
target_modules = final_modules,
lora_dropout = lora_dropout,
bias = bias,
task_type = TaskType.CAUSAL_LM,
layers_to_transform = layers_to_transform,
init_lora_weights = init_lora_weights,
loftq_config = loftq_config,
use_rslora = use_rslora,
modules_to_save = modules_to_save,
**kwargs,
)
if not SUPPORTS_LOFTQ: del arguments["loftq_config"]
if not SUPPORTS_RSLORA: del arguments["use_rslora"]
_saved_temp_tokenizer = model._saved_temp_tokenizer
lora_config = LoraConfig(**arguments)
# First offload lm_head and embed_tokens to disk
input_embeddings_device = model. get_input_embeddings().weight.device
output_embeddings_device = model.get_output_embeddings().weight.device
if use_gradient_checkpointing == "unsloth":
if train_embed_tokens:
print("Unsloth: Offloading input_embeddings to disk to save VRAM")
offload_input_embeddings(model, temporary_location)
pass
# Remove old items to save VRAM
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
pass
if train_lm_head:
print("Unsloth: Offloading output_embeddings to disk to save VRAM")
offload_output_embeddings(model, temporary_location)
pass
# Remove old items to save VRAM
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
pass
pass
model = _get_peft_model(model, lora_config)
model._saved_temp_tokenizer = _saved_temp_tokenizer
model = FastLlamaModel.patch_peft_model(model, use_gradient_checkpointing)
# Now patch lm_head and embed_tokens
if train_embed_tokens:
print("Unsloth: Casting embed_tokens to float32")
assert(hasattr(model.model.model.embed_tokens, "modules_to_save"))
model.model.model.embed_tokens.modules_to_save.default\
.to(device = "cuda:0", dtype = torch.float32, non_blocking = True)
model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True)
pass
if train_lm_head:
print("Unsloth: Casting lm_head to float32")
assert(hasattr(model.model.lm_head, "modules_to_save"))
model.model.lm_head.modules_to_save.default\
.to(device = "cuda:0", dtype = torch.float32, non_blocking = True)
model.model.lm_head.modules_to_save.default.requires_grad_(True)
pass
# Patch tokenizer to pad to the right
internal_model = model
while hasattr(internal_model, "model"):
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.padding_side = "right"
pass
internal_model = internal_model.model
pass
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.padding_side = "right"
pass
# Clear deleted GPU items
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
pass
return model
pass
@staticmethod
def patch_peft_model(
model,
use_gradient_checkpointing = True,
):
if not isinstance(model, PeftModelForCausalLM):
raise TypeError(
"Unsloth: Your model needs to call `.get_peft_model` first!"
)
pass
# Get activation function
model_type = model.config.model_type
if model_type == "llama": apply_lora_mlp = apply_lora_mlp_swiglu
elif model_type == "mistral": apply_lora_mlp = apply_lora_mlp_swiglu
elif model_type == "qwen2": apply_lora_mlp = apply_lora_mlp_swiglu
elif model_type == "gemma": apply_lora_mlp = apply_lora_mlp_geglu_approx
elif model_type == "gemma2": apply_lora_mlp = apply_lora_mlp_geglu_approx
else:
raise NotImplementedError(f"Unsloth: {model_type} is not yet implemented!")
pass
model = prepare_model_for_kbit_training(
model,
use_gradient_checkpointing = use_gradient_checkpointing,
use_reentrant = True,
)
# Fix up config for transformers uploading PEFT
for active_adapter in model.peft_config.keys():
# Not necessary since we requires transformers >= 4.37
if False:
name = model.peft_config[active_adapter].base_model_name_or_path
if name.startswith("unsloth/") and name.endswith("-bnb-4bit"):
name = name[:len(name) - len("-bnb-4bit")]
model.peft_config[active_adapter].base_model_name_or_path = name
pass
# Add revision to enable future fast inference paths
model.peft_config[active_adapter].revision = f"unsloth"
pass
from transformers.trainer import Trainer
if Trainer._inner_training_loop.__name__ != "_fast_inner_training_loop":
raise RuntimeError(
'Unsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so '\
'enabling it will require much more work, so we have to prioritize. Please understand!\n'\
'We do have a separate beta version, which you can contact us about!\n'\
'Thank you for your understanding and we appreciate it immensely!'
)
pass
# Fix loftq issues
# loftq_config must not = None, but rather {}
all_configs = model.peft_config
for key, current_config in all_configs.items():
if hasattr(current_config, "loftq_config") and current_config.loftq_config is None:
new_args = current_config.__dict__
new_args["loftq_config"] = {}
current_config = current_config.__class__(**new_args)
all_configs[key] = current_config
pass
pass
# Do patching
n_mlp = 0
n_qkv = 0
n_o = 0
import types
active_adapter = model.active_adapters[0] if \
hasattr(model, "active_adapters") else model.active_adapter
# Get dropout and bias
lora_dropout = model.peft_config[active_adapter].lora_dropout
bias = model.peft_config[active_adapter].bias
if lora_dropout == 0 and bias == "none":
for idx, layer in enumerate(model.model.model.layers):
# MLP patching
gate_proj = layer.mlp.gate_proj
up_proj = layer.mlp. up_proj
down_proj = layer.mlp.down_proj
if hasattr(gate_proj, "lora_A") and \
hasattr( up_proj, "lora_A") and \
hasattr(down_proj, "lora_A") and \
(getattr(gate_proj, "base_layer", gate_proj).bias is None) and \
(getattr( up_proj, "base_layer", up_proj).bias is None) and \
(getattr(down_proj, "base_layer", down_proj).bias is None) and \
(len(getattr(gate_proj, "lora_magnitude_vector", []) or []) == 0) and \
(len(getattr( up_proj, "lora_magnitude_vector", []) or []) == 0) and \
(len(getattr(down_proj, "lora_magnitude_vector", []) or []) == 0):
# https://stackoverflow.com/questions/50599045/python-replacing-a-function-within-a-class-of-a-module
layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp)
n_mlp += 1
else:
logger.warning_once(
"Not an error, but Unsloth cannot patch MLP layers with our manual autograd engine since either LoRA adapters\n"\
"are not enabled or a bias term (like in Qwen) is used."
)
pass
# QKV attention patching
q_proj = layer.self_attn.q_proj
k_proj = layer.self_attn.k_proj
v_proj = layer.self_attn.v_proj
if hasattr(q_proj, "lora_A") and \
hasattr(k_proj, "lora_A") and \
hasattr(v_proj, "lora_A") and \
(getattr(q_proj, "base_layer", q_proj).bias is None) and \
(getattr(k_proj, "base_layer", k_proj).bias is None) and \
(getattr(v_proj, "base_layer", v_proj).bias is None) and \
(len(getattr(q_proj, "lora_magnitude_vector", []) or []) == 0) and \
(len(getattr(k_proj, "lora_magnitude_vector", []) or []) == 0) and \
(len(getattr(v_proj, "lora_magnitude_vector", []) or []) == 0):
layer.self_attn.apply_qkv = apply_lora_qkv
n_qkv += 1
else:
if model_type != "qwen2":
logger.warning_once(
"Not an error, but Unsloth cannot patch Attention layers with our manual autograd engine since either LoRA adapters\n"\
"are not enabled or a bias term (like in Qwen) is used."
)
pass
pass
# O attention patching
o_proj = layer.self_attn.o_proj
if hasattr(o_proj, "lora_A") and \
(getattr(o_proj, "base_layer", o_proj).bias is None) and \
(len(getattr(o_proj, "lora_magnitude_vector", []) or []) == 0):
layer.self_attn.apply_o = apply_lora_o
n_o += 1
else:
logger.warning_once(
"Not an error, but Unsloth cannot patch O projection layer with our manual autograd engine since either LoRA adapters\n"\
"are not enabled or a bias term (like in Qwen) is used."
)
pass
pass
pass
logger.warning_once(
f"Unsloth {__version__} patched {len(model.model.model.layers)} layers with "\
f"{n_qkv} QKV layers, {n_o} O layers and {n_mlp} MLP layers.",
)
patch_saving_functions(model)
# Patch cross entropy loss labels
# Fixes https://github.com/unslothai/unsloth/issues/10
max_seq_length = model.max_seq_length
extra_ignored_labels = torch.full((max_seq_length, 1), -100, device = "cuda:0")
model.model.extra_ignored_labels = extra_ignored_labels
internal_model = model
while hasattr(internal_model, "model"):
internal_model.max_seq_length = max_seq_length
internal_model = internal_model.model
pass
internal_model.max_seq_length = max_seq_length
# Patch tokenizer to pad to the right
internal_model = model
while hasattr(internal_model, "model"):
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.padding_side = "right"
pass
internal_model = internal_model.model
pass
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.padding_side = "right"
pass
# Clear deleted GPU items
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
pass
return model
pass
@staticmethod
def for_inference(model):
# if model.config.model_type == "qwen2":
# FastLlamaModel.for_training(model)
# return
# pass
internal_model = model
internal_model.gradient_checkpointing = False
internal_model.training = False
while hasattr(internal_model, "model"):
internal_model = internal_model.model
internal_model.gradient_checkpointing = False
internal_model.training = False
pass
# Also check if lm_head / embeddings are trained
internal_model = model
while not hasattr(internal_model, "lm_head"):
internal_model = internal_model.model
pass
lm_head = internal_model.lm_head.weight
device_type = lm_head.device.type
dtype = model.config.torch_dtype
if type(dtype) is str:
if dtype == "float16": dtype = torch.float16
elif dtype == "bfloat16": dtype = torch.bfloat16
pass
# Wrap model.generate
if model.generate.__name__ != "_fast_generate":
model._unwrapped_old_generate = model.generate
model.generate = _wrap_fast_inference(model.generate, device_type, dtype, model)
pass
# Patch tokenizer to pad to the left
internal_model = model
while hasattr(internal_model, "model"):
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.padding_side = "left"
pass
internal_model = internal_model.model
pass
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.padding_side = "left"
pass
pass
@staticmethod
def for_training(model, use_gradient_checkpointing = True):
internal_model = model
internal_model.gradient_checkpointing = use_gradient_checkpointing
internal_model.training = True
# Delete all fast inference loras
for param in model.parameters():
if hasattr(param, "_fast_lora"):
del param._fast_lora
pass
while hasattr(internal_model, "model"):
internal_model = internal_model.model
internal_model.gradient_checkpointing = use_gradient_checkpointing
internal_model.training = True
pass
# Also revert model.generate
if hasattr(model, "_unwrapped_old_generate"):
model.generate = model._unwrapped_old_generate
del model._unwrapped_old_generate
pass
# Patch tokenizer to pad to the right
internal_model = model
while hasattr(internal_model, "model"):
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.padding_side = "right"
pass
internal_model = internal_model.model
pass
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.padding_side = "right"
pass
pass
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .llama import FastLlamaModel, logger
from .mistral import FastMistralModel
from .qwen2 import FastQwen2Model
from transformers import AutoConfig
from transformers import __version__ as transformers_version
from peft import PeftConfig, PeftModel
from .mapper import INT_TO_FLOAT_MAPPER, FLOAT_TO_INT_MAPPER
import os
# https://github.com/huggingface/transformers/pull/26037 allows 4 bit loading!
from packaging.version import Version
transformers_version = Version(transformers_version)
SUPPORTS_FOURBIT = transformers_version >= Version("4.37")
SUPPORTS_GEMMA = transformers_version >= Version("4.38")
SUPPORTS_GEMMA2 = transformers_version >= Version("4.42")
SUPPORTS_LLAMA31 = transformers_version >= Version("4.43.1")
if SUPPORTS_GEMMA:
from .gemma import FastGemmaModel
if SUPPORTS_GEMMA2:
from .gemma2 import FastGemma2Model
pass
def _get_model_name(model_name, load_in_4bit = True):
if not SUPPORTS_FOURBIT and model_name in INT_TO_FLOAT_MAPPER:
model_name = INT_TO_FLOAT_MAPPER[model_name.lower()]
logger.warning_once(
f"Unsloth: Your transformers version of {transformers_version} does not support native "\
f"4bit loading.\nThe minimum required version is 4.37.\n"\
f'Try `pip install --upgrade "transformers>=4.37"`\n'\
f"to obtain the latest transformers build, then restart this session.\n"\
f"For now, we shall load `{model_name}` instead (still 4bit, just slower downloading)."
)
elif not load_in_4bit and model_name in INT_TO_FLOAT_MAPPER:
new_model_name = INT_TO_FLOAT_MAPPER[model_name.lower()]
# logger.warning_once(
# f"Unsloth: You passed in `{model_name}` which is a 4bit model, yet you set\n"\
# f"`load_in_4bit = False`. We shall load `{new_model_name}` instead."
# )
model_name = new_model_name
elif load_in_4bit and SUPPORTS_FOURBIT and model_name in FLOAT_TO_INT_MAPPER:
new_model_name = FLOAT_TO_INT_MAPPER[model_name.lower()]
# logger.warning_once(
# f"Unsloth: You passed in `{model_name}` and `load_in_4bit = True`.\n"\
# f"We shall load `{new_model_name}` for 4x faster loading."
# )
model_name = new_model_name
pass
return model_name
pass
class FastLanguageModel(FastLlamaModel):
@staticmethod
def from_pretrained(
model_name = "unsloth/llama-3-8b-bnb-4bit",
max_seq_length = None,
dtype = None,
load_in_4bit = True,
token = None,
device_map = "sequential",
rope_scaling = None,
fix_tokenizer = True,
trust_remote_code = False,
use_gradient_checkpointing = "unsloth",
resize_model_vocab = None,
revision = None,
*args, **kwargs,
):
if token is None and "HF_TOKEN" in os.environ:
token = os.environ["HF_TOKEN"]
if token is None and "HUGGINGFACE_TOKEN" in os.environ:
token = os.environ["HUGGINGFACE_TOKEN"]
old_model_name = model_name
model_name = _get_model_name(model_name, load_in_4bit)
# First check if it's a normal model via AutoConfig
from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled
was_disabled = are_progress_bars_disabled()
disable_progress_bars()
try:
model_config = AutoConfig.from_pretrained(model_name, token = token, revision = revision)
is_model = True
except:
is_model = False
try:
peft_config = PeftConfig .from_pretrained(model_name, token = token, revision = revision)
is_peft = True
except:
is_peft = False
# Cannot be both!
if is_model and is_peft:
raise RuntimeError(
"Unsloth: Your repo has a LoRA adapter and a base model.\n"\
"You have 2 files `config.json` and `adapter_config.json`.\n"\
"We must only allow one config file.\n"\
"Please separate the LoRA and base models to 2 repos."
)
elif not is_model and not is_peft:
raise RuntimeError(
f"Unsloth: `{model_name}` is not a base model or a PEFT model.\n"\
"We could not locate a `config.json` or `adapter_config.json` file.\n"\
"Are you certain the model name is correct? Does it actually exist?"
)
pass
# Get base model for PEFT:
if is_peft:
# Check base model again for PEFT
model_name = _get_model_name(peft_config.base_model_name_or_path, load_in_4bit)
model_config = AutoConfig.from_pretrained(model_name, token = token, revision = revision)
pass
if not was_disabled: enable_progress_bars()
model_type = model_config.model_type
if model_type == "llama":
scaling_type = None
if getattr(model_config, "rope_scaling", None) is not None:
scaling_type1 = model_config.rope_scaling.get("type", None)
scaling_type2 = model_config.rope_scaling.get("rope_type", None)
scaling_type = scaling_type1 if scaling_type1 is not None else scaling_type2
pass
if scaling_type == "llama3" and not SUPPORTS_LLAMA31:
raise ImportError(
f"Unsloth: Your transformers version of {transformers_version} does not support Llama 3.1.\n"\
f"The minimum required version is 4.43.1\n"\
f'Try `pip install --upgrade "transformers>=4.43.1"`\n'\
f"to obtain the latest transformers build, then restart this session."\
)
dispatch_model = FastLlamaModel
elif model_type == "mistral": dispatch_model = FastMistralModel
elif model_type == "gemma":
if not SUPPORTS_GEMMA:
raise ImportError(
f"Unsloth: Your transformers version of {transformers_version} does not support Gemma.\n"\
f"The minimum required version is 4.38.\n"\
f'Try `pip install --upgrade "transformers>=4.38"`\n'\
f"to obtain the latest transformers build, then restart this session."\
)
dispatch_model = FastGemmaModel
elif model_type == "gemma2":
if not SUPPORTS_GEMMA2:
raise ImportError(
f"Unsloth: Your transformers version of {transformers_version} does not support Gemma2.\n"\
f"The minimum required version is 4.42.3.\n"\
f'Try `pip install --upgrade "transformers>=4.42.3"`\n'\
f"to obtain the latest transformers build, then restart this session."\
)
dispatch_model = FastGemma2Model
elif model_type == "qwen2":
dispatch_model = FastQwen2Model
else:
raise NotImplementedError(
f"Unsloth: {model_name} not supported yet!\n"\
"Make an issue to https://github.com/unslothai/unsloth!",
)
pass
# Check if this is local model since the tokenizer gets overwritten
if os.path.exists(os.path.join(old_model_name, "tokenizer_config.json")) and \
os.path.exists(os.path.join(old_model_name, "tokenizer.json")) and \
os.path.exists(os.path.join(old_model_name, "special_tokens_map.json")):
tokenizer_name = old_model_name
else:
tokenizer_name = None
pass
model, tokenizer = dispatch_model.from_pretrained(
model_name = model_name,
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
token = token,
device_map = device_map,
rope_scaling = rope_scaling,
fix_tokenizer = fix_tokenizer,
model_patcher = dispatch_model,
tokenizer_name = tokenizer_name,
trust_remote_code = trust_remote_code,
revision = revision if not is_peft else None,
*args, **kwargs,
)
if resize_model_vocab is not None:
model.resize_token_embeddings(resize_model_vocab)
pass
# In case the model supports tagging, add the unsloth tag.
if hasattr(model, "add_model_tags"):
model.add_model_tags(["unsloth",])
pass
if hasattr(tokenizer, "add_model_tags"):
tokenizer.add_model_tags(["unsloth",])
pass
if load_in_4bit:
# Fix up bitsandbytes config
quantization_config = \
{
# Sometimes torch_dtype is not a string!!
"bnb_4bit_compute_dtype" : model.config.to_dict()["torch_dtype"],
"bnb_4bit_quant_type" : "nf4",
"bnb_4bit_use_double_quant" : True,
"llm_int8_enable_fp32_cpu_offload" : False,
"llm_int8_has_fp16_weight" : False,
"llm_int8_skip_modules" : None,
"llm_int8_threshold" : 6.0,
"load_in_4bit" : True,
"load_in_8bit" : False,
"quant_method" : "bitsandbytes",
}
model.config.update({"quantization_config" : quantization_config})
pass
if is_peft:
# From https://github.com/huggingface/peft/issues/184
# Now add PEFT adapters
model.enable_input_require_grads()
model = PeftModel.from_pretrained(
model,
old_model_name,
token = token,
revision = revision,
is_trainable = True,
)
# Patch it as well!
model = dispatch_model.patch_peft_model(model, use_gradient_checkpointing)
pass
return model, tokenizer
pass
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = [
"INT_TO_FLOAT_MAPPER",
"FLOAT_TO_INT_MAPPER",
]
__INT_TO_FLOAT_MAPPER = \
{
"unsloth/mistral-7b-bnb-4bit" : (
"unsloth/mistral-7b",
"mistralai/Mistral-7B-v0.1",
),
"unsloth/llama-2-7b-bnb-4bit" : (
"unsloth/llama-2-7b",
"meta-llama/Llama-2-7b-hf",
),
"unsloth/llama-2-13b-bnb-4bit" : (
"unsloth/llama-2-13b",
"meta-llama/Llama-2-13b-hf",
),
"unsloth/codellama-34b-bnb-4bit" : (
"codellama/CodeLlama-34b-hf",
),
"unsloth/zephyr-sft-bnb-4bit" : (
"unsloth/zephyr-sft",
"HuggingFaceH4/mistral-7b-sft-beta",
),
"unsloth/tinyllama-bnb-4bit" : (
"unsloth/tinyllama",
"TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
),
"unsloth/tinyllama-chat-bnb-4bit" : (
"unsloth/tinyllama-chat",
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
),
"unsloth/mistral-7b-instruct-v0.1-bnb-4bit" : (
"unsloth/mistral-7b-instruct-v0.1",
"mistralai/Mistral-7B-Instruct-v0.1",
),
"unsloth/mistral-7b-instruct-v0.2-bnb-4bit" : (
"unsloth/mistral-7b-instruct-v0.2",
"mistralai/Mistral-7B-Instruct-v0.2",
),
"unsloth/llama-2-7b-chat-bnb-4bit" : (
"unsloth/llama-2-7b-chat",
"meta-llama/Llama-2-7b-chat-hf",
),
"unsloth/llama-2-7b-chat-bnb-4bit" : (
"unsloth/llama-2-7b-chat",
"meta-llama/Llama-2-7b-chat-hf",
),
"unsloth/codellama-7b-bnb-4bit" : (
"unsloth/codellama-7b",
"codellama/CodeLlama-7b-hf",
),
"unsloth/codellama-13b-bnb-4bit" : (
"codellama/CodeLlama-13b-hf",
),
"unsloth/yi-6b-bnb-4bit" : (
"unsloth/yi-6b",
"01-ai/Yi-6B",
),
"unsloth/solar-10.7b-bnb-4bit" : (
"upstage/SOLAR-10.7B-v1.0",
),
"unsloth/gemma-7b-bnb-4bit" : (
"unsloth/gemma-7b",
"google/gemma-7b",
),
"unsloth/gemma-2b-bnb-4bit" : (
"unsloth/gemma-2b",
"google/gemma-2b",
),
"unsloth/gemma-7b-it-bnb-4bit" : (
"unsloth/gemma-7b-it",
"google/gemma-7b-it",
),
"unsloth/gemma-2b-bnb-4bit" : (
"unsloth/gemma-2b-it",
"google/gemma-2b-it",
),
"unsloth/mistral-7b-v0.2-bnb-4bit" : (
"unsloth/mistral-7b-v0.2",
"alpindale/Mistral-7B-v0.2-hf",
),
"unsloth/gemma-1.1-2b-it-bnb-4bit" : (
"unsloth/gemma-1.1-2b-it",
"google/gemma-1.1-2b-it",
),
"unsloth/gemma-1.1-7b-it-bnb-4bit" : (
"unsloth/gemma-1.1-7b-it",
"google/gemma-1.1-7b-it",
),
"unsloth/Starling-LM-7B-beta-bnb-4bit" : (
"unsloth/Starling-LM-7B-beta",
"Nexusflow/Starling-LM-7B-beta",
),
"unsloth/Hermes-2-Pro-Mistral-7B-bnb-4bit" : (
"unsloth/Hermes-2-Pro-Mistral-7B",
"NousResearch/Hermes-2-Pro-Mistral-7B",
),
"unsloth/OpenHermes-2.5-Mistral-7B-bnb-4bit" : (
"unsloth/OpenHermes-2.5-Mistral-7B",
"teknium/OpenHermes-2.5-Mistral-7B",
),
"unsloth/codegemma-2b-bnb-4bit" : (
"unsloth/codegemma-2b",
"google/codegemma-2b",
),
"unsloth/codegemma-7b-bnb-4bit" : (
"unsloth/codegemma-7b",
"google/codegemma-7b",
),
"unsloth/codegemma-7b-it-bnb-4bit" : (
"unsloth/codegemma-7b-it",
"google/codegemma-7b-it",
),
"unsloth/llama-3-8b-bnb-4bit" : (
"unsloth/llama-3-8b",
"meta-llama/Meta-Llama-3-8B",
),
"unsloth/llama-3-8b-Instruct-bnb-4bit" : (
"unsloth/llama-3-8b-Instruct",
"meta-llama/Meta-Llama-3-8B-Instruct",
),
"unsloth/llama-3-70b-bnb-4bit" : (
"meta-llama/Meta-Llama-3-70B",
),
"unsloth/llama-3-70b-Instruct-bnb-4bit" : (
"meta-llama/Meta-Llama-3-70B-Instruct",
),
"unsloth/Phi-3-mini-4k-instruct-bnb-4bit" : (
"unsloth/Phi-3-mini-4k-instruct",
"microsoft/Phi-3-mini-4k-instruct",
),
"unsloth/mistral-7b-v0.3-bnb-4bit" : (
"unsloth/mistral-7b-v0.3",
"mistralai/Mistral-7B-v0.3",
),
"unsloth/mistral-7b-instruct-v0.3-bnb-4bit" : (
"unsloth/mistral-7b-instruct-v0.3",
"mistralai/Mistral-7B-Instruct-v0.3",
),
"unsloth/Phi-3-medium-4k-instruct-bnb-4bit" : (
"unsloth/Phi-3-medium-4k-instruct",
"microsoft/Phi-3-medium-4k-instruct",
),
"unsloth/Qwen2-0.5B-bnb-4bit" : (
"unsloth/Qwen2-0.5B",
"Qwen/Qwen2-0.5B",
),
"unsloth/Qwen2-0.5B-Instruct-bnb-4bit" : (
"unsloth/Qwen2-0.5B-Instruct",
"Qwen/Qwen2-0.5B-Instruct",
),
"unsloth/Qwen2-1.5B-bnb-4bit" : (
"unsloth/Qwen2-1.5B",
"Qwen/Qwen2-1.5B",
),
"unsloth/Qwen2-1.5B-Instruct-bnb-4bit" : (
"unsloth/Qwen2-1.5B-Instruct",
"Qwen/Qwen2-1.5B-Instruct",
),
"unsloth/Qwen2-7B-bnb-4bit" : (
"unsloth/Qwen2-7B",
"Qwen/Qwen2-7B",
),
"unsloth/Qwen2-7B-Instruct-bnb-4bit" : (
"unsloth/Qwen2-7B-Instruct",
"Qwen/Qwen2-7B-Instruct",
),
"unsloth/Qwen2-70B-bnb-4bit" : (
"Qwen/Qwen2-70B",
),
"unsloth/Qwen2-70B-Instruct-bnb-4bit" : (
"Qwen/Qwen2-70B-Instruct",
),
"mistralai/Codestral-22B-v0.1" : (
"mistral-community/Codestral-22B-v0.1",
),
"unsloth/gemma-2-9b-bnb-4bit" : (
"unsloth/gemma-2-9b",
"google/gemma-2-9b",
),
"unsloth/gemma-2-27b-bnb-4bit" : (
"unsloth/gemma-2-27b",
"google/gemma-2-27b",
),
"unsloth/gemma-2-9b-it-bnb-4bit" : (
"unsloth/gemma-2-9b-it",
"google/gemma-2-9b-it",
),
"unsloth/gemma-2-27b-it-bnb-4bit" : (
"unsloth/gemma-2-27b-it",
"google/gemma-2-27b-it",
),
"unsloth/Phi-3-mini-4k-instruct-v0-bnb-4bit" : ( # Old Phi pre July
"unsloth/Phi-3-mini-4k-instruct-v0",
),
"unsloth/Mistral-Nemo-Instruct-2407-bnb-4bit" : ( # New 12b Mistral models
"unsloth/Mistral-Nemo-Instruct-2407",
"mistralai/Mistral-Nemo-Instruct-2407",
),
"unsloth/Mistral-Nemo-Base-2407-bnb-4bit" : ( # New 12b Mistral models
"unsloth/Mistral-Nemo-Base-2407",
"mistralai/Mistral-Nemo-Base-2407",
),
"unsloth/Meta-Llama-3.1-8B-bnb-4bit" : (
"unsloth/Meta-Llama-3.1-8B",
"meta-llama/Meta-Llama-3.1-8B",
),
"unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit" : (
"unsloth/Meta-Llama-3.1-8B-Instruct",
"meta-llama/Meta-Llama-3.1-8B-Instruct",
),
"unsloth/Meta-Llama-3.1-70B-bnb-4bit" : (
"meta-llama/Meta-Llama-3.1-70B",
),
"unsloth/Meta-Llama-3.1-405B-bnb-4bit" : (
"meta-llama/Meta-Llama-3.1-405B",
),
"unsloth/Meta-Llama-3.1-405B-Instruct-bnb-4bit" : (
"meta-llama/Meta-Llama-3.1-405B-Instruct",
),
"unsloth/Meta-Llama-3.1-70B-Instruct-bnb-4bit" : (
"meta-llama/Meta-Llama-3.1-70B-Instruct",
),
"unsloth/Mistral-Large-Instruct-2407-bnb-4bit" : (
"mistralai/Mistral-Large-Instruct-2407",
),
}
INT_TO_FLOAT_MAPPER = {}
FLOAT_TO_INT_MAPPER = {}
for key, values in __INT_TO_FLOAT_MAPPER.items():
INT_TO_FLOAT_MAPPER[key] = values[0]
for value in values:
FLOAT_TO_INT_MAPPER[value] = key
pass
# Get lowercased
lowered_key = key.lower()
INT_TO_FLOAT_MAPPER[lowered_key] = values[0].lower()
for value in values:
FLOAT_TO_INT_MAPPER[value.lower()] = lowered_key
pass
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .llama import *
import os
from ._utils import __version__
from .llama import (
LlamaRotaryEmbedding,
LlamaLinearScalingRotaryEmbedding,
)
from transformers.models.mistral.modeling_mistral import (
MistralAttention,
MistralDecoderLayer,
MistralModel,
MistralForCausalLM,
)
# For Pytorch 2.1.1
try:
from transformers.models.mistral.modeling_mistral import (
MistralSdpaAttention,
MistralFlashAttention2,
)
except:
MistralSdpaAttention = MistralAttention
MistralFlashAttention2 = MistralAttention
pass
def MistralAttention_fast_forward(
self,
hidden_states: torch.Tensor,
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
*args, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# Clear inference
if hasattr(self, "paged_attention"):
del self.paged_attention_K
del self.paged_attention_V
del self.paged_attention
del self.temp_QA
del self.temp_KV
del self.RH_Q
del self.attention
pass
bsz, q_len, _ = hidden_states.size()
n_heads = self.num_heads
n_groups = self.num_key_value_groups
n_kv_heads = self.num_key_value_heads
head_dim = self.head_dim
assert(n_kv_heads * n_groups == n_heads)
Q, K, V = self.apply_qkv(self, hidden_states)
Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
kv_seq_len = K.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
# Extend RoPE dynamically to fit in VRAM
self.rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len)
if position_ids is None:
cos = self.rotary_emb.cos_cached
sin = self.rotary_emb.sin_cached
Q, K = fast_rope_embedding(Q, K, cos, sin)
else:
cos, sin = self.rotary_emb(V, seq_len = kv_seq_len)
Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
pass
if past_key_value is not None:
K = torch.cat([past_key_value[0], K], dim = 2)
V = torch.cat([past_key_value[1], V], dim = 2)
pass
past_key_value = (K, V) if use_cache else None
# Attention module
if (not HAS_FLASH_ATTENTION and attention_mask is None):
# Xformers memory efficient attention
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
K_M = V_M = bsz * kv_seq_len
Q_M = bsz * q_len
has_swa = isinstance(causal_mask, xformers.attn_bias.BlockDiagonalCausalMask)
# Group query attention
K = K .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
V = V .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
if hidden_states.requires_grad:
K = K.reshape(bsz, kv_seq_len, n_heads, head_dim)
V = V.reshape(bsz, kv_seq_len, n_heads, head_dim)
if has_swa:
Q = Q.view(1, Q_M, n_heads, head_dim)
K = K.view(1, K_M, n_heads, head_dim)
V = V.view(1, V_M, n_heads, head_dim)
pass
else:
# Xformers does support the forward pass though
Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
if has_swa:
Q = Q.view(1, Q_M, n_kv_heads, n_groups, head_dim)
K = K.view(1, K_M, n_kv_heads, n_groups, head_dim)
V = V.view(1, V_M, n_kv_heads, n_groups, head_dim)
pass
pass
A = xformers_attention(Q, K, V, attn_bias = causal_mask)
A = A.view(bsz, q_len, n_heads, head_dim)
elif HAS_FLASH_ATTENTION and attention_mask is None:
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
sw = getattr(self.config, "sliding_window", None)
sw = kv_seq_len if (sw is None or sw == "null") else sw
window = (-1, -1) if (kv_seq_len <= sw) else (sw, sw)
A = flash_attn_func(Q, K, V, causal = True, window_size = window)
else:
# Grouped query attention
# if n_groups != 1:
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
K = K.reshape(bsz, n_heads, kv_seq_len, head_dim)
V = V.reshape(bsz, n_heads, kv_seq_len, head_dim)
# pass
# Must be contiguous or else results are False!
# https://github.com/pytorch/pytorch/issues/112577
Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous()
# Needs (batch_size, n_heads, seq_len, head_dim)
# is_casual and attention_mask must not be both set!
A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False)
# Go back to (batch_size, seq_len, n_heads, head_dim)
A = A.transpose(1, 2).contiguous()
pass
attn_output = A.reshape(bsz, q_len, n_heads*head_dim)
attn_output = self.apply_o(self, attn_output)
attn_weights = None
return attn_output, attn_weights, past_key_value
pass
def MistralForCausalLM_fast_forward(
self,
input_ids: torch.LongTensor = None,
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
*args, **kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
if causal_mask is None and past_key_values is None:
bsz, q_len = input_ids.shape
sliding_window = getattr(self.config, "sliding_window", None)
if sliding_window is None or sliding_window == "null" or sliding_window <= 0:
causal_mask = xformers.attn_bias.LowerTriangularMask()
elif q_len <= sliding_window:
causal_mask = xformers.attn_bias.LowerTriangularMask()
else:
# Fix from https://github.com/Rypo
causal_mask = xformers.attn_bias.BlockDiagonalCausalMask\
.from_seqlens([q_len]*bsz)\
.make_local_attention(window_size = sliding_window)
pass
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
self.model._has_no_labels = labels is None
if past_key_values is not None:
outputs = LlamaModel_fast_forward_inference(
self,
input_ids,
past_key_values,
position_ids = position_ids,
attention_mask = attention_mask,
)
else:
outputs = self.model(
input_ids=input_ids,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pass
hidden_states = outputs[0]
bsz, q_len, hd = hidden_states.shape
lm_head = self.lm_head.weight
if bsz == 1 and q_len == 1:
logits = torch.mv(lm_head, hidden_states.ravel().to(lm_head.dtype))
logits = logits.unsqueeze(0).unsqueeze(0)
else:
logits = self.lm_head(hidden_states.to(lm_head.dtype))
pass
logits = logits.to(self.config.torch_dtype)
loss = None
if labels is not None:
shift_logits = logits
if not hasattr(self, "extra_ignored_labels"):
# Fixes https://github.com/unslothai/unsloth/issues/10
self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = "cuda:0")
pass
shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
loss = fast_cross_entropy_loss(
logits = shift_logits,
labels = shift_labels,
)
pass
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
pass
# Transformers had to update for Mistral Nemo 12b since Attention is (5120, 4096) now.
def patch_mistral_nemo_attention(function):
function = function.replace(
"(self.head_dim * self.num_heads) != self.hidden_size",
"False",
)
function = function.replace(
"self.head_dim = self.hidden_size // self.num_heads",
"self.head_dim = config.head_dim",
)
function = function.replace(
"self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)",
"self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)",
)
return function
pass
class FastMistralModel(FastLlamaModel):
@staticmethod
def pre_patch():
init_name, function = patch_linear_scaling(
model_name = "mistral",
rope_module = LlamaRotaryEmbedding,
scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
attention_module = MistralAttention,
)
# Just for Mistral Nemo models!
if function is not None:
function = patch_mistral_nemo_attention(function)
# if True:#init_name is not None:
exec(function, globals())
MistralAttention.__init__ = eval(init_name)
pass
MistralAttention .forward = MistralAttention_fast_forward
MistralSdpaAttention .forward = MistralAttention_fast_forward
MistralFlashAttention2.forward = MistralAttention_fast_forward
MistralDecoderLayer .forward = LlamaDecoderLayer_fast_forward
MistralModel .forward = LlamaModel_fast_forward
MistralForCausalLM .forward = MistralForCausalLM_fast_forward
PeftModelForCausalLM .forward = PeftModelForCausalLM_fast_forward
fix_prepare_inputs_for_generation(MistralForCausalLM)
# Solves https://github.com/unslothai/unsloth/issues/168
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
# Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
# https://github.com/huggingface/transformers/pull/27931
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
import transformers.models.mistral.modeling_mistral
transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding = LlamaRotaryEmbedding
return
pass
@staticmethod
def from_pretrained(
model_name = "unsloth/mistral-7b-bnb-4bit",
max_seq_length = None,
dtype = None,
load_in_4bit = True,
token = None,
device_map = "sequential",
rope_scaling = None, # Mistral does not support RoPE scaling
fix_tokenizer = True,
model_patcher = None,
tokenizer_name = None,
trust_remote_code = False,
**kwargs,
):
return FastLlamaModel.from_pretrained(
model_name = model_name,
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
token = token,
device_map = device_map,
rope_scaling = rope_scaling,
fix_tokenizer = fix_tokenizer,
model_patcher = FastMistralModel,
tokenizer_name = tokenizer_name,
trust_remote_code = trust_remote_code,
**kwargs,
)
pass
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment