Commit 0b5cd1a0 authored by liangjing's avatar liangjing
Browse files

update

parent 5352a639
Pipeline #1848 passed with stage
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import triton
import triton.language as tl
import torch
from .utils import calculate_settings
@triton.jit
def _fg_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
block_idx = tl.program_id(0)
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
# f = e * sigmoid(e)
f_row = e_row * tl.sigmoid(e_row) # e_row / (1 + tl.exp(-e_row))
f_row = f_row.to(g_row.dtype) # Exact copy from HF
# h = f * g
h_row = f_row * g_row
# Store h
tl.store(h + offsets, h_row, mask = mask)
pass
def swiglu_fg_kernel(e, g):
batch, seq_len, hd = e.shape
n_elements = e.numel()
h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = "cuda:0")
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_fg_kernel[grid](e, g, h, n_elements, BLOCK_SIZE = 1024,)
return h
pass
@triton.jit
def _DWf_DW_dfg_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
"""
e = e.float()
se = 1.0 / (1.0 + torch.exp(-e))
f = (se * e).to(dtype)
h = f * g
df = DW * f
dg = DW * g
de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)
"""
block_idx = tl.program_id(0)
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
# e = e.float()
# se = 1.0 / (1.0 + torch.exp(-e))
se_row = tl.sigmoid(e_row) # 1.0 / (1.0 + tl.exp(-e_row))
# f = (se * e).to(dtype)
f_row = se_row * e_row
f_row = f_row.to(DW_row.dtype)
# h = f * g
h_row = f_row * g_row
# df = DW * f
df_row = DW_row * f_row
# dg = DW * g
dg_row = DW_row * g_row
# de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)
de_row = dg_row.to(tl.float32) * se_row * (1.0 + e_row * (1.0 - se_row))
de_row = de_row.to(DW_row.dtype)
# Store derivatives in buffers
tl.store(DW + offsets, h_row, mask = mask) # h = f * g
tl.store(e + offsets, df_row, mask = mask) # df = DW * f
tl.store(g + offsets, de_row, mask = mask) # de
pass
def swiglu_DWf_DW_dfg_kernel(DW, e, g):
batch_seq_len, hd = e.shape
n_elements = e.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_DWf_DW_dfg_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
return DW, e, g
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import triton
MAX_FUSED_SIZE = 16384 #65536
next_power_of_2 = triton.next_power_of_2
# torch.cuda.amp.custom_fwd is deprecated >= 2.4
import torch
from packaging.version import Version
if Version(torch.__version__) < Version("2.4.0"):
torch_amp_custom_fwd = torch.cuda.amp.custom_fwd
torch_amp_custom_bwd = torch.cuda.amp.custom_bwd
else:
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda")
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda")
pass
# tl.math.tanh now is libdevice.tanh
from packaging.version import Version
import triton
if Version(triton.__version__) >= Version("3.0.0"):
from triton.language.extra import libdevice
triton_tanh = libdevice.tanh
else:
import triton.language as tl
triton_tanh = tl.math.tanh
pass
def calculate_settings(n):
BLOCK_SIZE = next_power_of_2(n)
if BLOCK_SIZE > MAX_FUSED_SIZE:
raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
num_warps = 4
if BLOCK_SIZE >= 32768: num_warps = 8 #32
elif BLOCK_SIZE >= 8192: num_warps = 8 #16
elif BLOCK_SIZE >= 2048: num_warps = 8
return BLOCK_SIZE, num_warps
pass
import bitsandbytes as bnb
# https://github.com/bitsandbytes-foundation/bitsandbytes/pull/1330/files
HAS_CUDA_STREAM = Version(bnb.__version__) > Version("0.43.3")
global CUDA_STREAM
CUDA_STREAM = None
get_ptr = bnb.functional.get_ptr
import ctypes
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4
cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4
cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16
cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemm_4bit_inference_naive_bf16
def QUANT_STATE(W):
return getattr(W, "quant_state", None)
pass
def get_lora_parameters(proj):
# For DPO or disabled adapters
base_layer = (proj.base_layer if hasattr(proj, "base_layer") else proj)
W = base_layer.weight
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
return W, QUANT_STATE(W), None, None, None
pass
active_adapter = proj.active_adapters[0] if \
hasattr(proj, "active_adapters") else proj.active_adapter
A = proj.lora_A [active_adapter].weight
B = proj.lora_B [active_adapter].weight
s = proj.scaling[active_adapter]
return W, QUANT_STATE(W), A, B, s
pass
def get_lora_parameters_bias(proj):
# For DPO or disabled adapters
base_layer = (proj.base_layer if hasattr(proj, "base_layer") else proj)
W = base_layer.weight
bias = base_layer.bias
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
return W, QUANT_STATE(W), None, None, None, bias
pass
active_adapter = proj.active_adapters[0] if \
hasattr(proj, "active_adapters") else proj.active_adapter
A = proj.lora_A [active_adapter].weight
B = proj.lora_B [active_adapter].weight
s = proj.scaling[active_adapter]
return W, QUANT_STATE(W), A, B, s, bias
pass
if HAS_CUDA_STREAM:
def fast_dequantize(W, quant_state = None, out = None):
if quant_state is None: return W
if type(quant_state) is not list:
# New quant_state as a class
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
absmax = quant_state.absmax
shape = quant_state.shape
dtype = quant_state.dtype
blocksize = quant_state.blocksize
offset = quant_state.offset
state2 = quant_state.state2
absmax2 = state2.absmax
code2 = state2.code
blocksize2 = state2.blocksize
else:
# Old quant_state as a list of lists
absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
offset, state2 = compressed_stats
absmax2, code2, blocksize2, _, _, _, _ = state2
pass
global CUDA_STREAM
if CUDA_STREAM is None: CUDA_STREAM = torch.cuda.current_stream("cuda:0")
# Create weight matrix
if out is None:
out = torch.empty(shape, dtype = dtype, device = "cuda:0")
else:
assert(out.shape == shape)
assert(out.dtype == dtype)
# NF4 dequantization of statistics
n_elements_absmax = absmax.numel()
out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0")
# Do dequantization
ptr_out_absmax = get_ptr(out_absmax)
cdequantize_blockwise_fp32(
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax,
ctypes.c_int(blocksize2), ctypes.c_int(n_elements_absmax), CUDA_STREAM,
)
out_absmax += offset
fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \
cdequantize_blockwise_bf16_nf4
fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out),
ctypes.c_int(blocksize), ctypes.c_int(out.numel()), CUDA_STREAM,)
# Careful returning transposed data
is_transposed = (True if W.shape[0] == 1 else False)
return out.t() if is_transposed else out
pass
else:
def fast_dequantize(W, quant_state = None, out = None):
if quant_state is None: return W
if type(quant_state) is not list:
# New quant_state as a class
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
absmax = quant_state.absmax
shape = quant_state.shape
dtype = quant_state.dtype
blocksize = quant_state.blocksize
offset = quant_state.offset
state2 = quant_state.state2
absmax2 = state2.absmax
code2 = state2.code
blocksize2 = state2.blocksize
else:
# Old quant_state as a list of lists
absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
offset, state2 = compressed_stats
absmax2, code2, blocksize2, _, _, _, _ = state2
pass
# Create weight matrix
if out is None:
out = torch.empty(shape, dtype = dtype, device = "cuda:0")
else:
assert(out.shape == shape)
assert(out.dtype == dtype)
# NF4 dequantization of statistics
n_elements_absmax = absmax.numel()
out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0")
# Do dequantization
ptr_out_absmax = get_ptr(out_absmax)
cdequantize_blockwise_fp32(
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax,
ctypes.c_int(blocksize2), ctypes.c_int(n_elements_absmax),
)
out_absmax += offset
fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \
cdequantize_blockwise_bf16_nf4
fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out),
ctypes.c_int(blocksize), ctypes.c_int(out.numel()),)
# Careful returning transposed data
is_transposed = (True if W.shape[0] == 1 else False)
return out.t() if is_transposed else out
pass
pass
if HAS_CUDA_STREAM:
def fast_gemv(X, W, quant_state, out = None):
if quant_state is None: return torch.matmul(X, W, out = out)
# For fast X @ W where seq_len == 1
# From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
_, q_len, hd = X.shape
# assert(q_len == 1)
if type(quant_state) is not list:
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
absmax = quant_state.absmax
shape = quant_state.shape
dtype = quant_state.dtype
blocksize = quant_state.blocksize
stats = quant_state.code
offset = quant_state.offset
state2 = quant_state.state2
absmax2 = state2.absmax
code2 = state2.code
blocksize2 = state2.blocksize
else:
absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = quant_state
offset, state2 = compressed_stats
absmax2, code2, blocksize2, _, _, _, _ = state2
pass
global CUDA_STREAM
if CUDA_STREAM is None: CUDA_STREAM = torch.cuda.current_stream("cuda:0")
# assert(dtype == X.dtype)
bout = shape[0]
if out is None:
out = torch.empty((1, 1, bout,), dtype = dtype, device = "cuda:0")
# else:
# assert(out.shape == (1, 1, bout,))
# pass
n = 1
m = shape[0]
k = shape[1]
lda = shape[0]
ldc = shape[0]
ldb = (hd+1)//2
m = ctypes.c_int32(m)
n = ctypes.c_int32(n)
k = ctypes.c_int32(k)
lda = ctypes.c_int32(lda)
ldb = ctypes.c_int32(ldb)
ldc = ctypes.c_int32(ldc)
df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda:0")
cdequantize_blockwise_fp32(
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df),
ctypes.c_int(blocksize2), ctypes.c_int(df.numel()), CUDA_STREAM,
)
df += offset
absmax = df
fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else \
cgemm_4bit_inference_naive_bf16
blocksize = ctypes.c_int32(blocksize)
fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out),
lda, ldb, ldc, blocksize, CUDA_STREAM,)
return out
pass
else:
def fast_gemv(X, W, quant_state, out = None):
if quant_state is None: return torch.matmul(X, W, out = out)
# For fast X @ W where seq_len == 1
# From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
_, q_len, hd = X.shape
# assert(q_len == 1)
if type(quant_state) is not list:
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
absmax = quant_state.absmax
shape = quant_state.shape
dtype = quant_state.dtype
blocksize = quant_state.blocksize
stats = quant_state.code
offset = quant_state.offset
state2 = quant_state.state2
absmax2 = state2.absmax
code2 = state2.code
blocksize2 = state2.blocksize
else:
absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = quant_state
offset, state2 = compressed_stats
absmax2, code2, blocksize2, _, _, _, _ = state2
pass
# assert(dtype == X.dtype)
bout = shape[0]
if out is None:
out = torch.empty((1, 1, bout,), dtype = dtype, device = "cuda:0")
# else:
# assert(out.shape == (1, 1, bout,))
# pass
n = 1
m = shape[0]
k = shape[1]
lda = shape[0]
ldc = shape[0]
ldb = (hd+1)//2
m = ctypes.c_int32(m)
n = ctypes.c_int32(n)
k = ctypes.c_int32(k)
lda = ctypes.c_int32(lda)
ldb = ctypes.c_int32(ldb)
ldc = ctypes.c_int32(ldc)
df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda:0")
cdequantize_blockwise_fp32(
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df),
ctypes.c_int(blocksize2), ctypes.c_int(df.numel()),
)
df += offset
absmax = df
fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else \
cgemm_4bit_inference_naive_bf16
blocksize = ctypes.c_int32(blocksize)
fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out),
lda, ldb, ldc, blocksize,)
return out
pass
pass
def fast_linear_forward(proj, X, temp_lora = None, out = None):
W, W_quant, lora_A, lora_B, lora_S, bias = get_lora_parameters_bias(proj)
bsz, q_len, in_dim = X.shape
if q_len != 1: return matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S)
if W_quant is None:
out = torch.matmul(X, W.t(), out = out)
elif bsz == 1 and q_len == 1:
out = fast_gemv(X, W, W_quant, out = out)
else:
W = fast_dequantize(W.t(), W_quant)
out = torch.matmul(X, W, out = out)
pass
# Add in LoRA weights
if lora_A is not None:
out_dim = out.shape[2]
dtype = X.dtype
if not hasattr(lora_A, "_fast_lora"):
lora_A._fast_lora = lora_A.to(dtype)
lora_B._fast_lora = lora_B.to(dtype)
pass
if bsz == 1:
out = out.view(out_dim)
temp_lora = torch.mv(lora_A._fast_lora, X.ravel(), out = temp_lora)
out.addmv_(lora_B._fast_lora, temp_lora, alpha = lora_S)
else:
out = out.view(bsz, out_dim)
temp_lora = torch.mm(X.view(bsz, in_dim), lora_A._fast_lora.t(), out = temp_lora)
out.addmm_(temp_lora, lora_B._fast_lora.t(), alpha = lora_S)
pass
out = out.view(bsz, 1, out_dim)
pass
if bias is not None: out += bias
return out
pass
def matmul_lora(X, W, W_quant, A, B, s, out = None):
dtype = X.dtype
W = fast_dequantize(W.t(), W_quant)
if X.dim() == 3:
batch, seq_len, d = X.shape
X = X.view(-1, X.shape[-1])
reshape = True
else:
reshape = False
pass
out = torch.matmul(X, W, out = out)
if W_quant is not None: del W
if A is not None:
# LoRA is enabled
A, B = A.t(), B.t()
out += (X @ A.to(dtype)) @ (s * B.to(dtype))
pass
return out.view(batch, seq_len, -1) if reshape else out
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .loader import FastLanguageModel
from .llama import FastLlamaModel
from .mistral import FastMistralModel
from .qwen2 import FastQwen2Model
from .dpo import PatchDPOTrainer
from ._utils import is_bfloat16_supported
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "2024.10.7"
__all__ = [
"prepare_model_for_kbit_training",
"xformers",
"xformers_attention",
"xformers_version",
"__version__",
"HAS_FLASH_ATTENTION",
"HAS_FLASH_ATTENTION_SOFTCAPPING",
"PRE_CHECK",
"platform_system",
"patch_tokenizer",
"get_statistics",
"Unsloth_Offloaded_Gradient_Checkpointer",
"offload_to_disk",
"offload_input_embeddings",
"offload_output_embeddings",
"is_bfloat16_supported",
"unsloth_offloaded_gradient_checkpoint",
"torch_compile_options",
"patch_linear_scaling",
"patch_llama_rope_scaling",
"check_nvidia",
"create_boolean_mask",
"torch_amp_custom_fwd",
"torch_amp_custom_bwd",
"accelerate_old_send_to_device",
"accelerate_new_send_to_device",
"patch_gradient_checkpointing",
"unpatch_gradient_checkpointing",
"patch_gradient_accumulation_fix",
]
import torch
from typing import Union, Optional, List, Any, Callable, Tuple
from platform import system as platform_system
platform_system = platform_system()
import numpy as np
import warnings, subprocess, re, inspect, psutil, os, math
from packaging.version import Version
# =============================================
# Disable some warnings which can get annoying
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "torch")
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "huggingface_hub")
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "trl")
warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "huggingface_hub")
warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "xformers")
warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "subprocess")
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "transformers")
warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "accelerate")
warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "multiprocessing")
warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "multiprocess")
# Stop "Special tokens have been added in the vocabulary, ..."
import logging
logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.CRITICAL+1)
# =============================================
# =============================================
# Edits all Config files to enable RoPE Scaling for all models
# Transformers had to update for Mistral Nemo 12b since Attention is (5120, 4096) now.
def patch_mistral_nemo_config(config):
if "head_dim (" not in config:
add_head_dim = "If it is not specified, will default to `8`.\n"\
" head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):\n"\
" The attention head dimension."
config = config.replace("If it is not specified, will default to `8`.", add_head_dim)
add_head_dim = "num_key_value_heads=8,\n head_dim=None,"
config = config.replace("num_key_value_heads=8,", add_head_dim)
add_head_dim = "self.sliding_window = sliding_window\n self.head_dim = head_dim or hidden_size // num_attention_heads\n"
config = config.replace("self.sliding_window = sliding_window", add_head_dim)
pass
return config
pass
from transformers import __version__ as transformers_version
from transformers import PretrainedConfig
model_architectures = ["llama", "mistral", "gemma", "gemma2", "qwen2",]
for model_name in model_architectures:
config_filepath = f"transformers.models.{model_name}.configuration_{model_name}"
model_filepath = f"transformers.models.{model_name}.modeling_{model_name}"
config_filename = f"{model_name.title()}Config"
exec(f"from {config_filepath} import {config_filename}", globals())
try:
config = inspect.getsource(eval(config_filename))
except:
continue
if "rope_scaling" in config: continue
config = re.sub(
r"(\*\*kwargs)[\s]{0,}\,[\s]{0,}\)[\s]{0,}\:",
r"rope_scaling=None,"\
r"\n **kwargs):\n"\
r"\n self.rope_scaling = rope_scaling\n",
config,
)
# Just for Mistral Nemo
if model_name == "mistral":
if Version(transformers_version) <= Version("4.42.4"):
config = patch_mistral_nemo_config(config)
pass
exec(config, globals())
exec(f"import {config_filepath}", globals())
exec(f"{config_filepath}.{config_filename} = {config_filename}", globals())
pass
# =============================================
# =============================================
# torch.cuda.amp.custom_fwd is deprecated >= 2.4
import torch
torch_version = torch.__version__
if Version(torch_version) < Version("2.4.0"):
torch_amp_custom_fwd = torch.cuda.amp.custom_fwd
torch_amp_custom_bwd = torch.cuda.amp.custom_bwd
else:
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda")
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda")
pass
# =============================================
# =============================================
# Fix KeyError: 'Cache only has 0 layers, attempted to access layer with index 0'
import transformers.cache_utils
if hasattr(transformers.cache_utils, "DynamicCache") and \
transformers.cache_utils.DynamicCache.__getitem__.__name__ != "__cache_utils_getitem__":
source = inspect.getsource(transformers.cache_utils.DynamicCache.__getitem__)
start = source.find("def")
spaces = start*" "
source = source.split("\n")
source = "\n".join(x[start:] for x in source)
where = source.find("raise KeyError")
source = source[:where] + \
f"if len(self) == 0:\n{spaces}{spaces}"\
" raise RuntimeError('Unsloth: You must call `FastLanguageModel.for_inference(model)` before doing inference for Unsloth models.')\n" + \
f"{spaces}{spaces}else:\n{spaces}{spaces}{spaces}" + source[where:]
source = source.replace("__getitem__", "__cache_utils_getitem__", 1)
exec(source)
transformers.cache_utils.DynamicCache.__getitem__ = __cache_utils_getitem__
pass
# =============================================
# =============================================
# Weird Databricks errors
from transformers.utils import is_openai_available
if is_openai_available():
try:
from openai import OpenAI
except:
print("Unsloth: OpenAI failed to import - ignoring for now.")
import transformers.utils
def _is_openai_available(): return False
transformers.utils.is_openai_available = _is_openai_available
pass
pass
# =============================================
# Get Flash Attention v2 if Ampere (RTX 30xx, A100)
import bitsandbytes as bnb
from transformers import AutoTokenizer
from transformers.utils.import_utils import _is_package_available
major_version, minor_version = torch.cuda.get_device_capability()
SUPPORTS_BFLOAT16 = False
HAS_FLASH_ATTENTION = False
HAS_FLASH_ATTENTION_SOFTCAPPING = False
if major_version >= 8:
SUPPORTS_BFLOAT16 = True
if _is_package_available("flash_attn"):
# Check for CUDA linking errors "undefined symbol: _ZNK3c106SymIntltEl"
try:
from flash_attn.flash_attn_interface import flash_attn_cuda
HAS_FLASH_ATTENTION = True
# Also check for softcapping
from flash_attn import __version__ as flash_attn_version
HAS_FLASH_ATTENTION_SOFTCAPPING = Version(flash_attn_version) >= Version("2.6.3")
if not HAS_FLASH_ATTENTION_SOFTCAPPING:
print(
"Unsloth: If you want to finetune Gemma 2, upgrade flash-attn to version 2.6.3 or higher!\n"\
"Newer versions support faster and less memory usage kernels for Gemma 2's attention softcapping!\n"\
"To update flash-attn, do the below:\n"\
'\npip install --no-deps --upgrade "flash-attn>=2.6.3"'
)
except:
print(
"Unsloth: Your Flash Attention 2 installation seems to be broken?\n"\
"A possible explanation is you have a new CUDA version which isn't\n"\
"yet compatible with FA2? Please file a ticket to Unsloth or FA2.\n"\
"We shall now use Xformers instead, which does not have any performance hits!\n"\
"We found this negligible impact by benchmarking on 1x A100."
)
# Stop Flash Attention from importing!
import transformers.utils.import_utils
transformers.utils.import_utils.is_flash_attn_2_available = lambda *args, **kwargs: False
import transformers.utils
transformers.utils.is_flash_attn_2_available = lambda *args, **kwargs: False
HAS_FLASH_ATTENTION = False
pass
else:
HAS_FLASH_ATTENTION = False
else:
# Tri Dao's benchmark shows xformers is faster for now.
HAS_FLASH_ATTENTION = False
pass
from transformers.models.llama.modeling_llama import logger
# =============================================
# Get Xformers
from xformers import __version__ as xformers_version
# Temporarily disable 0.0.27 and higher - inference issues
if False: #Version(xformers_version) >= Version("0.0.27"):
raise ImportError(
"Unsloth: If you are in Colab, we updated the top cell install instructions - please change it to below "\
"then press Disconnect Runtime and then Restart it.\n"\
"\n"\
"%%capture\n"
"# Installs Unsloth, Xformers (Flash Attention) and all other packages!\n"
'!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"\n'
'!pip install --no-deps "xformers<=0.0.27" trl peft accelerate bitsandbytes\n'\
'\n'\
f"Otherwise in local machines, your xformers version of {xformers_version} is too new.\n"\
'Please downgrade xformers via `pip install --force-reinstall "xformers<=0.0.27"'
)
pass
if Version(torch_version) < Version("2.2.0") and Version(xformers_version) >= Version("0.0.24"):
raise ImportError(
f"Unsloth: You have torch = {torch_version} but xformers = {xformers_version}.\n"\
f"Please install xformers < 0.0.24 for torch = {torch_version}."
)
elif Version(torch_version) < Version("2.3.0") and Version(xformers_version) >= Version("0.0.26"):
raise ImportError(
f"Unsloth: You have torch = {torch_version} but xformers = {xformers_version}.\n"\
f"Please install xformers < 0.0.26 for torch = {torch_version}."
)
elif Version(torch_version) < Version("2.4.0") and Version(xformers_version) > Version("0.0.27"):
raise ImportError(
f"Unsloth: You have torch = {torch_version} but xformers = {xformers_version}.\n"\
f"Please install xformers <= 0.0.27 for torch = {torch_version}."
)
pass
from xformers._cpp_lib import _register_extensions
try:
_register_extensions() # Check if C++ modules are loaded correctly
except Exception as error:
raise ImportError(
"Unsloth: Xformers was not installed correctly.\n"\
"Please install xformers separately first.\n"\
"Then confirm if it's correctly installed by running:\n"\
"python -m xformers.info\n\n"
"Longer error message:\n" + str(error)
)
pass
import xformers.ops.fmha as xformers
xformers_attention = xformers.memory_efficient_attention
# Check TRL version
from trl import __version__ as trl_version
# Unsloth now supports all TRL versions!
if False:#Version(trl_version) >= Version("0.9.0"):
raise ImportError(
"Unsloth: If you are in Colab, we updated the top cell install instructions - please change it to below "\
"then press Disconnect Runtime and then Restart it.\n"\
"\n"\
"%%capture\n"
"# Installs Unsloth, Xformers (Flash Attention) and all other packages!\n"
'!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"\n'
'!pip install --no-deps "xformers<=0.0.27" trl peft accelerate bitsandbytes\n'\
'\n'\
f"Otherwise in local machines, your TRL version of {trl_version} is too new.\n"\
'Please downgrade TRL via `pip install --force-reinstall trl'
)
pass
# =============================================
# Fix new Xformers versions TypeError: Multiple dispatch failed for 'torch._ops.aten.to.dtype_layout'
accelerate_old_send_to_device = None
accelerate_new_send_to_device = None
if Version(xformers_version) >= Version("0.0.27"):
import accelerate.utils.operations
if hasattr(accelerate.utils.operations, "send_to_device") and \
accelerate.utils.operations.send_to_device.__name__ != "_fixed_send_to_device":
accelerate_old_send_to_device = accelerate.utils.operations.send_to_device
from accelerate.utils.operations import *
send_to_device = inspect.getsource(accelerate.utils.operations.send_to_device)
send_to_device = re.sub(
r"([ ]{4,})return tensor\.to\(device\)",
r"\1try: return tensor.to(device)\n\1except: return tensor",
send_to_device,
).replace("def send_to_device", "def _fixed_send_to_device")
exec(send_to_device)
# accelerate.utils.operations.send_to_device = _fixed_send_to_device
accelerate_new_send_to_device = _fixed_send_to_device
pass
pass
# Transformers 4.46 breaks dynamic caching. This is a hack
import transformers.generation.configuration_utils
if hasattr(transformers.generation.configuration_utils, "ALL_CACHE_IMPLEMENTATIONS"):
if type(transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS) is list:
transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS.append("dynamic")
pass
pass
# =============================================
# =============================================
# Torch compile settings
# Just remove max_autotune_gemm warning
import functools
@functools.lru_cache(None)
def is_big_gpu(index):
sms = torch.cuda.get_device_properties(index).multi_processor_count
if sms < 80: # V100
# log.warning("not enough SMs to use max_autotune_gemm mode")
return False
return True
import torch._inductor.utils
torch._inductor.utils.is_big_gpu = is_big_gpu
# Torch compile arguments
torch_compile_arguments = [
"config.dce = True",
"config.memory_planning = True",
"config.memory_pool = 'combined'",
"config.coordinate_descent_tuning = True",
"config.max_autotune_gemm = False", # GEMM is unnecessary
"config.autotune_multi_device = False",
"config.max_autotune_gemm_backends = 'TRITON,ATEN,CPP'", # Not much faster
"config.aggressive_fusion = False", # Careful changes results!
"config.cuda.enable_cuda_lto = True",
"config.cuda.use_fast_math = True",
"config.cuda.compile_opt_level = '-O2'",
]
# Torch dynamo arguments
torch_dynamo_arguments = [
"config.accumulated_cache_size_limit = 1024", # Bump up a bit from 256
"config.suppress_errors = True", # Supress errors for now
"config.do_not_emit_runtime_asserts = True",
"config.cache_size_limit = 1024", # Flex Attention
"config.inline_inbuilt_nn_modules = True", # Torch 2.5 Regional recompilation
]
import torch._inductor.config as config
for _try_compile_argument in torch_compile_arguments:
try: exec(_try_compile_argument)
except: pass
pass
import torch._dynamo.config as config
for _try_dynamo_argument in torch_dynamo_arguments:
try: exec(_try_dynamo_argument)
except: pass
pass
torch_compile_options = {
"epilogue_fusion" : True,
"max_autotune" : True,
"shape_padding" : True,
"trace.enabled" : False, # Output Triton kernel outputs!
"triton.cudagraphs" : False,
}
# =============================================
def prepare_model_for_kbit_training(
model : Any,
use_gradient_checkpointing : Optional = True,
use_reentrant : Optional[bool] = True,
) -> Any:
"""
Calculates where to place the gradient checkpoints given n_layers.
We also freeze all other layers's gradients
Args:
model: Any LlamaModel with layers.
use_gradient_checkpointing (`bool`, *optional*):
Default enabled. Provides memory savings by not saving all activations,
but only some.
use_reentrant (`bool`, *optional*):
https://github.com/pytorch/pytorch/blob/main/torch/utils/checkpoint.py#L354
Optimal gradient checkpointing algorithm which will be the default in
future Pytorch versions.
"""
# Freeze all parameters except LoRA
with torch.no_grad():
for name, param in model.named_parameters():
if ".lora_A." in name or ".lora_B." in name or ".lora_magnitude_vector" in name:
param.requires_grad_(True)
# Also must be in float32!
if param.dtype != torch.float32:
name = name.replace("base_model", "model", 1)
layer_number = re.search(r"\.[\d]{1,}\.", name).group(0)
name = name.replace(layer_number, f"[{layer_number[1:-1]}].")
name = name.replace(".weight", "", 1)
exec(f"{name}.to(torch.float32)")
pass
else:
param.requires_grad_(False)
pass
pass
# Gradient checkpointing!
if use_gradient_checkpointing == "unsloth":
# Saves VRAM!
original_model = model
while hasattr(original_model, "model"):
original_model._offloaded_gradient_checkpointing = True
original_model = original_model.model
pass
original_model._offloaded_gradient_checkpointing = True
model.gradient_checkpointing_enable()
elif use_gradient_checkpointing == True:
model.gradient_checkpointing_enable()
pass
# If use_reentrant = True which is the Pytorch default, we just make the input requires_grad.
if use_reentrant:
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
return model
pass
def patch_tokenizer(model, tokenizer):
"""
Phi3's pad_token isn't set. We set it to <|placeholder...
Llama-3 is <|reserved...
Llama-2 is <unk>
Check if pad_token is not the same as eos_token otherwise the loss will ignore it!!
Fixes https://github.com/unslothai/unsloth/issues/5
"""
possible_reserved_tokens = (
"<|finetune_right_pad_id|>", # Llama-3.1
"<pad>", # Mistral Nemo
"<|reserved", # Llama-3
"<|placeholder", # Phi-3
"[control", # Mistral type models
)
joiner = "\1\0=+=\0\1"
number_repetitions = 3 - 1 # Number of reserved tokens needed
if model is not None:
model.config.update({"unsloth_version" : __version__})
bad_pad_token = False
if hasattr(tokenizer, "pad_token") and tokenizer.pad_token is not None:
# Check if pad_token is not the same as eos_token otherwise the loss will ignore it!!
bad_pad_token = tokenizer.eos_token == tokenizer.pad_token
elif hasattr(tokenizer, "pad_token") and tokenizer.pad_token is None:
bad_pad_token = True
else:
bad_pad_token = False
pass
if bad_pad_token:
# Find a better pad token
added_tokens = [str(x) for x in tokenizer.added_tokens_decoder.values()]
all_added_tokens = joiner.join(added_tokens[::-1])
all_added_tokens += joiner
final_pad_token = None
final_good_match = False
for possible_reserved_token in possible_reserved_tokens:
possible_reserved_token = re.escape(possible_reserved_token)
found = re.finditer(f"{possible_reserved_token}", all_added_tokens)
first_match = None
good_match = False
for j, x in enumerate(found):
if j == 0: first_match = x
if j >= number_repetitions:
good_match = True
break
pass
pass
if first_match is None: continue
# If it ends with |> or > etc, then set it as a good pad token!
start = first_match.span(0)[0]
possible_pad_token = first_match.group(0)
end = all_added_tokens.find(joiner, start)
first_match = all_added_tokens[start:end]
if first_match is not None:
good_match = possible_pad_token.endswith((">", "|>", "]", ")"))
pass
possible_pad_token = first_match
# Replace current pad token if another exact match is found
if not final_good_match and good_match:
final_good_match = True
final_pad_token = possible_pad_token
break
else:
final_good_match = False
final_pad_token = possible_pad_token
pass
pass
possible_pad_token = final_pad_token
# Try unk_token
if possible_pad_token is None and hasattr(tokenizer, "unk_token"):
possible_pad_token = tokenizer.unk_token
pass
# Check pad token's id must be less than vocab size
if possible_pad_token is not None:
check_pad_token = tokenizer(possible_pad_token, add_special_tokens = False).input_ids
if len(check_pad_token) != 1:
possible_pad_token = None
if model is not None and check_pad_token[0] >= model.config.vocab_size:
possible_pad_token = None
pass
if possible_pad_token is None:
# Failure to find a good replacement!! We shall manually add one!
new_pad_token = "<|PAD_TOKEN|>"
while new_pad_token in tokenizer.get_vocab():
new_pad_token = f"<{new_pad_token}>"
pass
possible_pad_token = new_pad_token
pass
name = model.config._name_or_path if model is not None else "Model"
logger.warning_once(
f"{name} does not have a padding token! Will use pad_token = {possible_pad_token}."
)
# Edit pad_token
tokenizer.add_special_tokens({"pad_token" : possible_pad_token})
tokenizer.pad_token = possible_pad_token
if model is not None:
model.config.update({"pad_token_id" : tokenizer.pad_token_id})
if getattr(model, "generation_config") is not None:
model.generation_config.update(pad_token_id = tokenizer.pad_token_id)
else:
if model is not None:
if model.config.pad_token_id is None:
model.config.update({"pad_token_id" : tokenizer.pad_token_id})
if getattr(model, "generation_config") is not None:
model.generation_config.update(pad_token_id = tokenizer.pad_token_id)
pass
pass
if model is not None:
if getattr(model, "generation_config") is not None:
model.generation_config.update(max_length = model.config.max_position_embeddings)
return model, tokenizer
pass
# =============================================
# Weirdly LoraLayer.update_layer downcasts PEFT layers to float16??
# For mixed precision, we need it to be in float32 not float16.
from peft import __version__ as peft_version
if Version(peft_version) < Version("0.12.0"):
from peft.tuners.lora.layer import LoraLayer
try:
source = inspect.getsource(LoraLayer.update_layer)
text = "if weight is not None:\n"
start = source.find(text) + len(text)
end = source.find("self.to(weight.device)", start)
spaces = re.findall(r"^([ ]{1,})break", source, flags = re.MULTILINE)[0]
source = source.replace(source[start : end], spaces)
spaces = len(re.match(r"[\s]{1,}", source).group(0))
lines = source.split("\n")
source = "\n".join(x[spaces:] for x in lines)
source = re.sub("([^\.])nn\.", r"\1torch.nn.", source)
source = source.replace("def update_layer", "def LoraLayer_update_layer")
exec(source, globals())
# Fix up incorrect downcasting of LoRA weights
from peft.tuners.lora.layer import LoraLayer
LoraLayer.update_layer = LoraLayer_update_layer
from peft.tuners.lora import LoraLayer
LoraLayer.update_layer = LoraLayer_update_layer
except:
logger.warning_once(
"Unsloth unsuccessfully patched LoraLayer.update_layer. Please file a bug report.\n"\
"Luckily, your training run will still work in the meantime!"
)
pass
pass
# =============================================
import psutil
def _get_statistics(statistics = None, force_download = True):
# We log some basic stats about which environment is being used.
# We simply download a README.md file from HF - all data is made public.
# This is simply so we can check if some envs are broken or not.
# You can disable this by commenting the below out
try:
n_cpus = psutil.cpu_count(logical = False)
keynames = "\n" + "\n".join(os.environ.keys())
if statistics is not None: pass
elif "\nCOLAB_" in keynames and n_cpus == 1: statistics = "colab"
elif "\nCOLAB_" in keynames: statistics = "colabpro"
elif "\nKAGGLE_" in keynames: statistics = "kaggle"
elif "\nRUNPOD_" in keynames: statistics = "runpod"
elif "\nAWS_" in keynames: statistics = "aws"
elif "\nAZURE_" in keynames: statistics = "azure"
# elif "\nK_" in keynames or "\nFUNCTION_" in keynames: statistics = "gcp"
elif "\nINVOCATION_ID" in keynames: statistics = "lambda"
# else: statistics = "other"
else:
def try_vllm_check():
vendor_files = (
"/sys/class/dmi/id/product_version",
"/sys/class/dmi/id/bios_vendor",
"/sys/class/dmi/id/product_name",
"/sys/class/dmi/id/chassis_asset_tag",
"/sys/class/dmi/id/sys_vendor",
)
from pathlib import Path
for vendor_file in vendor_files:
path = Path(vendor_file)
if path.is_file():
file_content = path.read_text().lower()
if "amazon" in file_content: return "aws"
elif "microsoft corporation" in file_content: return "azure"
elif "google" in file_content: return "gcp"
return "other"
pass
try: statistics = try_vllm_check()
except: statistics = "other"
pass
if statistics is not None:
from transformers import AutoModelForCausalLM
stats_model = AutoModelForCausalLM.from_pretrained(
f"unslothai/{statistics}",
force_download = force_download,
)
del stats_model
pass
except:
pass
pass
def get_statistics():
# We log some basic stats about which environment is being used.
# We simply download a README.md file from HF - all data is made public.
# This is simply so we can check if some envs are broken or not.
# You can disable this by commenting the below out
from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled
disabled = False
if not are_progress_bars_disabled():
disable_progress_bars()
disabled = True
pass
_get_statistics(None)
_get_statistics("repeat", force_download = False)
try:
vram = torch.cuda.get_device_properties(0).total_memory / 1024 / 1024 / 1024
if vram <= 8 : vram = 8
elif vram <= 16: vram = 16
elif vram <= 20: vram = 20
elif vram <= 24: vram = 24
elif vram <= 40: vram = 40
elif vram <= 48: vram = 48
elif vram <= 80: vram = 80
else: vram = 96
_get_statistics(f"vram-{vram}")
except:
pass
pass
try:
devices = torch.cuda.device_count()
_get_statistics(f"{devices if devices <= 8 else 9}")
except:
pass
if disabled: enable_progress_bars()
pass
def _calculate_n_gradient_checkpoints(
n_layers : int,
method : Optional[Union[str, int]] = "sqrt",
) -> List[int]:
assert(type(n_layers) is int and n_layers > 0)
if method is None: method = "sqrt"
if method == "sqrt":
n_checkpoints = int(n_layers**0.5)
elif type(method) is int and method > 0:
n_checkpoints = int(np.ceil(n_layers / method))
else:
raise ValueError("method must be 'sqrt' or an int >0 and <= n_layers.")
size = n_layers // n_checkpoints
sizes = np.full(n_checkpoints, size, dtype = int)
leftovers = n_layers % n_checkpoints
# We append leftovers from the right
for k in range(leftovers):
sizes[n_checkpoints-1-k] += 1
boundaries = np.hstack((0, np.cumsum(sizes)))
boundaries = boundaries.tolist()
return boundaries
pass
def calculate_n_gradient_checkpoints(
n_layers : int,
layers_per_checkpoint : Optional[Union[str, int]] = "sqrt",
) -> List[int]:
assert(type(n_layers) is int and n_layers > 0)
if layers_per_checkpoint is None or layers_per_checkpoint == 1:
return None
boundaries = _calculate_n_gradient_checkpoints(n_layers, layers_per_checkpoint)
assert(boundaries[0] == 0 and boundaries[-1] == n_layers)
assert(min(boundaries) == 0 and max(boundaries) == n_layers)
assert(np.diff(boundaries).min() >= 0)
return boundaries
pass
def prepare_n_gradient_checkpoints(
model : Any,
layers_per_checkpoint : Optional[Union[str, int]] = "sqrt",
use_reentrant : Optional[bool] = True,
) -> None:
"""
Calculates where to place the gradient checkpoints given n_layers.
Args:
model: Any LlamaModel with layers.
layers_per_checkpoint (`Union[str, int]`, *optional*):
Can either be `sqrt` or an integer for how many layers per checkpoint you want.
The more, the less memory usage, but can be slower. Default is `sqrt`.
Choose 1 for Pytorch gradient checkpointing. 2 to wrap 2 layers in 1 module etc.
use_reentrant (`bool`, *optional*):
https://github.com/pytorch/pytorch/blob/main/torch/utils/checkpoint.py#L354
Optimal gradient checkpointing algorithm `use_reentrant=False` which will
be the default in future Pytorch versions doesn't seem to work??
"""
_model = None
if hasattr(model, "layers"):
_model = model
elif hasattr(model, "model"):
if hasattr(model.model, "layers"):
_model = model.model
if _model is None:
raise TypeError("`model` or `model.model` does not have attribute `layers`. Are you sure this is a model?")
pass
if use_reentrant is False:
use_reentrant = True
pass
n_layers = len(_model.layers)
boundaries = calculate_n_gradient_checkpoints(n_layers, layers_per_checkpoint)
_model._gradient_checkpointing_boundaries = boundaries
_model._gradient_checkpointing_use_reentrant = use_reentrant
pass
class Unsloth_Offloaded_Gradient_Checkpointer(torch.autograd.Function):
"""
Saves VRAM by smartly offloading to RAM.
Tiny hit to performance, since we mask the movement via non blocking calls.
"""
@staticmethod
@torch_amp_custom_fwd
def forward(ctx, forward_function, hidden_states, *args):
saved_hidden_states = hidden_states.to("cpu", non_blocking = True)
with torch.no_grad():
output = forward_function(hidden_states, *args)
ctx.save_for_backward(saved_hidden_states)
ctx.forward_function = forward_function
ctx.args = args
return output
pass
@staticmethod
@torch_amp_custom_bwd
def backward(ctx, dY):
(hidden_states,) = ctx.saved_tensors
hidden_states = hidden_states.to("cuda:0", non_blocking = True).detach()
hidden_states.requires_grad_(True)
with torch.enable_grad():
(output,) = ctx.forward_function(hidden_states, *ctx.args)
torch.autograd.backward(output, dY)
return (None, hidden_states.grad,) + (None,)*len(ctx.args)
pass
pass
@torch._disable_dynamo
def unsloth_offloaded_gradient_checkpoint(function, *args, use_reentrant = None, **kwargs):
return Unsloth_Offloaded_Gradient_Checkpointer.apply(function, *args)
pass
import torch.utils
old_checkpoint = torch.utils.checkpoint
def patch_gradient_checkpointing():
torch.utils.checkpoint = unsloth_offloaded_gradient_checkpoint
pass
def unpatch_gradient_checkpointing():
torch.utils.checkpoint = old_checkpoint
pass
# =============================================
# Fixes Bitsandbytes to remove missing warnings
from transformers.utils.quantization_config import BitsAndBytesConfig, QuantizationMethod
from inspect import getsource
from accelerate.utils.dataclasses import DistributedType
BitsAndBytesConfig__init__ = getsource(BitsAndBytesConfig.__init__)
BitsAndBytesConfig__init__ = re.sub(
r"if[\s]{1,}kwargs\:[\s]{1,}.+?\n",
"",
BitsAndBytesConfig__init__,
flags = re.MULTILINE,
)
BitsAndBytesConfig__init__ = BitsAndBytesConfig__init__.split("\n")
length_spaces = len(re.match(r"[\s]{1,}", BitsAndBytesConfig__init__[0]).group(0))
BitsAndBytesConfig__init__ = "\n".join(x[length_spaces:] for x in BitsAndBytesConfig__init__)
BitsAndBytesConfig__init__ = BitsAndBytesConfig__init__.replace(
"__init__",
"_BitsAndBytesConfig__init__",
)
def _prepare_backend(
self, cpu: bool = False, sagemaker_dp = False, backend: str = None,
) -> tuple[str, DistributedType]:
return None, DistributedType.NO
pass
import accelerate.state
accelerate.state.PartialState._prepare_backend = _prepare_backend
import accelerate.accelerator
prepare = inspect.getsource(accelerate.accelerator.Accelerator.prepare)
prepare = prepare.split("\n")
spaces = prepare[0].find("def")
prepare = "\n".join(x[spaces:] for x in prepare)
x = "for obj in args:"
s = " "*spaces
prepare = prepare.replace(x, f'self.state.distributed_type = DistributedType.NO\n{s}{x}', 1)
exec(prepare, globals())
accelerate.accelerator.Accelerator.prepare = prepare
exec(BitsAndBytesConfig__init__, globals())
import transformers.utils.quantization_config
transformers.utils.quantization_config.BitsAndBytesConfig.__init__ = _BitsAndBytesConfig__init__
# =============================================
# Offloading to disk for modules (lm_head, embed_tokens)
import pickle
def offload_to_disk(W, model, name, temporary_location : str = "_unsloth_temporary_saved_buffers"):
file_location = os.path.join(temporary_location, model.config._name_or_path)
if not os.path.exists(file_location):
os.makedirs(file_location)
pass
filename = os.path.join(file_location, f"{name}.pt")
W = W.weight if hasattr(W, "weight") else W
torch.save(W, filename, pickle_module = pickle, pickle_protocol = pickle.HIGHEST_PROTOCOL,)
offloaded_W = torch.load(filename, map_location = "cpu", mmap = True)
offloaded_W._offloaded_file_location = filename
return offloaded_W
pass
def offload_input_embeddings(model, temporary_location : str = "_unsloth_temporary_saved_buffers"):
offloaded_W = offload_to_disk(model.get_input_embeddings(), model, "input_embeddings", temporary_location)
new_input_embeddings = torch.nn.Embedding.from_pretrained(offloaded_W)
new_input_embeddings._offloaded_file_location = offloaded_W._offloaded_file_location
model.set_input_embeddings(new_input_embeddings)
return
pass
def offload_output_embeddings(model, temporary_location : str = "_unsloth_temporary_saved_buffers"):
offloaded_W = offload_to_disk(model.get_output_embeddings(), model, "output_embeddings", temporary_location)
new_output_embeddings = torch.nn.Linear(1, 1, bias = None)
del new_output_embeddings.weight
new_output_embeddings.weight = offloaded_W
new_output_embeddings.in_features = offloaded_W.shape[1]
new_output_embeddings.out_features = offloaded_W.shape[0]
new_output_embeddings._offloaded_file_location = offloaded_W._offloaded_file_location
model.set_output_embeddings(new_output_embeddings)
return
pass
# Fixes a weird Torch 2.3 bug which says T4s have bfloat16
def is_bfloat16_supported():
return SUPPORTS_BFLOAT16
pass
# Patches models to add RoPE Scaling
def patch_linear_scaling(
model_name = "gemma2",
rope_module = None,
scaled_rope_module = None,
attention_module = None,
):
assert(rope_module is not None and scaled_rope_module is not None)
assert(attention_module is not None)
rope_name = rope_module.__name__
scaled_rope_name = scaled_rope_module.__name__
model_filepath = f"transformers.models.{model_name}.modeling_{model_name}"
exec_code = \
f"import torch.nn as nn\n"\
f"from typing import Union, Optional, List, Any, Callable, Tuple\n"\
f"from {model_filepath} import logger, "\
f"{model_name.title()}Attention, {model_name.title()}Config"
try:
function = inspect.getsource(attention_module.__init__)
except:
# Most likely already patched!
return None, None
where = function.find("def")
function = function.split("\n")
function = "\n".join(x[where:] for x in function)
init_name = f"{model_name.title()}Attention__init__"
function = function.replace("def __init__", f"def {init_name}")
function = function.replace(
"super().__init__()",
f"super({model_name.title()}Attention, self).__init__()",
)
fix_rope_function = """
if getattr(self.config, "rope_scaling", None) is None:
self.rotary_emb = {rope_function}(
dim = self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = {scaled_rope_function}(
dim = self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
else:
raise ValueError(f"Unknown RoPE scaling type {{scaling_type}}")
pass
"""
fix_rope_function = fix_rope_function.format(
rope_function = rope_module.__name__,
scaled_rope_function = scaled_rope_module.__name__,
)
rotary_emb = re.findall(
"self.rotary_emb = .+?\)", function,
flags = re.DOTALL | re.MULTILINE,
)
if len(rotary_emb) == 0: return None, function
rotary_emb = rotary_emb[0]
function = function.replace(rotary_emb, fix_rope_function, 1)
function = exec_code + "\n\n" + function
return init_name, function
pass
# Patches for Llama-3 LlamaExtendedRotaryEmbedding
def patch_llama_rope_scaling(
model_name = "llama",
rope_module = None,
scaled_rope_module = None,
extended_rope_module = None,
attention_module = None,
longrope_module = None,
):
assert(\
rope_module is not None and \
scaled_rope_module is not None and \
extended_rope_module is not None
)
assert(attention_module is not None)
rope_name = rope_module.__name__
scaled_rope_name = scaled_rope_module.__name__
model_filepath = f"transformers.models.{model_name}.modeling_{model_name}"
exec_code = \
f"import torch.nn as nn\n"\
f"from typing import Union, Optional, List, Any, Callable, Tuple\n"\
f"from {model_filepath} import logger, "\
f"{model_name.title()}Attention, {model_name.title()}Config"
try:
function = inspect.getsource(attention_module.__init__)
except:
# Most likely already patched!
return None, None
where = function.find("def")
function = function.split("\n")
function = "\n".join(x[where:] for x in function)
init_name = f"{model_name.title()}Attention__init__"
function = function.replace("def __init__", f"def {init_name}")
function = function.replace(
"super().__init__()",
f"super({model_name.title()}Attention, self).__init__()",
)
fix_rope_function = """
if getattr(self.config, "rope_scaling", None) is None:
self.rotary_emb = {rope_function}(
dim = self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
scaling_type1 = self.config.rope_scaling.get("type", None)
scaling_type2 = self.config.rope_scaling.get("rope_type", None)
scaling_type = scaling_type1 if scaling_type1 is not None else scaling_type2
scaling_factor = self.config.rope_scaling.get("factor")
if scaling_type == "linear":
self.rotary_emb = {scaled_rope_function}(
dim = self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
elif scaling_type == "llama3":
self.rotary_emb = {extended_rope_function}(
dim = self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
elif scaling_type == "longrope":
self.rotary_emb = {longrope_rope_function}(
dim = self.head_dim,
max_position_embeddings = self.max_position_embeddings,
original_max_position_embeddings = self.config.original_max_position_embeddings,
base = self.rope_theta,
short_factor = self.config.rope_scaling['short_factor'],
long_factor = self.config.rope_scaling['long_factor' ],
)
else:
raise ValueError(f"Unknown RoPE scaling type {{scaling_type}}")
pass
"""
fix_rope_function = fix_rope_function.format(
rope_function = rope_module.__name__,
scaled_rope_function = scaled_rope_module.__name__,
extended_rope_function = extended_rope_module.__name__,
longrope_rope_function = \
(longrope_module if longrope_module is not None else rope_module).__name__
)
rotary_emb = re.findall(
"self.rotary_emb = .+?\)", function,
flags = re.DOTALL | re.MULTILINE,
)
if len(rotary_emb) == 0: return None, function
rotary_emb = rotary_emb[0]
function = function.replace(rotary_emb, fix_rope_function, 1)
function = exec_code + "\n\n" + function
return init_name, function
pass
def check_nvidia():
# Unsloth doesn't work yet on AMD devices - we're working on it!
output = np.array([0,])
try:
output = subprocess.check_output("nvidia-smi --query-gpu=memory.used --format=csv", shell = True)
output = re.findall(rb'([\d]{1,})[\s]{1,}M', output)
output = np.array([int(x.decode('utf-8'))/1024 for x in output])
except:
if not torch.cuda.is_available():
raise RuntimeError("Unsloth: We do not support AMD / Intel machines yet - it is a work in progress!")
return output
pass
PRE_CHECK = check_nvidia()
def create_boolean_mask(n = 4096, sliding_window = 2048):
# Creates a boolean mask for attention
mask = torch.ones(n, n, dtype = torch.bool)
if sliding_window == 0:
return torch.triu(mask, diagonal = 1, out = mask)
pass
torch.triu(mask, diagonal = 0, out = mask)
torch.triu(mask.T, diagonal = -sliding_window, out = mask.T)
mask = mask.T
torch.logical_not(mask, out = mask)
return mask
pass
def test_mask_creation():
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
for n in range(2, 23):
for s in range(1, 23):
correct_mask = AttentionMaskConverter(
is_causal = True,
sliding_window = s,
).to_causal_4d(1, n, n, dtype = torch.float16,).squeeze(0).squeeze(0)
correct_mask = (correct_mask == correct_mask.min())
our_mask = create_boolean_mask(n = n, sliding_window = s)
assert(torch.all(correct_mask == our_mask))
pass
correct_mask = AttentionMaskConverter(
is_causal = True,
sliding_window = None,
).to_causal_4d(1, n, n, dtype = torch.float16,).squeeze(0).squeeze(0)
correct_mask = (correct_mask == correct_mask.min())
our_mask = create_boolean_mask(n = n, sliding_window = 0)
assert(torch.all(correct_mask == our_mask))
pass
pass
def _unsloth_get_batch_samples(self, epoch_iterator, num_batches):
batch_samples = []
num_items_in_batch = None
for _ in range(num_batches):
try:
batch_samples += [next(epoch_iterator)]
except StopIteration:
break
if len(batch_samples) > 0 and "labels" in batch_samples[0]:
try:
num_items_in_batch = sum(
[torch.count_nonzero(x["labels"][..., 1:] != -100) for x in batch_samples]
)
except TypeError:
pass
return batch_samples, num_items_in_batch
pass
def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs):
if "num_items_in_batch" in kwargs:
if "num_items_in_batch" not in inputs:
inputs["num_items_in_batch"] = kwargs["num_items_in_batch"]
pass
pass
return self._old_compute_loss(model, inputs, *args, **kwargs)
pass
def patch_gradient_accumulation_fix(Trainer):
# Fixes gradient accumulation
import inspect
if hasattr(Trainer, "get_batch_samples"):
if \
not inspect.getsource(Trainer.get_batch_samples).strip()\
.endswith("return batch_samples, num_items_in_batch"):
raise NotImplementedError("Unsloth: Please make a Github issue immediately!!")
else:
if Trainer.get_batch_samples.__name__ != "_unsloth_get_batch_samples":
Trainer.get_batch_samples = _unsloth_get_batch_samples
pass
# Also fix passing in num_items_in_batch
if not hasattr(Trainer, "_old_compute_loss"):
Trainer._old_compute_loss = Trainer.compute_loss
Trainer.compute_loss = _unsloth_pre_compute_loss
pass
pass
else:
logger.warning_once(
"Unsloth: We fixed a gradient accumulation bug, "\
"but it seems like you don't have the latest transformers version!\n"\
"Please update transformers, TRL and unsloth via:\n"\
'`pip install --upgrade --no-cache-dir --no-deps unsloth transformers git+https://github.com/huggingface/trl.git`'
)
pass
# Also fix up loss scaling ie negate loss *= self.args.gradient_accumulation_steps
if "num_items_in_batch" not in inspect.signature(Trainer.training_step).parameters: return
function = inspect.getsource(Trainer.training_step)
where = function.find("def")
function = function.split("\n")
function = "\n".join(x[where:] for x in function)
# Import all variables that need importing
import transformers.trainer
items_in_trainer = dir(transformers.trainer)
good_items = []
for item in items_in_trainer:
# TODO: Support Deepspeed
if item.startswith(("deepspeed", "xm", "met", "smp")): continue
if item in function: good_items.append(item)
pass
exec("from transformers.trainer import (" + ", ".join(x for x in good_items) + ")", globals())
# Accelerate does / self.args.gradient_accumulation_steps internally, so if we already
# summed it up and did the division before hand, we have to negate it.
function = function.replace(
"loss *= self.args.gradient_accumulation_steps",
"if num_items_in_batch is not None: loss *= self.args.gradient_accumulation_steps",
)
function = function.replace("def training_step", "def _unsloth_training_step", 1)
exec(function, globals())
Trainer.training_step = _unsloth_training_step
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .llama import *
from ._utils import __version__
try:
from transformers.models.cohere.modeling_cohere import (
CohereAttention,
CohereDecoderLayer,
CohereModel,
CohereForCausalLM,
CohereRotaryEmbedding,
apply_rotary_pos_emb,
repeat_kv,
)
except:
from packaging.version import Version
transformers_version = Version(transformers_version)
if not transformers_version >= Version("4.42"):
raise ImportError(
f"Unsloth: Your transformers version of {transformers_version} does not support Cohere.\n"\
f"The minimum required version is 4.42.3.\n"\
f'Try `pip install --upgrade "transformers>=4.42.3"`\n'\
f"to obtain the latest transformers build, then restart this session."\
)
pass
pass
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask_for_sdpa,
)
# For Pytorch 2.1.1
try:
from transformers.models.cohere.modeling_cohere import (
CohereSdpaAttention,
CohereFlashAttention2,
)
except:
CohereSdpaAttention = CohereAttention
CohereFlashAttention2 = CohereAttention
pass
def fast_layernorm_inference(self, X, out_weight = None):
XX = X.to(torch.float32, copy = True)
XX -= X.mean(-1, keepdim = True)
variance = XX.square().mean(-1, keepdim = True)
variance += self.variance_epsilon
XX *= variance.rsqrt_()
out_weight[:] = self.weight
XX *= out_weight
return XX.to(X.dtype)
pass
# QK norm in Cohere
def CohereAttention_fast_forward(
self,
hidden_states: torch.Tensor,
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
*args, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# Clear inference
if hasattr(self, "paged_attention"):
del self.paged_attention_K
del self.paged_attention_V
del self.paged_attention
del self.temp_QA
del self.temp_KV
del self.RH_Q
del self.attention
del self.q_norm_out_weight
del self.k_norm_out_weight
pass
bsz, q_len, _ = hidden_states.size()
n_heads = self.num_heads
n_groups = self.num_key_value_groups
n_kv_heads = self.num_key_value_heads
head_dim = self.head_dim
assert(n_kv_heads * n_groups == n_heads)
Q, K, V = self.apply_qkv(self, hidden_states)
Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
if self.use_qk_norm:
Q = fast_layernorm_compiled(self.q_norm, Q)
K = fast_layernorm_compiled(self.k_norm, K)
pass
kv_seq_len = K.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
if position_ids is None:
cos = self.rotary_emb.cos_cached
sin = self.rotary_emb.sin_cached
Q, K = fast_rope_embedding(Q, K, cos, sin)
else:
cos, sin = self.rotary_emb(V, seq_len = kv_seq_len)
Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
pass
if past_key_value is not None:
K = torch.cat([past_key_value[0], K], dim = 2)
V = torch.cat([past_key_value[1], V], dim = 2)
pass
past_key_value = (K, V) if use_cache else None
# Attention module
if (not HAS_FLASH_ATTENTION and attention_mask is None):
# Xformers memory efficient attention
# Also has Flash Attention v2 dispatching
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
# Group query attention
if n_groups != 1:
K = K .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
V = V .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
if hidden_states.requires_grad:
K = K.reshape(bsz, kv_seq_len, n_heads, head_dim)
V = V.reshape(bsz, kv_seq_len, n_heads, head_dim)
else:
Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
pass
A = xformers_attention(Q, K, V, attn_bias = causal_mask)
A = A.view(bsz, q_len, n_heads, head_dim)
elif HAS_FLASH_ATTENTION and attention_mask is None:
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
A = flash_attn_func(Q, K, V, causal = True)
else:
# Grouped query attention
if n_groups != 1:
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
K = K.reshape(bsz, n_heads, kv_seq_len, head_dim)
V = V.reshape(bsz, n_heads, kv_seq_len, head_dim)
pass
# Must be contiguous or else results are False!
# https://github.com/pytorch/pytorch/issues/112577
Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous()
# Needs (batch_size, n_heads, seq_len, head_dim)
# is_casual and attention_mask must not be both set!
A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False)
# Go back to (batch_size, seq_len, n_heads, head_dim)
A = A.transpose(1, 2).contiguous()
pass
attn_output = A.reshape(bsz, q_len, n_heads*head_dim)
attn_output = self.apply_o(self, attn_output)
attn_weights = None
return attn_output, attn_weights, past_key_value
pass
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L590
def CohereDecoderLayer_fast_forward(
self,
hidden_states: torch.Tensor,
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
*args, **kwargs,
):
if use_cache and hasattr(self, "_flag_for_generation"): #past_key_value is not None:
out_weight = torch.empty(self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda:0")
# Self Attention
residual = hidden_states
hidden_states = fast_layernorm_inference(self.input_layernorm, hidden_states, out_weight)
hidden_states_attention, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)
# Fully Connected
hidden_states_mlp = fast_swiglu_inference(self.mlp, hidden_states)
residual += hidden_states_attention
residual += hidden_states_mlp
hidden_states = residual
else:
residual = hidden_states
hidden_states = fast_layernorm_compiled(self.input_layernorm, hidden_states)
hidden_states_attention, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)
# Fully Connected
hidden_states_mlp = self.mlp(hidden_states)
hidden_states = residual + hidden_states_attention + hidden_states_mlp
pass
outputs = (hidden_states,)
if output_attentions: outputs += (self_attn_weights,)
if use_cache: outputs += (present_key_value,)
return outputs
pass
from math import sqrt as math_sqrt
KV_CACHE_INCREMENT = 256 # KV Cache update size
torch_nn_functional_softmax = torch.nn.functional.softmax
torch_matmul = torch.matmul
def CohereAttention_fast_forward_inference(
self,
hidden_states: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]],
position_ids,
do_prefill = False,
attention_mask = None,
):
Xn = hidden_states
bsz, _, hd = hidden_states.size()
K1, V1 = past_key_value
dtype = Xn.dtype
n_heads = self.num_heads
n_groups = self.num_key_value_groups
n_kv_heads = self.num_key_value_heads
head_dim = self.head_dim
attention_size = n_heads*head_dim
# assert(n_kv_heads * n_groups == n_heads)
seq_len = K1.shape[-2]
kv_seq_len = seq_len + 1
# Prefill phase
# if not hasattr(self, "paged_attention"):
if do_prefill:
self.paged_attention = torch.empty((KV_CACHE_INCREMENT+seq_len+1, 2, bsz, n_kv_heads, head_dim), dtype = dtype, device = "cuda:0")
self.paged_attention_K = self.paged_attention[:,0]
self.paged_attention_V = self.paged_attention[:,1]
self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = "cuda:0")
self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0")
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
# Mistral Nemo 12b has weird dimensions
if attention_size != self.hidden_size:
self.temp_O = torch.empty((1, bsz, self.hidden_size), dtype = dtype, device = "cuda:0")
else:
self.temp_O = self.temp_QA[1][:,:,:self.hidden_size]
pass
self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0")
self.scalar = 1.0 / math_sqrt(self.head_dim)
self.half_head_dim = head_dim // 2
# Cohere has QK layernorms
if self.use_qk_norm:
self.q_norm_out_weight = torch.empty(self.q_norm.weight.shape, dtype = torch.float32, device = "cuda:0")
self.k_norm_out_weight = torch.empty(self.k_norm.weight.shape, dtype = torch.float32, device = "cuda:0")
else:
self.q_norm_out_weight = None
self.k_norm_out_weight = None
pass
elif kv_seq_len >= self.paged_attention.shape[0]:
self.paged_attention.resize_((self.paged_attention.shape[0]+KV_CACHE_INCREMENT, 2, bsz, n_kv_heads, head_dim))
self.paged_attention_K = self.paged_attention[:,0]
self.paged_attention_V = self.paged_attention[:,1]
self.attention.resize_((bsz, n_heads, 1, self.attention.shape[-1]+KV_CACHE_INCREMENT))
pass
Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
if self.use_qk_norm:
Q = fast_layernorm_inference(self.q_norm, Q, self.q_norm_out_weight)
K = fast_layernorm_inference(self.k_norm, K, self.k_norm_out_weight)
pass
# cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len)
# Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids)
cos, sin = self.rotary_emb.get_cached(kv_seq_len)
cos = cos[position_ids].unsqueeze(1)
sin = sin[position_ids].unsqueeze(1)
h = self.half_head_dim
RH_Q = self.RH_Q
RH_Q[:,:,:,:h] = Qn[:,:,:,h:]
RH_Q[:,:,:,h:] = Qn[:,:,:,:h]
torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h])
Qn *= cos
Qn.addcmul_(RH_Q, sin)
RH_K = RH_Q[:,:n_kv_heads,:,:] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
RH_K[:,:,:,:h] = Kn[:,:,:,h:]
RH_K[:,:,:,h:] = Kn[:,:,:,:h]
torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])
Kn *= cos
Kn.addcmul_(RH_K, sin)
# New KV cache
# Kn = torch.cat([K1, Kn], dim = 2)
# Vn = torch.cat([V1, Vn], dim = 2)
self.paged_attention_K[seq_len] = Kn.permute(2, 0, 1, 3)
self.paged_attention_V[seq_len] = Vn.permute(2, 0, 1, 3)
Kn = self.paged_attention_K[:kv_seq_len].permute(1, 2, 0, 3)
Vn = self.paged_attention_V[:kv_seq_len].permute(1, 2, 0, 3)
# Handle sliding windows
sliding_window = getattr(self.config, "sliding_window", None)
if sliding_window is not None and kv_seq_len > sliding_window:
# From https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L193
slicing_tokens = 1 - sliding_window
Knn = Kn[:, :, slicing_tokens:, :]#.contiguous()
Vnn = Vn[:, :, slicing_tokens:, :]#.contiguous()
else:
Knn, Vnn = Kn, Vn
pass
# Grouped query attention
_, _, cached_len, _ = Knn.shape
if n_groups != 1:
Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)
Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim)
pass
# else:
# Knn, Vnn = Knn, Vnn
# pass
# Attention
if bsz == 1:
Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
# It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows
A = torch_matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len])
# if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype)
A = torch_matmul(A, Vnn, out = Qn)
else:
A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False)
pass
A = A.transpose(1, 2)
A = A.reshape(bsz, 1, attention_size)
A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
return A, (Kn, Vn)
pass
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825
# @torch.inference_mode
def CohereModel_fast_forward_inference(
self,
input_ids,
past_key_values,
position_ids,
attention_mask = None,
):
out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda:0")
input_ids = input_ids[:,:self.max_seq_length]
hidden_states = self.model.embed_tokens(input_ids)
hidden_states = hidden_states.to(self.config.torch_dtype)
bsz, q_len, hd = hidden_states.shape
seq_len = past_key_values[0][0].shape[-2]
if bsz != 1:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(bsz, q_len),
hidden_states,
seq_len,
sliding_window = getattr(self.config, "sliding_window", None),
)
else:
attention_mask = None
pass
next_decoder_cache = []
for idx, decoder_layer in enumerate(self.model.layers):
residual = hidden_states
hidden_states = fast_layernorm_inference(decoder_layer.input_layernorm, hidden_states, out_weight)
hidden_states_attention, present_key_value = CohereAttention_fast_forward_inference(
decoder_layer.self_attn,
hidden_states = hidden_states,
past_key_value = past_key_values[idx],
position_ids = position_ids,
attention_mask = attention_mask,
do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
)
hidden_states_mlp = fast_swiglu_inference(self.mlp, hidden_states)
residual += hidden_states_attention
residual += hidden_states_mlp
hidden_states = residual
next_decoder_cache.append(present_key_value)
pass
hidden_states = fast_layernorm_inference(self.model.norm, hidden_states, out_weight)
return BaseModelOutputWithPast(
last_hidden_state = hidden_states,
past_key_values = next_decoder_cache,
hidden_states = [],
attentions = [],
)
pass
class FastCohereModel(FastLlamaModel):
@staticmethod
def pre_patch():
init_name, function = patch_linear_scaling(
model_name = "cohere",
rope_module = LlamaRotaryEmbedding,
scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
attention_module = CohereAttention,
)
if init_name is not None:
exec(function, globals())
CohereAttention.__init__ = eval(init_name)
pass
CohereAttention .forward = CohereAttention_fast_forward
CohereSdpaAttention .forward = CohereAttention_fast_forward
CohereFlashAttention2.forward = CohereAttention_fast_forward
CohereDecoderLayer .forward = CohereDecoderLayer_fast_forward
CohereModel .forward = LlamaModel_fast_forward
CohereForCausalLM .forward = CausalLM_fast_forward(CohereModel_fast_forward_inference)
PeftModelForCausalLM .forward = PeftModelForCausalLM_fast_forward
fix_prepare_inputs_for_generation(CohereForCausalLM)
import transformers.models.cohere.modeling_cohere
transformers.models.cohere.modeling_cohere.CohereRotaryEmbedding = LlamaRotaryEmbedding
return
pass
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = [
"PatchDPOTrainer",
]
try:
from transformers.utils.notebook import (
IntervalStrategy,
NotebookTrainingTracker,
NotebookProgressCallback,
)
HAS_NOTEBOOK = True
except:
HAS_NOTEBOOK = False
pass
import torch
from ._utils import torch_compile_options
import inspect
import torch.nn as nn
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
DPOTrainer_metrics = [
"rewards/chosen",
"rewards/rejected",
"rewards/accuracies",
"rewards/margins",
"logps/rejected",
"logps/chosen",
"logits/rejected",
"logits/chosen",
]
set_DPOTrainer_metrics = frozenset(DPOTrainer_metrics)
def NotebookProgressCallback_on_train_begin(self, args, state, control, **kwargs):
self.first_column = "Epoch" if args.eval_strategy == IntervalStrategy.EPOCH else "Step"
self.training_loss = 0
self.last_log = 0
column_names = [self.first_column] + ["Training Loss"]
if args.eval_strategy != IntervalStrategy.NO:
column_names.append("Validation Loss")
column_names += [x.replace("/", " / ") for x in DPOTrainer_metrics]
self.training_tracker = NotebookTrainingTracker(state.max_steps, column_names)
pass
def NotebookProgressCallback_on_log(self, args, state, control, logs=None, **kwargs):
# Only for when there is no evaluation
if args.eval_strategy == IntervalStrategy.NO and "loss" in logs:
values = {"Training Loss": logs["loss"]}
for metric in DPOTrainer_metrics:
values[metric.replace("/", " / ")] = logs[metric]
pass
# First column is necessarily Step since we're not in epoch eval strategy
values["Step"] = state.global_step
self.training_tracker.write_line(values)
pass
pass
def NotebookTrainingTracker_write_line(self, values):
"""
Write the values in the inner table.
Args:
values (`Dict[str, float]`): The values to display.
"""
if self.inner_table is None:
self.inner_table = [list(values.keys()), list(values.values())]
else:
columns = self.inner_table[0]
new_values = {}
for key, value in values.items():
lowered = key.lower()
if lowered in set_DPOTrainer_metrics:
new_values[lowered.replace("/", " / ")] = value
else:
new_values[key] = value
pass
values = new_values
self.inner_table[0] = columns
if len(self.inner_table) > 1:
last_values = self.inner_table[-1]
first_column = self.inner_table[0][0]
if last_values[0] != values[first_column]:
# write new line
self.inner_table.append([values[c] if c in values else "No Log" for c in columns])
else:
# update last line
new_values = values
for c in columns:
if c not in new_values.keys():
new_values[c] = last_values[columns.index(c)]
self.inner_table[-1] = [new_values[c] for c in columns]
else:
# Edit for evaluation purposes
self.inner_table.append([values[c] if c in values else 0 for c in columns])
pass
pass
pass
def PatchDPOTrainer():
if HAS_NOTEBOOK:
from transformers.trainer import is_in_notebook
if is_in_notebook():
# Patch DPO notebook printing
NotebookTrainingTracker.write_line = NotebookTrainingTracker_write_line
from transformers.trainer import DEFAULT_PROGRESS_CALLBACK
DEFAULT_PROGRESS_CALLBACK.on_train_begin = NotebookProgressCallback_on_train_begin
DEFAULT_PROGRESS_CALLBACK.on_log = NotebookProgressCallback_on_log
pass
pass
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .llama import *
from ._utils import __version__
import math
try:
from transformers.models.gemma.modeling_gemma import (
GemmaAttention,
GemmaDecoderLayer,
GemmaModel,
GemmaForCausalLM,
GemmaRotaryEmbedding,
apply_rotary_pos_emb,
repeat_kv,
)
except:
from packaging.version import Version
transformers_version = Version(transformers_version)
if not transformers_version >= Version("4.38"):
raise ImportError(
f"Unsloth: Your transformers version of {transformers_version} does not support Gemma.\n"\
f"The minimum required version is 4.38.\n"\
f'Try `pip install --upgrade "transformers>=4.38"`\n'\
f"to obtain the latest transformers build, then restart this session."\
)
pass
pass
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask_for_sdpa,
)
# For Pytorch 2.1.1
try:
from transformers.models.gemma.modeling_gemma import (
GemmaSdpaAttention,
GemmaFlashAttention2,
)
except:
GemmaSdpaAttention = GemmaAttention
GemmaFlashAttention2 = GemmaAttention
pass
torch_nn_functional_gelu = torch.nn.functional.gelu
def fast_geglu_inference(self, X):
# gate = self.gate_proj(X)
# up = self.up_proj(X)
bsz, _, hd = X.shape
# mlp_size = self.config.intermediate_size
# temp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0")
gate = fast_linear_forward(self.gate_proj, X)#, out = temp[0])
up = fast_linear_forward(self. up_proj, X)#, out = temp[1])
gate = torch_nn_functional_gelu(gate, approximate = "tanh")
gate *= up
# X = self.down_proj(gate)
down = fast_linear_forward(self.down_proj, gate, out = up[:,:,:hd])
return down
pass
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L590
def GemmaDecoderLayer_fast_forward(
self,
hidden_states: torch.Tensor,
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
*args, **kwargs,
):
if use_cache and hasattr(self, "_flag_for_generation"): #past_key_value is not None:
out_weight = torch.empty(self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda:0")
# Self Attention
residual = hidden_states
hidden_states = fast_rms_layernorm_inference_gemma(self.input_layernorm, hidden_states, out_weight)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)
hidden_states += residual
# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm_inference_gemma(self.post_attention_layernorm, hidden_states, out_weight)
hidden_states = fast_geglu_inference(self.mlp, hidden_states)
hidden_states += residual
else:
residual = hidden_states
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states, gemma = True)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states, gemma = True)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
pass
outputs = (hidden_states,)
if output_attentions: outputs += (self_attn_weights,)
if use_cache: outputs += (present_key_value,)
return outputs
pass
from math import sqrt as math_sqrt
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825
# @torch.inference_mode
def GemmaModel_fast_forward_inference(
self,
input_ids,
past_key_values,
position_ids,
attention_mask = None,
):
out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda:0")
input_ids = input_ids[:,:self.max_seq_length]
hidden_states = self.model.embed_tokens(input_ids)
hidden_states = hidden_states.to(self.config.torch_dtype)
# 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32
# 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32
hidden_states *= torch.tensor(math_sqrt(self.config.hidden_size), dtype = hidden_states.dtype)
bsz, q_len, hd = hidden_states.shape
seq_len = past_key_values[0][0].shape[-2]
if bsz != 1:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(bsz, q_len),
hidden_states,
seq_len,
)
pass
next_decoder_cache = []
for idx, decoder_layer in enumerate(self.model.layers):
residual = hidden_states
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.input_layernorm, hidden_states, out_weight)
hidden_states, present_key_value = LlamaAttention_fast_forward_inference(
decoder_layer.self_attn,
hidden_states = hidden_states,
past_key_value = past_key_values[idx],
position_ids = position_ids,
attention_mask = attention_mask,
do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
)
hidden_states += residual
residual = hidden_states
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_attention_layernorm, hidden_states, out_weight)
hidden_states = fast_geglu_inference(decoder_layer.mlp, hidden_states)
hidden_states += residual
next_decoder_cache.append(present_key_value)
pass
hidden_states = fast_rms_layernorm_inference_gemma(self.model.norm, hidden_states, out_weight)
return BaseModelOutputWithPast(
last_hidden_state = hidden_states,
past_key_values = next_decoder_cache,
hidden_states = [],
attentions = [],
)
pass
# Follows line by line https://github.com/google-deepmind/gemma/blob/main/gemma/positional_embeddings.py#L45
# Formulates cos and sin differently from Llama!
class GemmaFixedRotaryEmbedding(torch.nn.Module):
# Fixes https://github.com/huggingface/transformers/pull/28837
# https://github.com/microsoft/DeepSpeed/issues/4932
# The precision of RoPE buffers is not correct, so we cast to int64.
def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None,
config = None, # [TODO] Hack to pass in config - need to remove later
):
super().__init__()
if config is not None: return # [TODO] Hack to pass in config - need to remove later
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
# Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this
self.current_rope_size = min(4 * 8192, self.max_position_embeddings)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype())
pass
def _set_cos_sin_cache(self, seq_len, device, dtype):
# Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
# in FP32. They are applied (multiplied) in FP32 as well.
self.current_rope_size = seq_len
# The difference is we do division explicity instead of t * (1/x) ie we do t/x.
freq_exponents = (2.0 / self.dim) * (
torch.arange(self.dim // 2, dtype = torch.int64, device = "cpu").float()
)
timescale = self.base**freq_exponents
positions = torch.arange(self.current_rope_size, device = "cpu", dtype = torch.int64).float()
radians_new = positions[..., None] / timescale[None, None, :]
radians_new = radians_new.squeeze(0)
emb = torch.cat((radians_new, radians_new), dim = -1)
# We must do RoPE in float32!
cos = emb.cos().to(device = "cuda:0", non_blocking = True)#, dtype = dtype)
sin = emb.sin().to(device = "cuda:0", non_blocking = True)#, dtype = dtype)
self.register_buffer("cos_cached", cos, persistent = False)
self.register_buffer("sin_cached", sin, persistent = False)
pass
def forward(self, x, position_ids=None, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.current_rope_size:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
pass
def get_cached(self, seq_len = None):
return self.cos_cached, self.sin_cached
pass
def extend_rope_embedding(self, x, seq_len):
if seq_len <= self.current_rope_size: return
# Iteratively grow by increments of 8192
self.current_rope_size = math.ceil(seq_len / 8192) * 8192
self._set_cos_sin_cache(self.current_rope_size, device = "cuda:0", dtype = x.dtype)
pass
pass
class GemmaFixedLinearScalingRotaryEmbedding(GemmaFixedRotaryEmbedding):
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
# Fixes https://github.com/huggingface/transformers/pull/28837
# https://github.com/microsoft/DeepSpeed/issues/4932
# The precision of RoPE buffers is not correct, so we cast to int64.
def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0,
config = None, # [TODO] Hack to pass in config - need to remove later
):
self.scaling_factor = scaling_factor
super().__init__(dim = dim, max_position_embeddings = max_position_embeddings, base = base, device = device, config = config)
pass
def _set_cos_sin_cache(self, seq_len, device, dtype):
# Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
# in FP32. They are applied (multiplied) in FP32 as well.
self.current_rope_size = seq_len
# The difference is we do division explicity instead of t * (1/x) ie we do t/x.
freq_exponents = (2.0 / self.dim) * (
torch.arange(self.dim // 2, dtype = torch.int64, device = "cpu").float()
)
timescale = self.base**freq_exponents
positions = torch.arange(self.current_rope_size, device = "cpu", dtype = torch.int64).float()
positions = positions / self.scaling_factor
radians_new = positions[..., None] / timescale[None, None, :]
radians_new = radians_new.squeeze(0)
emb = torch.cat((radians_new, radians_new), dim = -1)
# We must do RoPE in float32!
cos = emb.cos().to(device = "cuda:0", non_blocking = True)#, dtype = dtype)
sin = emb.sin().to(device = "cuda:0", non_blocking = True)#, dtype = dtype)
self.register_buffer("cos_cached", cos, persistent = False)
self.register_buffer("sin_cached", sin, persistent = False)
pass
pass
class FastGemmaModel(FastLlamaModel):
@staticmethod
def pre_patch():
init_name, function = patch_linear_scaling(
model_name = "gemma",
rope_module = GemmaFixedRotaryEmbedding,
scaled_rope_module = GemmaFixedLinearScalingRotaryEmbedding,
attention_module = GemmaAttention,
)
if init_name is not None:
exec(function, globals())
GemmaAttention.__init__ = eval(init_name)
pass
GemmaAttention .forward = LlamaAttention_fast_forward
GemmaSdpaAttention .forward = LlamaAttention_fast_forward
GemmaFlashAttention2.forward = LlamaAttention_fast_forward
GemmaDecoderLayer .forward = GemmaDecoderLayer_fast_forward
GemmaModel .forward = LlamaModel_fast_forward
GemmaForCausalLM .forward = CausalLM_fast_forward(GemmaModel_fast_forward_inference)
PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward
fix_prepare_inputs_for_generation(GemmaForCausalLM)
# Solves https://github.com/unslothai/unsloth/issues/168
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
# Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
# https://github.com/huggingface/transformers/pull/27931
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
import transformers.models.gemma.modeling_gemma
transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding = GemmaFixedRotaryEmbedding
return
pass
@staticmethod
def post_patch(model):
# Patch model for Gemma
layers = model.model.layers
# Torch.compile fails on embedding matrix??
# Workaround randomnly fixes it for torch versions < 2.2
model.model.embed_tokens = torch.nn.Embedding.from_pretrained(model.model.embed_tokens.weight)
model.config.update({"unsloth_version" : __version__})
# We also do this for the lm_head
lm_head = torch.nn.Linear(1, 1, bias = None)
del lm_head.weight
lm_head.weight = model.lm_head.weight
lm_head.in_features = lm_head.weight.shape[1]
lm_head.out_features = lm_head.weight.shape[0]
model.lm_head = lm_head
# Gemma has tied weights! This means lm_head == embed_tokens
if model.model.embed_tokens.weight.data_ptr() != model.lm_head.weight.data_ptr():
lm_head = torch.nn.Linear(1, 1, bias = None)
del lm_head.weight
lm_head.weight = model.model.embed_tokens.weight
lm_head.in_features = lm_head.weight.shape[1]
lm_head.out_features = lm_head.weight.shape[0]
model.lm_head = lm_head
pass
# Also patch all dtypes - BnB seems to not allocate the correct type?
# BnB default dtype seems to be float16!
correct_dtype = lm_head.weight.dtype
for name, module in model.named_modules():
if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)):
weight = module.weight
quant_state = weight.quant_state
if type(quant_state) is list:
# BnB seems to have float16 as default!
module.weight.quant_state[2] = correct_dtype # Cast to correct dtype
else:
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
quant_state.dtype = correct_dtype
pass
pass
# Downcast RoPE embedding to correct data type
# RoPE must be done in float32 for Gemma
# if (name.endswith("rotary_emb") or hasattr(module, "cos_cached")) \
# and (module.cos_cached.dtype != correct_dtype):
# module.cos_cached = module.cos_cached.to(correct_dtype)
# module.sin_cached = module.sin_cached.to(correct_dtype)
# pass
# pass
pass
# Add 1 to weight
# return output * (1 + self.weight)
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py#L89
from transformers.models.gemma.modeling_gemma import GemmaRMSNorm
# Freeze all parameters except LoRA
# We do this first since += 1 seems to not be liked by requires_grad = True
for name, param in model.named_parameters():
if ".lora_A." in name or ".lora_B." in name:
param.requires_grad_(True)
else:
param.requires_grad_(False)
pass
# Patch RMS Layernorm
for name, module in model.named_modules():
if isinstance(module, GemmaRMSNorm):
# Must be in float32
# https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L36
# module = module.to(torch.float32)
# Leave + 1 to Triton kernel itself
# module.weight += 1.0 # return output * (1 + self.weight)
if not hasattr(module, "variance_epsilon"):
module.variance_epsilon = module.eps # Gemma doesn't use variance_epsilon
pass
# Clear deleted GPU items
import gc
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
return model
pass
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .llama import *
from ._utils import __version__
from .gemma import (
GemmaFixedRotaryEmbedding,
GemmaFixedLinearScalingRotaryEmbedding,
fast_geglu_inference,
)
try:
from transformers.models.gemma2.modeling_gemma2 import (
Gemma2Attention,
Gemma2DecoderLayer,
Gemma2Model,
Gemma2ForCausalLM,
Gemma2RotaryEmbedding,
apply_rotary_pos_emb,
repeat_kv,
)
except:
from packaging.version import Version
transformers_version = Version(transformers_version)
if not transformers_version >= Version("4.42"):
raise ImportError(
f"Unsloth: Your transformers version of {transformers_version} does not support Gemma2.\n"\
f"The minimum required version is 4.42.3.\n"\
f'Try `pip install --upgrade "transformers>=4.42.3"`\n'\
f"to obtain the latest transformers build, then restart this session."\
)
pass
pass
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask_for_sdpa,
)
# For Pytorch 2.1.1
try:
from transformers.models.gemma2.modeling_gemma2 import (
Gemma2SdpaAttention,
Gemma2FlashAttention2,
)
except:
Gemma2SdpaAttention = Gemma2Attention
Gemma2FlashAttention2 = Gemma2Attention
pass
if HAS_FLASH_ATTENTION_SOFTCAPPING:
from flash_attn import flash_attn_func
# [TODO] We must randomnly use torch.compile?
# I checked the gradients and formulas and I'm sure it's correct.
# I'm stumped :(
@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
def fast_rms_layernorm_gemma2_compiled(layernorm, X, gemma = True):
old_dtype = X.dtype
X = X.float()
X = X * torch.rsqrt(X.square().mean(-1, keepdim = True) + layernorm.eps) * \
(1.0 + layernorm.weight.float())
return X.to(old_dtype)
pass
# Logit softcapping
def Gemma2Attention_fast_forward(
self,
hidden_states: torch.Tensor,
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
*args, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# Clear inference
if hasattr(self, "paged_attention"):
del self.paged_attention_K
del self.paged_attention_V
del self.paged_attention
del self.temp_QA
del self.temp_KV
del self.RH_Q
del self.attention
pass
bsz, q_len, _ = hidden_states.size()
n_heads = self.num_heads
n_groups = self.num_key_value_groups
n_kv_heads = self.num_key_value_heads
head_dim = self.head_dim
assert(n_kv_heads * n_groups == n_heads)
Q, K, V = self.apply_qkv(self, hidden_states)
Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
kv_seq_len = K.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
if position_ids is None:
cos = self.rotary_emb.cos_cached
sin = self.rotary_emb.sin_cached
Q, K = fast_rope_embedding(Q, K, cos, sin)
else:
cos, sin = self.rotary_emb(V, seq_len = kv_seq_len)
Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
pass
if past_key_value is not None:
K = torch.cat([past_key_value[0], K], dim = 2)
V = torch.cat([past_key_value[1], V], dim = 2)
pass
past_key_value = (K, V) if use_cache else None
# Only enable if the attention_mask is True
has_sliding_window = type(causal_mask) is bool and causal_mask is True
if HAS_FLASH_ATTENTION_SOFTCAPPING and attention_mask is None:
window = (-1, -1)
if has_sliding_window:
sw = getattr(self.config, "sliding_window", None)
sw = kv_seq_len if (sw is None or sw == "null") else sw
window = (-1, -1) if (kv_seq_len <= sw) else (sw, sw)
pass
# FA uses 1 / sqrt for softmax_scale!
if not hasattr(self, "_flash_attention_softmax_scale"):
self._flash_attention_softmax_scale = 1.0 / (self.config.query_pre_attn_scalar**0.5)
pass
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
A = flash_attn_func(
Q, K, V,
causal = True,
softcap = self.config.attn_logit_softcapping,
softmax_scale = self._flash_attention_softmax_scale,
window_size = window,
)
A = A.reshape(bsz, q_len, n_heads*head_dim)
else:
fx = slow_inference_attention_softcapping \
if "_flag_for_generation" in kwargs else \
slow_attention_softcapping
A = fx(Q, K, V, causal_mask, self, bsz, kv_seq_len)
pass
A = self.apply_o(self, A)
return A, None, past_key_value
pass
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L590
def Gemma2DecoderLayer_fast_forward(
self,
hidden_states: torch.Tensor,
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
*args, **kwargs,
):
if use_cache and hasattr(self, "_flag_for_generation"): #past_key_value is not None:
out_weight = torch.empty(self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda:0")
# Self Attention
residual = hidden_states
hidden_states = fast_rms_layernorm_inference_gemma(self.input_layernorm, hidden_states, out_weight)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
_flag_for_generation=True,
)
hidden_states = fast_rms_layernorm_inference_gemma(self.post_attention_layernorm, hidden_states, out_weight)
hidden_states += residual
# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm_inference_gemma(self. pre_feedforward_layernorm, hidden_states, out_weight)
hidden_states = fast_geglu_inference(self.mlp, hidden_states)
hidden_states = fast_rms_layernorm_inference_gemma(self.post_feedforward_layernorm, hidden_states, out_weight)
hidden_states += residual
else:
residual = hidden_states
hidden_states = fast_rms_layernorm_gemma2_compiled(self.input_layernorm, hidden_states, gemma = True)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)
hidden_states = fast_rms_layernorm_gemma2_compiled(self.post_attention_layernorm, hidden_states, gemma = True)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm_gemma2_compiled(self. pre_feedforward_layernorm, hidden_states, gemma = True)
hidden_states = self.mlp(hidden_states)
hidden_states = fast_rms_layernorm_gemma2_compiled(self.post_feedforward_layernorm, hidden_states, gemma = True)
hidden_states = residual + hidden_states
pass
outputs = (hidden_states,)
if output_attentions: outputs += (self_attn_weights,)
if use_cache: outputs += (present_key_value,)
return outputs
pass
from math import sqrt as math_sqrt
KV_CACHE_INCREMENT = 256 # KV Cache update size
torch_nn_functional_softmax = torch.nn.functional.softmax
torch_matmul = torch.matmul
torch_tanh = torch.tanh
def Gemma2Attention_fast_forward_inference(
self,
hidden_states: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]],
position_ids,
do_prefill = False,
attention_mask = None,
use_sliding_window = False,
):
Xn = hidden_states
bsz, _, hd = hidden_states.size()
K1, V1 = past_key_value
dtype = Xn.dtype
n_heads = self.num_heads
n_groups = self.num_key_value_groups
n_kv_heads = self.num_key_value_heads
head_dim = self.head_dim
attention_size = n_heads*head_dim
# assert(n_kv_heads * n_groups == n_heads)
seq_len = K1.shape[-2]
kv_seq_len = seq_len + 1
# Prefill phase
# if not hasattr(self, "paged_attention"):
if do_prefill:
self.paged_attention = torch.empty((KV_CACHE_INCREMENT+seq_len+1, 2, bsz, n_kv_heads, head_dim), dtype = dtype, device = "cuda:0")
self.paged_attention_K = self.paged_attention[:,0]
self.paged_attention_V = self.paged_attention[:,1]
self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = "cuda:0")
self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0")
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
# Only for Gemma2
self.temp_O = torch.empty((1, bsz, self.hidden_size), dtype = dtype, device = "cuda:0")
self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0")
# See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
# Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below
# We default to using the config file itself
# s = self.config.hidden_size // self.config.num_attention_heads
self.scalar = 1.0 / math_sqrt(self.config.query_pre_attn_scalar)
# self.scalar = 1.0 / math_sqrt(self.config.hidden_size // self.config.num_attention_heads)
self.half_head_dim = head_dim // 2
self. t = self.config.attn_logit_softcapping
self.reciprocal_t = 1.0 / self.config.attn_logit_softcapping
elif kv_seq_len >= self.paged_attention.shape[0]:
self.paged_attention.resize_((self.paged_attention.shape[0]+KV_CACHE_INCREMENT, 2, bsz, n_kv_heads, head_dim))
self.paged_attention_K = self.paged_attention[:,0]
self.paged_attention_V = self.paged_attention[:,1]
self.attention.resize_((bsz, n_heads, 1, self.attention.shape[-1]+KV_CACHE_INCREMENT))
pass
Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
# cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len)
# Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids)
cos = self.rotary_emb.cos_cached[position_ids].unsqueeze(1)
sin = self.rotary_emb.sin_cached[position_ids].unsqueeze(1)
h = self.half_head_dim
RH_Q = self.RH_Q
RH_Q[:,:,:,:h] = Qn[:,:,:,h:]
RH_Q[:,:,:,h:] = Qn[:,:,:,:h]
torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h])
Qn *= cos
Qn.addcmul_(RH_Q, sin)
RH_K = RH_Q[:,:n_kv_heads,:,:] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
RH_K[:,:,:,:h] = Kn[:,:,:,h:]
RH_K[:,:,:,h:] = Kn[:,:,:,:h]
torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])
Kn *= cos
Kn.addcmul_(RH_K, sin)
# New KV cache
# Kn = torch.cat([K1, Kn], dim = 2)
# Vn = torch.cat([V1, Vn], dim = 2)
self.paged_attention_K[seq_len] = Kn.permute(2, 0, 1, 3)
self.paged_attention_V[seq_len] = Vn.permute(2, 0, 1, 3)
Kn = self.paged_attention_K[:kv_seq_len].permute(1, 2, 0, 3)
Vn = self.paged_attention_V[:kv_seq_len].permute(1, 2, 0, 3)
# Handle sliding windows
sliding_window = self.config.sliding_window
if use_sliding_window and kv_seq_len > sliding_window:
# From https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L193
slicing_tokens = 1 - sliding_window
Knn = Kn[:, :, slicing_tokens:, :]#.contiguous()
Vnn = Vn[:, :, slicing_tokens:, :]#.contiguous()
else:
Knn, Vnn = Kn, Vn
pass
# Grouped query attention
_, _, cached_len, _ = Knn.shape
if n_groups != 1:
Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)
Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim)
pass
# else:
# Knn, Vnn = Knn, Vnn
# pass
# Attention
# if bsz == 1:
Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
# It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows
A = torch_matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len])
# if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
A *= self.reciprocal_t; torch_tanh(A, out = A); A *= self.t; # Logit softcapping
A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype)
A = torch_matmul(A, Vnn, out = Qn)
# else:
# A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False)
# pass
A = A.transpose(1, 2)
A = A.reshape(bsz, 1, attention_size)
A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
return A, (Kn, Vn)
pass
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825
# @torch.inference_mode
def Gemma2Model_fast_forward_inference(
self,
input_ids,
past_key_values,
position_ids,
attention_mask = None,
):
out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda:0")
input_ids = input_ids[:,:self.max_seq_length]
hidden_states = self.model.embed_tokens(input_ids)
hidden_states = hidden_states.to(self.config.torch_dtype)
# 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32
# 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32
hidden_states *= torch.tensor(math_sqrt(self.config.hidden_size), dtype = hidden_states.dtype)
bsz, q_len, hd = hidden_states.shape
seq_len = past_key_values[0][0].shape[-2]
if bsz != 1:
if HAS_FLASH_ATTENTION_SOFTCAPPING:
SWA = True
GA = False
else:
SWA = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(bsz, q_len),
hidden_states,
seq_len,
sliding_window = self.config.sliding_window,
)
GA = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(bsz, q_len),
hidden_states,
seq_len,
)
pass
else:
SWA = attention_mask
GA = attention_mask
pass
next_decoder_cache = []
for idx, decoder_layer in enumerate(self.model.layers):
use_sliding_window = idx % 2 == 0
residual = hidden_states
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.input_layernorm, hidden_states, out_weight)
hidden_states, present_key_value = Gemma2Attention_fast_forward_inference(
decoder_layer.self_attn,
hidden_states = hidden_states,
past_key_value = past_key_values[idx],
position_ids = position_ids,
attention_mask = SWA if use_sliding_window else GA,
do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
use_sliding_window = use_sliding_window,
)
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_attention_layernorm, hidden_states, out_weight)
hidden_states += residual
residual = hidden_states
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer. pre_feedforward_layernorm, hidden_states, out_weight)
hidden_states = fast_geglu_inference(decoder_layer.mlp, hidden_states)
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_feedforward_layernorm, hidden_states, out_weight)
hidden_states += residual
next_decoder_cache.append(present_key_value)
pass
hidden_states = fast_rms_layernorm_inference_gemma(self.model.norm, hidden_states, out_weight)
return BaseModelOutputWithPast(
last_hidden_state = hidden_states,
past_key_values = next_decoder_cache,
hidden_states = [],
attentions = [],
)
pass
class FastGemma2Model(FastLlamaModel):
@staticmethod
def pre_patch():
init_name, function = patch_linear_scaling(
model_name = "gemma2",
rope_module = GemmaFixedRotaryEmbedding,
scaled_rope_module = GemmaFixedLinearScalingRotaryEmbedding,
attention_module = Gemma2Attention,
)
if init_name is not None:
exec(function, globals())
Gemma2Attention.__init__ = eval(init_name)
pass
Gemma2Attention .forward = Gemma2Attention_fast_forward
Gemma2SdpaAttention .forward = Gemma2Attention_fast_forward
Gemma2FlashAttention2.forward = Gemma2Attention_fast_forward
Gemma2DecoderLayer .forward = Gemma2DecoderLayer_fast_forward
Gemma2Model .forward = LlamaModel_fast_forward
Gemma2ForCausalLM .forward = CausalLM_fast_forward(Gemma2Model_fast_forward_inference)
PeftModelForCausalLM .forward = PeftModelForCausalLM_fast_forward
fix_prepare_inputs_for_generation(Gemma2ForCausalLM)
# Solves https://github.com/unslothai/unsloth/issues/168
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
# Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
# https://github.com/huggingface/transformers/pull/27931
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
import transformers.models.gemma2.modeling_gemma2
transformers.models.gemma2.modeling_gemma2.Gemma2RotaryEmbedding = GemmaFixedRotaryEmbedding
return
pass
@staticmethod
def post_patch(model):
# Patch model for Gemma
layers = model.model.layers
# Torch.compile fails on embedding matrix??
# Workaround randomnly fixes it for torch versions < 2.2
model.model.embed_tokens = torch.nn.Embedding.from_pretrained(model.model.embed_tokens.weight)
model.config.update({"unsloth_version" : __version__})
# We also do this for the lm_head
lm_head = torch.nn.Linear(1, 1, bias = None)
del lm_head.weight
lm_head.weight = model.lm_head.weight
lm_head.in_features = lm_head.weight.shape[1]
lm_head.out_features = lm_head.weight.shape[0]
model.lm_head = lm_head
# Gemma has tied weights! This means lm_head == embed_tokens
if model.model.embed_tokens.weight.data_ptr() != model.lm_head.weight.data_ptr():
lm_head = torch.nn.Linear(1, 1, bias = None)
del lm_head.weight
lm_head.weight = model.model.embed_tokens.weight
lm_head.in_features = lm_head.weight.shape[1]
lm_head.out_features = lm_head.weight.shape[0]
model.lm_head = lm_head
pass
# Also patch all dtypes - BnB seems to not allocate the correct type?
# BnB default dtype seems to be float16!
correct_dtype = lm_head.weight.dtype
for name, module in model.named_modules():
if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)):
weight = module.weight
quant_state = weight.quant_state
if type(quant_state) is list:
# BnB seems to have float16 as default!
module.weight.quant_state[2] = correct_dtype # Cast to correct dtype
else:
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
quant_state.dtype = correct_dtype
pass
pass
# Downcast RoPE embedding to correct data type
# RoPE must be done in float32 for Gemma
# if (name.endswith("rotary_emb") or hasattr(module, "cos_cached")) \
# and (module.cos_cached.dtype != correct_dtype):
# module.cos_cached = module.cos_cached.to(correct_dtype)
# module.sin_cached = module.sin_cached.to(correct_dtype)
# pass
# pass
pass
# Add 1 to weight
# return output * (1 + self.weight)
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py#L89
from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm
# Freeze all parameters except LoRA
# We do this first since += 1 seems to not be liked by requires_grad = True
for name, param in model.named_parameters():
if ".lora_A." in name or ".lora_B." in name:
param.requires_grad_(True)
else:
param.requires_grad_(False)
pass
# Patch RMS Layernorm
for name, module in model.named_modules():
if isinstance(module, Gemma2RMSNorm):
# Must be in float32
# https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L36
# module = module.to(torch.float32)
# Leave + 1 to Triton kernel itself
# module.weight += 1.0 # return output * (1 + self.weight)
if not hasattr(module, "variance_epsilon"):
module.variance_epsilon = module.eps # Gemma doesn't use variance_epsilon
pass
# Clear deleted GPU items
import gc
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
return model
pass
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import gc
import math
from typing import Optional, Tuple, List, Union
from ._utils import *
from ._utils import __version__
from torch.nn.functional import scaled_dot_product_attention
from transformers import __version__ as transformers_version
from transformers.models.llama.modeling_llama import (
logger,
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask_for_sdpa,
)
from ..kernels import *
from ..tokenizer_utils import *
if HAS_FLASH_ATTENTION:
from flash_attn import flash_attn_func
# Final patching code
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaModel,
LlamaForCausalLM,
)
# For Pytorch 2.1.1
try:
from transformers.models.llama.modeling_llama import (
LlamaSdpaAttention,
LlamaFlashAttention2,
)
except:
LlamaSdpaAttention = LlamaAttention
LlamaFlashAttention2 = LlamaAttention
pass
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoConfig
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING
from transformers import set_seed as transformers_set_seed
from peft import LoraConfig, TaskType, get_peft_model as _get_peft_model
from peft import PeftModelForCausalLM
from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit
from peft.tuners.lora import Linear4bit as Peft_Linear4bit
from ..save import patch_saving_functions
import re, os, inspect, math, sys
try:
from huggingface_hub.utils import get_token
except:
# Old HF Hub versions <= 0.0.25
from huggingface_hub.utils._token import get_token
pass
def original_apply_qkv(self, X):
Q = self.q_proj(X)
K = self.k_proj(X)
V = self.v_proj(X)
return Q, K, V
pass
def original_apply_o(self, X):
O = self.o_proj(X)
return O
pass
from math import sqrt as math_sqrt
KV_CACHE_INCREMENT = 256 # KV Cache update size
torch_nn_functional_softmax = torch.nn.functional.softmax
# Fix new HF's inference code
def _fast_prepare_inputs_for_generation(self, input_ids, **kwargs,):
if "past_key_values" in kwargs:
input_ids = input_ids[:,[-1]]
kwargs["attention_mask"] = kwargs["attention_mask"][:,[-1]]
if "cache_position" in kwargs:
kwargs["position_ids"] = kwargs["cache_position"]
return { "input_ids" : input_ids, **kwargs, }
pass
def fix_prepare_inputs_for_generation(module):
# Fix prepare_inputs_for_generation
if hasattr(module, "prepare_inputs_for_generation"):
module.prepare_inputs_for_generation = _fast_prepare_inputs_for_generation
pass
pass
torch_matmul = torch.matmul
def LlamaAttention_fast_forward_inference(
self,
hidden_states: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]],
position_ids,
do_prefill = False,
attention_mask = None,
):
"""
https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L406
Fast inference using KV cache.
QK^T can be computed in 4 chunks
[Q, q] @ [K, k].T where q, k are the new tokens.
[QK^T, Qk^T]
[qK^T, qk^T]
Since the attention mask wipes Qk^T, we just get
[QK^T, 0]
[qK^T, qk^T]
Since softmax is row-wise, we get
softmax([QK^T, 0])
softmax([qK^T, qk^T])
We then multiply by [V]
[v]
softmax([QK^T, 0]) [softmax(QK^T)V] *
softmax([qK^T, qk^T]) [softmax([qK^T, qk^T]) @ [V, v]]
But notice * [softmax(QK^T)V] is just the last attention.
We just need to compute the last final row.
This means we can pass in a row of Q, but we need to
remember K and V, which are called the KV cache.
"""
Xn = hidden_states
bsz, _, hd = hidden_states.size()
K1, V1 = past_key_value
dtype = Xn.dtype
n_heads = self.num_heads
n_groups = self.num_key_value_groups
n_kv_heads = self.num_key_value_heads
head_dim = self.head_dim
attention_size = n_heads*head_dim
# assert(n_kv_heads * n_groups == n_heads)
seq_len = K1.shape[-2]
kv_seq_len = seq_len + 1
# Prefill phase
# if not hasattr(self, "paged_attention"):
if do_prefill:
self.paged_attention = torch.empty((KV_CACHE_INCREMENT+seq_len+1, 2, bsz, n_kv_heads, head_dim), dtype = dtype, device = "cuda:0")
self.paged_attention_K = self.paged_attention[:,0]
self.paged_attention_V = self.paged_attention[:,1]
self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = "cuda:0")
self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0")
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
# Mistral Nemo 12b has weird dimensions
if attention_size != self.hidden_size:
self.temp_O = torch.empty((1, bsz, self.hidden_size), dtype = dtype, device = "cuda:0")
else:
self.temp_O = self.temp_QA[1][:,:,:self.hidden_size]
pass
self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0")
self.scalar = 1.0 / math_sqrt(self.head_dim)
self.half_head_dim = head_dim // 2
elif kv_seq_len >= self.paged_attention.shape[0]:
self.paged_attention.resize_((self.paged_attention.shape[0]+KV_CACHE_INCREMENT, 2, bsz, n_kv_heads, head_dim))
self.paged_attention_K = self.paged_attention[:,0]
self.paged_attention_V = self.paged_attention[:,1]
self.attention.resize_((bsz, n_heads, 1, self.attention.shape[-1]+KV_CACHE_INCREMENT))
pass
Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
# cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len)
# Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids)
# Need to do it prior 2 steps before hitting full on short KV cache
# or else error
self.rotary_emb.extend_rope_embedding(Vn, seq_len + 2)
cos, sin = self.rotary_emb.get_cached(kv_seq_len)
cos = cos[position_ids].unsqueeze(1)
sin = sin[position_ids].unsqueeze(1)
h = self.half_head_dim
RH_Q = self.RH_Q
RH_Q[:,:,:,:h] = Qn[:,:,:,h:]
RH_Q[:,:,:,h:] = Qn[:,:,:,:h]
torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h])
Qn *= cos
Qn.addcmul_(RH_Q, sin)
RH_K = RH_Q[:,:n_kv_heads,:,:] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
RH_K[:,:,:,:h] = Kn[:,:,:,h:]
RH_K[:,:,:,h:] = Kn[:,:,:,:h]
torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])
Kn *= cos
Kn.addcmul_(RH_K, sin)
# New KV cache
# Kn = torch.cat([K1, Kn], dim = 2)
# Vn = torch.cat([V1, Vn], dim = 2)
self.paged_attention_K[seq_len] = Kn.permute(2, 0, 1, 3)
self.paged_attention_V[seq_len] = Vn.permute(2, 0, 1, 3)
Kn = self.paged_attention_K[:kv_seq_len].permute(1, 2, 0, 3)
Vn = self.paged_attention_V[:kv_seq_len].permute(1, 2, 0, 3)
# Handle sliding windows
sliding_window = getattr(self.config, "sliding_window", None)
if sliding_window is not None and kv_seq_len > sliding_window:
# From https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L193
slicing_tokens = 1 - sliding_window
Knn = Kn[:, :, slicing_tokens:, :]#.contiguous()
Vnn = Vn[:, :, slicing_tokens:, :]#.contiguous()
else:
Knn, Vnn = Kn, Vn
pass
# Grouped query attention
_, _, cached_len, _ = Knn.shape
if n_groups != 1:
Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)
Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim)
pass
# else:
# Knn, Vnn = Knn, Vnn
# pass
# Attention
if bsz == 1:
Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
# It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows
A = torch_matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len])
# if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype)
A = torch_matmul(A, Vnn, out = Qn)
else:
A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False)
pass
A = A.transpose(1, 2)
A = A.reshape(bsz, 1, attention_size)
A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
return A, (Kn, Vn)
pass
torch_nn_functional_silu = torch.nn.functional.silu
def fast_swiglu_inference(self, X):
# gate = self.gate_proj(X)
# up = self.up_proj(X)
bsz, _, hd = X.shape
# mlp_size = self.config.intermediate_size
# temp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0")
gate = fast_linear_forward(self.gate_proj, X)#, out = temp[0])
up = fast_linear_forward(self. up_proj, X)#, out = temp[1])
gate = torch_nn_functional_silu(gate, inplace = True)
gate *= up
# X = self.down_proj(gate)
down = fast_linear_forward(self.down_proj, gate, out = up[:,:,:hd])
return down
pass
def fast_rms_layernorm_inference(self, X):
old_dtype = X.dtype
XX = X.to(torch.float32)
variance = XX.square().mean(-1, keepdim = True)
variance += self.variance_epsilon
XX *= variance.rsqrt_()
X = XX.to(old_dtype) # Must preserve due to residual
X *= self.weight
return X
pass
def fast_rms_layernorm_inference_gemma(self, X, out_weight = None):
XX = X.to(torch.float32)
variance = XX.square().mean(-1, keepdim = True)
variance += self.variance_epsilon
XX *= variance.rsqrt_()
if out_weight is None:
out_weight = self.weight + 1.0
else:
out_weight[:] = self.weight
out_weight += 1.0
pass
XX *= out_weight
return XX.to(X.dtype)
pass
# Normal layernorm with mean removal
@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
def fast_layernorm_compiled(layernorm, X):
old_dtype = X.dtype
X = X.float()
mean = X.mean(-1, keepdim = True)
Xbar = X - mean
X = Xbar * torch.rsqrt(Xbar.square().mean(-1, keepdim = True) + \
layernorm.variance_epsilon) * \
layernorm.weight.float()
return X.to(old_dtype)
pass
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L320
def LlamaAttention_fast_forward(
self,
hidden_states: torch.Tensor,
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
*args, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# Clear inference
if hasattr(self, "paged_attention"):
del self.paged_attention_K
del self.paged_attention_V
del self.paged_attention
del self.temp_QA
del self.temp_KV
del self.RH_Q
del self.attention
pass
bsz, q_len, _ = hidden_states.size()
n_heads = self.num_heads
n_groups = self.num_key_value_groups
n_kv_heads = self.num_key_value_heads
head_dim = self.head_dim
assert(n_kv_heads * n_groups == n_heads)
Q, K, V = self.apply_qkv(self, hidden_states)
Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
kv_seq_len = K.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
# Extend RoPE dynamically to fit in VRAM
rotary_emb = self.rotary_emb
rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len)
if position_ids is None:
# Useful for LongRoPE
cos, sin = rotary_emb.get_cached(kv_seq_len)
# cos = self.rotary_emb.cos_cached
# sin = self.rotary_emb.sin_cached
Q, K = fast_rope_embedding(Q, K, cos, sin)
else:
cos, sin = rotary_emb(V, seq_len = kv_seq_len)
Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
pass
if past_key_value is not None:
K = torch.cat([past_key_value[0], K], dim = 2)
V = torch.cat([past_key_value[1], V], dim = 2)
pass
past_key_value = (K, V) if use_cache else None
# Attention module
if (not HAS_FLASH_ATTENTION and attention_mask is None):
# Xformers memory efficient attention
# Also has Flash Attention v2 dispatching
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
# Group query attention
if n_groups != 1:
K = K .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
V = V .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
if hidden_states.requires_grad:
K = K.reshape(bsz, kv_seq_len, n_heads, head_dim)
V = V.reshape(bsz, kv_seq_len, n_heads, head_dim)
else:
Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
pass
A = xformers_attention(Q, K, V, attn_bias = causal_mask)
A = A.view(bsz, q_len, n_heads, head_dim)
elif HAS_FLASH_ATTENTION and attention_mask is None:
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
A = flash_attn_func(Q, K, V, causal = True)
else:
# Grouped query attention
if n_groups != 1:
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
K = K.reshape(bsz, n_heads, kv_seq_len, head_dim)
V = V.reshape(bsz, n_heads, kv_seq_len, head_dim)
pass
# Must be contiguous or else results are False!
# https://github.com/pytorch/pytorch/issues/112577
Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous()
# Needs (batch_size, n_heads, seq_len, head_dim)
# is_casual and attention_mask must not be both set!
A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False)
# Go back to (batch_size, seq_len, n_heads, head_dim)
A = A.transpose(1, 2).contiguous()
pass
attn_output = A.reshape(bsz, q_len, n_heads*head_dim)
attn_output = self.apply_o(self, attn_output)
attn_weights = None
return attn_output, attn_weights, past_key_value
pass
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L590
def LlamaDecoderLayer_fast_forward(
self,
hidden_states: torch.Tensor,
causal_mask = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
*args, **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
if use_cache and hasattr(self, "_flag_for_generation"):
residual = hidden_states
hidden_states = fast_rms_layernorm_inference(self.input_layernorm, hidden_states)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)
hidden_states += residual
# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm_inference(self.post_attention_layernorm, hidden_states)
hidden_states = fast_swiglu_inference(self.mlp, hidden_states)
hidden_states += residual
else:
residual = hidden_states
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
pass
outputs = (hidden_states,)
if output_attentions: outputs += (self_attn_weights,)
if use_cache: outputs += (present_key_value,)
return outputs
pass
# https://github.com/unslothai/unsloth/issues/404#issuecomment-2323473452
__DTYPE_MAP = {
"float32": torch.float32,
torch.float32: torch.float32,
"float16": torch.float16,
torch.float16: torch.float16,
"bfloat16": torch.bfloat16,
torch.bfloat16: torch.bfloat16,
}
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825
def LlamaModel_fast_forward(
self,
input_ids: torch.LongTensor,
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
*args, **kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
assert(output_attentions is False)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("Unsloth: You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("Unsloth: You have to specify either decoder_input_ids or decoder_inputs_embeds")
seq_length_with_past = seq_length
# Fix out of bounds tokenization
if hasattr(self, "max_seq_length"):
if seq_length > self.max_seq_length:
logger.warning_once(
f"Unsloth: Input IDs of length {seq_length} > the model's max sequence length of {self.max_seq_length}.\n"\
"We shall truncate it ourselves. It's imperative if you correct this issue first."
)
if input_ids is not None:
input_ids = input_ids[:,:self.max_seq_length]
elif inputs_embeds is not None:
inputs_embeds = inputs_embeds[:,:self.max_seq_length,:]
pass
pass
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
pass
# We already handle KV cache position_ids ourselves.
if False:#(past_key_values_length != 0):
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length,
dtype = torch.int32,
device = "cuda:0",
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
elif position_ids is not None:
position_ids = position_ids.view(-1, seq_length).to(torch.int32)#.long()
else:
position_ids = None
pass
if position_ids is not None:
if position_ids.shape[0] != batch_size:
position_ids = position_ids.repeat((batch_size, 1))
pass
# Embed positions
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# inputs_embeds = inputs_embeds.to(self.config.torch_dtype)
torch_dtype = __DTYPE_MAP.get(self.config.torch_dtype, None)
if torch_dtype is not None:
inputs_embeds = inputs_embeds.to(torch_dtype)
else:
raise TypeError("Unsloth: torch_dtype for models is not bfloat16, float16 or float32!")
pass
# Normalized from Gemma
IS_GEMMA = self.config.model_type.startswith("gemma")
IS_GEMMA2 = self.config.model_type.startswith("gemma2")
IS_COHERE = self.config.model_type.startswith("cohere")
train_embed_tokens = self.embed_tokens.weight.requires_grad
if IS_GEMMA:
# Match Gemma exactly by casting to bfloat16 / float16
# inputs_embeds *= math_sqrt(self.config.hidden_size)
# Ie 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32
# & 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32
normalizer = torch.tensor(math_sqrt(self.config.hidden_size), dtype = inputs_embeds.dtype)
if train_embed_tokens:
# Careful we must not do an inplace op!
inputs_embeds = inputs_embeds * normalizer
else:
inputs_requires_grad = inputs_embeds.requires_grad
if not inputs_embeds.is_leaf:
inputs_embeds = inputs_embeds.detach()
inputs_requires_grad = True
elif inputs_requires_grad:
inputs_embeds.requires_grad_(False)
pass
inputs_embeds *= normalizer
# inputs_embeds *= math_sqrt(self.config.hidden_size)
if inputs_requires_grad: inputs_embeds.requires_grad_(True)
pass
pass
# Fix up attention mask by setting elements to 0
# Specifically for DPO
if self._has_no_labels and (attention_mask is not None) and (past_key_values is None) and \
(not train_embed_tokens):
# Careful for inference the attention_mask is size (1, kv_seq_len)
# Whilst the input_embeds is size (1, 1, 4096)
inputs_requires_grad = inputs_embeds.requires_grad
if not inputs_embeds.is_leaf:
inputs_embeds = inputs_embeds.detach()
inputs_requires_grad = True
elif inputs_requires_grad:
inputs_embeds.requires_grad_(False)
pass
inputs_embeds *= attention_mask.unsqueeze(0).transpose(0, 1).transpose(1, 2)
if inputs_requires_grad: inputs_embeds.requires_grad_(True)
pass
# Ignore attention_mask
if attention_mask is None:
padding_mask = None
elif self.training:
attention_mask = None
padding_mask = None
else:
# if 0 in attention_mask:
# padding_mask = attention_mask
# else:
padding_mask = None
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window = getattr(self.config, "sliding_window", None),
)
pass
hidden_states = inputs_embeds
if past_key_values is None and self.training:
use_cache = False
# if use_cache:
# logger.warning_once(
# "Unsloth: `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`"
# )
# use_cache = False
pass
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
# Gradient checkpointing methods (ie sqrt)
if hasattr(self, "_gradient_checkpointing_boundaries"):
boundaries = self._gradient_checkpointing_boundaries
else:
boundaries = None
pass
# Check checkpointing method
gradient_checkpointing = False
offloaded_gradient_checkpointing = False
if (self.gradient_checkpointing and self.training and not use_cache):
gradient_checkpointing = True
if output_attentions is False and hasattr(self, "_offloaded_gradient_checkpointing"):
offloaded_gradient_checkpointing = True
pass
# Gemma2 has alternating SWA and global attn
if IS_GEMMA2:
if HAS_FLASH_ATTENTION_SOFTCAPPING and attention_mask is None:
self.SWA_mask = True
self.GA_mask = False
elif attention_mask is not None:
self.SWA_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window = self.config.sliding_window,
)
self.GA_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window = None,
)
elif not hasattr(self, "SWA_mask"):
if HAS_FLEX_ATTENTION:
# Use Flex Attention instead!
self.SWA_mask = create_flex_attention_sliding_window_mask(self.max_seq_length, self.config.sliding_window)
self.GA_mask = create_flex_attention_causal_mask(self.max_seq_length)
else:
n = self.max_seq_length # self.config.max_position_embeddings
# masked_fill is making stuff slower!
# self. GA_mask = create_boolean_mask(n = n, sliding_window = 0)
# self.SWA_mask = create_boolean_mask(n = n, sliding_window = self.config.sliding_window)
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
self.SWA_mask = AttentionMaskConverter(
is_causal = True,
sliding_window = self.config.sliding_window,
)\
.to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda:0",)\
.squeeze(0).squeeze(0)
self.GA_mask = AttentionMaskConverter(
is_causal = True,
)\
.to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda:0",)\
.squeeze(0).squeeze(0)
pass
pass
pass
# Go through every layer!
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states: all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
mask = causal_mask
if IS_GEMMA2: mask = self.SWA_mask if (idx % 2 == 0) else self.GA_mask
if offloaded_gradient_checkpointing:
hidden_states = Unsloth_Offloaded_Gradient_Checkpointer.apply(
decoder_layer,
hidden_states,
mask,
attention_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
)[0]
elif gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions, padding_mask = padding_mask)
return custom_forward
pass
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
mask,
attention_mask,
position_ids,
use_reentrant = True,
preserve_rng_state = False,
)
hidden_states = layer_outputs[0]
else:
layer_outputs = decoder_layer(
hidden_states,
causal_mask=mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)
hidden_states = layer_outputs[0]
pass
if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions: all_self_attns += (layer_outputs[1],)
pass
# Final layernorm
if use_cache:
hidden_states = \
(fast_rms_layernorm_inference_gemma if IS_GEMMA else fast_rms_layernorm_inference)\
(self.norm, hidden_states)
elif IS_COHERE:
hidden_states = self.norm(hidden_states)
else:
hidden_states = fast_rms_layernorm(self.norm, hidden_states, gemma = IS_GEMMA)
pass
if output_hidden_states: all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
pass
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825
def LlamaModel_fast_forward_inference(
self,
input_ids,
past_key_values,
position_ids,
attention_mask = None,
):
input_ids = input_ids[:,:self.max_seq_length]
hidden_states = self.model.embed_tokens(input_ids)
hidden_states = hidden_states.to(self.config.torch_dtype)
bsz, q_len, hd = hidden_states.shape
seq_len = past_key_values[0][0].shape[-2]
if bsz != 1:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(bsz, q_len),
hidden_states,
seq_len,
sliding_window = getattr(self.config, "sliding_window", None),
)
else:
attention_mask = None
pass
next_decoder_cache = []
for idx, decoder_layer in enumerate(self.model.layers):
residual = hidden_states
hidden_states = fast_rms_layernorm_inference(decoder_layer.input_layernorm, hidden_states)
hidden_states, present_key_value = LlamaAttention_fast_forward_inference(
decoder_layer.self_attn,
hidden_states = hidden_states,
past_key_value = past_key_values[idx],
position_ids = position_ids,
attention_mask = attention_mask,
do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
)
hidden_states += residual
residual = hidden_states
hidden_states = fast_rms_layernorm_inference(decoder_layer.post_attention_layernorm, hidden_states)
hidden_states = fast_swiglu_inference(decoder_layer.mlp, hidden_states)
hidden_states += residual
next_decoder_cache.append(present_key_value)
pass
hidden_states = fast_rms_layernorm_inference(self.model.norm, hidden_states)
return BaseModelOutputWithPast(
last_hidden_state = hidden_states,
past_key_values = next_decoder_cache,
hidden_states = [],
attentions = [],
)
pass
def CausalLM_fast_forward(fast_forward_inference):
def _CausalLM_fast_forward(
self,
input_ids: torch.LongTensor = None,
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
num_logits_to_keep: Optional[int] = 0,
*args, **kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
if past_key_values is not None:
outputs = fast_forward_inference(
self,
input_ids,
past_key_values,
position_ids = position_ids,
attention_mask = attention_mask,
)
else:
causal_mask = xformers.attn_bias.LowerTriangularMask()
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
self.model._has_no_labels = labels is None
outputs = self.model(
input_ids=input_ids,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pass
hidden_states = outputs[0]
bsz, q_len, hd = hidden_states.shape
lm_head = self.lm_head.weight
if bsz == 1 and q_len == 1:
logits = torch.mv(lm_head, hidden_states.ravel().to(lm_head.dtype))
logits = logits.unsqueeze(0).unsqueeze(0)
elif num_logits_to_keep != 0:
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :].to(lm_head.dtype))
else:
logits = self.lm_head(hidden_states.to(lm_head.dtype))
pass
torch_dtype = __DTYPE_MAP.get(self.config.torch_dtype, None)
if torch_dtype is not None:
logits = logits.to(torch_dtype)
else:
raise TypeError("Unsloth: torch_dtype for models is not bfloat16, float16 or float32!")
pass
loss = None
logit_softcapping = getattr(self.config, "final_logit_softcapping", 0)
logit_scaling = getattr(self.config, "logit_scale", 0)
if labels is not None:
shift_logits = logits
if not hasattr(self, "extra_ignored_labels"):
# Fixes https://github.com/unslothai/unsloth/issues/10
self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = "cuda:0")
pass
shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
loss = fast_cross_entropy_loss(
logits = shift_logits,
labels = shift_labels,
logit_softcapping = logit_softcapping,
logit_scaling = logit_scaling,
n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None),
)
else:
if logit_scaling != 0:
if logits.requires_grad:
logits = logit_scaling * logits
else:
logits *= logit_scaling
pass
pass
if logit_softcapping != 0:
if logits.requires_grad:
logits = (1.0 / logit_softcapping) * logits
logits = torch.tanh(logits)
logits = logit_softcapping * logits
else:
logits *= (1.0 / logit_softcapping)
torch.tanh(logits, out = logits)
logits *= logit_softcapping
pass
pass
pass
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
pass
return _CausalLM_fast_forward
pass
@torch._disable_dynamo
def PeftModelForCausalLM_fast_forward(
self,
input_ids=None,
causal_mask=None,
attention_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
task_ids=None,
num_logits_to_keep=0,
**kwargs,
):
return self.base_model(
input_ids=input_ids,
causal_mask=causal_mask,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
num_logits_to_keep=num_logits_to_keep,
**kwargs,
)
pass
# Solves https://github.com/unslothai/unsloth/issues/168
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
# Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
# https://github.com/huggingface/transformers/pull/27931
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
class LlamaRotaryEmbedding(torch.nn.Module):
# Fixes https://github.com/huggingface/transformers/pull/28837
# https://github.com/microsoft/DeepSpeed/issues/4932
# The precision of RoPE buffers is not correct, so we cast to int64.
def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None,
config = None, # [TODO] Hack to pass in config - need to remove later
):
super().__init__()
if config is not None:
# [TODO] Hack to pass in config - need to remove later
base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = int((config.hidden_size // config.num_attention_heads))
device = "cuda"
max_position_embeddings = config.max_position_embeddings
pass
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
# Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this
self.current_rope_size = min(4 * 8192, self.max_position_embeddings)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype())
pass
def _set_cos_sin_cache(self, seq_len, device, dtype):
# Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
# in FP32. They are applied (multiplied) in FP32 as well.
self.current_rope_size = seq_len
inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() / self.dim)
)
t = torch.arange(self.current_rope_size, device="cpu", dtype=torch.int64).float()
freqs = torch.outer(t, inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype=dtype, device=device, non_blocking=True), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype=dtype, device=device, non_blocking=True), persistent=False)
pass
def forward(self, x, position_ids=None, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.current_rope_size:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:seq_len].to(dtype = x.dtype),
self.sin_cached[:seq_len].to(dtype = x.dtype),
)
pass
def get_cached(self, seq_len = None):
return self.cos_cached, self.sin_cached
pass
def extend_rope_embedding(self, x, seq_len):
if seq_len <= self.current_rope_size: return
# Iteratively grow by increments of 8192
self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192
self._set_cos_sin_cache(self.current_rope_size, device = "cuda:0", dtype = x.dtype)
pass
pass
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
# Fixes https://github.com/huggingface/transformers/pull/28837
# https://github.com/microsoft/DeepSpeed/issues/4932
# The precision of RoPE buffers is not correct, so we cast to int64.
def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0,
config = None, # [TODO] Hack to pass in config - need to remove later
):
self.scaling_factor = scaling_factor
super().__init__(dim = dim, max_position_embeddings = max_position_embeddings, base = base, device = device, config = config)
pass
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.current_rope_size = seq_len
inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() / self.dim)
)
t = torch.arange(self.current_rope_size, device="cpu", dtype=torch.int64).float()
t = t / self.scaling_factor
freqs = torch.outer(t, inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype=dtype, device=device, non_blocking=True), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype=dtype, device=device, non_blocking=True), persistent=False)
pass
pass
# See https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py#L736
# For Llama 3.1
class LlamaExtendedRotaryEmbedding(torch.nn.Module):
def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None,
config = None, # [TODO] Hack to pass in config - need to remove later
):
super().__init__()
if config is not None:
# [TODO] Hack to pass in config - need to remove later
base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = int((config.hidden_size // config.num_attention_heads))
device = "cuda"
max_position_embeddings = config.max_position_embeddings
pass
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
# Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this
self.current_rope_size = min(4 * 8192, self.max_position_embeddings)
# Normal Llama-3 RoPE
inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() / self.dim)
)
inv_freq = self.apply_scaling(inv_freq)
self.register_buffer("inv_freq", inv_freq, persistent = False)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype())
pass
def _set_cos_sin_cache(self, seq_len, device, dtype):
# Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
# in FP32. They are applied (multiplied) in FP32 as well.
self.current_rope_size = seq_len
t = torch.arange(self.current_rope_size, device=self.inv_freq.device, dtype=torch.int64).float()
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype=dtype, device=device, non_blocking=True), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype=dtype, device=device, non_blocking=True), persistent=False)
pass
# From https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py#L41
def apply_scaling(self, freqs: torch.Tensor):
# Values obtained from grid search
scale_factor = 8
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192 # original llama3 length
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scale_factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
pass
def forward(self, x, position_ids=None, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.current_rope_size:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:seq_len].to(dtype = x.dtype),
self.sin_cached[:seq_len].to(dtype = x.dtype),
)
pass
def get_cached(self, seq_len = None):
return self.cos_cached, self.sin_cached
pass
def extend_rope_embedding(self, x, seq_len):
if seq_len <= self.current_rope_size: return
# Iteratively grow by increments of 8192
self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192
self._set_cos_sin_cache(self.current_rope_size, device = "cuda:0", dtype = x.dtype)
pass
pass
class LongRopeRotaryEmbedding(torch.nn.Module):
# For Phi 3.5 128K https://huggingface.co/microsoft/Phi-3.5-mini-instruct/blob/main/modeling_phi3.py
def __init__(self,
dim = None,
max_position_embeddings = 131072,
original_max_position_embeddings = 4096,
base = 10000,
short_factor = None,
long_factor = None,
device = None,
config = None, # [TODO] Hack to pass in config - need to remove later
):
super().__init__()
assert(short_factor is not None)
assert(long_factor is not None)
assert(type(original_max_position_embeddings) is int)
if config is not None:
# [TODO] Hack to pass in config - need to remove later
base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = int((config.hidden_size // config.num_attention_heads))
device = "cuda"
max_position_embeddings = config.max_position_embeddings
pass
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.original_max_position_embeddings = original_max_position_embeddings
self.base = base
# Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this
self.current_rope_size = min(original_max_position_embeddings, self.max_position_embeddings)
# Long RoPE similar to RoPE except short sequences have 1 cos / sin
# and long sequences have another cos / sin
inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() / self.dim
short_factor = torch.tensor(short_factor, device = "cpu", dtype = torch.float32)
long_factor = torch.tensor(long_factor, device = "cpu", dtype = torch.float32)
short_inv_freq = 1.0 / (short_factor * self.base**inv_freq_shape)
long_inv_freq = 1.0 / (long_factor * self.base**inv_freq_shape)
# Phi-3 Scale factor
scale = self.max_position_embeddings / self.original_max_position_embeddings
if scale <= 1.0:
scaling_factor = 1.0
else:
scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
pass
self.scaling_factor = scaling_factor
# Short and long inv_freq
self.register_buffer("short_inv_freq", short_inv_freq, persistent = False)
self.register_buffer("long_inv_freq", long_inv_freq, persistent = False)
# Build here to make `torch.jit.trace` work.
# self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype())
# Short sequences
dtype = torch.bfloat16 if is_bfloat16_supported() else torch.float16
t = torch.arange(original_max_position_embeddings, device=self.short_inv_freq.device, dtype=torch.int64).float()
freqs = torch.outer(t, self.short_inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
cos_cached = (emb.cos() * self.scaling_factor).to(dtype=dtype, device=device, non_blocking=True)
sin_cached = (emb.sin() * self.scaling_factor).to(dtype=dtype, device=device, non_blocking=True)
self.register_buffer("short_cos_cached", cos_cached, persistent=False)
self.register_buffer("short_sin_cached", sin_cached, persistent=False)
pass
def _set_cos_sin_cache(self, seq_len, device, dtype):
# Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
# in FP32. They are applied (multiplied) in FP32 as well.
self.current_rope_size = seq_len
t = torch.arange(self.current_rope_size, device=self.long_inv_freq.device, dtype=torch.int64).float()
# Long sequences
freqs = torch.outer(t, self.long_inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
cos_cached = (emb.cos() * self.scaling_factor).to(dtype=dtype, device=device, non_blocking=True)
sin_cached = (emb.sin() * self.scaling_factor).to(dtype=dtype, device=device, non_blocking=True)
self.register_buffer("long_cos_cached", cos_cached, persistent=False)
self.register_buffer("long_sin_cached", sin_cached, persistent=False)
pass
def forward(self, x, position_ids=None, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.current_rope_size:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
if seq_len < self.original_max_position_embeddings:
return (
self.short_cos_cached[:seq_len].to(dtype = x.dtype),
self.short_sin_cached[:seq_len].to(dtype = x.dtype),
)
else:
return (
self.long_cos_cached[:seq_len].to(dtype = x.dtype),
self.long_sin_cached[:seq_len].to(dtype = x.dtype),
)
pass
pass
def get_cached(self, seq_len = None):
if seq_len < self.original_max_position_embeddings:
return self.short_cos_cached, self.short_sin_cached
return self.long_cos_cached, self.long_sin_cached
pass
def extend_rope_embedding(self, x, seq_len):
if seq_len <= self.current_rope_size: return
# Iteratively grow by increments of 8192
self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192
self._set_cos_sin_cache(self.current_rope_size, device = "cuda:0", dtype = x.dtype)
pass
pass
def _wrap_fast_inference(generate, device_type, dtype, model):
# Wraps inference with bfloat16 / float16
@torch.inference_mode
def _fast_generate(*args, **kwargs):
# Set a flag for generation!
internal_model = model
while hasattr(internal_model, "model"):
internal_model._flag_for_generation = True
internal_model = internal_model.model
pass
internal_model._flag_for_generation = True
# Must patch accelerate for Xformers
if accelerate_new_send_to_device is not None:
import accelerate.utils.operations
accelerate.utils.operations.send_to_device = accelerate_new_send_to_device
pass
# For newer HF
kwargs["cache_implementation"] = "dynamic"
# For num_logits_to_keep
kwargs["num_logits_to_keep"] = 1
# Remove token_type_ids
kwargs.pop("token_type_ids", None)
# Check pad_token
model_eos_token_id = getattr(model.config, "eos_token_id", None)
if model_eos_token_id is not None and hasattr(model_eos_token_id, "__iter__"):
model_eos_token_id = model_eos_token_id[0]
kwargs["pad_token_id"] = kwargs.pop("pad_token_id", model_eos_token_id)
# Set pad token
# old_pad_token_id = getattr(model.config, "pad_token_id", None)
# old_eos_token_id = getattr(model.config, "eos_token_id", None)
# model.config.pad_token_id = old_eos_token_id
# Autocasted
with torch.autocast(device_type = device_type, dtype = dtype):
output = generate(*args, **kwargs)
pass
# Revert
# model.config.pad_token_id = old_pad_token_id
# Unset a flag for generation!
internal_model = model
while hasattr(internal_model, "model"):
if hasattr(internal_model, "_flag_for_generation"): del internal_model._flag_for_generation
internal_model = internal_model.model
pass
if hasattr(internal_model, "_flag_for_generation"): del internal_model._flag_for_generation
# Return accelerate back
if accelerate_new_send_to_device is not None:
accelerate.utils.operations.send_to_device = accelerate_old_send_to_device
pass
return output
pass
return _fast_generate
pass
class FastLlamaModel:
@staticmethod
def pre_patch():
init_name, function = patch_llama_rope_scaling(
model_name = "llama",
rope_module = LlamaRotaryEmbedding,
scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
extended_rope_module = LlamaExtendedRotaryEmbedding,
attention_module = LlamaAttention,
longrope_module = LongRopeRotaryEmbedding,
)
if init_name is not None:
exec(function, globals())
LlamaAttention.__init__ = eval(init_name)
pass
LlamaAttention .forward = LlamaAttention_fast_forward
LlamaSdpaAttention .forward = LlamaAttention_fast_forward
LlamaFlashAttention2.forward = LlamaAttention_fast_forward
LlamaDecoderLayer .forward = LlamaDecoderLayer_fast_forward
LlamaModel .forward = LlamaModel_fast_forward
LlamaForCausalLM .forward = CausalLM_fast_forward(LlamaModel_fast_forward_inference)
PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward
fix_prepare_inputs_for_generation(LlamaForCausalLM)
# Solves https://github.com/unslothai/unsloth/issues/168
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
# Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
# https://github.com/huggingface/transformers/pull/27931
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
import transformers.models.llama.modeling_llama
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = LlamaRotaryEmbedding
transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding = LlamaLinearScalingRotaryEmbedding
return
pass
@staticmethod
def from_pretrained(
model_name = "unsloth/llama-3-8b-bnb-4bit",
max_seq_length = None,
dtype = None,
load_in_4bit = True,
token = None,
device_map = "sequential",
rope_scaling = None,
fix_tokenizer = True,
model_patcher = None,
tokenizer_name = None,
trust_remote_code = False,
**kwargs,
):
if trust_remote_code:
print(
"Unsloth: WARNING `trust_remote_code` is True.\n"\
"Are you certain you want to do remote code execution?"
)
pass
if token is None: token = get_token()
if model_patcher is None: model_patcher = FastLlamaModel
SUPPORTS_BFLOAT16 = is_bfloat16_supported()
gpu_stats = torch.cuda.get_device_properties(0)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
statistics = \
f"==((====))== Unsloth {__version__}: Fast {model_patcher.__name__[4:-5]} patching. Transformers = {transformers_version}.\n"\
f" \\\ /| GPU: {gpu_stats.name}. Max memory: {max_memory} GB. Platform = {platform_system}.\n"\
f"O^O/ \_/ \\ Pytorch: {torch.__version__}. CUDA = {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit = {torch.version.cuda}.\n"\
f"\ / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\
f' "-____-" Free Apache license: http://github.com/unslothai/unsloth'
print(statistics)
# Warn about fast transfers
old_hf_transfer = os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "0")
if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "0") == "1":
print("Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!")
pass
# Return old flag
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_hf_transfer
model_patcher.pre_patch()
get_statistics() # For debugging - we use a download counter to see if environments are not breaking
if dtype is None:
dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16
elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16:
logger.warning_once("Device does not support bfloat16. Will change to float16.")
dtype = torch.float16
assert(dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32)
# RoPE Scaling
model_config = AutoConfig.from_pretrained(model_name, token = token)
model_max_seq_length = model_config.max_position_embeddings
# Check if RoPE Scaling is even allowed
model_function = MODEL_FOR_CAUSAL_LM_MAPPING[model_config.__class__]
has_rope_scaling = False
try:
with open(inspect.getfile(model_function), "r") as file:
has_rope_scaling = "self.config.rope_scaling" in file.read()
except: pass
has_rope_scaling = True
# If max_seq_length is not specified, use maximum fron config
if max_seq_length is None:
max_seq_length = model_max_seq_length
pass
if (rope_scaling is None) and (max_seq_length > model_max_seq_length):
rope_scaling = max_seq_length / model_max_seq_length
logger.warning_once(
f"Unsloth: {model_name} can only handle sequence lengths of at most "\
f"{model_max_seq_length}.\nBut with kaiokendev's RoPE scaling of "\
f"{round(rope_scaling, 3)}, it can be magically be extended to "\
f"{max_seq_length}!"
)
# Warn RoPE scaling isn't allowed
if not has_rope_scaling:
raise RuntimeError(
"However, {model_name} doesn't support RoPE Scaling!\n"\
"Please file a feature request at https://github.com/unslothai/unsloth."
)
pass
rope_scaling = {"type": "linear", "factor": rope_scaling,}
# Add to kwargs
kwargs["rope_scaling"] = rope_scaling
pass
# We currently only support NVIDIA GPUs - AMD / Intel is a work in progress!
pre_check = check_nvidia()
bnb_config = None
if load_in_4bit:
bnb_config = BitsAndBytesConfig(
load_in_4bit = True,
bnb_4bit_use_double_quant = True,
bnb_4bit_quant_type = "nf4",
bnb_4bit_compute_dtype = dtype,
)
pass
# https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/discussions/12
# RoPE Scaling's max_position_embeddings must be updated
max_position_embeddings = max(max_seq_length, model_max_seq_length)
kwargs.pop("attn_implementation", None); # No need since we auto call it
# Cannot be None, since HF now checks for the config
if load_in_4bit: kwargs["quantization_config"] = bnb_config
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map = device_map,
torch_dtype = dtype,
# quantization_config = bnb_config,
token = token,
max_position_embeddings = max_position_embeddings,
trust_remote_code = trust_remote_code,
attn_implementation = "eager",
**kwargs,
)
# Return old flag
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_hf_transfer
# We currently only support NVIDIA GPUs - AMD / Intel is a work in progress!
post_check = check_nvidia()
# Counteract saved tokenizers
tokenizer_name = model_name if tokenizer_name is None else tokenizer_name
tokenizer = load_correct_tokenizer(
tokenizer_name = tokenizer_name,
model_max_length = max_position_embeddings,
padding_side = "right",
token = token,
trust_remote_code = trust_remote_code,
fix_tokenizer = fix_tokenizer,
)
model, tokenizer = patch_tokenizer(model, tokenizer)
model = model_patcher.post_patch(model)
# Patch up QKV / O and MLP
for idx, layer in enumerate(model.model.layers):
layer.self_attn.apply_qkv = original_apply_qkv
layer.self_attn.apply_o = original_apply_o
pass
# Patch Trainer
from transformers.trainer import Trainer
try:
if Trainer._inner_training_loop.__name__ != "_fast_inner_training_loop":
inner_training_loop = inspect.getsource(Trainer._inner_training_loop)
Trainer._original_training_loop = inner_training_loop
else:
inner_training_loop = Trainer._original_training_loop
except:
raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!')
pass
if ((post_check - pre_check) >= 1).sum() > 1:
raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!')
import transformers.trainer
items_in_trainer = dir(transformers.trainer)
good_items = []
for item in items_in_trainer:
# TODO: Support Deepspeed
if item.startswith(("deepspeed", "xm", "met", "smp")): continue
if item in inner_training_loop: good_items.append(item)
pass
exec("from transformers.trainer import (" + ", ".join(x for x in good_items) + ")", globals())
start = re.search('logger\.info\([\"\'].+?Running training', inner_training_loop).span(0)[0]
end = inner_training_loop.find("\n\n", start)
original_debug = inner_training_loop[start:end]
spaces = re.search('\n([\s\t]{1,})', original_debug).group(0)[1:]
front_spaces = re.match('([\s\t]{1,})', inner_training_loop).group(0)
debug_info = """debug_info = \\
f"==((====))== Unsloth - 2x faster free finetuning | Num GPUs = {args.world_size}\\n"\\
f" \\\\\\ /| Num examples = {num_examples:,} | Num Epochs = {num_train_epochs:,}\\n"\\
f"O^O/ \\_/ \\ Batch size per device = {self._train_batch_size:,} | Gradient Accumulation steps = {args.gradient_accumulation_steps}\\n"\\
f"\\ / Total batch size = {total_train_batch_size:,} | Total steps = {max_steps:,}\\n"\\
f' "-____-" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}'
logger.warning(debug_info)
import subprocess, re, gc, numpy as np
a = np.array([0,])
try:
a = subprocess.check_output('nvidia-smi --query-gpu=memory.used --format=csv', shell = True)
a = re.findall(rb'([\\d]{1,})[\\s]{1,}M', a)
a = np.array([int(x.decode('utf-8'))/1024 for x in a])
except:
if not torch.cuda.is_available():
raise RuntimeError('Unsloth: We do not support AMD / Intel machines yet - it is a work in progress!')
if ((a - PRE_CHECK) >= 1).sum() > 1:
raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!')
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()"""
debug_info = debug_info.split('\n')
debug_info = "\n".join([debug_info[0]] + [spaces + x[8:] for x in debug_info[1:]])
inner_training_loop = inner_training_loop.replace(original_debug, debug_info)
debug_info = """n_total_devices = total_train_batch_size // \\
args.gradient_accumulation_steps // self._train_batch_size
if n_total_devices > 1:
logger.warning_once('Unsloth currently does not support multi GPU setups - but we are working on it!')
debug_info ="""
debug_info = debug_info.split('\n')
debug_info = "\n".join([debug_info[0]] + [spaces + x[8:] for x in debug_info[1:]])
inner_training_loop = inner_training_loop.replace("debug_info =", debug_info, 1)
front_spaces = re.match(r"[\t\s]{1,}", inner_training_loop).group(0)
inner_training_loop = re.sub(r"^" + front_spaces, "", inner_training_loop, flags = re.MULTILINE)
inner_training_loop = inner_training_loop.replace(
"train_dataloader = tpu_spmd_dataloader(train_dataloader)",
"raise RuntimeError('Unsloth: TPUs are not yet supported!')"
)
inner_training_loop = inner_training_loop.replace(
"self.accelerator.free_memory()",
"self.accelerator.free_memory()\n" + \
front_spaces + "if self.is_deepspeed_enabled:"\
"raise RuntimeError('Unsloth: Deepspeed is not yet supported!')\n", 1,
)
check_batches = """train_dataloader = self.get_train_dataloader()
ga = args.gradient_accumulation_steps
bsz = self._train_batch_size
total_batches = bsz * ga * args.world_size
n_total_devices = total_batches // ga // bsz
if n_total_devices > 1:
logger.warning_once('Unsloth currently does not support multi GPU setups - but we are working on it!')
divisor = n_total_devices / 1
bsz = self._train_batch_size = max(int(bsz / divisor), 1)
if total_batches // ga // bsz > 1:
divisor = n_total_devices / 1
ga = args.gradient_accumulation_steps = max(int(ga / divisor), 1)"""
check_batches = check_batches.split('\n')
check_batches = "\n".join([check_batches[0]] + [front_spaces + x[8:] for x in check_batches[1:]])
inner_training_loop = inner_training_loop.replace(
"train_dataloader = self.get_train_dataloader()",
check_batches, 1,
)
inner_training_loop = inner_training_loop.replace(
"_inner_training_loop",
"_fast_inner_training_loop", 1,
)
exec(inner_training_loop, globals())
Trainer._inner_training_loop = _fast_inner_training_loop
inner_training_loop = inner_training_loop.replace(
"is_torch_tpu_available()",
"False",
)
if "n_total_devices >" not in inner_training_loop:
raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!')
pass
inner_training_loop = inner_training_loop.replace(
"is_sagemaker_mp_enabled()",
"False",
)
exec(inner_training_loop, globals())
Trainer._inner_training_loop = _fast_inner_training_loop
# Save max_seq_length
model.max_seq_length = max_position_embeddings
internal_model = model
while hasattr(internal_model, "model"):
internal_model.max_seq_length = max_position_embeddings
internal_model = internal_model.model
pass
internal_model.max_seq_length = max_position_embeddings
# We check the tokenizer first for errors
if fix_tokenizer:
tokenizer = check_tokenizer(
model = model,
tokenizer = tokenizer,
model_name = model_name,
model_max_length = max_position_embeddings,
padding_side = "right",
token = token,
)
pass
patch_saving_functions(tokenizer)
# Fix up config for transformers uploading PEFT
# Not necessary anymore since we require transformers>=4.37!
if False:
name = model.config._name_or_path
if name.startswith("unsloth/") and name.endswith("-bnb-4bit"):
name = name[:len(name) - len("-bnb-4bit")]
model.config.update({"_name_or_path" : name})
pass
pass
# Log Unsloth version for future fastpaths for inference
model.config.update({"unsloth_version" : __version__})
# Add save modules
patch_saving_functions(model)
Trainer._inner_training_loop = _fast_inner_training_loop
# Fix gradient accumulation
patch_gradient_accumulation_fix(Trainer)
# Save tokenizer for inference purposes
tokenizer.padding_side = "left" # Force inference
internal_model = model
while hasattr(internal_model, "model"):
internal_model._saved_temp_tokenizer = tokenizer
internal_model = internal_model.model
pass
internal_model._saved_temp_tokenizer = tokenizer
# Also fix torch_dtype
internal_model = model
while hasattr(internal_model, "model"):
if hasattr(internal_model, "config"):
if internal_model.config.torch_dtype == "float32":
internal_model.config.torch_dtype = torch.float32
elif internal_model.config.torch_dtype == "bfloat16":
internal_model.config.torch_dtype = torch.bfloat16
elif internal_model.config.torch_dtype == "float16":
internal_model.config.torch_dtype = torch.float16
pass
pass
internal_model = internal_model.model
pass
if hasattr(internal_model, "config"):
if internal_model.config.torch_dtype == "float32":
internal_model.config.torch_dtype = torch.float32
elif internal_model.config.torch_dtype == "bfloat16":
internal_model.config.torch_dtype = torch.bfloat16
elif internal_model.config.torch_dtype == "float16":
internal_model.config.torch_dtype = torch.float16
pass
pass
return model, tokenizer
pass
@staticmethod
def post_patch(model):
# Patch model
layers = model.model.layers
# Torch.compile fails on embedding matrix??
# Workaround randomnly fixes it for torch versions < 2.
model.set_input_embeddings(torch.nn.Embedding.from_pretrained(model.get_input_embeddings().weight))
model.config.update({"unsloth_version" : __version__})
# We also do this for the lm_head
lm_head = torch.nn.Linear(1, 1, bias = None)
del lm_head.weight
lm_head.weight = model.get_output_embeddings().weight
lm_head.in_features = lm_head.weight.shape[1]
lm_head.out_features = lm_head.weight.shape[0]
model.lm_head = lm_head
# Also patch all dtypes - BnB seems to not allocate the correct type?
# BnB default dtype seems to be float16!
correct_dtype = lm_head.weight.dtype
for name, module in model.named_modules():
if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)):
weight = module.weight
quant_state = weight.quant_state
if type(quant_state) is list:
# BnB seems to have float16 as default!
module.weight.quant_state[2] = correct_dtype # Cast to correct dtype
else:
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
quant_state.dtype = correct_dtype
pass
pass
# Downcast RoPE embedding to correct data type
if (name.endswith("rotary_emb") or hasattr(module, "cos_cached")):
if hasattr(module, "cos_cached") and \
(module.cos_cached.dtype != correct_dtype):
module.cos_cached = module.cos_cached.to(correct_dtype)
module.sin_cached = module.sin_cached.to(correct_dtype)
elif hasattr(module, "short_cos_cached") and \
(module.short_cos_cached.dtype != correct_dtype):
module.short_cos_cached = module.short_cos_cached.to(correct_dtype)
module.short_sin_cached = module.short_sin_cached.to(correct_dtype)
pass
pass
pass
# Clear deleted GPU items
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
return model
pass
@staticmethod
def get_peft_model(
model,
r = 16,
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_alpha = 16,
lora_dropout = 0,
bias = "none",
layers_to_transform = None,
layers_pattern = None,
use_gradient_checkpointing = True,
random_state = 3407,
max_seq_length = 2048, # not used anymore
use_rslora = False,
modules_to_save = None,
init_lora_weights = True,
loftq_config = {},
temporary_location = "_unsloth_temporary_saved_buffers",
**kwargs,
):
transformers_set_seed(random_state)
if isinstance(model, PeftModelForCausalLM):
# Check if exactly the same and then pass through!
assert(hasattr(model, "peft_config"))
peft_config = model.peft_config["default"].to_dict()
check_parameters = [
"r", "lora_alpha", "lora_dropout",
"bias", "layers_to_transform", "layers_pattern",
"use_rslora", "init_lora_weights",
]
check_all = True
for param in check_parameters:
check_all = check_all and (peft_config[param] == eval(param))
pass
# Check save_modules
old_target_modules = list(peft_config["target_modules"])
modules_to_save = peft_config["modules_to_save"]
if modules_to_save is None: modules_to_save = {}
modules_to_save = list(modules_to_save)
old_target_modules += modules_to_save
# Combine all
new_target_modules = list(target_modules) + \
list(modules_to_save if modules_to_save is not None else [])
# Now check!
new_target_modules = set(new_target_modules)
check_all = check_all and (
len(set(old_target_modules) ^ new_target_modules) == 0
)
check_all = check_all and (
(loftq_config == {} or loftq_config is None) and \
(peft_config["loftq_config"] == {} or peft_config["loftq_config"] is None)
)
if check_all:
# Simply pass through!
logger.warning(
"Unsloth: Already have LoRA adapters! We shall skip this step."
)
# Offload!
# [TODO] First offload lm_head and embed_tokens to CPU (should be disk!!)
if "embed_tokens" in new_target_modules:
print("Unsloth: Training embed_tokens in mixed precision to save VRAM")
dtype = model.model.model.embed_tokens.modules_to_save.default.weight.dtype
model.model.model.embed_tokens.modules_to_save.default\
.to(device = "cuda:0", dtype=(dtype if (dtype != torch.float16) else torch.float32), non_blocking = True)
model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True)
# [TODO] Move old embed_tokens to CPU - should be disk!
model.model.model.embed_tokens.original_module\
.to(device = "cpu", non_blocking = True)
model.model.model.embed_tokens.original_module.requires_grad_(False)
pass
if "lm_head" in new_target_modules:
print("Unsloth: Training lm_head in mixed precision to save VRAM")
dtype = model.model.model.lm_head.modules_to_save.default.weight.dtype
model.model.lm_head.modules_to_save.default\
.to(device = "cuda:0", dtype=(dtype if (dtype != torch.float16) else torch.float32), non_blocking = True)
model.model.lm_head.modules_to_save.default.requires_grad_(True)
# [TODO] Move old lm_head to CPU - should be disk!
model.model.lm_head.original_module\
.to(device = "cpu", non_blocking = True)
model.model.lm_head.original_module.requires_grad_(False)
pass
return model
else:
raise TypeError(
"Unsloth: Your model already has LoRA adapters. Your new parameters are different."
)
pass
pass
if loftq_config is None: loftq_config = {}
signature = str(inspect.signature(LoraConfig))
SUPPORTS_LOFTQ = "loftq_config" in signature
SUPPORTS_RSLORA = "use_rslora" in signature
assert(max_seq_length <= model.max_seq_length)
if lora_dropout != 0:
logger.warning_once(
f"Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = {lora_dropout}.\n"\
f"Unsloth will patch all other layers, except LoRA matrices, causing a performance hit."
)
pass
if bias != "none":
logger.warning_once(
f"Unsloth: bias = `none` is supported for fast patching. You are using bias = {bias}.\n"\
f"Unsloth will patch all other layers, except LoRA matrices, causing a performance hit."
)
pass
if not (type(init_lora_weights) is bool or \
init_lora_weights == "gaussian" or init_lora_weights == "loftq"):
raise ValueError(
'Unsloth: `init_lora_weights` must be either [True, False, "gaussian", "loftq"].'
)
pass
if init_lora_weights == "loftq":
if not SUPPORTS_LOFTQ:
import peft
raise RuntimeError(
f"Unsloth: Your PEFT version of {peft.__version__} does not support LoftQ init.\n"\
"Please install PEFT 0.7.2 or higher.\n"\
"You can also install from source: `pip install git+https://github.com/huggingface/peft.git"
)
pass
if loftq_config == {}:
from peft import LoftQConfig
logger.warning_once(
"Unsloth: init_lora_weights = `loftq` is set, but `loftq_config` is None.\n"\
"We shall use `loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)`."
)
loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)
pass
if hasattr(model.config, "quantization_config"):
raise ValueError(
"Unsloth: You are using `loftq` init, yet `load_in_4bit = True` was set.\n"\
"Reload your model without any quantization by setting `load_in_4bit = False`."
)
pass
pass
assert(type(use_rslora) is bool)
if use_rslora:
if not SUPPORTS_RSLORA:
# We manually check for PEFT
import peft
raise RuntimeError(
f"Unsloth: Your PEFT version of {peft.__version__} does not support `use_rslora`.\n"\
"Please install PEFT 0.7.2 or higher.\n"\
"You can also install from source: `pip install git+https://github.com/huggingface/peft.git"
)
pass
pass
accepted_modules = frozenset(("q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",),)
model.config.update({"unsloth_version" : __version__})
if type(modules_to_save) is tuple:
modules_to_save = list(modules_to_save)
pass
train_lm_head = False
train_embed_tokens = False
final_modules = []
for module in target_modules:
if module == "lm_head":
# logger.warning_once(
# "Unsloth: `lm_head` should be placed in `modules_to_save` and not `target_modules`. "\
# "Luckily, we shall do it for you!"
# )
train_lm_head = True
if modules_to_save is None: modules_to_save = ["lm_head"]
else: modules_to_save.append("lm_head")
elif module == "embed_tokens":
# logger.warning_once(
# "Unsloth: `embed_tokens` should be placed in `modules_to_save` and not `target_modules`. "\
# "Luckily, we shall do it for you!"
# )
train_embed_tokens = True
if modules_to_save is None: modules_to_save = ["embed_tokens"]
else: modules_to_save.append("embed_tokens")
else:
try:
assert(module in accepted_modules)
final_modules.append(module)
except AssertionError as e:
final_modules.append(module)
print(
"Unsloth: You added custom modules, but Unsloth hasn't optimized for this.\n"\
"Beware - your finetuning might be noticeably slower!"
)
pass
pass
pass
# Check if we added new tokens!
if hasattr(model, "_need_to_train_embeddings"):
if not train_lm_head or not train_embed_tokens:
print(
"Unsloth: You added new tokens but did not specify if you wanted to "\
"train the lm_head and embed_tokens.\nWe must turn it on for you."
)
train_lm_head = True
train_embed_tokens = True
if modules_to_save is None: modules_to_save = ["embed_tokens"]
else: modules_to_save.append("embed_tokens")
if modules_to_save is None: modules_to_save = ["lm_head"]
else: modules_to_save.append("lm_head")
pass
pass
# Check for Llama-3
# if hasattr(model._saved_temp_tokenizer, "_using_llama3_template"):
# if not train_embed_tokens and not train_lm_head:
# raise RuntimeError("")
# First fix untrained tokens
# Wrong - can cause reserved tokens to pop out!!
# if train_embed_tokens or train_lm_head:
# fix_untrained_tokens(model, eps = 1e-16)
# pass
# Check modules_to_save
if modules_to_save is not None:
for module in modules_to_save:
if module == "lm_head":
train_lm_head = True
elif module == "embed_tokens":
train_embed_tokens = True
else:
raise TypeError(
f"Unsloth: Module = {module} is not allowed. Only 'lm_head' and 'embed_tokens' is allowed."
)
pass
pass
if isinstance(modules_to_save, (tuple, list)):
modules_to_save = list(set(modules_to_save))
pass
# Get LoRA
arguments = dict(
r = r,
lora_alpha = lora_alpha,
target_modules = final_modules,
lora_dropout = lora_dropout,
bias = bias,
task_type = TaskType.CAUSAL_LM,
layers_to_transform = layers_to_transform,
init_lora_weights = init_lora_weights,
loftq_config = loftq_config,
use_rslora = use_rslora,
modules_to_save = modules_to_save,
**kwargs,
)
if not SUPPORTS_LOFTQ: del arguments["loftq_config"]
if not SUPPORTS_RSLORA: del arguments["use_rslora"]
_saved_temp_tokenizer = model._saved_temp_tokenizer
lora_config = LoraConfig(**arguments)
# First offload lm_head and embed_tokens to disk
input_embeddings_device = model. get_input_embeddings().weight.device
output_embeddings_device = model.get_output_embeddings().weight.device
if use_gradient_checkpointing == "unsloth":
if train_embed_tokens:
print("Unsloth: Offloading input_embeddings to disk to save VRAM")
offload_input_embeddings(model, temporary_location)
pass
# Remove old items to save VRAM
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
pass
if train_lm_head:
print("Unsloth: Offloading output_embeddings to disk to save VRAM")
offload_output_embeddings(model, temporary_location)
pass
# Remove old items to save VRAM
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
pass
pass
model = _get_peft_model(model, lora_config)
model._saved_temp_tokenizer = _saved_temp_tokenizer
model = FastLlamaModel.patch_peft_model(model, use_gradient_checkpointing)
# Now patch lm_head and embed_tokens
if train_embed_tokens:
print("Unsloth: Training embed_tokens in mixed precision to save VRAM")
assert(hasattr(model.model.model.embed_tokens, "modules_to_save"))
dtype = model.model.model.embed_tokens.modules_to_save.default.weight.dtype
model.model.model.embed_tokens.modules_to_save.default\
.to(device = "cuda:0", dtype=(dtype if (dtype != torch.float16) else torch.float32), non_blocking = True)
model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True)
pass
if train_lm_head:
print("Unsloth: Training lm_head in mixed precision to save VRAM")
assert(hasattr(model.model.lm_head, "modules_to_save"))
dtype = model.model.lm_head.modules_to_save.default.weight.dtype
model.model.lm_head.modules_to_save.default\
.to(device = "cuda:0", dtype=(dtype if (dtype != torch.float16) else torch.float32), non_blocking = True)
model.model.lm_head.modules_to_save.default.requires_grad_(True)
pass
# Patch tokenizer to pad to the right
internal_model = model
while hasattr(internal_model, "model"):
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.padding_side = "right"
pass
internal_model = internal_model.model
pass
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.padding_side = "right"
pass
# Clear deleted GPU items
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
pass
return model
pass
@staticmethod
def patch_peft_model(
model,
use_gradient_checkpointing = True,
):
if not isinstance(model, PeftModelForCausalLM):
raise TypeError(
"Unsloth: Your model needs to call `.get_peft_model` first!"
)
pass
# Get activation function
model_type = model.config.model_type
if model_type == "llama": apply_lora_mlp = apply_lora_mlp_swiglu
elif model_type == "mistral": apply_lora_mlp = apply_lora_mlp_swiglu
elif model_type == "qwen2": apply_lora_mlp = apply_lora_mlp_swiglu
elif model_type == "gemma": apply_lora_mlp = apply_lora_mlp_geglu_approx
elif model_type == "gemma2": apply_lora_mlp = apply_lora_mlp_geglu_approx
elif model_type == "cohere": apply_lora_mlp = apply_lora_mlp_swiglu
else:
raise NotImplementedError(f"Unsloth: {model_type} is not yet implemented!")
pass
model = prepare_model_for_kbit_training(
model,
use_gradient_checkpointing = use_gradient_checkpointing,
use_reentrant = True,
)
# Fix up config for transformers uploading PEFT
for active_adapter in model.peft_config.keys():
# Not necessary since we requires transformers >= 4.37
if False:
name = model.peft_config[active_adapter].base_model_name_or_path
if name.startswith("unsloth/") and name.endswith("-bnb-4bit"):
name = name[:len(name) - len("-bnb-4bit")]
model.peft_config[active_adapter].base_model_name_or_path = name
pass
# Add revision to enable future fast inference paths
# [TODO] Bugs out!see https://github.com/unslothai/unsloth/issues/492
# model.peft_config[active_adapter].revision = f"unsloth"
pass
from transformers.trainer import Trainer
if Trainer._inner_training_loop.__name__ != "_fast_inner_training_loop":
raise RuntimeError(
'Unsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so '\
'enabling it will require much more work, so we have to prioritize. Please understand!\n'\
'We do have a separate beta version, which you can contact us about!\n'\
'Thank you for your understanding and we appreciate it immensely!'
)
pass
# Fix loftq issues
# loftq_config must not = None, but rather {}
all_configs = model.peft_config
for key, current_config in all_configs.items():
if hasattr(current_config, "loftq_config") and current_config.loftq_config is None:
new_args = current_config.__dict__
new_args["loftq_config"] = {}
current_config = current_config.__class__(**new_args)
all_configs[key] = current_config
pass
pass
# Do patching
n_mlp = 0
n_qkv = 0
n_o = 0
import types
active_adapter = model.active_adapters[0] if \
hasattr(model, "active_adapters") else model.active_adapter
# Get dropout and bias
lora_dropout = model.peft_config[active_adapter].lora_dropout
bias = model.peft_config[active_adapter].bias
# We also do not inplace edit QKV for Cohere!
from functools import partial
_apply_lora_mlp = \
partial(apply_lora_mlp, inplace = False) \
if model_type == "cohere" else \
apply_lora_mlp
pass
if lora_dropout == 0 and bias == "none":
for idx, layer in enumerate(model.model.model.layers):
# MLP patching
gate_proj = layer.mlp.gate_proj
up_proj = layer.mlp. up_proj
down_proj = layer.mlp.down_proj
if hasattr(gate_proj, "lora_A") and \
hasattr( up_proj, "lora_A") and \
hasattr(down_proj, "lora_A") and \
(getattr(gate_proj, "base_layer", gate_proj).bias is None) and \
(getattr( up_proj, "base_layer", up_proj).bias is None) and \
(getattr(down_proj, "base_layer", down_proj).bias is None) and \
(len(getattr(gate_proj, "lora_magnitude_vector", []) or []) == 0) and \
(len(getattr( up_proj, "lora_magnitude_vector", []) or []) == 0) and \
(len(getattr(down_proj, "lora_magnitude_vector", []) or []) == 0):
# https://stackoverflow.com/questions/50599045/python-replacing-a-function-within-a-class-of-a-module
layer.mlp.forward = types.MethodType(_apply_lora_mlp, layer.mlp)
n_mlp += 1
else:
logger.warning_once(
"Not an error, but Unsloth cannot patch MLP layers with our manual autograd engine since either LoRA adapters\n"\
"are not enabled or a bias term (like in Qwen) is used."
)
pass
# QKV attention patching
q_proj = layer.self_attn.q_proj
k_proj = layer.self_attn.k_proj
v_proj = layer.self_attn.v_proj
if hasattr(q_proj, "lora_A") and \
hasattr(k_proj, "lora_A") and \
hasattr(v_proj, "lora_A") and \
(getattr(q_proj, "base_layer", q_proj).bias is None) and \
(getattr(k_proj, "base_layer", k_proj).bias is None) and \
(getattr(v_proj, "base_layer", v_proj).bias is None) and \
(len(getattr(q_proj, "lora_magnitude_vector", []) or []) == 0) and \
(len(getattr(k_proj, "lora_magnitude_vector", []) or []) == 0) and \
(len(getattr(v_proj, "lora_magnitude_vector", []) or []) == 0):
layer.self_attn.apply_qkv = apply_lora_qkv
n_qkv += 1
else:
if model_type != "qwen2":
logger.warning_once(
"Not an error, but Unsloth cannot patch Attention layers with our manual autograd engine since either LoRA adapters\n"\
"are not enabled or a bias term (like in Qwen) is used."
)
pass
pass
# O attention patching
o_proj = layer.self_attn.o_proj
if hasattr(o_proj, "lora_A") and \
(getattr(o_proj, "base_layer", o_proj).bias is None) and \
(len(getattr(o_proj, "lora_magnitude_vector", []) or []) == 0):
layer.self_attn.apply_o = apply_lora_o
n_o += 1
else:
logger.warning_once(
"Not an error, but Unsloth cannot patch O projection layer with our manual autograd engine since either LoRA adapters\n"\
"are not enabled or a bias term (like in Qwen) is used."
)
pass
pass
pass
logger.warning_once(
f"Unsloth {__version__} patched {len(model.model.model.layers)} layers with "\
f"{n_qkv} QKV layers, {n_o} O layers and {n_mlp} MLP layers.",
)
patch_saving_functions(model)
# Patch cross entropy loss labels
# Fixes https://github.com/unslothai/unsloth/issues/10
max_seq_length = model.max_seq_length
extra_ignored_labels = torch.full((max_seq_length, 1), -100, device = "cuda:0")
model.model.extra_ignored_labels = extra_ignored_labels
internal_model = model
while hasattr(internal_model, "model"):
internal_model.max_seq_length = max_seq_length
internal_model = internal_model.model
pass
internal_model.max_seq_length = max_seq_length
# Patch tokenizer to pad to the right
internal_model = model
while hasattr(internal_model, "model"):
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.padding_side = "right"
pass
internal_model = internal_model.model
pass
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.padding_side = "right"
pass
# Clear deleted GPU items
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
pass
return model
pass
@staticmethod
def for_inference(model):
# if model.config.model_type == "qwen2":
# FastLlamaModel.for_training(model)
# return
# pass
internal_model = model
internal_model.gradient_checkpointing = False
internal_model.training = False
while hasattr(internal_model, "model"):
internal_model = internal_model.model
internal_model.gradient_checkpointing = False
internal_model.training = False
pass
if hasattr(internal_model, "training"):
internal_model.training = False
pass
# Also check if lm_head / embeddings are trained
internal_model = model
while not hasattr(internal_model, "lm_head"):
internal_model = internal_model.model
pass
lm_head = internal_model.lm_head.weight
device_type = lm_head.device.type
dtype = model.config.torch_dtype
if type(dtype) is str:
if dtype == "float16": dtype = torch.float16
elif dtype == "bfloat16": dtype = torch.bfloat16
pass
# Wrap model.generate
if model.generate.__name__ != "_fast_generate":
model._unwrapped_old_generate = model.generate
model.generate = _wrap_fast_inference(model.generate, device_type, dtype, model)
pass
# Patch tokenizer to pad to the left
internal_model = model
while hasattr(internal_model, "model"):
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.padding_side = "left"
pass
internal_model = internal_model.model
pass
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.padding_side = "left"
pass
# Also disable training for embeddings for NEFTune
if hasattr(model, "get_input_embeddings"):
embeddings = model.get_input_embeddings()
if hasattr(embeddings, "training"): embeddings.training = False
pass
if hasattr(model, "get_output_embeddings"):
embeddings = model.get_output_embeddings()
if hasattr(embeddings, "training"): embeddings.training = False
pass
return model
pass
@staticmethod
def for_training(model, use_gradient_checkpointing = True):
internal_model = model
internal_model.gradient_checkpointing = use_gradient_checkpointing
internal_model.training = True
# Delete all fast inference loras
for param in model.parameters():
if hasattr(param, "_fast_lora"):
del param._fast_lora
pass
while hasattr(internal_model, "model"):
internal_model = internal_model.model
internal_model.gradient_checkpointing = use_gradient_checkpointing
internal_model.training = True
pass
if hasattr(internal_model, "training"):
internal_model.training = True
pass
# Also revert model.generate
if hasattr(model, "_unwrapped_old_generate"):
model.generate = model._unwrapped_old_generate
del model._unwrapped_old_generate
pass
# Patch tokenizer to pad to the right
internal_model = model
while hasattr(internal_model, "model"):
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.padding_side = "right"
pass
internal_model = internal_model.model
pass
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.padding_side = "right"
pass
# Also re-enable training for embeddings for NEFTune
if hasattr(model, "get_input_embeddings"):
embeddings = model.get_input_embeddings()
if hasattr(embeddings, "training"): embeddings.training = True
pass
if hasattr(model, "get_output_embeddings"):
embeddings = model.get_output_embeddings()
if hasattr(embeddings, "training"): embeddings.training = True
pass
return model
pass
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._utils import is_bfloat16_supported, HAS_FLASH_ATTENTION, HAS_FLASH_ATTENTION_SOFTCAPPING
from .llama import FastLlamaModel, logger
from .mistral import FastMistralModel
from .qwen2 import FastQwen2Model
from .cohere import FastCohereModel
from transformers import AutoConfig
from transformers import __version__ as transformers_version
from peft import PeftConfig, PeftModel
from .mapper import INT_TO_FLOAT_MAPPER, FLOAT_TO_INT_MAPPER, MAP_TO_UNSLOTH_16bit
import os
try:
from huggingface_hub.utils import get_token
except:
# Old HF Hub versions <= 0.0.25
from huggingface_hub.utils._token import get_token
pass
from huggingface_hub import HfFileSystem
# https://github.com/huggingface/transformers/pull/26037 allows 4 bit loading!
from packaging.version import Version
transformers_version = Version(transformers_version)
SUPPORTS_FOURBIT = transformers_version >= Version("4.37")
SUPPORTS_GEMMA = transformers_version >= Version("4.38")
SUPPORTS_GEMMA2 = transformers_version >= Version("4.42")
SUPPORTS_LLAMA31 = transformers_version >= Version("4.43.2")
SUPPORTS_LLAMA32 = transformers_version > Version("4.45.0")
if SUPPORTS_GEMMA:
from .gemma import FastGemmaModel
if SUPPORTS_GEMMA2:
from .gemma2 import FastGemma2Model
pass
def __get_model_name(
model_name,
load_in_4bit = True,
INT_TO_FLOAT_MAPPER = None,
FLOAT_TO_INT_MAPPER = None,
MAP_TO_UNSLOTH_16bit = None,
):
model_name = str(model_name)
lower_model_name = model_name.lower()
if not SUPPORTS_FOURBIT and lower_model_name in INT_TO_FLOAT_MAPPER:
model_name = INT_TO_FLOAT_MAPPER[lower_model_name]
logger.warning_once(
f"Unsloth: Your transformers version of {transformers_version} does not support native "\
f"4bit loading.\nThe minimum required version is 4.37.\n"\
f'Try `pip install --upgrade "transformers>=4.37"`\n'\
f"to obtain the latest transformers build, then restart this session.\n"\
f"For now, we shall load `{model_name}` instead (still 4bit, just slower downloading)."
)
return model_name
elif not load_in_4bit and lower_model_name in INT_TO_FLOAT_MAPPER:
new_model_name = INT_TO_FLOAT_MAPPER[lower_model_name]
# logger.warning_once(
# f"Unsloth: You passed in `{model_name}` which is a 4bit model, yet you set\n"\
# f"`load_in_4bit = False`. We shall load `{new_model_name}` instead."
# )
return new_model_name
elif not load_in_4bit and lower_model_name in MAP_TO_UNSLOTH_16bit:
new_model_name = MAP_TO_UNSLOTH_16bit[lower_model_name]
return new_model_name
elif load_in_4bit and SUPPORTS_FOURBIT and lower_model_name in FLOAT_TO_INT_MAPPER:
new_model_name = FLOAT_TO_INT_MAPPER[lower_model_name]
# logger.warning_once(
# f"Unsloth: You passed in `{model_name}` and `load_in_4bit = True`.\n"\
# f"We shall load `{new_model_name}` for 4x faster loading."
# )
return new_model_name
pass
return None
pass
def _get_new_mapper():
try:
import requests
new_mapper = "https://raw.githubusercontent.com/unslothai/unsloth/main/unsloth/models/mapper.py"
with requests.get(new_mapper, timeout = 3) as new_mapper: new_mapper = new_mapper.text
new_mapper = new_mapper[new_mapper.find("__INT_TO_FLOAT_MAPPER"):]
new_mapper = new_mapper\
.replace("INT_TO_FLOAT_MAPPER", "NEW_INT_TO_FLOAT_MAPPER")\
.replace("FLOAT_TO_INT_MAPPER", "NEW_FLOAT_TO_INT_MAPPER")\
.replace("MAP_TO_UNSLOTH_16bit", "NEW_MAP_TO_UNSLOTH_16bit")
exec(new_mapper, globals())
return NEW_INT_TO_FLOAT_MAPPER, NEW_FLOAT_TO_INT_MAPPER, NEW_MAP_TO_UNSLOTH_16bit
except:
return {}, {}, {}
pass
pass
def get_model_name(model_name, load_in_4bit = True):
new_model_name = __get_model_name(
model_name = model_name,
load_in_4bit = load_in_4bit,
INT_TO_FLOAT_MAPPER = INT_TO_FLOAT_MAPPER,
FLOAT_TO_INT_MAPPER = FLOAT_TO_INT_MAPPER,
MAP_TO_UNSLOTH_16bit = MAP_TO_UNSLOTH_16bit,
)
if new_model_name is None and model_name.count("/") == 1 and model_name[0].isalnum():
# Try checking if a new Unsloth version allows it!
NEW_INT_TO_FLOAT_MAPPER, NEW_FLOAT_TO_INT_MAPPER, NEW_MAP_TO_UNSLOTH_16bit = _get_new_mapper()
upgraded_model_name = __get_model_name(
model_name = model_name,
load_in_4bit = load_in_4bit,
INT_TO_FLOAT_MAPPER = NEW_INT_TO_FLOAT_MAPPER,
FLOAT_TO_INT_MAPPER = NEW_FLOAT_TO_INT_MAPPER,
MAP_TO_UNSLOTH_16bit = NEW_MAP_TO_UNSLOTH_16bit,
)
if upgraded_model_name is not None:
raise NotImplementedError(
f"Unsloth: {model_name} is not supported in your current Unsloth version! Please update Unsloth via:\n\n"\
'pip uninstall unsloth -y\n'\
'pip install --upgrade --no-cache-dir "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"'
)
pass
pass
return new_model_name if new_model_name is not None else model_name
pass
class FastLanguageModel(FastLlamaModel):
@staticmethod
def from_pretrained(
model_name = "unsloth/llama-3-8b-bnb-4bit",
max_seq_length = None,
dtype = None,
load_in_4bit = True,
token = None,
device_map = "sequential",
rope_scaling = None,
fix_tokenizer = True,
trust_remote_code = False,
use_gradient_checkpointing = "unsloth",
resize_model_vocab = None,
revision = None,
*args, **kwargs,
):
if token is None: token = get_token()
old_model_name = model_name
model_name = get_model_name(model_name, load_in_4bit)
# First check if it's a normal model via AutoConfig
from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled
was_disabled = are_progress_bars_disabled()
disable_progress_bars()
autoconfig_error = None
peft_error = None
try:
model_config = AutoConfig.from_pretrained(
model_name,
token = token,
revision = revision,
trust_remote_code = trust_remote_code,
)
is_model = True
except Exception as error:
autoconfig_error = str(error)
is_model = False
try:
peft_config = PeftConfig.from_pretrained(
model_name,
token = token,
revision = revision,
trust_remote_code = trust_remote_code,
)
is_peft = True
except Exception as error:
peft_error = str(error)
is_peft = False
pass
# Both config.json and adapter_config.json should not exist!
# Old transformers versions check
both_exist = (is_model and is_peft) and not SUPPORTS_LLAMA32
# New transformers need to check manually.
if SUPPORTS_LLAMA32:
# Check if folder exists locally
if os.path.isdir(model_name):
exist_adapter_config = os.path.exists(os.path.join(model_name, "adapter_config.json"))
exist_config = os.path.exists(os.path.join(model_name, "config.json"))
both_exist = exist_adapter_config and exist_config
else:
files = HfFileSystem(token = token).glob(os.path.join(model_name, "*.json"))
files = (os.path.split(x)[-1] for x in files)
if sum(x == "adapter_config.json" or x == "config.json" for x in files) >= 2:
both_exist = True
pass
pass
pass
# Error out if both LoRA and normal model config exists.
if both_exist:
raise RuntimeError(
"Unsloth: Your repo has a LoRA adapter and a base model.\n"\
"You have 2 files `config.json` and `adapter_config.json`.\n"\
"We must only allow one config file.\n"\
"Please separate the LoRA and base models to 2 repos."
)
elif not is_model and not is_peft:
error = autoconfig_error or peft_error
# Old transformers version
if "rope_scaling" in error.lower() and not SUPPORTS_LLAMA31:
raise ImportError(
f"Unsloth: Your transformers version of {transformers_version} does not support new RoPE scaling methods.\n"\
f"This includes Llama 3.1. The minimum required version is 4.43.2\n"\
f'Try `pip install --upgrade "transformers>=4.43.2"`\n'\
f"to obtain the latest transformers build, then restart this session."\
)
raise RuntimeError(autoconfig_error or peft_error)
pass
# Get base model for PEFT:
if is_peft:
# Check base model again for PEFT
model_name = get_model_name(peft_config.base_model_name_or_path, load_in_4bit)
model_config = AutoConfig.from_pretrained(
model_name,
token = token,
revision = revision,
trust_remote_code = trust_remote_code,
)
pass
if not was_disabled: enable_progress_bars()
model_type = model_config.model_type
if model_type == "llama":
scaling_type = None
if getattr(model_config, "rope_scaling", None) is not None:
scaling_type1 = model_config.rope_scaling.get("type", None)
scaling_type2 = model_config.rope_scaling.get("rope_type", None)
scaling_type = scaling_type1 if scaling_type1 is not None else scaling_type2
pass
if scaling_type == "llama3" and not SUPPORTS_LLAMA31:
raise ImportError(
f"Unsloth: Your transformers version of {transformers_version} does not support Llama 3.1.\n"\
f"The minimum required version is 4.43.2\n"\
f'Try `pip install --upgrade "transformers>=4.43.2"`\n'\
f"to obtain the latest transformers build, then restart this session."\
)
dispatch_model = FastLlamaModel
elif model_type == "mistral": dispatch_model = FastMistralModel
elif model_type == "gemma":
if not SUPPORTS_GEMMA:
raise ImportError(
f"Unsloth: Your transformers version of {transformers_version} does not support Gemma.\n"\
f"The minimum required version is 4.38.\n"\
f'Try `pip install --upgrade "transformers>=4.38"`\n'\
f"to obtain the latest transformers build, then restart this session."\
)
dispatch_model = FastGemmaModel
elif model_type == "gemma2":
if not SUPPORTS_GEMMA2:
raise ImportError(
f"Unsloth: Your transformers version of {transformers_version} does not support Gemma2.\n"\
f"The minimum required version is 4.42.3.\n"\
f'Try `pip install --upgrade "transformers>=4.42.3"`\n'\
f"to obtain the latest transformers build, then restart this session."\
)
# Also check for softcapping support in flash-attn which is faster!
if is_bfloat16_supported() and not HAS_FLASH_ATTENTION:
print(
"Unsloth: If you want to finetune Gemma 2, install flash-attn to make it faster!\n"\
"To install flash-attn, do the below:\n"\
'\npip install --no-deps --upgrade "flash-attn>=2.6.3"'
)
elif HAS_FLASH_ATTENTION and not HAS_FLASH_ATTENTION_SOFTCAPPING:
print(
"Unsloth: If you want to finetune Gemma 2, upgrade flash-attn to version 2.6.3 or higher!\n"\
"Newer versions support faster and less memory usage kernels for Gemma 2's attention softcapping!\n"\
"To update flash-attn, do the below:\n"\
'\npip install --no-deps --upgrade "flash-attn>=2.6.3"'
)
dispatch_model = FastGemma2Model
elif model_type == "qwen2":
dispatch_model = FastQwen2Model
elif model_type == "cohere":
dispatch_model = FastCohereModel
else:
raise NotImplementedError(
f"Unsloth: {model_name} not supported yet!\n"\
"Make an issue to https://github.com/unslothai/unsloth!",
)
pass
# Check if this is local model since the tokenizer gets overwritten
if os.path.exists(os.path.join(old_model_name, "tokenizer_config.json")) and \
os.path.exists(os.path.join(old_model_name, "tokenizer.json")) and \
os.path.exists(os.path.join(old_model_name, "special_tokens_map.json")):
tokenizer_name = old_model_name
else:
tokenizer_name = None
pass
model, tokenizer = dispatch_model.from_pretrained(
model_name = model_name,
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
token = token,
device_map = device_map,
rope_scaling = rope_scaling,
fix_tokenizer = fix_tokenizer,
model_patcher = dispatch_model,
tokenizer_name = tokenizer_name,
trust_remote_code = trust_remote_code,
revision = revision if not is_peft else None,
*args, **kwargs,
)
if resize_model_vocab is not None:
model.resize_token_embeddings(resize_model_vocab)
pass
# In case the model supports tagging, add the unsloth tag.
if hasattr(model, "add_model_tags"):
model.add_model_tags(["unsloth",])
pass
if hasattr(tokenizer, "add_model_tags"):
tokenizer.add_model_tags(["unsloth",])
pass
if load_in_4bit:
# Fix up bitsandbytes config
quantization_config = \
{
# Sometimes torch_dtype is not a string!!
"bnb_4bit_compute_dtype" : model.config.to_dict()["torch_dtype"],
"bnb_4bit_quant_type" : "nf4",
"bnb_4bit_use_double_quant" : True,
"llm_int8_enable_fp32_cpu_offload" : False,
"llm_int8_has_fp16_weight" : False,
"llm_int8_skip_modules" : None,
"llm_int8_threshold" : 6.0,
"load_in_4bit" : True,
"load_in_8bit" : False,
"quant_method" : "bitsandbytes",
}
model.config.update({"quantization_config" : quantization_config})
pass
if is_peft:
# From https://github.com/huggingface/peft/issues/184
# Now add PEFT adapters
model.enable_input_require_grads()
model = PeftModel.from_pretrained(
model,
old_model_name,
token = token,
revision = revision,
is_trainable = True,
trust_remote_code = trust_remote_code,
)
# Patch it as well!
model = dispatch_model.patch_peft_model(model, use_gradient_checkpointing)
pass
return model, tokenizer
pass
pass
\ No newline at end of file
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = [
"INT_TO_FLOAT_MAPPER",
"FLOAT_TO_INT_MAPPER",
]
__INT_TO_FLOAT_MAPPER = \
{
"unsloth/mistral-7b-bnb-4bit" : (
"unsloth/mistral-7b",
"mistralai/Mistral-7B-v0.1",
),
"unsloth/llama-2-7b-bnb-4bit" : (
"unsloth/llama-2-7b",
"meta-llama/Llama-2-7b-hf",
),
"unsloth/llama-2-13b-bnb-4bit" : (
"unsloth/llama-2-13b",
"meta-llama/Llama-2-13b-hf",
),
"unsloth/codellama-34b-bnb-4bit" : (
"codellama/CodeLlama-34b-hf",
),
"unsloth/zephyr-sft-bnb-4bit" : (
"unsloth/zephyr-sft",
"HuggingFaceH4/mistral-7b-sft-beta",
),
"unsloth/tinyllama-bnb-4bit" : (
"unsloth/tinyllama",
"TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
),
"unsloth/tinyllama-chat-bnb-4bit" : (
"unsloth/tinyllama-chat",
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
),
"unsloth/mistral-7b-instruct-v0.1-bnb-4bit" : (
"unsloth/mistral-7b-instruct-v0.1",
"mistralai/Mistral-7B-Instruct-v0.1",
),
"unsloth/mistral-7b-instruct-v0.2-bnb-4bit" : (
"unsloth/mistral-7b-instruct-v0.2",
"mistralai/Mistral-7B-Instruct-v0.2",
),
"unsloth/llama-2-7b-chat-bnb-4bit" : (
"unsloth/llama-2-7b-chat",
"meta-llama/Llama-2-7b-chat-hf",
),
"unsloth/llama-2-7b-chat-bnb-4bit" : (
"unsloth/llama-2-7b-chat",
"meta-llama/Llama-2-7b-chat-hf",
),
"unsloth/codellama-7b-bnb-4bit" : (
"unsloth/codellama-7b",
"codellama/CodeLlama-7b-hf",
),
"unsloth/codellama-13b-bnb-4bit" : (
"codellama/CodeLlama-13b-hf",
),
"unsloth/yi-6b-bnb-4bit" : (
"unsloth/yi-6b",
"01-ai/Yi-6B",
),
"unsloth/solar-10.7b-bnb-4bit" : (
"upstage/SOLAR-10.7B-v1.0",
),
"unsloth/gemma-7b-bnb-4bit" : (
"unsloth/gemma-7b",
"google/gemma-7b",
),
"unsloth/gemma-2b-bnb-4bit" : (
"unsloth/gemma-2b",
"google/gemma-2b",
),
"unsloth/gemma-7b-it-bnb-4bit" : (
"unsloth/gemma-7b-it",
"google/gemma-7b-it",
),
"unsloth/gemma-2b-bnb-4bit" : (
"unsloth/gemma-2b-it",
"google/gemma-2b-it",
),
"unsloth/mistral-7b-v0.2-bnb-4bit" : (
"unsloth/mistral-7b-v0.2",
"alpindale/Mistral-7B-v0.2-hf",
),
"unsloth/gemma-1.1-2b-it-bnb-4bit" : (
"unsloth/gemma-1.1-2b-it",
"google/gemma-1.1-2b-it",
),
"unsloth/gemma-1.1-7b-it-bnb-4bit" : (
"unsloth/gemma-1.1-7b-it",
"google/gemma-1.1-7b-it",
),
"unsloth/Starling-LM-7B-beta-bnb-4bit" : (
"unsloth/Starling-LM-7B-beta",
"Nexusflow/Starling-LM-7B-beta",
),
"unsloth/Hermes-2-Pro-Mistral-7B-bnb-4bit" : (
"unsloth/Hermes-2-Pro-Mistral-7B",
"NousResearch/Hermes-2-Pro-Mistral-7B",
),
"unsloth/OpenHermes-2.5-Mistral-7B-bnb-4bit" : (
"unsloth/OpenHermes-2.5-Mistral-7B",
"teknium/OpenHermes-2.5-Mistral-7B",
),
"unsloth/codegemma-2b-bnb-4bit" : (
"unsloth/codegemma-2b",
"google/codegemma-2b",
),
"unsloth/codegemma-7b-bnb-4bit" : (
"unsloth/codegemma-7b",
"google/codegemma-7b",
),
"unsloth/codegemma-7b-it-bnb-4bit" : (
"unsloth/codegemma-7b-it",
"google/codegemma-7b-it",
),
"unsloth/llama-3-8b-bnb-4bit" : (
"unsloth/llama-3-8b",
"meta-llama/Meta-Llama-3-8B",
),
"unsloth/llama-3-8b-Instruct-bnb-4bit" : (
"unsloth/llama-3-8b-Instruct",
"meta-llama/Meta-Llama-3-8B-Instruct",
),
"unsloth/llama-3-70b-bnb-4bit" : (
"meta-llama/Meta-Llama-3-70B",
),
"unsloth/llama-3-70b-Instruct-bnb-4bit" : (
"meta-llama/Meta-Llama-3-70B-Instruct",
),
"unsloth/Phi-3-mini-4k-instruct-bnb-4bit" : (
"unsloth/Phi-3-mini-4k-instruct",
"microsoft/Phi-3-mini-4k-instruct",
),
"unsloth/mistral-7b-v0.3-bnb-4bit" : (
"unsloth/mistral-7b-v0.3",
"mistralai/Mistral-7B-v0.3",
),
"unsloth/mistral-7b-instruct-v0.3-bnb-4bit" : (
"unsloth/mistral-7b-instruct-v0.3",
"mistralai/Mistral-7B-Instruct-v0.3",
),
"unsloth/Phi-3-medium-4k-instruct-bnb-4bit" : (
"unsloth/Phi-3-medium-4k-instruct",
"microsoft/Phi-3-medium-4k-instruct",
),
"unsloth/Qwen2-0.5B-bnb-4bit" : (
"unsloth/Qwen2-0.5B",
"Qwen/Qwen2-0.5B",
),
"unsloth/Qwen2-0.5B-Instruct-bnb-4bit" : (
"unsloth/Qwen2-0.5B-Instruct",
"Qwen/Qwen2-0.5B-Instruct",
),
"unsloth/Qwen2-1.5B-bnb-4bit" : (
"unsloth/Qwen2-1.5B",
"Qwen/Qwen2-1.5B",
),
"unsloth/Qwen2-1.5B-Instruct-bnb-4bit" : (
"unsloth/Qwen2-1.5B-Instruct",
"Qwen/Qwen2-1.5B-Instruct",
),
"unsloth/Qwen2-7B-bnb-4bit" : (
"unsloth/Qwen2-7B",
"Qwen/Qwen2-7B",
),
"unsloth/Qwen2-7B-Instruct-bnb-4bit" : (
"unsloth/Qwen2-7B-Instruct",
"Qwen/Qwen2-7B-Instruct",
),
"unsloth/Qwen2-70B-bnb-4bit" : (
"Qwen/Qwen2-70B",
),
"unsloth/Qwen2-70B-Instruct-bnb-4bit" : (
"Qwen/Qwen2-70B-Instruct",
),
"mistralai/Codestral-22B-v0.1" : (
"mistral-community/Codestral-22B-v0.1",
),
"unsloth/gemma-2-9b-bnb-4bit" : (
"unsloth/gemma-2-9b",
"google/gemma-2-9b",
),
"unsloth/gemma-2-27b-bnb-4bit" : (
"unsloth/gemma-2-27b",
"google/gemma-2-27b",
),
"unsloth/gemma-2-9b-it-bnb-4bit" : (
"unsloth/gemma-2-9b-it",
"google/gemma-2-9b-it",
),
"unsloth/gemma-2-27b-it-bnb-4bit" : (
"unsloth/gemma-2-27b-it",
"google/gemma-2-27b-it",
),
"unsloth/Phi-3-mini-4k-instruct-v0-bnb-4bit" : ( # Old Phi pre July
"unsloth/Phi-3-mini-4k-instruct-v0",
),
"unsloth/Mistral-Nemo-Instruct-2407-bnb-4bit" : ( # New 12b Mistral models
"unsloth/Mistral-Nemo-Instruct-2407",
"mistralai/Mistral-Nemo-Instruct-2407",
),
"unsloth/Mistral-Nemo-Base-2407-bnb-4bit" : ( # New 12b Mistral models
"unsloth/Mistral-Nemo-Base-2407",
"mistralai/Mistral-Nemo-Base-2407",
),
"unsloth/Meta-Llama-3.1-8B-bnb-4bit" : (
"unsloth/Meta-Llama-3.1-8B",
"meta-llama/Meta-Llama-3.1-8B",
),
"unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit" : (
"unsloth/Meta-Llama-3.1-8B-Instruct",
"meta-llama/Meta-Llama-3.1-8B-Instruct",
),
"unsloth/Meta-Llama-3.1-70B-bnb-4bit" : (
"unsloth/Meta-Llama-3.1-70B",
"meta-llama/Meta-Llama-3.1-70B",
),
"unsloth/Meta-Llama-3.1-405B-bnb-4bit" : (
"meta-llama/Meta-Llama-3.1-405B",
),
"unsloth/Meta-Llama-3.1-405B-Instruct-bnb-4bit" : (
"meta-llama/Meta-Llama-3.1-405B-Instruct",
),
"unsloth/Meta-Llama-3.1-70B-Instruct-bnb-4bit" : (
"unsloth/Meta-Llama-3.1-70B-Instruct",
"meta-llama/Meta-Llama-3.1-70B-Instruct",
),
"unsloth/Mistral-Large-Instruct-2407-bnb-4bit" : (
"mistralai/Mistral-Large-Instruct-2407",
),
"unsloth/gemma-2-2b-bnb-4bit" : (
"unsloth/gemma-2-2b",
"google/gemma-2-2b",
),
"unsloth/gemma-2-2b-it-bnb-4bit" : (
"unsloth/gemma-2-2b-it",
"google/gemma-2-2b-it",
),
"unsloth/Phi-3.5-mini-instruct-bnb-4bit" : (
"unsloth/Phi-3.5-mini-instruct",
"microsoft/Phi-3.5-mini-instruct",
),
"unsloth/c4ai-command-r-08-2024-bnb-4bit" : (
"CohereForAI/c4ai-command-r-08-2024",
),
"unsloth/c4ai-command-r-plus-08-2024-bnb-4bit" : (
"CohereForAI/c4ai-command-r-plus-08-2024",
),
"unsloth/Llama-3.1-Storm-8B-bnb-4bit" : (
"unsloth/Llama-3.1-Storm-8B",
"akjindal53244/Llama-3.1-Storm-8B",
),
"unsloth/Hermes-3-Llama-3.1-8B-bnb-4bit" : (
"unsloth/Hermes-3-Llama-3.1-8B",
"NousResearch/Hermes-3-Llama-3.1-8B",
),
"unsloth/Hermes-3-Llama-3.1-70B-bnb-4bit" : (
"unsloth/Hermes-3-Llama-3.1-70B",
"NousResearch/Hermes-3-Llama-3.1-70B",
),
"unsloth/Hermes-3-Llama-3.1-405B-bnb-4bit" : (
"NousResearch/Hermes-3-Llama-3.1-405B",
),
"unsloth/SmolLM-135M-bnb-4bit" : (
"unsloth/SmolLM-135M",
"HuggingFaceTB/SmolLM-135M",
),
"unsloth/SmolLM-360M-bnb-4bit" : (
"unsloth/SmolLM-360M",
"HuggingFaceTB/SmolLM-360M",
),
"unsloth/SmolLM-1.7B-bnb-4bit" : (
"unsloth/SmolLM-1.7B",
"HuggingFaceTB/SmolLM-1.7B",
),
"unsloth/SmolLM-135M-Instruct-bnb-4bit" : (
"unsloth/SmolLM-135M-Instruct",
"HuggingFaceTB/SmolLM-135M-Instruct",
),
"unsloth/SmolLM-360M-Instruct-bnb-4bit" : (
"unsloth/SmolLM-360M-Instruct",
"HuggingFaceTB/SmolLM-360M-Instruct",
),
"unsloth/SmolLM-1.7B-Instruct-bnb-4bit" : (
"unsloth/SmolLM-1.7B-Instruct",
"HuggingFaceTB/SmolLM-1.7B-Instruct",
),
"unsloth/Mistral-Small-Instruct-2409-bnb-4bit" : (
"unsloth/Mistral-Small-Instruct-2409",
"mistralai/Mistral-Small-Instruct-2409",
),
"unsloth/Qwen2.5-0.5B-Instruct-bnb-4bit" : (
"unsloth/Qwen2.5-0.5B-Instruct",
"Qwen/Qwen2.5-0.5B-Instruct",
),
"unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit" : (
"unsloth/Qwen2.5-1.5B-Instruct",
"Qwen/Qwen2.5-1.5B-Instruct",
),
"unsloth/Qwen2.5-3B-Instruct-bnb-4bit" : (
"unsloth/Qwen2.5-3B-Instruct",
"Qwen/Qwen2.5-3B-Instruct",
),
"unsloth/Qwen2.5-7B-Instruct-bnb-4bit" : (
"unsloth/Qwen2.5-7B-Instruct",
"Qwen/Qwen2.5-7B-Instruct",
),
"unsloth/Qwen2.5-14B-Instruct-bnb-4bit" : (
"unsloth/Qwen2.5-14B-Instruct",
"Qwen/Qwen2.5-14B-Instruct",
),
"unsloth/Qwen2.5-32B-Instruct-bnb-4bit" : (
"unsloth/Qwen2.5-32B-Instruct",
"Qwen/Qwen2.5-32B-Instruct",
),
"unsloth/Qwen2.5-72B-Instruct-bnb-4bit" : (
"unsloth/Qwen2.5-72B-Instruct",
"Qwen/Qwen2.5-72B-Instruct",
),
"unsloth/Qwen2.5-0.5B-bnb-4bit" : (
"unsloth/Qwen2.5-0.5B",
"Qwen/Qwen2.5-0.5B",
),
"unsloth/Qwen2.5-1.5B-bnb-4bit" : (
"unsloth/Qwen2.5-1.5B",
"Qwen/Qwen2.5-1.5B",
),
"unsloth/Qwen2.5-3B-bnb-4bit" : (
"unsloth/Qwen2.5-3B",
"Qwen/Qwen2.5-3B",
),
"unsloth/Qwen2.5-7B-bnb-4bit" : (
"unsloth/Qwen2.5-7B",
"Qwen/Qwen2.5-7B",
),
"unsloth/Qwen2.5-14B-bnb-4bit" : (
"unsloth/Qwen2.5-14B",
"Qwen/Qwen2.5-14B",
),
"unsloth/Qwen2.5-32B-bnb-4bit" : (
"unsloth/Qwen2.5-32B",
"Qwen/Qwen2.5-32B",
),
"unsloth/Qwen2.5-72B-bnb-4bit" : (
"unsloth/Qwen2.5-72B",
"Qwen/Qwen2.5-72B",
),
"unsloth/Qwen2.5-Math-1.5B-bnb-4bit" : (
"unsloth/Qwen2.5-Math-1.5B",
"Qwen/Qwen2.5-Math-1.5B",
),
"unsloth/Qwen2.5-Math-7B-bnb-4bit" : (
"unsloth/Qwen2.5-Math-7B",
"Qwen/Qwen2.5-Math-7B",
),
"unsloth/Qwen2.5-Math-72B-bnb-4bit" : (
"unsloth/Qwen2.5-Math-72B",
"Qwen/Qwen2.5-Math-72B",
),
"unsloth/Qwen2.5-Math-1.5B-Instruct-bnb-4bit" : (
"unsloth/Qwen2.5-Math-1.5B-Instruct",
"Qwen/Qwen2.5-Math-1.5B-Instruct",
),
"unsloth/Qwen2.5-Math-7B-Instruct-bnb-4bit" : (
"unsloth/Qwen2.5-Math-7B-Instruct",
"Qwen/Qwen2.5-Math-7B-Instruct",
),
"unsloth/Qwen2.5-Math-72B-Instruct-bnb-4bit" : (
"unsloth/Qwen2.5-Math-72B-Instruct",
"Qwen/Qwen2.5-Math-72B-Instruct",
),
"unsloth/Qwen2.5-Coder-1.5B-bnb-4bit" : (
"unsloth/Qwen2.5-Coder-1.5B",
"Qwen/Qwen2.5-Coder-1.5B",
),
"unsloth/Qwen2.5-Coder-7B-bnb-4bit" : (
"unsloth/Qwen2.5-Coder-7B",
"Qwen/Qwen2.5-Coder-7B",
),
"unsloth/Qwen2.5-Coder-1.5B-Instruct-bnb-4bit" : (
"unsloth/Qwen2.5-Coder-Instruct-1.5B",
"Qwen/Qwen2.5-Coder-Instruct-1.5B",
),
"unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit" : (
"unsloth/Qwen2.5-Coder-7B-Instruct",
"Qwen/Qwen2.5-Coder-7B-Instruct",
),
"unsloth/Llama-3.2-1B-bnb-4bit" : (
"unsloth/Llama-3.2-1B",
"meta-llama/Llama-3.2-1B",
),
"unsloth/Llama-3.2-3B-bnb-4bit" : (
"unsloth/Llama-3.2-3B",
"meta-llama/Llama-3.2-3B",
),
"unsloth/Llama-3.2-1B-Instruct-bnb-4bit" : (
"unsloth/Llama-3.2-1B-Instruct",
"meta-llama/Llama-3.2-1B-Instruct",
),
"unsloth/Llama-3.2-3B-Instruct-bnb-4bit" : (
"unsloth/Llama-3.2-3B-Instruct",
"meta-llama/Llama-3.2-3B-Instruct",
),
"unsloth/Llama-3.1-Nemotron-70B-Instruct-bnb-4bit" : (
"unsloth/Llama-3.1-Nemotron-70B-Instruct",
"nvidia/Llama-3.1-Nemotron-70B-Instruct-HF",
),
}
INT_TO_FLOAT_MAPPER = {}
FLOAT_TO_INT_MAPPER = {}
MAP_TO_UNSLOTH_16bit = {}
for key, values in __INT_TO_FLOAT_MAPPER.items():
INT_TO_FLOAT_MAPPER[key] = values[0]
for value in values:
FLOAT_TO_INT_MAPPER[value] = key
pass
# Map to Unsloth version for 16bit versions
if len(values) == 2:
if values[0].startswith("unsloth"):
MAP_TO_UNSLOTH_16bit[values[1]] = values[0]
MAP_TO_UNSLOTH_16bit[values[1].lower()] = values[0]
pass
pass
# Get lowercased
lowered_key = key.lower()
INT_TO_FLOAT_MAPPER[lowered_key] = values[0].lower()
for value in values:
FLOAT_TO_INT_MAPPER[value.lower()] = lowered_key
pass
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .llama import *
import os
from ._utils import __version__
from .llama import (
LlamaRotaryEmbedding,
LlamaLinearScalingRotaryEmbedding,
)
from transformers.models.mistral.modeling_mistral import (
MistralAttention,
MistralDecoderLayer,
MistralModel,
MistralForCausalLM,
)
# For Pytorch 2.1.1
try:
from transformers.models.mistral.modeling_mistral import (
MistralSdpaAttention,
MistralFlashAttention2,
)
except:
MistralSdpaAttention = MistralAttention
MistralFlashAttention2 = MistralAttention
pass
def MistralAttention_fast_forward(
self,
hidden_states: torch.Tensor,
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
*args, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# Clear inference
if hasattr(self, "paged_attention"):
del self.paged_attention_K
del self.paged_attention_V
del self.paged_attention
del self.temp_QA
del self.temp_KV
del self.RH_Q
del self.attention
pass
bsz, q_len, _ = hidden_states.size()
n_heads = self.num_heads
n_groups = self.num_key_value_groups
n_kv_heads = self.num_key_value_heads
head_dim = self.head_dim
assert(n_kv_heads * n_groups == n_heads)
Q, K, V = self.apply_qkv(self, hidden_states)
Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
kv_seq_len = K.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
# Extend RoPE dynamically to fit in VRAM
self.rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len)
if position_ids is None:
cos = self.rotary_emb.cos_cached
sin = self.rotary_emb.sin_cached
Q, K = fast_rope_embedding(Q, K, cos, sin)
else:
cos, sin = self.rotary_emb(V, seq_len = kv_seq_len)
Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
pass
if past_key_value is not None:
K = torch.cat([past_key_value[0], K], dim = 2)
V = torch.cat([past_key_value[1], V], dim = 2)
pass
past_key_value = (K, V) if use_cache else None
# Attention module
if (not HAS_FLASH_ATTENTION and attention_mask is None):
# Xformers memory efficient attention
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
K_M = V_M = bsz * kv_seq_len
Q_M = bsz * q_len
has_swa = isinstance(causal_mask, xformers.attn_bias.BlockDiagonalCausalMask)
# Group query attention
K = K .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
V = V .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
if hidden_states.requires_grad:
K = K.reshape(bsz, kv_seq_len, n_heads, head_dim)
V = V.reshape(bsz, kv_seq_len, n_heads, head_dim)
if has_swa:
Q = Q.view(1, Q_M, n_heads, head_dim)
K = K.view(1, K_M, n_heads, head_dim)
V = V.view(1, V_M, n_heads, head_dim)
pass
else:
# Xformers does support the forward pass though
Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
if has_swa:
Q = Q.view(1, Q_M, n_kv_heads, n_groups, head_dim)
K = K.view(1, K_M, n_kv_heads, n_groups, head_dim)
V = V.view(1, V_M, n_kv_heads, n_groups, head_dim)
pass
pass
A = xformers_attention(Q, K, V, attn_bias = causal_mask)
A = A.view(bsz, q_len, n_heads, head_dim)
elif HAS_FLASH_ATTENTION and attention_mask is None:
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
sw = getattr(self.config, "sliding_window", None)
sw = kv_seq_len if (sw is None or sw == "null") else sw
window = (-1, -1) if (kv_seq_len <= sw) else (sw, sw)
A = flash_attn_func(Q, K, V, causal = True, window_size = window)
else:
# Grouped query attention
# if n_groups != 1:
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
K = K.reshape(bsz, n_heads, kv_seq_len, head_dim)
V = V.reshape(bsz, n_heads, kv_seq_len, head_dim)
# pass
# Must be contiguous or else results are False!
# https://github.com/pytorch/pytorch/issues/112577
Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous()
# Needs (batch_size, n_heads, seq_len, head_dim)
# is_casual and attention_mask must not be both set!
A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False)
# Go back to (batch_size, seq_len, n_heads, head_dim)
A = A.transpose(1, 2).contiguous()
pass
attn_output = A.reshape(bsz, q_len, n_heads*head_dim)
attn_output = self.apply_o(self, attn_output)
attn_weights = None
return attn_output, attn_weights, past_key_value
pass
def MistralForCausalLM_fast_forward(
self,
input_ids: torch.LongTensor = None,
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
num_logits_to_keep: Optional[int] = 0,
*args, **kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
if causal_mask is None and past_key_values is None:
bsz, q_len = input_ids.shape
sliding_window = getattr(self.config, "sliding_window", None)
if sliding_window is None or sliding_window == "null" or sliding_window <= 0:
causal_mask = xformers.attn_bias.LowerTriangularMask()
elif q_len <= sliding_window:
causal_mask = xformers.attn_bias.LowerTriangularMask()
else:
# Fix from https://github.com/Rypo
causal_mask = xformers.attn_bias.BlockDiagonalCausalMask\
.from_seqlens([q_len]*bsz)\
.make_local_attention(window_size = sliding_window)
pass
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
self.model._has_no_labels = labels is None
if past_key_values is not None:
outputs = LlamaModel_fast_forward_inference(
self,
input_ids,
past_key_values,
position_ids = position_ids,
attention_mask = attention_mask,
)
else:
outputs = self.model(
input_ids=input_ids,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pass
hidden_states = outputs[0]
bsz, q_len, hd = hidden_states.shape
lm_head = self.lm_head.weight
if bsz == 1 and q_len == 1:
logits = torch.mv(lm_head, hidden_states.ravel().to(lm_head.dtype))
logits = logits.unsqueeze(0).unsqueeze(0)
elif num_logits_to_keep != 0:
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :].to(lm_head.dtype))
else:
logits = self.lm_head(hidden_states.to(lm_head.dtype))
pass
logits = logits.to(self.config.torch_dtype)
loss = None
if labels is not None:
shift_logits = logits
if not hasattr(self, "extra_ignored_labels"):
# Fixes https://github.com/unslothai/unsloth/issues/10
self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = "cuda:0")
pass
shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
loss = fast_cross_entropy_loss(
logits = shift_logits,
labels = shift_labels,
n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None),
)
pass
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
pass
# Transformers had to update for Mistral Nemo 12b since Attention is (5120, 4096) now.
def patch_mistral_nemo_attention(function):
function = function.replace(
"(self.head_dim * self.num_heads) != self.hidden_size",
"False",
)
function = function.replace(
"self.head_dim = self.hidden_size // self.num_heads",
"self.head_dim = config.head_dim",
)
function = function.replace(
"self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)",
"self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)",
)
return function
pass
class FastMistralModel(FastLlamaModel):
@staticmethod
def pre_patch():
init_name, function = patch_linear_scaling(
model_name = "mistral",
rope_module = LlamaRotaryEmbedding,
scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
attention_module = MistralAttention,
)
# Just for Mistral Nemo models!
if function is not None:
function = patch_mistral_nemo_attention(function)
# if True:#init_name is not None:
exec(function, globals())
MistralAttention.__init__ = eval(init_name)
pass
MistralAttention .forward = MistralAttention_fast_forward
MistralSdpaAttention .forward = MistralAttention_fast_forward
MistralFlashAttention2.forward = MistralAttention_fast_forward
MistralDecoderLayer .forward = LlamaDecoderLayer_fast_forward
MistralModel .forward = LlamaModel_fast_forward
MistralForCausalLM .forward = MistralForCausalLM_fast_forward
PeftModelForCausalLM .forward = PeftModelForCausalLM_fast_forward
fix_prepare_inputs_for_generation(MistralForCausalLM)
# Solves https://github.com/unslothai/unsloth/issues/168
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
# Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
# https://github.com/huggingface/transformers/pull/27931
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
import transformers.models.mistral.modeling_mistral
transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding = LlamaRotaryEmbedding
return
pass
@staticmethod
def from_pretrained(
model_name = "unsloth/mistral-7b-bnb-4bit",
max_seq_length = None,
dtype = None,
load_in_4bit = True,
token = None,
device_map = "sequential",
rope_scaling = None, # Mistral does not support RoPE scaling
fix_tokenizer = True,
model_patcher = None,
tokenizer_name = None,
trust_remote_code = False,
**kwargs,
):
return FastLlamaModel.from_pretrained(
model_name = model_name,
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
token = token,
device_map = device_map,
rope_scaling = rope_scaling,
fix_tokenizer = fix_tokenizer,
model_patcher = FastMistralModel,
tokenizer_name = tokenizer_name,
trust_remote_code = trust_remote_code,
**kwargs,
)
pass
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .llama import *
from .llama import (
LlamaRotaryEmbedding,
LlamaLinearScalingRotaryEmbedding,
)
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2Attention,
Qwen2DecoderLayer,
Qwen2Model,
Qwen2ForCausalLM,
)
# For Pytorch 2.1.1
try:
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2SdpaAttention,
Qwen2FlashAttention2,
)
except:
Qwen2SdpaAttention = Qwen2Attention
Qwen2FlashAttention2 = Qwen2Attention
pass
class FastQwen2Model(FastLlamaModel):
@staticmethod
def pre_patch():
init_name, function = patch_linear_scaling(
model_name = "qwen2",
rope_module = LlamaRotaryEmbedding,
scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
attention_module = Qwen2Attention,
)
if init_name is not None:
exec(function, globals())
Qwen2Attention.__init__ = eval(init_name)
pass
Qwen2Attention .forward = LlamaAttention_fast_forward
Qwen2SdpaAttention .forward = LlamaAttention_fast_forward
Qwen2FlashAttention2.forward = LlamaAttention_fast_forward
Qwen2DecoderLayer .forward = LlamaDecoderLayer_fast_forward
Qwen2Model .forward = LlamaModel_fast_forward
Qwen2ForCausalLM .forward = CausalLM_fast_forward(LlamaModel_fast_forward_inference)
PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward
fix_prepare_inputs_for_generation(Qwen2ForCausalLM)
# Solves https://github.com/unslothai/unsloth/issues/168
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
# Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
# https://github.com/huggingface/transformers/pull/27931
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
import transformers.models.qwen2.modeling_qwen2
transformers.models.qwen2.modeling_qwen2.Qwen2RotaryEmbedding = LlamaRotaryEmbedding
return
pass
@staticmethod
def from_pretrained(
model_name = "Qwen/Qwen2-7B",
max_seq_length = 4096,
dtype = None,
load_in_4bit = True,
token = None,
device_map = "sequential",
rope_scaling = None, # Qwen2 does not support RoPE scaling
fix_tokenizer = True,
model_patcher = None,
tokenizer_name = None,
trust_remote_code = False,
**kwargs,
):
return FastLlamaModel.from_pretrained(
model_name = model_name,
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
token = token,
device_map = device_map,
rope_scaling = rope_scaling,
fix_tokenizer = fix_tokenizer,
model_patcher = FastQwen2Model,
tokenizer_name = tokenizer_name,
trust_remote_code = trust_remote_code,
**kwargs,
)
pass
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .llama import *
from ..kernels import patch_layernorm, unpatch_layernorm
from ..kernels import patch_rms_layernorm, unpatch_rms_layernorm
from ..kernels import patch_llama_for_causal_lm, unpatch_llama_for_causal_lm
from ._utils import patch_gradient_checkpointing
from transformers import AutoProcessor
try:
from transformers import MllamaForConditionalGeneration
except:
raise ImportError(
"Unsloth: Please update your transformers version to 4.46.0 for Llama 3.2 support!"
)
pass
class FastVisionModel:
def pre_patch(self):
patch_gradient_checkpointing()
patch_layernorm()
patch_rms_layernorm()
patch_llama_for_causal_lm()
pass
def post_unpatch(self):
unpatch_layernorm()
unpatch_rms_layernorm()
unpatch_llama_for_causal_lm()
pass
@staticmethod
def from_pretrained(
model_name = "llava-hf/llava-1.5-7b-hf",
max_seq_length = None,
dtype = None,
load_in_4bit = True,
token = None,
device_map = "sequential",
rope_scaling = None,
trust_remote_code = False,
**kwargs,
):
if trust_remote_code:
print(
"Unsloth: WARNING `trust_remote_code` is True.\n"\
"Are you certain you want to do remote code execution?"
)
pass
if token is None: token = get_token()
SUPPORTS_BFLOAT16 = is_bfloat16_supported()
gpu_stats = torch.cuda.get_device_properties(0)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
statistics = \
f"==((====))== Unsloth {__version__}: Fast {model_patcher.__name__[4:-5]} patching. Transformers = {transformers_version}.\n"\
f" \\\ /| GPU: {gpu_stats.name}. Max memory: {max_memory} GB. Platform = {platform_system}.\n"\
f"O^O/ \_/ \\ Pytorch: {torch.__version__}. CUDA = {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit = {torch.version.cuda}.\n"\
f"\ / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\
f' "-____-" Free Apache license: http://github.com/unslothai/unsloth'
print(statistics)
# Warn about fast transfers
old_hf_transfer = os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "0")
if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "0") == "1":
print("Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!")
pass
# Return old flag
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_hf_transfer
get_statistics() # For debugging - we use a download counter to see if environments are not breaking
if dtype is None:
dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16
elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16:
logger.warning_once("Device does not support bfloat16. Will change to float16.")
dtype = torch.float16
assert(dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32)
# We currently only support NVIDIA GPUs - AMD / Intel is a work in progress!
pre_check = check_nvidia()
bnb_config = None
if load_in_4bit:
bnb_config = BitsAndBytesConfig(
load_in_4bit = True,
bnb_4bit_use_double_quant = True,
bnb_4bit_quant_type = "nf4",
bnb_4bit_compute_dtype = dtype,
)
pass
# Cannot be None, since HF now checks for the config
if load_in_4bit: kwargs["quantization_config"] = bnb_config
self.pre_patch()
model = MllamaForConditionalGeneration.from_pretrained(
model_name,
device_map = device_map,
torch_dtype = dtype,
# quantization_config = bnb_config,
token = token,
max_position_embeddings = max_position_embeddings,
trust_remote_code = trust_remote_code,
attn_implementation = "sdpa",
**kwargs,
)
self.post_unpatch()
# Return old flag
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_hf_transfer
# We currently only support NVIDIA GPUs - AMD / Intel is a work in progress!
post_check = check_nvidia()
# Counteract saved tokenizers
tokenizer = AutoProcessor.from_pretrained(
model_name,
)
model = FastVisionModel.post_patch(model)
# Patch Trainer
from transformers.trainer import Trainer
try:
if Trainer._inner_training_loop.__name__ != "_fast_inner_training_loop":
inner_training_loop = inspect.getsource(Trainer._inner_training_loop)
Trainer._original_training_loop = inner_training_loop
else:
inner_training_loop = Trainer._original_training_loop
except:
raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!')
pass
if ((post_check - pre_check) >= 1).sum() > 1:
raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!')
import transformers.trainer
items_in_trainer = dir(transformers.trainer)
good_items = []
for item in items_in_trainer:
# TODO: Support Deepspeed
if item.startswith(("deepspeed", "xm", "met", "smp")): continue
if item in inner_training_loop: good_items.append(item)
pass
exec("from transformers.trainer import (" + ", ".join(x for x in good_items) + ")", globals())
start = re.search('logger\.info\([\"\'].+?Running training', inner_training_loop).span(0)[0]
end = inner_training_loop.find("\n\n", start)
original_debug = inner_training_loop[start:end]
spaces = re.search('\n([\s\t]{1,})', original_debug).group(0)[1:]
front_spaces = re.match('([\s\t]{1,})', inner_training_loop).group(0)
debug_info = """debug_info = \\
f"==((====))== Unsloth - 2x faster free finetuning | Num GPUs = {args.world_size}\\n"\\
f" \\\\\\ /| Num examples = {num_examples:,} | Num Epochs = {num_train_epochs:,}\\n"\\
f"O^O/ \\_/ \\ Batch size per device = {self._train_batch_size:,} | Gradient Accumulation steps = {args.gradient_accumulation_steps}\\n"\\
f"\\ / Total batch size = {total_train_batch_size:,} | Total steps = {max_steps:,}\\n"\\
f' "-____-" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}'
logger.warning(debug_info)
import subprocess, re, gc, numpy as np
a = np.array([0,])
try:
a = subprocess.check_output('nvidia-smi --query-gpu=memory.used --format=csv', shell = True)
a = re.findall(rb'([\\d]{1,})[\\s]{1,}M', a)
a = np.array([int(x.decode('utf-8'))/1024 for x in a])
except:
if not torch.cuda.is_available():
raise RuntimeError('Unsloth: We do not support AMD / Intel machines yet - it is a work in progress!')
if ((a - PRE_CHECK) >= 1).sum() > 1:
raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!')
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()"""
debug_info = debug_info.split('\n')
debug_info = "\n".join([debug_info[0]] + [spaces + x[8:] for x in debug_info[1:]])
inner_training_loop = inner_training_loop.replace(original_debug, debug_info)
debug_info = """n_total_devices = total_train_batch_size // \\
args.gradient_accumulation_steps // self._train_batch_size
if n_total_devices > 1:
logger.warning_once('Unsloth currently does not support multi GPU setups - but we are working on it!')
debug_info ="""
debug_info = debug_info.split('\n')
debug_info = "\n".join([debug_info[0]] + [spaces + x[8:] for x in debug_info[1:]])
inner_training_loop = inner_training_loop.replace("debug_info =", debug_info, 1)
front_spaces = re.match(r"[\t\s]{1,}", inner_training_loop).group(0)
inner_training_loop = re.sub(r"^" + front_spaces, "", inner_training_loop, flags = re.MULTILINE)
inner_training_loop = inner_training_loop.replace(
"train_dataloader = tpu_spmd_dataloader(train_dataloader)",
"raise RuntimeError('Unsloth: TPUs are not yet supported!')"
)
inner_training_loop = inner_training_loop.replace(
"self.accelerator.free_memory()",
"self.accelerator.free_memory()\n" + \
front_spaces + "if self.is_deepspeed_enabled:"\
"raise RuntimeError('Unsloth: Deepspeed is not yet supported!')\n", 1,
)
check_batches = """train_dataloader = self.get_train_dataloader()
ga = args.gradient_accumulation_steps
bsz = self._train_batch_size
total_batches = bsz * ga * args.world_size
n_total_devices = total_batches // ga // bsz
if n_total_devices > 1:
logger.warning_once('Unsloth currently does not support multi GPU setups - but we are working on it!')
divisor = n_total_devices / 1
bsz = self._train_batch_size = max(int(bsz / divisor), 1)
if total_batches // ga // bsz > 1:
divisor = n_total_devices / 1
ga = args.gradient_accumulation_steps = max(int(ga / divisor), 1)"""
check_batches = check_batches.split('\n')
check_batches = "\n".join([check_batches[0]] + [front_spaces + x[8:] for x in check_batches[1:]])
inner_training_loop = inner_training_loop.replace(
"train_dataloader = self.get_train_dataloader()",
check_batches, 1,
)
inner_training_loop = inner_training_loop.replace(
"_inner_training_loop",
"_fast_inner_training_loop", 1,
)
exec(inner_training_loop, globals())
Trainer._inner_training_loop = _fast_inner_training_loop
inner_training_loop = inner_training_loop.replace(
"is_torch_tpu_available()",
"False",
)
if "n_total_devices >" not in inner_training_loop:
raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!')
pass
inner_training_loop = inner_training_loop.replace(
"is_sagemaker_mp_enabled()",
"False",
)
exec(inner_training_loop, globals())
Trainer._inner_training_loop = _fast_inner_training_loop
# Save max_seq_length
model.max_seq_length = max_position_embeddings
internal_model = model
while hasattr(internal_model, "model"):
internal_model.max_seq_length = max_position_embeddings
internal_model = internal_model.model
pass
internal_model.max_seq_length = max_position_embeddings
# Fix up config for transformers uploading PEFT
# Not necessary anymore since we require transformers>=4.37!
if False:
name = model.config._name_or_path
if name.startswith("unsloth/") and name.endswith("-bnb-4bit"):
name = name[:len(name) - len("-bnb-4bit")]
model.config.update({"_name_or_path" : name})
pass
pass
# Log Unsloth version for future fastpaths for inference
model.config.update({"unsloth_version" : __version__})
# Add save modules
patch_saving_functions(model)
Trainer._inner_training_loop = _fast_inner_training_loop
# Also fix torch_dtype
internal_model = model
while hasattr(internal_model, "model"):
if hasattr(internal_model, "config"):
if internal_model.config.torch_dtype == "float32":
internal_model.config.torch_dtype = torch.float32
elif internal_model.config.torch_dtype == "bfloat16":
internal_model.config.torch_dtype = torch.bfloat16
elif internal_model.config.torch_dtype == "float16":
internal_model.config.torch_dtype = torch.float16
pass
pass
internal_model = internal_model.model
pass
if hasattr(internal_model, "config"):
if internal_model.config.torch_dtype == "float32":
internal_model.config.torch_dtype = torch.float32
elif internal_model.config.torch_dtype == "bfloat16":
internal_model.config.torch_dtype = torch.bfloat16
elif internal_model.config.torch_dtype == "float16":
internal_model.config.torch_dtype = torch.float16
pass
pass
return model, tokenizer
pass
@staticmethod
def post_patch(model):
# Patch model
layers = model.model.layers
lm_head = model.get_output_embeddings().weight
# Also patch all dtypes - BnB seems to not allocate the correct type?
# BnB default dtype seems to be float16!
correct_dtype = lm_head.weight.dtype
for name, module in model.named_modules():
if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)):
weight = module.weight
quant_state = weight.quant_state
if type(quant_state) is list:
# BnB seems to have float16 as default!
module.weight.quant_state[2] = correct_dtype # Cast to correct dtype
else:
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
quant_state.dtype = correct_dtype
pass
pass
pass
# Clear deleted GPU items
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
return model
pass
@staticmethod
def get_peft_model(
model,
r = 16,
target_modules = "all-linear",
lora_alpha = 16,
lora_dropout = 0,
bias = "none",
layers_to_transform = None,
layers_pattern = None,
use_gradient_checkpointing = True,
random_state = 3407,
max_seq_length = 2048, # not used anymore
use_rslora = False,
modules_to_save = None,
init_lora_weights = True,
loftq_config = {},
temporary_location = "_unsloth_temporary_saved_buffers",
**kwargs,
):
transformers_set_seed(random_state)
# Get LoRA
arguments = dict(
r = r,
lora_alpha = lora_alpha,
target_modules = target_modules,
lora_dropout = lora_dropout,
bias = bias,
layers_to_transform = layers_to_transform,
init_lora_weights = init_lora_weights,
# loftq_config = loftq_config,
# use_rslora = use_rslora,
modules_to_save = modules_to_save,
**kwargs,
)
lora_config = LoraConfig(**arguments)
model = _get_peft_model(model, lora_config)
model = FastVisionModel.patch_peft_model(model, use_gradient_checkpointing)
# Clear deleted GPU items
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
pass
return model
pass
@staticmethod
def patch_peft_model(
model,
use_gradient_checkpointing = True,
):
model = prepare_model_for_kbit_training(
model,
use_gradient_checkpointing = use_gradient_checkpointing,
use_reentrant = True,
)
# Fix up config for transformers uploading PEFT
for active_adapter in model.peft_config.keys():
# Not necessary since we requires transformers >= 4.37
if False:
name = model.peft_config[active_adapter].base_model_name_or_path
if name.startswith("unsloth/") and name.endswith("-bnb-4bit"):
name = name[:len(name) - len("-bnb-4bit")]
model.peft_config[active_adapter].base_model_name_or_path = name
pass
# Add revision to enable future fast inference paths
# [TODO] Bugs out!see https://github.com/unslothai/unsloth/issues/492
# model.peft_config[active_adapter].revision = f"unsloth"
pass
from transformers.trainer import Trainer
if Trainer._inner_training_loop.__name__ != "_fast_inner_training_loop":
raise RuntimeError(
'Unsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so '\
'enabling it will require much more work, so we have to prioritize. Please understand!\n'\
'We do have a separate beta version, which you can contact us about!\n'\
'Thank you for your understanding and we appreciate it immensely!'
)
pass
logger.warning_once(
f"Unsloth {__version__} patched {len(model.model.model.layers)} layers with "\
f"{n_qkv} QKV layers, {n_o} O layers and {n_mlp} MLP layers.",
)
patch_saving_functions(model)
# Patch cross entropy loss labels
# Fixes https://github.com/unslothai/unsloth/issues/10
max_seq_length = model.max_seq_length
extra_ignored_labels = torch.full((max_seq_length, 1), -100, device = "cuda:0")
model.model.extra_ignored_labels = extra_ignored_labels
internal_model = model
while hasattr(internal_model, "model"):
internal_model.max_seq_length = max_seq_length
internal_model = internal_model.model
pass
internal_model.max_seq_length = max_seq_length
# Patch tokenizer to pad to the right
internal_model = model
while hasattr(internal_model, "model"):
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.padding_side = "right"
pass
internal_model = internal_model.model
pass
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.padding_side = "right"
pass
# Clear deleted GPU items
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
pass
return model
pass
@staticmethod
def for_inference(model):
# if model.config.model_type == "qwen2":
# FastLlamaModel.for_training(model)
# return
# pass
internal_model = model
internal_model.gradient_checkpointing = False
internal_model.training = False
while hasattr(internal_model, "model"):
internal_model = internal_model.model
internal_model.gradient_checkpointing = False
internal_model.training = False
pass
if hasattr(internal_model, "training"):
internal_model.training = False
pass
# Also check if lm_head / embeddings are trained
internal_model = model
while not hasattr(internal_model, "lm_head"):
internal_model = internal_model.model
pass
lm_head = internal_model.lm_head.weight
device_type = lm_head.device.type
dtype = model.config.torch_dtype
if type(dtype) is str:
if dtype == "float16": dtype = torch.float16
elif dtype == "bfloat16": dtype = torch.bfloat16
pass
# Also disable training for embeddings for NEFTune
if hasattr(model, "get_input_embeddings"):
embeddings = model.get_input_embeddings()
if hasattr(embeddings, "training"): embeddings.training = False
pass
if hasattr(model, "get_output_embeddings"):
embeddings = model.get_output_embeddings()
if hasattr(embeddings, "training"): embeddings.training = False
pass
return model
pass
@staticmethod
def for_training(model, use_gradient_checkpointing = True):
internal_model = model
internal_model.gradient_checkpointing = use_gradient_checkpointing
internal_model.training = True
# Delete all fast inference loras
for param in model.parameters():
if hasattr(param, "_fast_lora"):
del param._fast_lora
pass
while hasattr(internal_model, "model"):
internal_model = internal_model.model
internal_model.gradient_checkpointing = use_gradient_checkpointing
internal_model.training = True
pass
if hasattr(internal_model, "training"):
internal_model.training = True
pass
# Also re-enable training for embeddings for NEFTune
if hasattr(model, "get_input_embeddings"):
embeddings = model.get_input_embeddings()
if hasattr(embeddings, "training"): embeddings.training = True
pass
if hasattr(model, "get_output_embeddings"):
embeddings = model.get_output_embeddings()
if hasattr(embeddings, "training"): embeddings.training = True
pass
return model
pass
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit
from peft.tuners.lora import Linear4bit as Peft_Linear4bit
from peft.tuners.lora import Linear as Peft_Linear
from typing import Optional, Callable, Union, List
import torch
import os
import shutil
import pickle
import gc
from transformers.models.llama.modeling_llama import logger
from .kernels import fast_dequantize, QUANT_STATE, get_lora_parameters_bias
import subprocess
import psutil
import re
from transformers.models.llama.modeling_llama import logger
from .tokenizer_utils import fix_sentencepiece_gguf
from huggingface_hub import HfApi
try:
from huggingface_hub.utils import get_token
except:
# Old HF Hub versions <= 0.0.25
from huggingface_hub.utils._token import get_token
pass
from pathlib import Path
__all__ = [
"print_quantization_methods",
"unsloth_save_model",
"save_to_gguf",
"patch_saving_functions",
"create_huggingface_repo",
]
# Check environments
keynames = "\n" + "\n".join(os.environ.keys())
IS_COLAB_ENVIRONMENT = "\nCOLAB_" in keynames
IS_KAGGLE_ENVIRONMENT = "\nKAGGLE_" in keynames
del keynames
# Weights
LLAMA_WEIGHTS = (
"self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj", "self_attn.o_proj",
"mlp.gate_proj", "mlp.up_proj", "mlp.down_proj",
)
LLAMA_LAYERNORMS = (
"input_layernorm", "post_attention_layernorm",
"pre_feedforward_layernorm", "post_feedforward_layernorm",
)
# https://github.com/ggerganov/llama.cpp/blob/master/examples/quantize/quantize.cpp#L19
# From https://mlabonne.github.io/blog/posts/Quantize_Llama_2_models_using_ggml.html
ALLOWED_QUANTS = \
{
"not_quantized" : "Recommended. Fast conversion. Slow inference, big files.",
"fast_quantized" : "Recommended. Fast conversion. OK inference, OK file size.",
"quantized" : "Recommended. Slow conversion. Fast inference, small files.",
"f32" : "Not recommended. Retains 100% accuracy, but super slow and memory hungry.",
"bf16" : "Bfloat16 - Fastest conversion + retains 100% accuracy. Slow and memory hungry.",
"f16" : "Float16 - Fastest conversion + retains 100% accuracy. Slow and memory hungry.",
"q8_0" : "Fast conversion. High resource use, but generally acceptable.",
"q4_k_m" : "Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K",
"q5_k_m" : "Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K",
"q2_k" : "Uses Q4_K for the attention.vw and feed_forward.w2 tensors, Q2_K for the other tensors.",
"q3_k_l" : "Uses Q5_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
"q3_k_m" : "Uses Q4_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
"q3_k_s" : "Uses Q3_K for all tensors",
"q4_0" : "Original quant method, 4-bit.",
"q4_1" : "Higher accuracy than q4_0 but not as high as q5_0. However has quicker inference than q5 models.",
"q4_k_s" : "Uses Q4_K for all tensors",
"q4_k" : "alias for q4_k_m",
"q5_k" : "alias for q5_k_m",
"q5_0" : "Higher accuracy, higher resource usage and slower inference.",
"q5_1" : "Even higher accuracy, resource usage and slower inference.",
"q5_k_s" : "Uses Q5_K for all tensors",
"q6_k" : "Uses Q8_K for all tensors",
# "iq2_xxs" : "2.06 bpw quantization", # Not supported sadly
# "iq2_xs" : "2.31 bpw quantization",
# "iq3_xxs" : "3.06 bpw quantization",
"q3_k_xs" : "3-bit extra small quantization",
}
def print_quantization_methods():
for key, value in ALLOWED_QUANTS.items():
print(f'"{key}" ==> {value}')
pass
pass
def check_if_sentencepiece_model(model, temporary_location = "_unsloth_sentencepiece_temp"):
if not hasattr(model, "_saved_temp_tokenizer"): return False
temp_tokenizer = model._saved_temp_tokenizer
sentencepiece_model = False
file_location = os.path.join(temporary_location, temp_tokenizer.name_or_path)
created_folder = False
if not os.path.exists(file_location):
created_folder = True
os.makedirs(file_location)
pass
temp_tokenizer.save_pretrained(file_location)
if os.path.isfile(f"{file_location}/tokenizer.model"):
sentencepiece_model = True
pass
if created_folder:
shutil.rmtree(file_location, ignore_errors = True)
return sentencepiece_model
pass
def _free_cached_model(model):
from huggingface_hub import scan_cache_dir
cached_repos = list(scan_cache_dir().repos)
# Go through every cached repo, and delete the one that matches the model we want to save.
# Can save 4GB of disk space - useful for Kaggle systems.
for cached_repo in cached_repos:
if cached_repo.repo_id == model.config._name_or_path:
remove_cache_commit = list(cached_repo.revisions)[0].commit_hash
delete_strategy = scan_cache_dir().delete_revisions(remove_cache_commit,)
logger.warning_once(
"Unsloth: Will remove a cached repo with size " + \
delete_strategy.expected_freed_size_str,
)
delete_strategy.execute()
pass
pass
pass
def _merge_lora(layer, name):
bias = getattr(layer, "bias", None)
if isinstance(layer, (Bnb_Linear4bit, Peft_Linear4bit, Peft_Linear)):
# Is LoRA so we need to merge!
W, quant_state, A, B, s, bias = get_lora_parameters_bias(layer)
if quant_state is not None:
dtype = quant_state.dtype if type(quant_state) is not list else quant_state[2]
W = fast_dequantize(W, quant_state)
else:
dtype = W.dtype
W = W.to(torch.float32).t()
# W = W.t()
if A is not None:
# sAB = (A.t().to(torch.float32) @ (s * B.t().to(torch.float32)))
# W += sAB
W.addmm_(A.t().to(torch.float32), B.t().to(torch.float32), alpha = s)
# W.addmm_(A.t().to(W.dtype), B.t().to(W.dtype), alpha = s)
# if not torch.isfinite(W).all():
maximum_element = torch.max(W.min().abs(), W.max())
if not torch.isfinite(maximum_element).item():
raise ValueError(f"Unsloth: Merge failed.\n{name} has some elements = infinity.")
pass
W = W.t().to(dtype)
else:
W = layer.weight
return W, bias
pass
def fast_save_pickle(shard, name):
# Use this if # CPUs is <= 2
print(f"Unsloth: Saving {name}...")
torch.save(
shard,
name,
# HIGHEST_PROTOCOL seems to not work with Pytorch!
# pickle_module = pickle,
# pickle_protocol = pickle.HIGHEST_PROTOCOL,
)
return
pass
@torch.inference_mode
def unsloth_save_model(
model,
tokenizer,
save_directory : Union[str, os.PathLike],
save_method : str = "lora", # ["lora", "merged_16bit", "merged_4bit"]
push_to_hub : bool = False,
token : Optional[Union[str, bool]] = None,
is_main_process : bool = True,
state_dict : Optional[dict] = None,
save_function : Callable = torch.save,
max_shard_size : Union[int, str] = "5GB",
safe_serialization : bool = True,
variant : Optional[str] = None,
save_peft_format : bool = True,
# Push to hub
use_temp_dir : Optional[bool] = None,
commit_message : Optional[str] = "Trained with Unsloth",
private : Optional[bool] = None,
create_pr : bool = False,
revision : str = None,
commit_description : str = "Upload model trained with Unsloth 2x faster",
tags : List[str] = None,
# Our functions
temporary_location : str = "_unsloth_temporary_saved_buffers",
maximum_memory_usage : float = 0.9,
):
if token is None: token = get_token()
if commit_message is None: commit_message = ""
if "Unsloth" not in commit_message:
commit_message += " (Trained with Unsloth)"
commit_message = commit_message.lstrip()
if commit_description is None:
commit_description = "Upload model trained with Unsloth 2x faster"
elif "Unsloth 2x faster" not in commit_description:
commit_description += " (Trained with Unsloth 2x faster)"
pass
if save_method == "merged_4bit":
raise RuntimeError(
"Unsloth: Merging into 4bit will cause your model to lose accuracy if you plan\n"\
"to merge to GGUF or others later on. I suggest you to do this as a final step\n"\
"if you're planning to do multiple saves.\n"\
"If you are certain, change `save_method` to `merged_4bit_forced`."
)
elif save_method == "merged_4bit_forced":
save_method = "merged_4bit"
pass
save_pretrained_settings = dict(locals())
for deletion in ("model", "tokenizer", "save_method", "temporary_location", "maximum_memory_usage"):
del save_pretrained_settings[deletion]
pass
# First check for a token!
if push_to_hub:
from huggingface_hub import whoami
try:
username = whoami(token = token)["name"]
except:
raise RuntimeError(
"Unsloth: Please supply a token!\n"\
"Go to https://huggingface.co/settings/tokens"
)
pass
pass
assert(maximum_memory_usage > 0 and maximum_memory_usage <= 0.95)
# Clean memory up first
for _ in range(3):
torch.cuda.empty_cache()
gc.collect()
pass
save_method = save_method.lower().replace(" ", "_")
if save_method != "lora" and save_method != "merged_16bit" and save_method != "merged_4bit":
raise RuntimeError(
"Unsloth: You must select one of 3 options when saving models:\n"\
'"lora" ==> This is the fastest and easiet. Just saves LoRA modules.\n'\
'"merged_16bit" ==> This merges LoRA weights and saves to float16. Needed for llama.cpp / GGUF.\n'\
'"merged_4bit" ==> This merges LoRA weights and saves to 4bit. Useful for DPO / inference.'
)
pass
if save_method == "merged_4bit":
print("Unsloth: Merging 4bit and LoRA weights to 4bit...")
print("This might take 5 minutes...")
# Counteract no LoRA adapters!
if hasattr(model, "merge_and_unload"):
model = model.merge_and_unload()
pass
print("Done.")
pass
if tags is not None:
assert(isinstance(tags, (list, tuple)))
tags = list(tags) + ["unsloth",]
else:
tags = ["unsloth",]
pass
save_pretrained_settings["tags"] = tags
if ((save_method == "lora") or (save_method == "merged_4bit")) and push_to_hub:
if token is None:
raise RuntimeError(
"Unsloth: Pushing to HF requires a token. Pass `token = 'hf_....'`\n"\
"Go to https://huggingface.co/settings/tokens."
)
pass
if save_method == "lora":
print("Unsloth: Saving LoRA adapters. Please wait...")
elif save_method == "merged_4bit":
print("Unsloth: Saving 4bit Bitsandbytes model. Please wait...")
pass
# Update model tag
_ = upload_to_huggingface(
model, save_directory, token,
"finetuned", "trl", file_location = None,
old_username = None, private = private,
)
getattr(model, "original_push_to_hub", tokenizer.push_to_hub)\
(
repo_id = save_directory,
use_temp_dir = use_temp_dir,
commit_message = commit_message,
private = private,
token = token,
max_shard_size = max_shard_size,
create_pr = create_pr,
safe_serialization = safe_serialization,
revision = revision,
commit_description = commit_description,
tags = tags,
)
if tokenizer is not None:
# Set padding side to left for inference
old_padding_side = tokenizer.padding_side
tokenizer.padding_side = "left"
getattr(tokenizer, "original_push_to_hub", tokenizer.push_to_hub)\
(
repo_id = save_directory,
use_temp_dir = use_temp_dir,
commit_message = commit_message,
private = private,
token = token,
max_shard_size = max_shard_size,
create_pr = create_pr,
safe_serialization = safe_serialization,
revision = revision,
commit_description = commit_description,
tags = tags,
)
# Revert back padding side
tokenizer.padding_side = old_padding_side
pass
if hasattr(model, "config"):
print(f"Saved {save_method} model to https://huggingface.co/" + save_directory)
pass
return save_directory, None
pass
# Tokenizer has different saving arguments
tokenizer_save_settings = \
{
"save_directory" : save_pretrained_settings["save_directory"],
"legacy_format" : None,
"filename_prefix" : None,
"push_to_hub" : save_pretrained_settings["push_to_hub"],
"private" : save_pretrained_settings["private"],
"token" : save_pretrained_settings["token"],
}
# Check if PEFT Model or not - if yes, 3 levels. If not 2 levels.
from peft import PeftModelForCausalLM
if isinstance(model, PeftModelForCausalLM):
internal_model = model.model
else:
internal_model = model
pass
# Cannot be converted properly!
if (save_method == "merged_4bit") or (save_method == "lora") or (
not hasattr(model, "model") or \
not hasattr(internal_model.model, "layers")
):
# Do general saving
# Edit save_pretrained_settings
# [TODO] _create_repo has errors due to **kwargs getting accepted
# commit_description does not seem to work?
what_to_delete = ("use_temp_dir", "commit_message", "create_pr", "revision", "commit_description", "tags",) \
if save_pretrained_settings["push_to_hub"] is False else \
("use_temp_dir", "create_pr", "revision", "tags", "commit_description",)
for deletion in what_to_delete:
del save_pretrained_settings[deletion]
pass
if hasattr(model, "add_model_tags"):
model.add_model_tags(["unsloth",])
# Update model tag
if push_to_hub:
_ = upload_to_huggingface(
model, save_pretrained_settings["save_directory"], token,
"finetuned", "trl", file_location = None,
old_username = None, private = private,
)
pass
if tokenizer is not None:
print("Unsloth: Saving tokenizer...", end = "")
# Set padding side to left for inference
old_padding_side = tokenizer.padding_side
tokenizer.padding_side = "left"
tokenizer.save_pretrained(**tokenizer_save_settings)
# Revert back padding side
tokenizer.padding_side = old_padding_side
print(" Done.")
else:
print()
print("Unsloth: Saving model...", end = "")
if save_method != "lora": print(" This might take 10 minutes for Llama-7b...", end = "")
# [TODO] Is this correct?
if save_method == "lora":
save_pretrained_settings["selected_adapters"] = None
pass
model.save_pretrained(**save_pretrained_settings)
if push_to_hub and hasattr(model, "config"):
print("Saved to https://huggingface.co/" + save_pretrained_settings["save_directory"])
pass
print(" Done.")
return save_directory, None
pass
# If push_to_hub, we must remove the .../ part of a repo
username = None
if push_to_hub and "/" in save_directory:
# +1 solves absolute path issues
username = save_directory[:save_directory.find("/")]
new_save_directory = save_directory[save_directory.find("/")+1:]
logger.warning_once(
f"Unsloth: You are pushing to hub, but you passed your HF username = {username}.\n"\
f"We shall truncate {save_directory} to {new_save_directory}"
)
save_pretrained_settings["save_directory"] = new_save_directory
tokenizer_save_settings ["save_directory"] = new_save_directory
save_directory = new_save_directory
pass
print("Unsloth: Merging 4bit and LoRA weights to 16bit...")
# Determine max RAM usage minus sharding
max_ram = psutil.virtual_memory().available
sharded_ram_usage = 5 * 1024 * 1024 * 1024
if type(max_shard_size) is str:
gb_found = re.match("([0-9]{1,})[\s]{0,}GB", max_shard_size, flags = re.IGNORECASE)
mb_found = re.match("([0-9]{1,})[\s]{0,}MB", max_shard_size, flags = re.IGNORECASE)
if gb_found: sharded_ram_usage = int(gb_found.group(1)) * 1024 * 1024 * 1024
elif mb_found: sharded_ram_usage = int(mb_found.group(1)) * 1024 * 1024
elif type(max_shard_size) is int:
sharded_ram_usage = sharded_ram_usage
pass
# Switch to our fast saving modules if it's a slow PC!
n_cpus = psutil.cpu_count(logical = False)
if n_cpus is None: n_cpus = psutil.cpu_count()
if n_cpus is None: n_cpus = 1
if safe_serialization is None:
safe_serialization = True
save_pretrained_settings["safe_serialization"] = safe_serialization
elif safe_serialization and (n_cpus <= 2):
logger.warning_once(
f"Unsloth: You have {n_cpus} CPUs. Using `safe_serialization` is 10x slower.\n"\
f"We shall switch to Pytorch saving, which will take 3 minutes and not 30 minutes.\n"\
f"To force `safe_serialization`, set it to `None` instead.",
)
safe_serialization = False
save_function = fast_save_pickle
save_pretrained_settings["safe_serialization"] = safe_serialization
save_pretrained_settings["save_function"] = save_function
pass
# Only safe_serialization uses more RAM
if safe_serialization:
max_ram -= sharded_ram_usage
else:
max_ram -= sharded_ram_usage*0.25 # Uses much less
pass
max_ram = int(max(0, max_ram) * maximum_memory_usage)
print(f"Unsloth: Will use up to "\
f"{round(max_ram/1024/1024/1024, 2)} out of "\
f"{round(psutil.virtual_memory().total/1024/1024/1024, 2)} RAM for saving.")
# Max directory for disk saving
if not os.path.exists(temporary_location):
os.makedirs(temporary_location)
pass
# Check if Kaggle or Colab, since only 20GB of Disk space allowed.
if IS_KAGGLE_ENVIRONMENT or IS_COLAB_ENVIRONMENT:
# We free up 4GB of space
logger.warning_once(
"Unsloth: Kaggle/Colab has limited disk space. We need to delete the downloaded\n"\
"model which will save 4-16GB of disk space, allowing you to save on Kaggle/Colab."
)
_free_cached_model(internal_model)
pass
# HF also uses a OrderedDict
from collections import OrderedDict
state_dict = OrderedDict()
torch_dtype = internal_model.config.torch_dtype
if type(torch_dtype) is str:
if torch_dtype == "float16": torch_dtype = torch.float16
elif torch_dtype == "bfloat16": torch_dtype = torch.bfloat16
pass
# Check modules to save float32 dtype
state_dict["model.embed_tokens.weight"] = internal_model.model.embed_tokens.weight.data.to(torch_dtype)
max_vram = int(torch.cuda.get_device_properties(0).total_memory * maximum_memory_usage)
from tqdm import tqdm as ProgressBar
for j, layer in enumerate(ProgressBar(internal_model.model.layers)):
for item in LLAMA_WEIGHTS:
proj = eval(f"layer.{item}")
name = f"model.layers.{j}.{item}.weight"
W, bias = _merge_lora(proj, name)
# Bias term
if bias is not None:
state_dict[f"model.layers.{j}.{item}.bias"] = bias
pass
if (torch.cuda.memory_allocated() + W.nbytes) < max_vram:
# Save to GPU memory
state_dict[name] = W
# [TODO] Saving to RAM seems to leak memory???
# elif (max_ram - W.nbytes) > 0:
# # Save to CPU memory
# logger.warning_once(f"We will save to RAM and not VRAM now.")
# state_dict[name] = W.to("cpu", non_blocking = True, copy = True)
# max_ram = max(max_ram - W.nbytes, 0)
else:
# Save to Disk
logger.warning_once("We will save to Disk and not RAM now.")
filename = os.path.join(temporary_location, f"{name}.pt")
torch.save(W, filename, pickle_module = pickle, pickle_protocol = pickle.HIGHEST_PROTOCOL,)
# weights_only = True weirdly fails?
state_dict[name] = torch.load(filename, map_location = "cpu", mmap = True, weights_only = False)
pass
for item in LLAMA_LAYERNORMS:
try:
# Skip for Gemma 2
state_dict[f"model.layers.{j}.{item}.weight"] = eval(f"layer.{item}.weight.data")
except:
continue
pass
pass
state_dict["model.norm.weight"] = internal_model.model.norm.weight.data
# Check for modules_to_save float32 dtype
# Check for tied weights
if internal_model.model.embed_tokens.weight.data_ptr() != internal_model.lm_head.weight.data_ptr():
state_dict["lm_head.weight"] = internal_model.lm_head.weight.data.to(torch_dtype)
pass
# All tensors MUST be type torch.Tensor and not torch.nn.parameter.Parameter
for key, value in state_dict.items():
if hasattr(value, "data"): state_dict[key] = value = value.data
if type(value) is not torch.Tensor:
logger.warning_once(f"Unsloth: {key} is not a Tensor but a {type(value)}.")
pass
pass
# Edit save_pretrained_settings
# [TODO] _create_repo has errors due to **kwargs getting accepted
save_pretrained_settings["state_dict"] = state_dict
# commit_description does not seem to work?
what_to_delete = ("use_temp_dir", "commit_message", "create_pr", "revision", "commit_description", "tags",) \
if not push_to_hub else \
("use_temp_dir", "create_pr", "revision", "tags", "commit_description",)
for deletion in what_to_delete:
del save_pretrained_settings[deletion]
pass
if hasattr(model, "add_model_tags"):
model.add_model_tags(["unsloth",])
# Update model tag
if push_to_hub:
_ = upload_to_huggingface(
model, save_pretrained_settings["save_directory"], token,
"finetuned", "trl", file_location = None,
old_username = username, private = private,
)
pass
# First check if we're pushing to an organization!
save_directory = save_pretrained_settings["save_directory"]
if save_pretrained_settings["push_to_hub"]:
new_save_directory, new_username = _determine_username(save_directory, username, token)
if token is not None:
from huggingface_hub import whoami
actual_username = whoami(token = token)["name"]
else:
actual_username = username
pass
# Check if pushing to an organization
if save_pretrained_settings["push_to_hub"] and (username != actual_username):
print(f"Unsloth: Saving to organization with address {new_save_directory}")
# We upload everything at the end!
tokenizer_save_settings["push_to_hub"] = False
tokenizer_save_settings["save_directory"] = new_save_directory
pass
# Save tokenizer
if tokenizer is not None:
print("Unsloth: Saving tokenizer...", end = "")
# Set padding side to left for inference
old_padding_side = tokenizer.padding_side
tokenizer.padding_side = "left"
tokenizer.save_pretrained(**tokenizer_save_settings)
# Revert back padding side
tokenizer.padding_side = old_padding_side
print(" Done.")
else:
print()
pass
print("Unsloth: Saving model... This might take 5 minutes for Llama-7b...")
# Since merged, edit quantization_config
old_config = model.config
new_config = model.config.to_dict()
if "quantization_config" in new_config:
del new_config["quantization_config"]
original_model = model
new_config = type(model.config).from_dict(new_config)
while hasattr(original_model, "model"):
original_model = original_model.model
original_model.config = new_config
model.config = new_config
# Save!
# [TODO] --> is this correct?
# save_pretrained_settings["selected_adapters"] = None
# Check if pushing to an organization
if save_pretrained_settings["push_to_hub"] and (username != actual_username):
print(f"Unsloth: Saving to organization with address {new_save_directory}")
# Pushing to organization!
# Sadly .save_pretrained doesn't work :(
# We first save it via .save_pretrained, then upload manually!
save_pretrained_settings["save_directory"] = new_save_directory
save_pretrained_settings["push_to_hub"] = False
internal_model.save_pretrained(**save_pretrained_settings)
# Now manually go through each file and upload them manually!
filenames = os.listdir(new_save_directory)
hf_api = HfApi(token = save_pretrained_settings["token"])
print("Unsloth: Uploading all files... Please wait...")
hf_api.upload_folder(
folder_path = new_save_directory,
path_in_repo = ".",
repo_id = new_save_directory,
repo_type = "model",
commit_message = "(Trained with Unsloth)",
ignore_patterns = "*.md",
)
else:
internal_model.save_pretrained(**save_pretrained_settings)
pass
# Revert config back
original_model = model
while hasattr(original_model, "model"):
original_model = original_model.model
original_model.config = old_config
model.config = old_config
print("Done.")
if push_to_hub and hasattr(model, "config"):
print(f"Saved merged model to https://huggingface.co/{username}/{save_directory.lstrip('/')}")
pass
save_pretrained_settings["state_dict"] = None
for j, (key, value) in enumerate(state_dict.items()):
state_dict[key] = None
if j % 10 == 0:
torch.cuda.empty_cache()
gc.collect()
pass
pass
state_dict = None
del state_dict
torch.cuda.empty_cache()
gc.collect()
# Remove temporary location
import shutil
shutil.rmtree(temporary_location, ignore_errors = True)
for _ in range(3):
torch.cuda.empty_cache()
gc.collect()
return save_directory, username
pass
def install_llama_cpp_clone_non_blocking():
full_command = ["git", "clone", "--recursive", "https://github.com/ggerganov/llama.cpp"]
run_installer = subprocess.Popen(full_command, stdout = subprocess.DEVNULL, stderr = subprocess.STDOUT)
return run_installer
pass
def install_llama_cpp_make_non_blocking():
# https://github.com/ggerganov/llama.cpp/issues/7062
# Weirdly GPU conversion for GGUF breaks??
# env = { **os.environ, "LLAMA_CUDA": "1", }
n_jobs = max(int(psutil.cpu_count()*1.5), 1)
# Force make clean
os.system("make clean -C llama.cpp")
full_command = ["make", "all", "-j"+str(n_jobs), "-C", "llama.cpp"]
# https://github.com/ggerganov/llama.cpp/issues/7062
# Weirdly GPU conversion for GGUF breaks??
# run_installer = subprocess.Popen(full_command, env = env, stdout = subprocess.DEVNULL, stderr = subprocess.STDOUT)
run_installer = subprocess.Popen(full_command, stdout = subprocess.DEVNULL, stderr = subprocess.STDOUT)
return run_installer
pass
def install_python_non_blocking(packages = []):
full_command = ["pip", "install"] + packages
run_installer = subprocess.Popen(full_command, stdout = subprocess.DEVNULL, stderr = subprocess.STDOUT)
return run_installer
pass
def install_llama_cpp_old(version = -10):
# Download the 10th latest release since the latest might be broken!
# FALLBACK mechanism
releases = subprocess.check_output(["git", "ls-remote", "--tags", "https://github.com/ggerganov/llama.cpp.git"])
releases = releases.decode("utf-8").replace("\t", " ").split("\n")
for i, x in enumerate(releases):
if "refs/tags/b" not in x: break
releases = releases[:i]
latest = releases[-1]
version = releases[version].split(" ")[0]
# Check if the llama.cpp exists
if os.path.exists("llama.cpp"):
print(
"**[WARNING]** You have a llama.cpp old directory which is broken.\n"\
"Unsloth will DELETE the broken directory and install a new one.\n"\
"Press CTRL + C / cancel this if this is wrong. We shall wait 10 seconds.\n"
)
import time
for i in range(10):
print(f"**[WARNING]** Deleting llama.cpp directory... {10-i} seconds left.")
time.sleep(1)
import shutil
shutil.rmtree("llama.cpp", ignore_errors = True)
pass
# Clone a specific commit
# Also don't use the GPU!
commands = [
"git clone --recursive https://github.com/ggerganov/llama.cpp",
f"cd llama.cpp && git reset --hard {version} && git clean -df",
"make clean -C llama.cpp",
f"make all -j{psutil.cpu_count()*2} -C llama.cpp",
]
for command in commands:
with subprocess.Popen(command, shell = True, stdout = subprocess.PIPE, stderr = subprocess.STDOUT, bufsize = 1) as sp:
for line in sp.stdout:
line = line.decode("utf-8", errors = "replace")
if "undefined reference" in line:
raise RuntimeError("Failed compiling llama.cpp. Please report this ASAP!")
print(line, flush = True, end = "")
pass
pass
# Check if successful
if not os.path.exists("llama.cpp/quantize") and not os.path.exists("llama.cpp/llama-quantize"):
raise RuntimeError(
"Unsloth: The file 'llama.cpp/llama-quantize' or `llama.cpp/quantize` does not exist.\n"\
"But we expect this file to exist! Maybe the llama.cpp developers changed the name?"
)
pass
pass
def install_llama_cpp_blocking(use_cuda = False):
# https://github.com/ggerganov/llama.cpp/issues/7062
# Weirdly GPU conversion for GGUF breaks??
# use_cuda = "LLAMA_CUDA=1" if use_cuda else ""
commands = [
"git clone --recursive https://github.com/ggerganov/llama.cpp",
"make clean -C llama.cpp",
# https://github.com/ggerganov/llama.cpp/issues/7062
# Weirdly GPU conversion for GGUF breaks??
# f"{use_cuda} make all -j{psutil.cpu_count()*2} -C llama.cpp",
f"make all -j{psutil.cpu_count()*2} -C llama.cpp",
"pip install gguf protobuf",
]
if os.path.exists("llama.cpp"): return
for command in commands:
with subprocess.Popen(command, shell = True, stdout = subprocess.PIPE, stderr = subprocess.STDOUT, bufsize = 1) as sp:
for line in sp.stdout:
line = line.decode("utf-8", errors = "replace")
if "undefined reference" in line:
raise RuntimeError("Failed compiling llama.cpp. Please report this ASAP!")
print(line, flush = True, end = "")
pass
pass
pass
def get_executable(executables):
# Get system locations (System Path).split(system separator)
system_directories = os.environ.get("PATH").split(os.pathsep)
for directory in system_directories:
for executable in executables:
path = os.path.join(directory, executable)
# Check if the executable exists and is executable
if os.path.exists(path) and os.access(path, os.X_OK): return path
pass
pass
return None
pass
def save_to_gguf(
model_type : str,
model_dtype : str,
is_sentencepiece : bool = False,
model_directory : str = "unsloth_finetuned_model",
quantization_method = "fast_quantized", # Can be a list of options! ["q4_k_m", "q8_0", "q5_k_m"]
first_conversion : str = None,
_run_installer = None, # Non blocking install of llama.cpp
):
# logger.warning(
# "NOTICE: llama.cpp GGUF conversion is currently unstable, since llama.cpp is\n"\
# "undergoing some major bug fixes as at 5th of May 2024. This is not an Unsloth issue.\n"\
# "Please be patient - GGUF saving should still work, but might not work as well."
# )
assert(model_dtype == "float16" or model_dtype == "bfloat16")
model_dtype = "f16" if model_dtype == "float16" else "bf16"
# Convert quantization_method to list
if isinstance(quantization_method, list): pass
elif isinstance(quantization_method, str): quantization_method = [ quantization_method, ]
elif isinstance(quantization_method, tuple): quantization_method = list(quantization_method)
else:
raise TypeError("Unsloth: quantization_method can only be a string or a list of strings")
pass
# Check if bfloat16 is supported
if model_dtype == "bf16" and not torch.cuda.is_bf16_supported():
logger.warning(
"Unsloth: Cannot convert to bf16 GGUF since your computer doesn't support it.\n"\
"We shall switch instead to f16."
)
model_dtype = "f16"
pass
# Check first_conversion as well
if first_conversion is None:
first_conversion = model_dtype
pass
# Check I quants
for quant_method in quantization_method:
if quant_method.startswith("iq2"):
raise RuntimeError("Unsloth: Currently iq2 type quantizations aren't supported yet - sorry!")
pass
# Careful convert.py is only for Llama / Mistral based archs
use_fast_convert = False
if not is_sentencepiece: use_fast_convert = False # Llama-3
elif model_type == "llama": use_fast_convert = True
elif model_type == "mistral": use_fast_convert = True
pass
logger.warning_once(f"Unsloth: Converting {model_type} model. Can use fast conversion = {use_fast_convert}.")
# Map quant methods
new_quantization_method = []
for quant_method in quantization_method:
if quant_method == "not_quantized": quant_method = model_dtype
elif quant_method == "fast_quantized": quant_method = "q8_0"
elif quant_method == "quantized": quant_method = "q4_k_m"
elif quant_method is None: quant_method = "q8_0"
# Check if wrong method
if quant_method not in ALLOWED_QUANTS.keys():
error = f"Unsloth: Quant method = [{quant_method}] not supported. Choose from below:\n"
for key, value in ALLOWED_QUANTS.items():
error += f"[{key}] => {value}\n"
raise RuntimeError(error)
pass
new_quantization_method.append(quant_method)
pass
quantization_method = new_quantization_method
print_info = \
f"==((====))== Unsloth: Conversion from QLoRA to GGUF information\n"\
f" \\\ /| [0] Installing llama.cpp will take 3 minutes.\n"\
f"O^O/ \_/ \\ [1] Converting HF to GGUF 16bits will take 3 minutes.\n"\
f"\ / [2] Converting GGUF 16bits to {quantization_method} will take 10 minutes each.\n"\
f' "-____-" In total, you will have to wait at least 16 minutes.\n'
print(print_info)
# Check first_conversion format
if first_conversion == "f16" : pass
elif first_conversion == "bf16" : pass
elif first_conversion == "f32" : pass
elif first_conversion == "q8_0" : pass
else:
raise RuntimeError(
f"Unsloth: `first_conversion` can only be one of ['f16', 'bf16', 'f32', 'q8_0'] and not `{first_conversion}`."
)
pass
# Determine whether the system already has llama.cpp installed and the scripts are executable
quantize_location = get_executable(["llama-quantize", "quantize"])
convert_location = get_executable(["convert-hf-to-gguf.py", "convert_hf_to_gguf.py"])
if quantize_location is not None and convert_location is not None:
print("Unsloth: llama.cpp found in the system. We shall skip installation.")
else:
print("Unsloth: [0] Installing llama.cpp. This will take 3 minutes...")
if _run_installer is not None:
error = _run_installer.wait()
else:
error = 0
install_llama_cpp_blocking()
pass
# Check if successful. If not install 10th latest release
# Careful llama.cpp/quantize changed to llama.cpp/llama-quantize
# and llama.cpp/main changed to llama.cpp/llama-cli
# See https://github.com/ggerganov/llama.cpp/pull/7809
quantize_location = None
if os.path.exists("llama.cpp/quantize"):
quantize_location = "llama.cpp/quantize"
elif os.path.exists("llama.cpp/llama-quantize"):
quantize_location = "llama.cpp/llama-quantize"
else:
raise RuntimeError(
"Unsloth: The file 'llama.cpp/llama-quantize' or 'llama.cpp/quantize' does not exist.\n"\
"But we expect this file to exist! Maybe the llama.cpp developers changed the name?"
)
pass
# See https://github.com/unslothai/unsloth/pull/730
# Filenames changed again!
convert_location = None
if os.path.exists("llama.cpp/convert-hf-to-gguf.py"):
convert_location = "llama.cpp/convert-hf-to-gguf.py"
elif os.path.exists("llama.cpp/convert_hf_to_gguf.py"):
convert_location = "llama.cpp/convert_hf_to_gguf.py"
else:
raise RuntimeError(
"Unsloth: The file 'llama.cpp/convert-hf-to-gguf.py' or 'llama.cpp/convert_hf_to_gguf.py' does not exist.\n"\
"But we expect this file to exist! Maybe the llama.cpp developers changed the name?"
)
pass
if error != 0 or quantize_location is None or convert_location is None:
print(f"Unsloth: llama.cpp error code = {error}.")
install_llama_cpp_old(-10)
pass
pass
# Determine maximum first_conversion state
if first_conversion == "f32" : strength = 3
elif first_conversion == "f16" : strength = 2
elif first_conversion == "bf16" : strength = 1
elif first_conversion == "q8_0" : strength = 0
for quant_method in quantization_method:
if quant_method == "f32": strength = max(strength, 3)
elif quant_method == "f16": strength = max(strength, 2)
elif quant_method == "bf16": strength = max(strength, 1)
elif quant_method == "q8_0": strength = max(strength, 0)
else:
# Quantized models must have f16 as the default argument
if first_conversion == "f32" : pass
elif first_conversion == "f16" : pass
elif first_conversion == "bf16" : pass
elif first_conversion == "q8_0":
logger.warning_once(
"Unsloth: Using q8_0 for the `first_conversion` will lose a bit of accuracy, "\
"but saves disk space!"
)
# first_conversion = "f16"
pass
pass
pass
# If only q8_0:
if len(quantization_method) == 1 and quantization_method[0] == "q8_0":
strength = 0
pass
if strength >= 3: first_conversion = "f32"
elif strength >= 2: first_conversion = "f16"
elif strength >= 1: first_conversion = "bf16"
else: first_conversion = "q8_0"
# Non llama/mistral needs can only use f32 or f16
if not use_fast_convert and \
(first_conversion != "f16" or first_conversion != "bf16" or first_conversion != "f32"):
pass
# Latest llama.cpp works for all models for q8_0!
# logger.warning_once("Unsloth: We must use f16 for non Llama and Mistral models.")
# first_conversion = "f16"
pass
# Check if bfloat16 is supported
if first_conversion == "bf16" and not torch.cuda.is_bf16_supported():
logger.warning(
"Unsloth: Cannot convert to bf16 GGUF since your computer doesn't support it.\n"\
"We shall switch instead to f16."
)
first_conversion = "f16"
pass
n_cpus = psutil.cpu_count()
if n_cpus is None: n_cpus = 1
n_cpus *= 2
# Concurrency from https://rentry.org/llama-cpp-conversions#merging-loras-into-a-model
final_location = str((Path(model_directory) / f"unsloth.{first_conversion.upper()}.gguf").absolute())
print(f"Unsloth: [1] Converting model at {model_directory} into {first_conversion} GGUF format.\n"\
f"The output location will be {final_location}\n"\
"This will take 3 minutes...")
# We first check if tokenizer.model exists in the model_directory
if os.path.exists(f"{model_directory}/tokenizer.model"):
vocab_type = "spm,hfft,bpe"
# Fix Sentencepiece model as well!
fix_sentencepiece_gguf(model_directory)
else:
vocab_type = "bpe"
pass
# convert.py is deprecated!
use_fast_convert = False
if use_fast_convert:
command = f"python llama.cpp/convert.py {model_directory} "\
f"--outfile {final_location} --vocab-type {vocab_type} "\
f"--outtype {first_conversion} --concurrency {n_cpus} --pad-vocab"
else:
command = f"python {convert_location} {model_directory} "\
f"--outfile {final_location} "\
f"--outtype {first_conversion}"
pass
with subprocess.Popen(command, shell = True, stdout = subprocess.PIPE, stderr = subprocess.STDOUT, bufsize = 1) as sp:
for line in sp.stdout:
line = line.decode("utf-8", errors = "replace")
if "undefined reference" in line:
raise RuntimeError("Failed compiling llama.cpp. Please report this ASAP!")
print(line, flush = True, end = "")
if sp.returncode is not None and sp.returncode != 0:
raise subprocess.CalledProcessError(sp.returncode, sp.args)
pass
# Check if quantization succeeded!
if not os.path.isfile(final_location):
if IS_KAGGLE_ENVIRONMENT:
raise RuntimeError(
f"Unsloth: Quantization failed for {final_location}\n"\
"You are in a Kaggle environment, which might be the reason this is failing.\n"\
"Kaggle only provides 20GB of disk space. Merging to 16bit for 7b models use 16GB of space.\n"\
"This means using `model.{save_pretrained/push_to_hub}_merged` works, but\n"\
"`model.{save_pretrained/push_to_hub}_gguf will use too much disk space.\n"\
"I suggest you to save the 16bit model first, then use manual llama.cpp conversion."
)
else:
raise RuntimeError(
f"Unsloth: Quantization failed for {final_location}\n"\
"You might have to compile llama.cpp yourself, then run this again.\n"\
"You do not need to close this Python program. Run the following commands in a new terminal:\n"\
"You must run this in the same folder as you're saving your model.\n"\
"git clone --recursive https://github.com/ggerganov/llama.cpp\n"\
"cd llama.cpp && make clean && make all -j\n"\
"Once that's done, redo the quantization."
)
pass
pass
print(f"Unsloth: Conversion completed! Output location: {final_location}")
full_precision_location = final_location
all_saved_locations = [full_precision_location,]
# Convert each type!
for quant_method in quantization_method:
if quant_method != first_conversion:
print(f"Unsloth: [2] Converting GGUF 16bit into {quant_method}. This will take 20 minutes...")
final_location = str((Path(model_directory) / f"unsloth.{quant_method.upper()}.gguf").absolute())
command = f"./{quantize_location} {full_precision_location} "\
f"{final_location} {quant_method} {n_cpus}"
# quantize uses stderr
with subprocess.Popen(command, shell = True, stdout = subprocess.PIPE, stderr = subprocess.STDOUT, bufsize = 1) as sp:
for line in sp.stdout:
line = line.decode("utf-8", errors = "replace")
if "undefined reference" in line:
raise RuntimeError("Failed compiling llama.cpp. Please report this ASAP!")
print(line, flush = True, end = "")
if sp.returncode is not None and sp.returncode != 0:
raise subprocess.CalledProcessError(sp.returncode, sp.args)
pass
# Check if quantization succeeded!
if not os.path.isfile(final_location):
if IS_KAGGLE_ENVIRONMENT:
raise RuntimeError(
f"Unsloth: Quantization failed for {final_location}\n"\
"You are in a Kaggle environment, which might be the reason this is failing.\n"\
"Kaggle only provides 20GB of disk space. Merging to 16bit for 7b models use 16GB of space.\n"\
"This means using `model.{save_pretrained/push_to_hub}_merged` works, but\n"\
"`model.{save_pretrained/push_to_hub}_gguf will use too much disk space.\n"\
"I suggest you to save the 16bit model first, then use manual llama.cpp conversion."
)
else:
raise RuntimeError(
"Unsloth: Quantization failed! You might have to compile llama.cpp yourself, then run this again.\n"\
"You do not need to close this Python program. Run the following commands in a new terminal:\n"\
"You must run this in the same folder as you're saving your model.\n"\
"git clone --recursive https://github.com/ggerganov/llama.cpp\n"\
"cd llama.cpp && make clean && make all -j\n"\
"Once that's done, redo the quantization."
)
pass
pass
print(f"Unsloth: Conversion completed! Output location: {final_location}")
all_saved_locations.append(final_location)
pass
pass
# Finally check if first_conversion (f16, bf16 etc) was in the list of actual quant methods
full_precision_seen = first_conversion in frozenset(quantization_method)
return all_saved_locations, full_precision_seen
pass
def unsloth_save_pretrained_merged(
self,
save_directory : Union[str, os.PathLike],
tokenizer = None,
save_method : str = "merged_16bit", # ["lora", "merged_16bit", "merged_4bit"]
push_to_hub : bool = False,
token : Optional[Union[str, bool]] = None,
is_main_process : bool = True,
state_dict : Optional[dict] = None,
save_function : Callable = torch.save,
max_shard_size : Union[int, str] = "5GB",
safe_serialization : bool = True,
variant : Optional[str] = None,
save_peft_format : bool = True,
tags : List[str] = None,
temporary_location : str = "_unsloth_temporary_saved_buffers",
maximum_memory_usage : float = 0.75,
):
"""
Same as .save_pretrained(...) except 4bit weights are auto
converted to float16 with as few overhead as possible.
Choose for `save_method` to be either:
1. `16bit`: Merge LoRA into float16 weights. Useful for GGUF / llama.cpp.
2. `4bit`: Merge LoRA into int4 weights. Useful for DPO / HF inference.
3. `lora`: Save LoRA adapters with no merging. Useful for HF inference.
"""
if tokenizer is None:
logger.warning_once(
"Unsloth: You're not saving a tokenizer as well?\n"\
"You can do it separately via `tokenizer.save_pretrained(...)`"
)
pass
arguments = dict(locals())
arguments["model"] = self
del arguments["self"]
unsloth_save_model(**arguments)
for _ in range(3):
gc.collect()
pass
def unsloth_push_to_hub_merged(
self,
repo_id : str,
tokenizer = None,
save_method : str = "merged_16bit", # ["lora", "merged_16bit", "merged_4bit"]
use_temp_dir : Optional[bool] = None,
commit_message : Optional[str] = "Trained with Unsloth",
private : Optional[bool] = None,
token : Union[bool, str, None] = None,
max_shard_size : Union[int, str, None] = "5GB",
create_pr : bool = False,
safe_serialization : bool = True,
revision : str = None,
commit_description : str = "Upload model trained with Unsloth 2x faster",
tags : Optional[List[str]] = None,
temporary_location : str = "_unsloth_temporary_saved_buffers",
maximum_memory_usage : float = 0.75,
):
"""
Same as .push_to_hub(...) except 4bit weights are auto
converted to float16 with as few overhead as possible.
Choose for `save_method` to be either:
1. `16bit`: Merge LoRA into float16 weights. Useful for GGUF / llama.cpp.
2. `4bit`: Merge LoRA into int4 weights. Useful for DPO / HF inference.
3. `lora`: Save LoRA adapters with no merging. Useful for HF inference.
"""
if tokenizer is None:
logger.warning_once(
"Unsloth: You're not saving a tokenizer as well?\n"\
"You can do it separately via `tokenizer.push_to_hub(...)`"
)
pass
arguments = dict(locals())
arguments["model"] = self
arguments["save_directory"] = repo_id
arguments["push_to_hub"] = True
del arguments["self"]
del arguments["repo_id"]
unsloth_save_model(**arguments)
for _ in range(3):
gc.collect()
pass
MODEL_CARD = \
"""---
base_model: {base_model}
tags:
- text-generation-inference
- transformers
- unsloth
- {model_type}
- {extra}
license: apache-2.0
language:
- en
---
# Uploaded {method} model
- **Developed by:** {username}
- **License:** apache-2.0
- **Finetuned from model :** {base_model}
This {model_type} model was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth) and Huggingface's TRL library.
[<img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20made%20with%20love.png" width="200"/>](https://github.com/unslothai/unsloth)
"""
def _determine_username(save_directory, old_username, token):
username = ""
save_directory = save_directory.lstrip("./")
if "/" not in save_directory:
from huggingface_hub import whoami
try:
username = whoami(token = token)["name"]
if type(old_username) is str and username != old_username:
username = old_username
pass
save_directory = f"{username}/{save_directory}"
except:
raise RuntimeError(f"Unsloth: {save_directory} is not a Huggingface directory.")
else:
username = save_directory.split("/")[0]
pass
return save_directory, username
pass
def create_huggingface_repo(
model,
save_directory,
token = None,
private = False,
):
if token is None :
token = get_token()
pass
save_directory, username = _determine_username(save_directory, "", token)
from huggingface_hub import create_repo
try:
create_repo(
repo_id = save_directory,
token = token,
repo_type = "model",
exist_ok = False,
private = private,
)
# Create model card
from huggingface_hub import ModelCard
content = MODEL_CARD.format(
username = username,
base_model = model.config._name_or_path,
model_type = model.config.model_type,
method = "",
extra = "unsloth",
)
card = ModelCard(content)
card.push_to_hub(save_directory, token = token)
except:
pass
hf_api = HfApi(token = token)
return save_directory, hf_api
pass
def upload_to_huggingface(
model,
save_directory,
token,
method,
extra = "",
file_location = None,
old_username = None,
private = None,
create_config = True,
):
save_directory, username = _determine_username(save_directory, old_username, token)
from huggingface_hub import create_repo
try:
create_repo(
repo_id = save_directory,
token = token,
repo_type = "model",
exist_ok = False,
private = private,
)
# Create model card
from huggingface_hub import ModelCard
content = MODEL_CARD.format(
username = username,
base_model = model.config._name_or_path,
model_type = model.config.model_type,
method = "",
extra = extra,
)
card = ModelCard(content)
card.push_to_hub(save_directory, token = token)
except:
pass
if file_location is not None:
# Now upload file
hf_api = HfApi(token = token)
if "/" in file_location:
uploaded_location = file_location[file_location.rfind("/")+1:]
else:
uploaded_location = file_location
pass
# find ftevent file from tensorboard and upload it
import glob
ftevent_files = glob.glob("*out.tfevents*", recursive = True)
if len(ftevent_files) > 0:
print("Unsloth: Uploading tensorboard files... Please wait...", file_location + "*out.tfevents*")
for ftevent_file in ftevent_files:
hf_api.upload_file(
path_or_fileobj = ftevent_file,
path_in_repo = ftevent_file.replace(file_location, ""),
repo_id = save_directory,
repo_type = "model",
commit_message = "(Trained with Unsloth)",
)
pass
pass
hf_api.upload_file(
path_or_fileobj = file_location,
path_in_repo = uploaded_location,
repo_id = save_directory,
repo_type = "model",
commit_message = "(Trained with Unsloth)",
)
# We also upload a config.json file
if create_config:
import json
with open("_temporary_unsloth_config.json", "w") as file:
json.dump({"model_type" : model.config.model_type}, file, indent = 4)
pass
hf_api.upload_file(
path_or_fileobj = "_temporary_unsloth_config.json",
path_in_repo = "config.json",
repo_id = save_directory,
repo_type = "model",
commit_message = "(Trained with Unsloth)",
)
os.remove("_temporary_unsloth_config.json")
pass
pass
return username
pass
def fix_tokenizer_bos_token(tokenizer):
# Check if BOS added already, then warn
fix_bos_token = False
chat_template = getattr(tokenizer, "chat_template", None)
if (tokenizer("A").input_ids[0] == getattr(tokenizer, "bos_token_id", None)):
if chat_template is not None and \
(
tokenizer.bos_token in chat_template or \
"{bos_token}" in chat_template.replace(" ", "") or \
"{bos_token+" in chat_template.replace(" ", "")
):
fix_bos_token = True
logger.warning(
"Unsloth: ##### The current model auto adds a BOS token.\n"\
"Unsloth: ##### Your chat template has a BOS token. We shall remove it temporarily."
)
# Remove {{bos_token}}
new_chat_template = re.sub(r"\{[\s]{0,}\{[\s]{0,}bos\_token[\s]{0,}\}[\s]{0,}\}", "", chat_template)
# Remove {{bos_token +
new_chat_template = re.sub(r"\{[\s]{0,}\{[\s]{0,}bos\_token[\s]{0,}\+[\s]{0,}", "", new_chat_template)
tokenizer.chat_template = new_chat_template
pass
pass
return fix_bos_token, chat_template
pass
def create_ollama_modelfile(tokenizer, gguf_location):
"""
Creates an Ollama Modelfile.
Use ollama.create(model = "new_ollama_model", modelfile = modelfile)
"""
modelfile = getattr(tokenizer, "_ollama_modelfile", None)
if modelfile is None: return None
FILE_LOCATION_REPLACER = "⚫@✅#🦥__FILE_LOCATION__⚡@🦥#⛵"
EOS_TOKEN_REPLACER = "⚫@✅#🦥__EOS_TOKEN__⚡@🦥#⛵"
LEFT_BRACKET_REPLACER = "⚫@✅#🦥"
RIGHT_BRACKET_REPLACER = "⚡@🦥#⛵"
# Fixes https://github.com/unslothai/unsloth/issues/1087
# We must convert all {'s and }'s but keep {__FILE_LOCATION__} intact
modelfile = modelfile\
.replace("{__FILE_LOCATION__}", FILE_LOCATION_REPLACER)\
.replace("{__EOS_TOKEN__}", EOS_TOKEN_REPLACER)\
.replace("{", LEFT_BRACKET_REPLACER)\
.replace("}", RIGHT_BRACKET_REPLACER)
# Revert {__FILE_LOCATION__} back
modelfile = modelfile\
.replace(FILE_LOCATION_REPLACER, "{__FILE_LOCATION__}")\
.replace(EOS_TOKEN_REPLACER, "{__EOS_TOKEN__}")
if "__EOS_TOKEN__" in modelfile:
modelfile = modelfile.format(
__FILE_LOCATION__ = gguf_location,
__EOS_TOKEN__ = tokenizer.eos_token,
)
else:
modelfile = modelfile.format(
__FILE_LOCATION__ = gguf_location,
)
pass
modelfile = modelfile\
.replace("⚫@✅#🦥", "{")\
.replace("⚡@🦥#⛵", "}")\
.rstrip()
return modelfile
pass
def unsloth_save_pretrained_gguf(
self,
save_directory : Union[str, os.PathLike],
tokenizer = None,
quantization_method : str = "fast_quantized",
first_conversion : str = None,
push_to_hub : bool = False,
token : Optional[Union[str, bool]] = None,
private : Optional[bool] = None,
is_main_process : bool = True,
state_dict : Optional[dict] = None,
save_function : Callable = torch.save,
max_shard_size : Union[int, str] = "5GB",
safe_serialization : bool = True,
variant : Optional[str] = None,
save_peft_format : bool = True,
tags : List[str] = None,
temporary_location : str = "_unsloth_temporary_saved_buffers",
maximum_memory_usage : float = 0.85,
):
"""
Same as .save_pretrained(...) except 4bit weights are auto
converted to float16 then converted to GGUF / llama.cpp format.
Choose for `quantization_method` to be:
"not_quantized" : "Recommended. Fast conversion. Slow inference, big files.",
"fast_quantized" : "Recommended. Fast conversion. OK inference, OK file size.",
"quantized" : "Recommended. Slow conversion. Fast inference, small files.",
"f32" : "Not recommended. Retains 100% accuracy, but super slow and memory hungry.",
"f16" : "Fastest conversion + retains 100% accuracy. Slow and memory hungry.",
"q8_0" : "Fast conversion. High resource use, but generally acceptable.",
"q4_k_m" : "Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K",
"q5_k_m" : "Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K",
"q2_k" : "Uses Q4_K for the attention.vw and feed_forward.w2 tensors, Q2_K for the other tensors.",
"q3_k_l" : "Uses Q5_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
"q3_k_m" : "Uses Q4_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
"q3_k_s" : "Uses Q3_K for all tensors",
"q4_0" : "Original quant method, 4-bit.",
"q4_1" : "Higher accuracy than q4_0 but not as high as q5_0. However has quicker inference than q5 models.",
"q4_k_s" : "Uses Q4_K for all tensors",
"q4_k" : "alias for q4_k_m",
"q5_k" : "alias for q5_k_m",
"q5_0" : "Higher accuracy, higher resource usage and slower inference.",
"q5_1" : "Even higher accuracy, resource usage and slower inference.",
"q5_k_s" : "Uses Q5_K for all tensors",
"q6_k" : "Uses Q8_K for all tensors",
"iq2_xxs" : "2.06 bpw quantization",
"iq2_xs" : "2.31 bpw quantization",
"iq3_xxs" : "3.06 bpw quantization",
"q3_k_xs" : "3-bit extra small quantization",
"""
if tokenizer is None:
raise ValueError("Unsloth: Saving to GGUF must have a tokenizer.")
arguments = dict(locals())
arguments["model"] = self
arguments["tokenizer"] = tokenizer
arguments["push_to_hub"] = False # We save ourselves
arguments["save_method"] = "merged_16bit" # Must be 16bit
del arguments["self"]
del arguments["quantization_method"]
del arguments["first_conversion"]
# Fix tokenizer adding an extra BOS token at the front
fix_bos_token, old_chat_template = fix_tokenizer_bos_token(tokenizer)
# Non blocking install GGUF first
if not os.path.exists("llama.cpp"):
if IS_KAGGLE_ENVIRONMENT:
# Kaggle is weird - no blocking installs, and no CUDA?
python_install = install_python_non_blocking(["gguf", "protobuf"])
python_install.wait()
install_llama_cpp_blocking(use_cuda = False)
new_save_directory, old_username = unsloth_save_model(**arguments)
makefile = None
else:
git_clone = install_llama_cpp_clone_non_blocking()
python_install = install_python_non_blocking(["gguf", "protobuf"])
git_clone.wait()
makefile = install_llama_cpp_make_non_blocking()
new_save_directory, old_username = unsloth_save_model(**arguments)
python_install.wait()
pass
else:
try:
new_save_directory, old_username = unsloth_save_model(**arguments)
makefile = None
except:
# Retry by recloning llama.cpp
if IS_KAGGLE_ENVIRONMENT:
# Kaggle is weird - no blocking installs, and no CUDA?
python_install = install_python_non_blocking(["gguf", "protobuf"])
python_install.wait()
install_llama_cpp_blocking(use_cuda = False)
new_save_directory, old_username = unsloth_save_model(**arguments)
makefile = None
else:
git_clone = install_llama_cpp_clone_non_blocking()
python_install = install_python_non_blocking(["gguf", "protobuf"])
git_clone.wait()
makefile = install_llama_cpp_make_non_blocking()
new_save_directory, old_username = unsloth_save_model(**arguments)
python_install.wait()
pass
pass
pass
# Use old chat template if the bos is removed
if fix_bos_token:
tokenizer.chat_template = old_chat_template
pass
for _ in range(3):
gc.collect()
model_dtype = self.config.torch_dtype
model_type = self.config.model_type
if type(model_dtype) is str:
assert(model_dtype == "float16" or model_dtype == "bfloat16")
elif model_dtype == torch.float16:
model_dtype = "float16"
elif model_dtype == torch.bfloat16:
model_dtype = "bfloat16"
else:
raise TypeError("Unsloth: Model dtype can only be float16 or bfloat16")
pass
is_sentencepiece_model = check_if_sentencepiece_model(self)
# Save to GGUF
all_file_locations, want_full_precision = save_to_gguf(
model_type, model_dtype, is_sentencepiece_model,
new_save_directory, quantization_method, first_conversion, makefile,
)
# Save Ollama modelfile
modelfile = create_ollama_modelfile(tokenizer, all_file_locations[0])
modelfile_location = None
if modelfile is not None:
modelfile_location = os.path.join(new_save_directory, "Modelfile")
with open(modelfile_location, "w") as file:
file.write(modelfile)
pass
print(f"Unsloth: Saved Ollama Modelfile to {modelfile_location}")
pass
if fix_bos_token:
logger.warning(
"Unsloth: ##### The current model auto adds a BOS token.\n"\
"Unsloth: ##### We removed it in GGUF's chat template for you."
)
pass
if push_to_hub:
print("Unsloth: Uploading GGUF to Huggingface Hub...")
# If not needing full precision, skip the first
if not want_full_precision: all_file_locations = all_file_locations[1:]
for file_location in all_file_locations:
username = upload_to_huggingface(
self, save_directory, token,
"GGUF converted", "gguf", file_location, old_username, private,
)
link = f"{username}/{new_save_directory.lstrip('/.')}" \
if username not in new_save_directory else \
new_save_directory.lstrip('/.')
print(f"Saved GGUF to https://huggingface.co/{link}")
pass
# Save modelfile
if modelfile_location is not None:
username = upload_to_huggingface(
self, save_directory, token,
"GGUF converted", "gguf", modelfile_location, old_username, private,
)
print(f"Saved Ollama Modelfile to https://huggingface.co/{link}")
pass
pass
pass
def unsloth_push_to_hub_gguf(
self,
repo_id : str,
tokenizer = None,
quantization_method : str = "fast_quantized",
first_conversion : str = None,
use_temp_dir : Optional[bool] = None,
commit_message : Optional[str] = "Trained with Unsloth",
private : Optional[bool] = None,
token : Union[bool, str, None] = None,
max_shard_size : Union[int, str, None] = "5GB",
create_pr : bool = False,
safe_serialization : bool = True,
revision : str = None,
commit_description : str = "Upload model trained with Unsloth 2x faster",
tags : Optional[List[str]] = None,
temporary_location : str = "_unsloth_temporary_saved_buffers",
maximum_memory_usage : float = 0.85,
):
"""
Same as .push_to_hub(...) except 4bit weights are auto
converted to float16 then converted to GGUF / llama.cpp format.
Choose for `quantization_method` to be:
"not_quantized" : "Recommended. Fast conversion. Slow inference, big files.",
"fast_quantized" : "Recommended. Fast conversion. OK inference, OK file size.",
"quantized" : "Recommended. Slow conversion. Fast inference, small files.",
"f32" : "Not recommended. Retains 100% accuracy, but super slow and memory hungry.",
"f16" : "Fastest conversion + retains 100% accuracy. Slow and memory hungry.",
"q8_0" : "Fast conversion. High resource use, but generally acceptable.",
"q4_k_m" : "Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K",
"q5_k_m" : "Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K",
"q2_k" : "Uses Q4_K for the attention.vw and feed_forward.w2 tensors, Q2_K for the other tensors.",
"q3_k_l" : "Uses Q5_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
"q3_k_m" : "Uses Q4_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
"q3_k_s" : "Uses Q3_K for all tensors",
"q4_0" : "Original quant method, 4-bit.",
"q4_1" : "Higher accuracy than q4_0 but not as high as q5_0. However has quicker inference than q5 models.",
"q4_k_s" : "Uses Q4_K for all tensors",
"q5_0" : "Higher accuracy, higher resource usage and slower inference.",
"q5_1" : "Even higher accuracy, resource usage and slower inference.",
"q5_k_s" : "Uses Q5_K for all tensors",
"q6_k" : "Uses Q8_K for all tensors",
"""
if tokenizer is None:
raise ValueError("Unsloth: Saving to GGUF must have a tokenizer.")
arguments = dict(locals())
arguments["model"] = self
arguments["tokenizer"] = tokenizer
arguments["save_directory"] = repo_id
arguments["push_to_hub"] = False # We save ourselves
arguments["save_method"] = "merged_16bit" # Must be 16bit
del arguments["self"]
del arguments["repo_id"]
del arguments["quantization_method"]
del arguments["first_conversion"]
# Fix tokenizer adding an extra BOS token at the front
fix_bos_token, old_chat_template = fix_tokenizer_bos_token(tokenizer)
# Non blocking install GGUF first
if not os.path.exists("llama.cpp"):
if IS_KAGGLE_ENVIRONMENT:
# Kaggle is weird - no blocking installs, and no CUDA?
python_install = install_python_non_blocking(["gguf", "protobuf"])
python_install.wait()
install_llama_cpp_blocking(use_cuda = False)
new_save_directory, old_username = unsloth_save_model(**arguments)
makefile = None
else:
git_clone = install_llama_cpp_clone_non_blocking()
python_install = install_python_non_blocking(["gguf", "protobuf"])
git_clone.wait()
makefile = install_llama_cpp_make_non_blocking()
new_save_directory, old_username = unsloth_save_model(**arguments)
python_install.wait()
pass
else:
try:
new_save_directory, old_username = unsloth_save_model(**arguments)
makefile = None
except:
# Retry by recloning llama.cpp
if IS_KAGGLE_ENVIRONMENT:
# Kaggle is weird - no blocking installs, and no CUDA?
python_install = install_python_non_blocking(["gguf", "protobuf"])
python_install.wait()
install_llama_cpp_blocking(use_cuda = False)
new_save_directory, old_username = unsloth_save_model(**arguments)
makefile = None
else:
git_clone = install_llama_cpp_clone_non_blocking()
python_install = install_python_non_blocking(["gguf", "protobuf"])
git_clone.wait()
makefile = install_llama_cpp_make_non_blocking()
new_save_directory, old_username = unsloth_save_model(**arguments)
python_install.wait()
pass
pass
pass
# Use old chat template if the bos is removed
if fix_bos_token:
tokenizer.chat_template = old_chat_template
pass
for _ in range(3):
gc.collect()
model_dtype = self.config.torch_dtype
model_type = self.config.model_type
if type(model_dtype) is str:
assert(model_dtype == "float16" or model_dtype == "bfloat16")
elif model_dtype == torch.float16:
model_dtype = "float16"
elif model_dtype == torch.bfloat16:
model_dtype = "bfloat16"
else:
raise TypeError("Unsloth: Model dtype can only be float16 or bfloat16")
pass
is_sentencepiece_model = check_if_sentencepiece_model(self)
# Save to GGUF
all_file_locations, want_full_precision = save_to_gguf(
model_type, model_dtype, is_sentencepiece_model,
new_save_directory, quantization_method, first_conversion, makefile,
)
# Save Ollama modelfile
modelfile = create_ollama_modelfile(tokenizer, all_file_locations[0])
modelfile_location = None
if modelfile is not None:
modelfile_location = os.path.join(new_save_directory, "Modelfile")
with open(modelfile_location, "w") as file:
file.write(modelfile)
pass
print(f"Unsloth: Saved Ollama Modelfile to {modelfile_location}")
pass
# If not needing full precision, skip the first
if not want_full_precision: all_file_locations = all_file_locations[1:]
for file_location in all_file_locations:
print("Unsloth: Uploading GGUF to Huggingface Hub...")
username = upload_to_huggingface(
self, repo_id, token,
"GGUF converted", "gguf", file_location, old_username, private,
)
link = f"{username}/{new_save_directory.lstrip('/.')}" \
if username not in new_save_directory else \
new_save_directory.lstrip('/.')
print(f"Saved GGUF to https://huggingface.co/{link}")
pass
# Save modelfile
if modelfile_location is not None:
username = upload_to_huggingface(
self, repo_id, token,
"GGUF converted", "gguf", modelfile_location, old_username, private,
)
print(f"Saved Ollama Modelfile to https://huggingface.co/{link}")
pass
if fix_bos_token:
logger.warning(
"Unsloth: ##### The current model auto adds a BOS token.\n"\
"Unsloth: ##### We removed it in GGUF's chat template for you."
)
pass
pass
# Corrected function to save LoRA to a custom directory
def save_lora_to_custom_dir(model, tokenizer, save_directory):
# Create the custom directory if it doesn't exist
os.makedirs(save_directory, exist_ok=True)
# Call the unsloth_save_model function with the custom directory
unsloth_save_model(
model,
tokenizer,
save_directory=save_directory,
save_method="lora",
push_to_hub=False,
)
# Corrected method within the model class to convert LoRA to GGML and push to Hugging Face Hub
def unsloth_convert_lora_to_ggml_and_push_to_hub(
self,
tokenizer,
repo_id: str,
use_temp_dir: Optional[bool] = None,
commit_message: Optional[str] = "Converted LoRA to GGML with Unsloth",
private: Optional[bool] = None,
token: Union[bool, str, None] = None,
create_pr: bool = False,
revision: str = None,
commit_description: str = "Convert LoRA to GGML format using Unsloth",
temporary_location: str = "_unsloth_temporary_saved_buffers",
maximum_memory_usage: float = 0.85,
):
if not os.path.exists("llama.cpp"):
if IS_KAGGLE_ENVIRONMENT:
python_install = install_python_non_blocking(["protobuf"])
python_install.wait()
install_llama_cpp_blocking(use_cuda=False)
makefile = None
else:
git_clone = install_llama_cpp_clone_non_blocking()
python_install = install_python_non_blocking(["protobuf"])
git_clone.wait()
makefile = install_llama_cpp_make_non_blocking()
python_install.wait()
else:
makefile = None
for _ in range(3):
gc.collect()
lora_directory_push = "lora-to-ggml-push"
save_lora_to_custom_dir(self, tokenizer, lora_directory_push)
model_type = self.config.model_type
output_file = os.path.join(lora_directory_push, "ggml-adapter-model.bin")
print(f"Unsloth: Converting auto-saved LoRA adapters at {lora_directory_push} to GGML format.")
print(f"The output file will be {output_file}")
command = f"python3 llama.cpp/convert-lora-to-ggml.py {lora_directory_push} {output_file} llama"
try:
with subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=1, universal_newlines=True) as sp:
for line in sp.stdout:
print(line, end="", flush=True)
for line in sp.stderr:
print(line, end="", flush=True)
sp.wait()
if sp.returncode != 0:
raise subprocess.CalledProcessError(sp.returncode, command)
except subprocess.CalledProcessError as e:
print(f"Error: Conversion failed with return code {e.returncode}")
return
print(f"Unsloth: Conversion completed! Output file: {output_file}")
print("Unsloth: Uploading GGML file to Hugging Face Hub...")
username = upload_to_huggingface(
self, repo_id, token,
"GGML converted LoRA", "ggml", output_file, None, private,
)
link = f"{repo_id.lstrip('/')}"
print("Unsloth: Done.")
print(f"Converted LoRA to GGML and uploaded to https://huggingface.co/{link}")
print("\nThis GGML making function was made by Maheswar. Ping him @Maheswar on the Unsloth Discord or on HuggingFace (@mahiatlinux) if you like this!")
def unsloth_convert_lora_to_ggml_and_save_locally(
self,
save_directory: str, # Added parameter for the folder name
tokenizer,
temporary_location: str = "_unsloth_temporary_saved_buffers",
maximum_memory_usage: float = 0.85,
):
if not os.path.exists("llama.cpp"):
if IS_KAGGLE_ENVIRONMENT:
python_install = install_python_non_blocking(["protobuf"])
python_install.wait()
install_llama_cpp_blocking(use_cuda=False)
makefile = None
else:
git_clone = install_llama_cpp_clone_non_blocking()
python_install = install_python_non_blocking(["protobuf"])
git_clone.wait()
makefile = install_llama_cpp_make_non_blocking()
python_install.wait()
else:
makefile = None
for _ in range(3):
gc.collect()
# Use the provided save_directory for local saving
save_lora_to_custom_dir(self, tokenizer, save_directory)
model_type = self.config.model_type
output_file = os.path.join(save_directory, "ggml-adapter-model.bin")
print(f"Unsloth: Converting auto-saved LoRA adapters at {save_directory} to GGML format.")
print(f"The output file will be {output_file}")
command = f"python3 llama.cpp/convert-lora-to-ggml.py {save_directory} {output_file} llama"
try:
with subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=1, universal_newlines=True) as sp:
for line in sp.stdout:
print(line, end="", flush=True)
for line in sp.stderr:
print(line, end="", flush=True)
sp.wait()
if sp.returncode != 0:
raise subprocess.CalledProcessError(sp.returncode, command)
except subprocess.CalledProcessError as e:
print(f"Error: Conversion failed with return code {e.returncode}")
return
print("Unsloth: Done.")
print(f"Unsloth: Conversion completed! Output file: {output_file}")
print("\nThis GGML making function was made by Maheswar. Ping him @Maheswar on the Unsloth Discord or on HuggingFace (@mahiatlinux) if you like this!")
def patch_saving_functions(model):
import inspect
import types
from typing import Callable, Optional, Union, List
# And now re add our saving methods!
if model.push_to_hub.__name__ == "unsloth_push_to_hub":
original_push_to_hub = model.original_push_to_hub
else:
original_push_to_hub = model.push_to_hub
pass
signature = str(inspect.signature(original_push_to_hub)).replace("NoneType", "None")
signature = signature[1:]
signature = re.sub("<function save at .+?>", "torch.save", signature)
docs = original_push_to_hub.__doc__.encode("utf-8").decode("utf-8")
push_to_hub_text = f'''def unsloth_push_to_hub(self, {signature}:
"""
{docs}
"""
arguments = dict(locals())
del arguments["self"]
if "tags" in arguments and arguments["tags"] is not None:
assert(isinstance(arguments["tags"], (list, tuple)))
arguments["tags"] = list(arguments["tags"]) + ["unsloth",]
elif "tags" in arguments:
arguments["tags"] = ["unsloth",]
elif hasattr(self, "add_model_tags"):
self.add_model_tags(["unsloth",])
if "commit_message" in arguments:
commit_message = arguments["commit_message"]
if commit_message is not None:
if not commit_message.endswith(" "): commit_message += " "
if "Unsloth" not in commit_message:
commit_message += "(Trained with Unsloth)"
else:
commit_message = "Upload model trained with Unsloth"
arguments["commit_message"] = commit_message
if "commit_description" in arguments:
commit_description = arguments["commit_description"]
if commit_description is not None:
if not commit_description.endswith(" "): commit_description += " "
if "Unsloth" not in commit_description:
commit_description += "(Trained with Unsloth 2x faster)"
else:
commit_description = "Upload model trained with Unsloth 2x faster"
arguments["commit_description"] = commit_description
# Update model tag
if hasattr(self, "config"):
_ = upload_to_huggingface(
self, arguments["repo_id"], arguments["token"],
"finetuned", "trl", file_location = None,
old_username = None, private = arguments["private"],
)
pass
try:
self.original_push_to_hub(**arguments)
except:
del arguments["tags"]
self.original_push_to_hub(**arguments)
pass
if hasattr(self, "config"):
print("Saved model to https://huggingface.co/" + arguments["repo_id"])
pass
'''
exec(push_to_hub_text, globals())
original_model = model
while True:
if original_model.push_to_hub.__name__ != "unsloth_push_to_hub":
original_model.original_push_to_hub = original_model.push_to_hub
original_model.push_to_hub = types.MethodType(unsloth_push_to_hub, original_model)
if hasattr(original_model, "add_model_tags"):
original_model.add_model_tags(["unsloth",])
pass
pass
if hasattr(original_model, "model"): original_model = original_model.model
else: break
pass
# Add saving methods to top level model
if hasattr(model, "config"):
# Counteract tokenizers
model.push_to_hub_merged = types.MethodType(unsloth_push_to_hub_merged, model)
model.save_pretrained_merged = types.MethodType(unsloth_save_pretrained_merged, model)
model.push_to_hub_gguf = types.MethodType(unsloth_push_to_hub_gguf, model)
model.save_pretrained_gguf = types.MethodType(unsloth_save_pretrained_gguf, model)
model.push_to_hub_ggml = types.MethodType(unsloth_convert_lora_to_ggml_and_push_to_hub, model)
model.save_pretrained_ggml = types.MethodType(unsloth_convert_lora_to_ggml_and_save_locally, model)
pass
return model
pass
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers import AutoTokenizer
from transformers.convert_slow_tokenizer import convert_slow_tokenizer
from transformers import PreTrainedTokenizerFast
import re
import os
from transformers.models.llama.modeling_llama import logger
from peft import PeftModelForCausalLM
import torch
import itertools
import collections
import numpy as np
import gc
import subprocess
from unsloth_zoo.tokenizer_utils import (
mean_of_trained_tokens,
add_new_tokens,
fix_untrained_tokens,
)
from unsloth_zoo.training_utils import (
fix_zero_training_loss,
)
__all__ = [
"load_correct_tokenizer",
"fix_sentencepiece_tokenizer",
"check_tokenizer",
"add_new_tokens",
"fix_sentencepiece_gguf",
]
IGNORED_TOKENIZER_CHECKING = frozenset((
"CodeLlamaTokenizerFast",
"CodeLlamaTokenizer",
))
IGNORED_TOKENIZER_NAMES = [
# Qwen Coder did not train on tool calling. Math did!
"unsloth/Qwen2.5-Coder-1.5B-Instruct",
"unsloth/Qwen2.5-Coder-7B-Instruct",
]
IGNORED_TOKENIZER_NAMES = frozenset(
[x.lower() for x in IGNORED_TOKENIZER_NAMES] + \
[x.lower()+"-bnb-4bit" for x in IGNORED_TOKENIZER_NAMES]
)
# Check environments
keynames = "\n" + "\n".join(os.environ.keys())
IS_COLAB_ENVIRONMENT = "\nCOLAB_" in keynames
IS_KAGGLE_ENVIRONMENT = "\nKAGGLE_" in keynames
del keynames
def try_fix_tokenizer(tokenizer, prepend = True):
if hasattr(tokenizer, "_tokenizer"):
converted_tokenizer = tokenizer._tokenizer
else:
converted_tokenizer = convert_slow_tokenizer(tokenizer)
pass
tokenizer_string = converted_tokenizer.to_str()
# Llama does _apple. Sometimes this is wrong!!
prepend_text = '{"type":"Prepend","prepend":"▁"},'
if not prepend and prepend_text in tokenizer_string:
tokenizer_string = tokenizer_string.replace(prepend_text, "", 1)
pass
dir_names = dir(tokenizer)
# Get eos_token, bos_token etc
token_names = [x for x in dir_names if x.endswith("_token") and x.count("_") == 1]
for token_name in token_names:
token = getattr(tokenizer, token_name, None)
if token is None: continue
token_id = getattr(tokenizer, token_name + "_id", None)
# Locate the token's id mapping in the string
find_text = f'"id":{token_id},"content":"'
start = tokenizer_string.find(find_text) + len(find_text)
if start == -1: continue
end = tokenizer_string.find('",', start)
bad_token = tokenizer_string[start : end]
# Check if token is the actual same one - if not, edit it
if bad_token != token:
bad_text = f'{find_text}{bad_token}",'
good_text = f'{find_text}{token}",'
tokenizer_string = tokenizer_string.replace(bad_text, good_text, 1)
# And replace vocab section
bad_text = f'"{bad_token}":{token_id},'
good_text = f'"{token}":{token_id},'
tokenizer_string = tokenizer_string.replace(bad_text, good_text, 1)
pass
pass
fixed_tokenizer = converted_tokenizer.from_str(tokenizer_string)
return fixed_tokenizer
pass
def get_sorted_dict(dictionary):
sorted_keys = sorted(dictionary.values())
inverted_dictionary = { value : key for key, value in dictionary.items() }
sorted_dictionary = {}
for key in sorted_keys:
value = inverted_dictionary[key]
sorted_dictionary[value] = key
return sorted_dictionary
pass
def convert_to_fast_tokenizer(
slow_tokenizer,
temporary_location = "_unsloth_sentencepiece_temp",
):
is_fast = getattr(slow_tokenizer, "is_fast", False)
if is_fast: return slow_tokenizer
try:
tokenizer_name = slow_tokenizer.__class__.__name__
lowered_tokenizer_name = tokenizer_name.lower()
if lowered_tokenizer_name.endswith("tokenizer"):
class_name = lowered_tokenizer_name[:-len("tokenizer")]
FastTokenizer = eval(
f'__import__(f"transformers.models.{class_name}").{tokenizer_name}Fast'
)
else:
FastTokenizer = PreTrainedTokenizerFast
except:
FastTokenizer = PreTrainedTokenizerFast
pass
# Get all arguments (bos_token, etc)
docs = FastTokenizer.__doc__
docs = docs[docs.find("Args:"):]
args = re.findall(r"\n[\s]+([^\s]{1,}) \(", docs, flags = re.MULTILINE)
args = [x for x in args if not x.endswith("_file")]
# Also some missing maybe!
docs = PreTrainedTokenizerFast.__doc__
docs = docs[docs.find("Args:"):]
args2 = re.findall(r"\n[\s]+([^\s]{1,}) \(", docs, flags = re.MULTILINE)
args2 = [x for x in args2 if not x.endswith("_file")]
args = list(set(args + args2))
kwargs = {}
for arg in args: kwargs[arg] = getattr(slow_tokenizer, arg, None)
kwargs["tokenizer_object"] = try_fix_tokenizer(slow_tokenizer, prepend = True)
fast_tokenizer = FastTokenizer( **kwargs )
# Check if they're similar!
sorted_slow_tokenizer = get_sorted_dict(slow_tokenizer.get_vocab())
sorted_fast_tokenizer = get_sorted_dict(fast_tokenizer.get_vocab())
check_vocab = (sorted_slow_tokenizer == sorted_fast_tokenizer)
check_special = (slow_tokenizer.all_special_tokens == fast_tokenizer.all_special_tokens)
# Failure so return slow_tokenizer
if not check_vocab or not check_special: return slow_tokenizer
# Now confirm if they match
if not assert_same_tokenization(slow_tokenizer, fast_tokenizer):
# Maybe remove prepending of __apple?
kwargs["tokenizer_object"] = try_fix_tokenizer(slow_tokenizer, prepend = False)
fast_tokenizer = FastTokenizer( **kwargs )
if not assert_same_tokenization(slow_tokenizer, fast_tokenizer):
# Failure :(
return slow_tokenizer
pass
pass
# Also tokenizer.model is missing!
name = slow_tokenizer.name_or_path.replace("/", "_")
if not os.path.exists(temporary_location):
os.makedirs(temporary_location)
pass
new_location = f"{temporary_location}/{name}"
slow_tokenizer.save_pretrained(new_location)
fast_tokenizer.save_pretrained(new_location)
# Now load it!
fast_tokenizer = AutoTokenizer.from_pretrained(new_location)
if assert_same_tokenization(slow_tokenizer, fast_tokenizer):
return fast_tokenizer
return slow_tokenizer
pass
# Check Mistral chat template without BOS / EOS
mistral_template = \
"{% if messages[0]['role'] == 'system' %}"\
"{% if messages[1]['role'] == 'user' %}"\
"{{ '[INST] ' + messages[0]['content'] + ' ' + messages[1]['content'] + ' [/INST]' }}"\
"{% set loop_messages = messages[2:] %}"\
"{% else %}"\
"{{ '[INST] ' + messages[0]['content'] + ' [/INST]' }}"\
"{% set loop_messages = messages[1:] %}"\
"{% endif %}"\
"{% else %}"\
"{% set loop_messages = messages %}"\
"{% endif %}"\
"{% for message in loop_messages %}"\
"{% if message['role'] == 'user' %}"\
"{{ '[INST] ' + message['content'] + ' [/INST]' }}"\
"{% elif message['role'] == 'assistant' %}"\
"{{ message['content'] }}"\
"{% else %}"\
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
"{% endif %}"\
"{% endfor %}"
pass
# Check Llama chat template without BOS / EOS
llama_template = \
"{% if messages[0]['role'] == 'system' %}"\
"{% if messages[1]['role'] == 'user' %}"\
"{{ '[INST] <<SYS>>\n' + messages[0]['content'] + '\n<</SYS>>\n\n' + messages[1]['content'] + ' [/INST]' }}"\
"{% set loop_messages = messages[2:] %}"\
"{% else %}"\
"{{ '[INST] ' + messages[0]['content'] + ' [/INST]' }}"\
"{% set loop_messages = messages[1:] %}"\
"{% endif %}"\
"{% else %}"\
"{% set loop_messages = messages %}"\
"{% endif %}"\
"{% for message in loop_messages %}"\
"{% if message['role'] == 'user' %}"\
"{{ '[INST] ' + message['content'].strip() + ' [/INST]' }}"\
"{% elif message['role'] == 'assistant' %}"\
"{{ ' ' + message['content'].strip() + ' ' }}"\
"{% else %}"\
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
"{% endif %}"\
"{% endfor %}"
pass
def assert_same_tokenization(slow_tokenizer, fast_tokenizer):
# Get eos_token, bos_token etc
dir_names = dir(slow_tokenizer)
special_tokens = list(filter(None, (
getattr(slow_tokenizer, x) for x in dir_names
if x.endswith("_token") and x.count("_") == 1
)))
all_special_tokens = list(set(special_tokens + slow_tokenizer.all_special_tokens))
# Check if chat template is enabled!
check_chat_template1 = True
check_chat_template2 = True
check_chat_template3 = True
"""
Weirdly Mistral tokenizers are actually correct??
Ie below will actually load mistral v1 and v3 incorrectly!
slow_chat_template = getattr(slow_tokenizer, "chat_template", None)
fast_chat_template = getattr(fast_tokenizer, "chat_template", None)
messages = [
{"role": "user", "content": " What is 2+2? "},
{"role": "assistant", "content": " It's 4. "},
]
# Check the tokenizer's own chat template
if slow_chat_template is not None and fast_chat_template is not None:
check_chat_template1 = \
slow_tokenizer.apply_chat_template(messages) == \
fast_tokenizer.apply_chat_template(messages)
pass
# Check Mistral chat template without BOS / EOS
slow_tokenizer.chat_template = mistral_template
fast_tokenizer.chat_template = mistral_template
check_chat_template2 = \
slow_tokenizer.apply_chat_template(messages) == \
fast_tokenizer.apply_chat_template(messages)
pass
# Check Llama chat template without BOS / EOS
slow_tokenizer.chat_template = llama_template
fast_tokenizer.chat_template = llama_template
check_chat_template3 = \
slow_tokenizer.apply_chat_template(messages) == \
fast_tokenizer.apply_chat_template(messages)
pass
# Combine them all and revert chat templates
slow_tokenizer.chat_template = slow_chat_template
fast_tokenizer.chat_template = fast_chat_template
"""
check_chat_template = check_chat_template1 and check_chat_template2 and check_chat_template3
# Try special tokens
try:
string = "\n".join(all_special_tokens) + \
"A quick brown fox jumps over the lazy dog!!\n\nHi</s>\n\n" + \
"".join(all_special_tokens)
check_special_tokens = \
slow_tokenizer(string).input_ids == \
fast_tokenizer(string).input_ids
return check_chat_template and check_special_tokens
except:
# For eg see https://github.com/unslothai/unsloth/issues/292
# Sometimes tokenizer has weird tokens, causing a combined tokenization to fail.
# [TODO] We temporarily disable this for CodeLlama tokenizers
if slow_tokenizer.__repr__().split("(", 1)[0] in IGNORED_TOKENIZER_CHECKING:
return check_chat_template
else:
return False
pass
pass
def fix_sentencepiece_tokenizer(
old_tokenizer,
new_tokenizer,
token_mapping,
temporary_location = "_unsloth_sentencepiece_temp",
):
# From https://github.com/google/sentencepiece/issues/121
# We need to manually edit the sentencepiece tokenizer!
from transformers.utils import sentencepiece_model_pb2
if not os.path.exists(temporary_location):
os.makedirs(temporary_location)
pass
# Check if tokenizer.model exists
if not os.path.isfile(f"{temporary_location}/tokenizer.model"):
return new_tokenizer
pass
# First save the old tokenizer
old_tokenizer.save_pretrained(temporary_location)
tokenizer_file = sentencepiece_model_pb2.ModelProto()
tokenizer_file.ParseFromString(open(f"{temporary_location}/tokenizer.model", "rb").read())
# Now save the new tokenizer
new_tokenizer.save_pretrained(temporary_location)
# Now correct the old tokenizer's .model file
for old_token, new_token in token_mapping.items():
ids = old_tokenizer([old_token], add_special_tokens = False).input_ids
ids = ids[0]
if (len(ids) != 1):
# Skip this token!
print(f"Skip mapping {old_token} to {new_token} since {new_token} is already in the tokenizer!")
continue
pass
ids = ids[0]
# [TODO] Hack for Starling - try except
try:
tokenizer_piece = tokenizer_file.pieces[ids]
except:
continue
assert(tokenizer_piece.piece == old_token)
tokenizer_piece.piece = new_token
pass
# And now write it
with open(f"{temporary_location}/tokenizer.model", "wb") as file:
file.write(tokenizer_file.SerializeToString())
pass
# And load it!
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
temporary_location,
eos_token = new_tokenizer.eos_token,
pad_token = new_tokenizer.pad_token,
)
return tokenizer
pass
def fix_sentencepiece_gguf(saved_location):
"""
Fixes sentencepiece tokenizers which did not extend the vocabulary with
user defined tokens.
Inspiration from https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py
"""
from copy import deepcopy
from transformers.utils import sentencepiece_model_pb2
import json
from enum import IntEnum
class SentencePieceTokenTypes(IntEnum):
NORMAL = 1
UNKNOWN = 2
CONTROL = 3
USER_DEFINED = 4
UNUSED = 5
BYTE = 6
pass
# Load tokenizer.model
tokenizer_file = sentencepiece_model_pb2.ModelProto()
if not os.path.isfile(f"{saved_location}/tokenizer.model"): return
tokenizer_file.ParseFromString(open(f"{saved_location}/tokenizer.model", "rb").read())
sentence_piece_size = len(tokenizer_file.pieces)
# Load added_tokens_json
if not os.path.isfile(f"{saved_location}/added_tokens.json"): return
with open(f"{saved_location}/added_tokens.json", "r", encoding = "utf-8") as file:
added_tokens_json = json.load(file)
pass
if len(added_tokens_json) == 0: return
added_tokens_json = dict(sorted(added_tokens_json.items(), key = lambda item: item[1]))
new_size = sentence_piece_size + len(added_tokens_json)
# Confirm added_tokens_json is correct
added_tokens_ids = np.array(list(added_tokens_json.values()))
diff = np.diff(added_tokens_ids)
if (diff.min() != 1 or diff.max() != 1): return
if (added_tokens_ids.min() != sentence_piece_size): return
# Edit sentence piece tokens with added_tokens_json
logger.warning(
f"Unsloth: Extending {saved_location}/tokenizer.model with added_tokens.json.\n"\
f"Originally tokenizer.model is of size ({sentence_piece_size}).\n"\
f"But we need to extend to sentencepiece vocab size ({new_size})."
)
new_tokens = deepcopy(tokenizer_file.pieces[-len(added_tokens_ids):])
for new_token, added_token in zip(new_tokens, added_tokens_json.keys()):
new_token.piece = added_token.encode("utf-8")
new_token.score = -1000.0
new_token.type = SentencePieceTokenTypes.USER_DEFINED
pass
tokenizer_file.pieces.extend(new_tokens)
with open(f"{saved_location}/tokenizer.model", "wb") as file:
file.write(tokenizer_file.SerializeToString())
pass
# Add padding tokens
# actual_vocab_size = model.config.vocab_size
# padding = actual_vocab_size - len(tokenizer_file.pieces)
return
pass
def _load_correct_tokenizer(
tokenizer_name,
model_max_length = None,
padding_side = "right",
token = None,
trust_remote_code = False,
cache_dir = "huggingface_tokenizers_cache",
fix_tokenizer = True,
):
if IS_COLAB_ENVIRONMENT or IS_KAGGLE_ENVIRONMENT:
cache_dir = cache_dir
else:
cache_dir = None
pass
# Try loading the slow tokenizer. If it fails, then try Fast only
# Mainly to solve Deepseek models with no tokenizer.model file
slow_tokenizer = None
try:
slow_tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
model_max_length = model_max_length,
padding_side = padding_side,
token = token,
trust_remote_code = trust_remote_code,
# Cannot just use use_fast = False as per https://twitter.com/danielhanchen/status/1789659394302718373
use_fast = False,
legacy = False,
from_slow = True,
cache_dir = cache_dir,
)
except:
pass
# print(
# f"Unsloth: {tokenizer_name} has no tokenizer.model file.\n"\
# "Just informing you about this - this is not a critical error."
# )
pass
fast_tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
model_max_length = model_max_length,
padding_side = padding_side,
token = token,
trust_remote_code = trust_remote_code,
cache_dir = cache_dir,
)
if not fix_tokenizer or tokenizer_name in IGNORED_TOKENIZER_NAMES:
return fast_tokenizer
# Ignore Mistral ones - they're a bit weird to handle!
elif "mistral" in tokenizer_name.lower():
return fast_tokenizer
elif slow_tokenizer is not None:
if hasattr(fast_tokenizer, "add_bos_token") and hasattr(slow_tokenizer, "add_bos_token"):
fast_tokenizer.add_bos_token = slow_tokenizer.add_bos_token
if hasattr(fast_tokenizer, "add_eos_token") and hasattr(slow_tokenizer, "add_eos_token"):
fast_tokenizer.add_eos_token = slow_tokenizer.add_eos_token
# Confirm if slow and fast are equivalent!
if assert_same_tokenization(slow_tokenizer, fast_tokenizer):
return fast_tokenizer
else:
logger.warning(f"Unsloth: Will load {tokenizer_name} as a legacy tokenizer.")
return convert_to_fast_tokenizer(slow_tokenizer)
pass
else:
return fast_tokenizer
pass
pass
def load_correct_tokenizer(
tokenizer_name,
model_max_length = None,
padding_side = "right",
token = None,
trust_remote_code = False,
cache_dir = "huggingface_tokenizers_cache",
fix_tokenizer = True,
):
tokenizer = _load_correct_tokenizer(
tokenizer_name = tokenizer_name,
model_max_length = model_max_length,
padding_side = padding_side,
token = token,
trust_remote_code = trust_remote_code,
cache_dir = cache_dir,
fix_tokenizer = fix_tokenizer,
)
### 1. Fixup tokenizer's chat_template
old_chat_template = getattr(tokenizer, "chat_template", None)
# Ignore mistral type models since they don't have a add_generation_prompt
if "mistral" in str(getattr(tokenizer, "name_or_path", "")).lower():
chat_template = old_chat_template
# Also check Llama-2 old style models
elif old_chat_template is not None and \
"[/INST]" in old_chat_template and "[INST]" in old_chat_template and \
"bos_token" in old_chat_template and "eos_token" in old_chat_template:
chat_template = old_chat_template
else:
chat_template = fix_chat_template(tokenizer)
if old_chat_template is not None and chat_template is None:
raise RuntimeError(
"Unsloth: Fixing chat template failed - please file a report immediately!"
)
pass
pass
tokenizer.chat_template = chat_template
return tokenizer
pass
def _fix_chat_template(chat_template):
endfor = "{% endfor %}"
where = chat_template.find(endfor)
if where == -1: return chat_template
after_endfor = chat_template[where + len(endfor):]
if "{% if" not in after_endfor and "{% set " not in after_endfor and \
after_endfor.startswith("{{") and after_endfor.endswith("}}") and \
after_endfor.count("{{") == 1 and after_endfor.count("}}") == 1:
after_endfor = "{% if add_generation_prompt %}" + after_endfor + "{% endif %}"
chat_template = chat_template[:where + len(endfor)] + after_endfor
pass
return chat_template
pass
def fix_chat_template(tokenizer):
chat_template = getattr(tokenizer, "chat_template", None)
if chat_template is None: return None
### 1. Check if add_generation_prompt works
# Check for ShareGPT style first
is_sharegpt = None
try:
messages = [
{"role": "user", "content": "Who are you?"},
]
tokenizer.apply_chat_template(messages, add_generation_prompt = False, tokenize = False)
is_sharegpt = False
except:
try:
messages = [
{"from": "human", "value": "Who are you?"},
]
tokenizer.apply_chat_template(messages, add_generation_prompt = False, tokenize = False)
is_sharegpt = True
except:
is_sharegpt = None
pass
pass
# Not ShareGPT or HF style - just return
if is_sharegpt is None: return chat_template
# Tokenize
messages = [
{"role": "user", "content": "Who are you?"} \
if not is_sharegpt else \
{"from": "human", "value": "Who are you?"}
]
no = tokenizer.apply_chat_template(messages, add_generation_prompt = False, tokenize = False)
yes = tokenizer.apply_chat_template(messages, add_generation_prompt = True, tokenize = False)
if no == yes:
# SAME?! That's not good! We check for add_generation_prompt
if "{% if add_generation_prompt %}" not in chat_template:
# Try fixing it by adding it
new_chat_template = _fix_chat_template(chat_template)
if "{% if add_generation_prompt %}" not in new_chat_template:
raise RuntimeError(
f"Unsloth: The tokenizer `{tokenizer.name_or_path}`\n"\
"does not have a {% if add_generation_prompt %} for generation purposes.\n"\
"Please file a bug report immediately - thanks!"
)
else:
logger.warning_once(
"Unsloth: We successfully patched the tokenizer to add a {% if add_generation_prompt %} to the chat_template.\n"\
"This is not a bug, but please notify the Unsloth maintainers - thanks!"
)
chat_template = new_chat_template
pass
else:
raise RuntimeError(
f"Unsloth: The tokenizer `{tokenizer.name_or_path}`\n"\
"has a {% if add_generation_prompt %} for generation purposes, but wasn't provided correctly.\n"\
"Please file a bug report immediately - thanks!"
)
pass
pass
return chat_template
pass
def check_tokenizer(
model,
tokenizer,
model_name = "unsloth/llama-2-7b-bnb-4bit",
model_max_length = 4096,
padding_side = "right",
token = None,
_reload = True,
):
# Checks tokenizer for out of bounds ids.
# Mainly a fix for https://huggingface.co/berkeley-nest/Starling-LM-7B-alpha
# where <sep> had token id=32002.
# See https://huggingface.co/berkeley-nest/Starling-LM-7B-alpha/discussions/25
# Seems like the Fast tokenizer in Rust breaks things!
# We ignore some of them!
if tokenizer.__repr__().split("(", 1)[0] in IGNORED_TOKENIZER_CHECKING:
return tokenizer
pass
max_embedding_size = model.model.embed_tokens.weight.shape[0]
added_tokens_fast = tokenizer.added_tokens_decoder
added_tokens_fast = {index : str(value) for index, value in added_tokens_fast.items()}
sorted_keys = sorted(added_tokens_fast)
added_tokens_fast = {key : added_tokens_fast[key] for key in sorted_keys}
for j, index in enumerate(added_tokens_fast.keys()):
if index >= max_embedding_size:
bad_indices = list(added_tokens_fast.keys ())[j:]
bad_tokens = list(added_tokens_fast.values())[j:]
if not _reload:
# Try removing the token
added_tokens = [str(x) for x in tokenizer.added_tokens_decoder.values()]
special_tokens = tokenizer.special_tokens_map
import itertools
special_tokens = frozenset(
itertools.chain.from_iterable(
[x] if type(x) is str else x for x in special_tokens.values()
)
)
can_be_removed1 = [x for x in bad_tokens if x not in special_tokens]
can_be_removed2 = [x for x in can_be_removed1 if x in tokenizer._added_tokens_encoder.keys()]
# Check of extra tokens can in fact we removed!
can_be_removed = \
(len(can_be_removed1) == len(bad_tokens)) and \
(len(can_be_removed2) == len(bad_tokens))
# Check if sep_token or other generic types
remove_generic = False
try_mapper = []
if not can_be_removed:
names = dir(tokenizer)
names = (x for x in names if x.endswith("_token") and x.count("_") == 1)
generic_tokens = [(x, getattr(tokenizer, x, None)) for x in names]
try_removal = []
for token in bad_tokens:
for (name_token, check_token) in generic_tokens:
if check_token == token:
try_removal.append(token)
try_mapper.append(name_token)
pass
pass
pass
# Recheck!
can_be_removed = (len(try_removal) == len(bad_tokens))
if can_be_removed: remove_generic = True
can_be_removed1 = bad_tokens
pass
if can_be_removed:
# Yes it can be fixed!
for j, bad_token in enumerate(can_be_removed1):
remove_id = tokenizer._added_tokens_encoder[bad_token]
del tokenizer._added_tokens_decoder[remove_id]
del tokenizer._added_tokens_encoder[bad_token]
if remove_generic and (try_removal[j] == bad_token):
# Remove sep token for example
setattr(tokenizer, try_mapper[j], None)
setattr(tokenizer, try_mapper[j] + "_id", None)
pass
pass
# Confirm 1 more time!
if max(tokenizer.added_tokens_decoder.keys()) < max_embedding_size:
logger.warning_once(
f"Unsloth loaded a broken tokenizer `{model_name}`, but managed to repair it!\n"\
f"Tokens {bad_tokens} with ids {bad_indices} exceeds the max vocab size of {max_embedding_size}.\n"\
"We removed these bad tokens. If you think this is incorrect, fix your tokenizer first."
)
return convert_to_fast_tokenizer(tokenizer)
pass
pass
# :( Failure
raise RuntimeError(
f"Unsloth tried to load `{model_name}`, but cannot succeed.\n"\
f"Tokens {bad_tokens} with ids {bad_indices} exceeds the max vocab size of {max_embedding_size}.\n"\
f"Fix your tokenizer since it'll perform out of bounds memory accesses."
)
pass
if IS_COLAB_ENVIRONMENT or IS_KAGGLE_ENVIRONMENT:
cache_dir = "huggingface_tokenizers_cache"
else:
cache_dir = None
pass
# Sometimes slow tokenizer does not work like Deepseek
try:
# Try slow tokenizer which can fix things!
tokenizer = AutoTokenizer.from_pretrained(
model_name,
model_max_length = model_max_length,
padding_side = padding_side,
token = token,
# Cannot just use use_fast = False as per https://twitter.com/danielhanchen/status/1789659394302718373
use_fast = False,
legacy = False,
from_slow = True,
cache_dir = cache_dir,
)
return check_tokenizer(
model = model,
tokenizer = tokenizer,
model_name = model_name,
model_max_length = model_max_length,
padding_side = padding_side,
token = token,
_reload = False,
)
break
except:
# Tokenizer has out of bounds issues and we can't
# load the slow tokenizer version :(
logger.warning_once(
"Unsloth: Tokenizer is most likely buggy, and Unsloth failed to repair it.\n"\
"It will still work, but beware of out of bounds memory accesses.\n"\
"Please file an issue on the model owner's repo about this issue."
)
return tokenizer
pass
pass
pass
return convert_to_fast_tokenizer(tokenizer)
pass
def check_nvidia():
# Unsloth doesn't work yet on AMD devices - we're working on it!
output = np.array([0,])
try:
output = subprocess.check_output("nvidia-smi --query-gpu=memory.used --format=csv", shell = True)
output = re.findall(rb'([\d]{1,})[\s]{1,}M', output)
output = np.array([int(x.decode('utf-8'))/1024 for x in output])
except:
if not torch.cuda.is_available():
raise RuntimeError("Unsloth: We do not support AMD / Intel machines yet - it is a work in progress!")
return output
pass
PRE_CHECK = check_nvidia()
import inspect
from inspect import getsource
import trl.trainer.sft_trainer
from trl.trainer.sft_trainer import *
from transformers.trainer import *
try:
from trl.trainer.sft_trainer import neftune_post_forward_hook
except:
def neftune_post_forward_hook(module, input, output):
"""
Implements the NEFTune forward pass for the model using forward hooks. Note this works only for
torch.nn.Embedding layers. This method is slightly adapted from the original source code
that can be found here: https://github.com/neelsjain/NEFTune
Simply add it to your model as follows:
```python
model = ...
model.embed_tokens.neftune_noise_alpha = 0.1
model.embed_tokens.register_forward_hook(neftune_post_forward_hook)
```
Args:
module (`torch.nn.Module`):
The embedding module where the hook is attached. Note that you need to set
`module.neftune_noise_alpha` to the desired noise alpha value.
input (`torch.Tensor`):
The input tensor to the model.
output (`torch.Tensor`):
The output tensor of the model (i.e. the embeddings).
"""
if module.training:
dims = torch.tensor(output.size(1) * output.size(2))
mag_norm = module.neftune_noise_alpha / torch.sqrt(dims)
output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
return output
pass
pass
def patch_trl_tokenizer_processing_class(trainer_name):
# New TRL removes tokenizer!
# We return it back!
exec(f"from trl import {trainer_name}", globals())
if str(eval(f"{trainer_name}").__name__).startswith("Unsloth"): return None
parameters = eval(f"inspect.signature({trainer_name}).parameters")
if "tokenizer" in parameters: return None
args = {
key : \
value.default \
if type(value.default) is not str else \
f"'{value.default}'" \
for key, value in parameters.items()
}
args["tokenizer"] = None
new_args = args.copy()
del new_args["tokenizer"]
del new_args["processing_class"]
new_args = ",\n".join(f"{' '*12}{key} = {key}" for key in new_args) + \
f",\n{' '*12}processing_class = tokenizer if tokenizer else processing_class"
args = ",\n".join(f"{' '*8}{key} = {value}" for key, value in args.items())
args = f"def __init__(\n" + f"{' '*8}self,\n" + args + "):"
args += f"\n{' '*8}\n{' '*8}super().__init__(\n{new_args}\n{' '*8})"
new_class = f"""class Unsloth{trainer_name}({trainer_name}):\n{' '*4}{args}\n"""
return new_class
pass
def patch_sft_trainer_tokenizer():
"""
Patches the trainer with changes
"""
for function_name, replacer in (
("_prepare_non_packed_dataloader", "def tokenize(element):",),
# ("_prepare_packed_dataloader", "if dataset_text_field is not None",),
):
function = getsource(eval(f"trl.trainer.sft_trainer.SFTTrainer.{function_name}"))
where = function.find("def")
function = function.split("\n")
function = "\n".join(x[where:] for x in function)
check_text = \
"\n"\
"if 'tokenizer' not in locals(): tokenizer = processing_class\n"\
"if 'formatting_func' not in locals(): raise RuntimeError('Unsloth: Please file a bug report - `formatting_func` does not exist!')\n"\
"if 'dataset_text_field' not in locals(): raise RuntimeError('Unsloth: Please file a bug report - `dataset_text_field` does not exist!')\n"\
"test_text = dataset[0][dataset_text_field] if (formatting_func is None and dataset_text_field is not None) else formatting_func(dataset[0])[0]\n"\
"chat_template = getattr(tokenizer, 'chat_template', None)\n"\
"chat_template = '' if chat_template is None else chat_template\n"\
"has_bos_token_already = (test_text.startswith(tokenizer.bos_token) or tokenizer.bos_token in chat_template) "\
"if getattr(tokenizer, 'bos_token', None) is not None else False\n"\
"add_special_tokens = False if has_bos_token_already else add_special_tokens\n\n"
check_text = check_text.split("\n")
check_text = "\n".join(" "*where + x for x in check_text)
function = function.replace(replacer, check_text + replacer)
exec(function, globals())
exec(f"trl.trainer.sft_trainer.SFTTrainer.{function_name} = {function_name}", globals())
pass
# Patch train with fix_untrained_tokens
for path_to_trainer in \
("sft_trainer.SFTTrainer", "dpo_trainer.DPOTrainer", "kto_trainer.KTOTrainer"):
function_name, replacer = "train", "if resume_from_checkpoint is False:"
function = getsource(eval(f"trl.trainer.{path_to_trainer}.{function_name}"))
where = function.find("def")
function = function.split("\n")
function = "\n".join(x[where:] for x in function)
check_text = \
"\n"\
"import subprocess, re, gc, numpy as np\n"\
"a = np.array([0,])\n"\
"try:\n"\
" a = subprocess.check_output('nvidia-smi --query-gpu=memory.used --format=csv', shell = True)\n"\
" a = re.findall(rb'([\\d]{1,})[\\s]{1,}M', a)\n"\
" a = np.array([int(x.decode('utf-8'))/1024 for x in a])\n"\
"except:\n"\
" if not torch.cuda.is_available():\n"\
" raise RuntimeError('Unsloth: We do not support AMD / Intel machines yet - it is a work in progress!')\n"\
"if ((a - PRE_CHECK) >= 1).sum() > 1:\n"\
" raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!')\n"\
"for _ in range(3):\n"\
" gc.collect()\n"\
" torch.cuda.empty_cache()\n"\
"pass\n"\
"\n"\
"tokenizer = self.processing_class if hasattr(self, 'processing_class') else self.tokenizer\n"\
"fix_untrained_tokens(self.model, tokenizer, self.train_dataset, IGNORED_TOKENIZER_NAMES, eps = 1e-16)\n\n"\
"fix_zero_training_loss(self.model, tokenizer, self.train_dataset)\n\n"
# Warn on gradient accumulation steps if it's used
check_text += \
"\n"\
"try:\n"\
" gradient_accumulation_steps = self.args.gradient_accumulation_steps\n"\
" if type(gradient_accumulation_steps) is int and gradient_accumulation_steps > 1:\n"\
" from transformers import __version__ as transformers_version\n"\
" from packaging.version import Version\n"\
" if Version(transformers_version) <= Version('4.45.2'):\n"\
" print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\\n'\\\n"\
" '`pip install --upgrade --no-cache-dir --no-deps unsloth transformers git+https://github.com/huggingface/trl.git`')\n"\
"except:\n"\
" pass\n"\
"\n\n"
# Add NEFTune since it doesn't seem to work?? We need to manually inject it
check_text += \
"\n"\
"if hasattr(self, 'neftune_hook_handle'):\n"\
" self.neftune_hook_handle.remove()\n"\
" if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle\n"\
"\n"\
"if getattr(self, 'neftune_noise_alpha', None) is not None:\n"\
" self.model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha\n"\
" self.neftune_hook_handle = self.model.get_input_embeddings().register_forward_hook(neftune_post_forward_hook)\n"\
"pass\n"\
"\n"
# Also DPO weirdly tokenizes non numeric columns? Delete them!
check_text += \
"\n"\
"column_names = set(self.train_dataset.column_names)\n"\
"check = ['chosen', 'rejected', 'prompt', 'chosen_input_ids', 'chosen_attention_mask',\n"\
" 'chosen_labels', 'rejected_input_ids', 'rejected_attention_mask', 'rejected_labels',\n"\
" 'prompt_input_ids', 'prompt_attention_mask']\n"\
"if all(x in column_names for x in check):\n"\
" self.train_dataset = self.train_dataset.remove_columns(['chosen', 'rejected', 'prompt'])\n"\
"del check, column_names\n"\
"\n"
check_text = check_text.split("\n")
check_text = "\n".join(" "*where + x for x in check_text)
function = function.replace(replacer, check_text + replacer)
exec(function, globals())
exec(f"trl.trainer.{path_to_trainer}.{function_name} = {function_name}", globals())
pass
pass
# Fix TRL trainers with removed tokenizer args (got replaced with processing_class)
for trainer_name in ("SFTTrainer", "DPOTrainer", "KTOTrainer"):
trainer_text = patch_trl_tokenizer_processing_class(trainer_name)
if trainer_text is None: continue
try:
exec(trainer_text, globals())
except:
raise RuntimeError(f"Unsloth: Please file a bug report! Error patching {trainer_name}")
exec(f"trl.trainer.{trainer_name} = Unsloth{trainer_name}", globals())
pass
# FInally patch TRL tokenizer things
patch_sft_trainer_tokenizer()
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional
from trl import SFTTrainer
try:
from trl import SFTConfig as TrainingArguments
except:
from transformers import TrainingArguments
pass
from . import is_bfloat16_supported
from unsloth_zoo.training_utils import unsloth_train as _unsloth_train
from packaging.version import Version
# Unsloth gradient accumulation fix:
from transformers import __version__ as transformers_version
if Version(transformers_version) > Version("4.45.2"):
def unsloth_train(trainer):
return trainer.train()
pass
else:
def unsloth_train(trainer):
print(
"Unsloth: Using our custom gradient accumulation fixed trainer, which is not feature complete.\n"\
"If you want to use our fix inside of HF, please update `transformers` to the latest version via:\n"\
'`pip uninstall transformers -y && pip install --upgrade --no-cache-dir "git+https://github.com/huggingface/transformers.git"`'
)
return _unsloth_train(trainer)
pass
pass
__all__ = [
"UnslothTrainingArguments",
"UnslothTrainer",
"unsloth_train",
]
@dataclass
class UnslothTrainingArguments(TrainingArguments):
embedding_learning_rate : Optional[float] = field(
default = None,
metadata = {"help" : "Different learning rates for embeddings and lm_head."}
)
pass
def _create_unsloth_optimizer(
model,
optimizer_cls,
optimizer_kwargs,
embedding_lr = 5e-5,
):
lr = optimizer_kwargs["lr"]
weight_decay = optimizer_kwargs.get("weight_decay", 0.0)
param_groups = \
{
"non_embeddings" : {},
"embeddings" : {},
}
for name, param in model.named_parameters():
if not param.requires_grad: continue
if name.endswith("modules_to_save.default.weight"):
partial_name = name[:-len(".modules_to_save.default.weight")]
partial_name = partial_name[partial_name.rfind(".")+1:]
print(f"Unsloth: Setting lr = {embedding_lr:.2e} instead of {lr:.2e} for {partial_name}.")
param_groups["embeddings"] [name] = param
else:
param_groups["non_embeddings"][name] = param
pass
pass
optimizer_grouped_parameters = [
{
"params" : list(param_groups["non_embeddings"].values()),
"weight_decay" : weight_decay,
"lr" : lr,
},
{
"params" : list(param_groups["embeddings"].values()),
"weight_decay" : weight_decay,
"lr" : embedding_lr,
},
]
optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
return optimizer
pass
class UnslothTrainer(SFTTrainer):
def create_optimizer(self):
embedding_learning_rate = getattr(self.args, "embedding_learning_rate", None)
if embedding_learning_rate is None: return super().create_optimizer()
if self.optimizer is None:
optimizer_cls, optimizer_kwargs = SFTTrainer.get_optimizer_cls_and_kwargs(self.args)
self.optimizer = _create_unsloth_optimizer(
self.model,
optimizer_cls,
optimizer_kwargs,
embedding_learning_rate,
)
pass
return self.optimizer
pass
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment