Unverified Commit a74a87cc authored by MPU王荣胜's avatar MPU王荣胜 Committed by GitHub
Browse files

fix merge lora error

parent 7139e128
""" """
In this mixin, I use a different implementation than lora.py In this mixin, I use a different implementation than sat/model/finetune/lora.py
I just use a fake linear layer to replace any model with lora mixin. I just use a fake linear layer to replace any model with lora mixin.
""" """
...@@ -17,6 +17,20 @@ class HackLinear(nn.Linear): ...@@ -17,6 +17,20 @@ class HackLinear(nn.Linear):
if prefix + 'bias' in state_dict: if prefix + 'bias' in state_dict:
self.bias.data.copy_(state_dict[prefix+'bias']) self.bias.data.copy_(state_dict[prefix+'bias'])
class HackRowParallelLinear(RowParallelLinear):
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
if prefix + 'weight' in state_dict:
self.weight.data.copy_(state_dict[prefix+'weight'])
if prefix + 'bias' in state_dict:
self.bias.data.copy_(state_dict[prefix+'bias'])
class HackColumnParallelLinear(ColumnParallelLinear):
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
if prefix + 'weight' in state_dict:
self.weight.data.copy_(state_dict[prefix+'weight'])
if prefix + 'bias' in state_dict:
self.bias.data.copy_(state_dict[prefix+'bias'])
try: try:
from bitsandbytes.nn import LinearNF4 from bitsandbytes.nn import LinearNF4
def copy_nested_list(src, dst): def copy_nested_list(src, dst):
...@@ -48,40 +62,14 @@ class HackParameterList(nn.ParameterList): ...@@ -48,40 +62,14 @@ class HackParameterList(nn.ParameterList):
if prefix + str(i) in state_dict: if prefix + str(i) in state_dict:
self[i].data.copy_(state_dict[prefix+str(i)]) self[i].data.copy_(state_dict[prefix+str(i)])
class LoraLinear(nn.Module): map_cls = {
def __init__(self, in_dim, out_dim, r, lora_alpha=1., lora_dropout=0., qlora=False): nn.Linear: (HackLinear, {}),
super().__init__() ColumnParallelLinear: (HackColumnParallelLinear, {'gather_output': False}),
if lora_dropout and lora_dropout > 0: RowParallelLinear: (HackRowParallelLinear, {'input_is_parallel': True})
self.lora_dropout = nn.Dropout(p=lora_dropout) }
else:
self.lora_dropout = lambda x: x
self.r = r
self.lora_alpha = lora_alpha
self.scaling = self.lora_alpha / self.r
if qlora:
self.original = HackLinearNF4(in_dim, out_dim)
else:
self.original = HackLinear(in_dim, out_dim)
self.matrix_A = nn.Parameter(torch.empty((r, in_dim)))
self.matrix_B = nn.Parameter(torch.empty((out_dim, r)))
nn.init.kaiming_uniform_(self.matrix_A, a=math.sqrt(5))
nn.init.zeros_(self.matrix_B)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
# This is not a perfect version, becuase it doesn't handle errors and unexpected keys.
if prefix + 'weight' in state_dict:
# load from normal Linear
self.original._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
else:
# load from LoraLinear
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
def forward(self, x): class LoraLinear(nn.Module):
return self.original(x) + (self.lora_dropout(x) @ self.matrix_A.T @ self.matrix_B.T) * self.scaling def __init__(self, original_cls, partition, in_dim, out_dim, r, lora_alpha=1., lora_dropout=0., head_first=False, num_attention_heads=None, hidden_size_per_attention_head=None, qlora=False):
class LoraQKV(nn.Module):
def __init__(self, in_dim, out_dim, r, lora_alpha=1., lora_dropout=0., head_first=False, num_attention_heads=None, hidden_size_per_attention_head=None, qlora=False):
""" """
You can use safely with this layer, ONLY WHEN query_key_value output is query_key_value order. You can use safely with this layer, ONLY WHEN query_key_value output is query_key_value order.
If you use a different order like ChatGLM If you use a different order like ChatGLM
...@@ -95,15 +83,20 @@ class LoraQKV(nn.Module): ...@@ -95,15 +83,20 @@ class LoraQKV(nn.Module):
self.lora_alpha = lora_alpha self.lora_alpha = lora_alpha
self.scaling = self.lora_alpha / self.r self.scaling = self.lora_alpha / self.r
if qlora: if qlora:
try:
self.original = HackLinearNF4(in_dim, out_dim) self.original = HackLinearNF4(in_dim, out_dim)
except:
raise Exception('Build 4bit layer failed. You need to install the latest bitsandbytes. Try `pip install bitsandbytes`. If you still meet error after installation, try running `from bitsandbytes.nn import LinearNF4` with python and fix the error.')
else: else:
self.original = HackLinear(in_dim, out_dim) base_cls, kwargs = map_cls[original_cls]
self.matrix_A = HackParameterList([nn.Parameter(torch.empty((r, in_dim))) for _ in range(3)]) self.original = base_cls(in_dim, out_dim, **kwargs)
self.matrix_B = HackParameterList([nn.Parameter(torch.empty((out_dim // 3, r))) for _ in range(3)]) self.matrix_A = HackParameterList([nn.Parameter(torch.empty((r, in_dim))) for _ in range(partition)])
for i in range(3): self.matrix_B = HackParameterList([nn.Parameter(torch.empty((out_dim // partition, r))) for _ in range(partition)])
for i in range(partition):
nn.init.kaiming_uniform_(self.matrix_A[i], a=math.sqrt(5)) nn.init.kaiming_uniform_(self.matrix_A[i], a=math.sqrt(5))
nn.init.zeros_(self.matrix_B[i]) nn.init.zeros_(self.matrix_B[i])
self.head_first = head_first self.head_first = head_first
self.partition = partition
if head_first: if head_first:
assert num_attention_heads is not None and hidden_size_per_attention_head is not None, "You should set num_attention_heads and hidden_size_per_attention_head if you use head_first=True!" assert num_attention_heads is not None and hidden_size_per_attention_head is not None, "You should set num_attention_heads and hidden_size_per_attention_head if you use head_first=True!"
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
...@@ -121,14 +114,14 @@ class LoraQKV(nn.Module): ...@@ -121,14 +114,14 @@ class LoraQKV(nn.Module):
def forward(self, x): def forward(self, x):
mixed_raw_layer = self.original(x) mixed_raw_layer = self.original(x)
lora_outputs = [] lora_outputs = []
for i in range(3): for i in range(self.partition):
lora_outputs.append((self.lora_dropout(x) @ self.matrix_A[i].T @ self.matrix_B[i].T) * self.scaling) lora_outputs.append((self.lora_dropout(x) @ self.matrix_A[i].T @ self.matrix_B[i].T) * self.scaling)
if self.head_first: if self.head_first:
new_tensor_shape = lora_outputs[0].size()[:-1] + ( new_tensor_shape = lora_outputs[0].size()[:-1] + (
self.num_attention_heads, self.num_attention_heads,
self.hidden_size_per_attention_head, self.hidden_size_per_attention_head,
) )
for i in range(3): for i in range(self.partition):
lora_outputs[i] = lora_outputs[i].view(*new_tensor_shape) lora_outputs[i] = lora_outputs[i].view(*new_tensor_shape)
mixed_raw_layer = mixed_raw_layer + torch.cat(lora_outputs, -1).view(*mixed_raw_layer.size()) mixed_raw_layer = mixed_raw_layer + torch.cat(lora_outputs, -1).view(*mixed_raw_layer.size())
else: else:
...@@ -137,33 +130,35 @@ class LoraQKV(nn.Module): ...@@ -137,33 +130,35 @@ class LoraQKV(nn.Module):
return mixed_raw_layer return mixed_raw_layer
def replace_linear_with_lora(lin, base_cls, r, *args, **kw_args): def replace_linear_with_lora(lin, partition, r, *args, **kw_args):
# not supported for linear without bias for now # not supported for linear without bias for now
out_dim, in_dim = lin.weight.shape out_dim, in_dim = lin.weight.shape
return base_cls(in_dim, out_dim, r, *args, **kw_args) original_cls = type(lin)
del lin
return LoraLinear(original_cls, partition, in_dim, out_dim, r, *args, **kw_args)
def merge_linear_lora(lin): def merge_linear_lora(lin):
out_dim, in_dim = lin.original.weight.shape if lin.original.weight.data.dtype is not torch.uint8:
new_lin = nn.Linear(in_dim, out_dim) weight = lin.original.weight
new_lin.bias.data = lin.original.bias.data out_dim, in_dim = weight.shape
new_lin.weight.data = lin.original.weight.data + (lin.matrix_A.data.T.float() @ lin.matrix_B.data.T.float() * lin.scaling).T.to(lin.original.weight.data.dtype)
return new_lin
def merge_qkv_lora(lin):
out_dim, in_dim = lin.original.weight.shape
new_lin = nn.Linear(in_dim, out_dim) new_lin = nn.Linear(in_dim, out_dim)
else:
import bitsandbytes.functional as F
weight = F.dequantize_fp4(lin.original.weight.data, lin.original.weight.quant_state).to(lin.original.bias.data.dtype)
out_dim, in_dim = weight.shape
new_lin = HackLinearNF4(in_dim, out_dim)
new_lin.bias.data = lin.original.bias.data new_lin.bias.data = lin.original.bias.data
new_qkv = [] new_qkv = []
for i in range(3): for i in range(lin.partition):
new_qkv.append(lin.matrix_A[i].data.T.float() @ lin.matrix_B[i].data.T.float() * lin.scaling) new_qkv.append(lin.matrix_A[i].data.T.float() @ lin.matrix_B[i].data.T.float() * lin.scaling)
if lin.head_first: if lin.head_first:
ini_shape = new_qkv[0].shape ini_shape = new_qkv[0].shape
new_qkv = [x.view(ini_shape[0], lin.num_attention_heads, -1) for x in new_qkv] new_qkv = [x.view(ini_shape[0], lin.num_attention_heads, -1) for x in new_qkv]
new_qkv = torch.cat(new_qkv, -1).view(ini_shape[0], 3*ini_shape[1]) new_qkv = torch.cat(new_qkv, -1).view(ini_shape[0], lin.partition*ini_shape[1])
else: else:
new_qkv = torch.cat(new_qkv, -1) new_qkv = torch.cat(new_qkv, -1)
new_lin.weight.data = lin.original.weight.data + new_qkv.T.to(lin.original.weight.data.dtype) new_lin.weight.data = weight + new_qkv.T.to(lin.original.bias.data.dtype)
return new_lin return new_lin.cuda() if torch.cuda.is_available() else new_lin
class LoraMixin(BaseMixin): class LoraMixin(BaseMixin):
def __init__(self, def __init__(self,
...@@ -175,7 +170,8 @@ class LoraMixin(BaseMixin): ...@@ -175,7 +170,8 @@ class LoraMixin(BaseMixin):
head_first = False, head_first = False,
num_attention_heads = None, num_attention_heads = None,
hidden_size_per_attention_head = None, hidden_size_per_attention_head = None,
qlora = False): qlora = False,
cross_attention = True):
super().__init__() super().__init__()
self.r = r self.r = r
self.lora_alpha = lora_alpha self.lora_alpha = lora_alpha
...@@ -190,16 +186,18 @@ class LoraMixin(BaseMixin): ...@@ -190,16 +186,18 @@ class LoraMixin(BaseMixin):
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.hidden_size_per_attention_head = hidden_size_per_attention_head self.hidden_size_per_attention_head = hidden_size_per_attention_head
self.qlora = qlora self.qlora = qlora
self.cross_attention = cross_attention
def reinit(self, parent_model): def reinit(self, parent_model):
"""
only support self-attention part
not supported for cross-attention for now
"""
for i in self.layer_range: for i in self.layer_range:
print(f'replacing layer {i} with lora') print(f'replacing layer {i} attention with lora')
parent_model.transformer.layers[i].attention.dense = replace_linear_with_lora(parent_model.transformer.layers[i].attention.dense, LoraLinear, self.r, self.lora_alpha, self.lora_dropout, self.qlora) parent_model.transformer.layers[i].attention.dense = replace_linear_with_lora(parent_model.transformer.layers[i].attention.dense, 1, self.r, self.lora_alpha, self.lora_dropout, qlora=self.qlora)
parent_model.transformer.layers[i].attention.query_key_value = replace_linear_with_lora(parent_model.transformer.layers[i].attention.query_key_value, LoraQKV, self.r, self.lora_alpha, self.lora_dropout, head_first=self.head_first, num_attention_heads=self.num_attention_heads, hidden_size_per_attention_head=self.hidden_size_per_attention_head, qlora=self.qlora) parent_model.transformer.layers[i].attention.query_key_value = replace_linear_with_lora(parent_model.transformer.layers[i].attention.query_key_value, 3, self.r, self.lora_alpha, self.lora_dropout, head_first=self.head_first, num_attention_heads=self.num_attention_heads, hidden_size_per_attention_head=self.hidden_size_per_attention_head, qlora=self.qlora)
if self.cross_attention and parent_model.transformer.layers[i].is_decoder:
print(f'replacing layer {i} cross attention with lora')
parent_model.transformer.layers[i].cross_attention.dense = replace_linear_with_lora(parent_model.transformer.layers[i].cross_attention.dense, 1, self.r, self.lora_alpha, self.lora_dropout, qlora=self.qlora)
parent_model.transformer.layers[i].cross_attention.query = replace_linear_with_lora(parent_model.transformer.layers[i].cross_attention.query, 1, self.r, self.lora_alpha, self.lora_dropout, qlora=self.qlora)
parent_model.transformer.layers[i].cross_attention.key_value = replace_linear_with_lora(parent_model.transformer.layers[i].cross_attention.key_value, 2, self.r, self.lora_alpha, self.lora_dropout, qlora=self.qlora)
if self.qlora: if self.qlora:
print('replacing chatglm linear layer with 4bit') print('replacing chatglm linear layer with 4bit')
def replace_linear_with_nf4(model, name=None, cache={}): def replace_linear_with_nf4(model, name=None, cache={}):
...@@ -229,9 +227,14 @@ class LoraMixin(BaseMixin): ...@@ -229,9 +227,14 @@ class LoraMixin(BaseMixin):
def merge_lora(self): def merge_lora(self):
for i in self.layer_range: for i in self.layer_range:
print(f'merge layer {i} lora back to linear') print(f'merge layer {i} lora attention back to linear')
self.transformer.layers[i].attention.dense = merge_linear_lora(self.transformer.layers[i].attention.dense) self.transformer.layers[i].attention.dense = merge_linear_lora(self.transformer.layers[i].attention.dense)
self.transformer.layers[i].attention.query_key_value = merge_qkv_lora(self.transformer.layers[i].attention.query_key_value) self.transformer.layers[i].attention.query_key_value = merge_linear_lora(self.transformer.layers[i].attention.query_key_value)
if self.transformer.layers[i].is_decoder:
print(f'merge layer {i} lora cross attention back to linear')
self.transformer.layers[i].cross_attention.dense = merge_linear_lora(self.transformer.layers[i].cross_attention.dense)
self.transformer.layers[i].cross_attention.query = merge_linear_lora(self.transformer.layers[i].cross_attention.query)
self.transformer.layers[i].cross_attention.key_value = merge_linear_lora(self.transformer.layers[i].cross_attention.key_value)
if __name__ == '__main__': if __name__ == '__main__':
class Model(nn.Module): class Model(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment