Unverified Commit 89f3d322 authored by Li Zhang's avatar Li Zhang Committed by GitHub
Browse files

Support TP for w4a16 (#262)

parent 4a60b45d
...@@ -157,18 +157,15 @@ def export(model_name: str, ...@@ -157,18 +157,15 @@ def export(model_name: str,
if key == 'w_qkv' and ext == 'bias': if key == 'w_qkv' and ext == 'bias':
attn_bias = True attn_bias = True
copy = False copy = False
if key in ['w1', 'w3', 'w13']: if key in ['w1', 'w3', 'w13', 'w_qkv']:
split_dim = -1 split_dim = -1
# TODO: move parameter extraction outside of the loop # TODO: move parameter extraction outside of the loop
if key == 'w1': if key == 'w1':
inter_size = max(inter_size, param_data.shape[-1]) inter_size = max(inter_size, param_data.shape[-1])
elif key == 'w13': elif key == 'w13':
inter_size = max(inter_size, param_data.shape[-1] // 2) inter_size = max(inter_size, param_data.shape[-1] // 2)
elif key == 'w_qkv':
split_dim = -2
elif key in ['w2', 'wo']: elif key in ['w2', 'wo']:
if ext in ['scales', 'zeros', 'bias']: if ext in ['bias']:
copy = True copy = True
else: else:
split_dim = 0 split_dim = 0
...@@ -243,7 +240,10 @@ def merge_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, tp: int, ...@@ -243,7 +240,10 @@ def merge_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, tp: int,
def reshape(x): def reshape(x):
return x.view(x.size(0), tp, -1) if dim == 2 else x.view(tp, -1) return x.view(x.size(0), tp, -1) if dim == 2 else x.view(tp, -1)
return torch.cat((reshape(q), reshape(k), reshape(v)), dim=-1) qkv = torch.cat((reshape(q), reshape(k), reshape(v)), dim=-1)
# (input_dim, head_num + 2 * kv_head_num)
return qkv.view(q.size(0), -1)
def deploy_llama(model_name: str, model_path: str, tokenizer_path: str, def deploy_llama(model_name: str, model_path: str, tokenizer_path: str,
...@@ -594,16 +594,16 @@ def deploy_awq(model_name: str, model_path: str, tokenizer_path: str, ...@@ -594,16 +594,16 @@ def deploy_awq(model_name: str, model_path: str, tokenizer_path: str,
sys.path.append(osp.join(lmdeploy_dir, 'lib')) sys.path.append(osp.join(lmdeploy_dir, 'lib'))
import _turbomind as _tm # noqa: E402 import _turbomind as _tm # noqa: E402
def transpose_qk(src: torch.Tensor): def transpose_qk_s4(src: torch.Tensor):
assert src.is_contiguous() assert src.is_contiguous()
dst = torch.zeros_like(src) dst = torch.zeros_like(src)
_tm.transpose_qk_s4_k_m8(src, dst, _tm.transpose_qk_s4_k_m8(src, dst,
src.size(-1) * 8, src.size(0), group_size) src.size(-1) * 8, src.size(0), group_size)
return dst return dst
def fuse_w1_w3(w1_qw: torch.Tensor, w1_qz: torch.Tensor, def fuse_w1_w3_s4(w1_qw: torch.Tensor, w1_qz: torch.Tensor,
w1_s: torch.Tensor, w3_qw: torch.Tensor, w1_s: torch.Tensor, w3_qw: torch.Tensor,
w3_qz: torch.Tensor, w3_s: torch.Tensor): w3_qz: torch.Tensor, w3_s: torch.Tensor):
def fuse(a: torch.Tensor, b: torch.Tensor): def fuse(a: torch.Tensor, b: torch.Tensor):
ab = torch.cat((a, b)).contiguous() ab = torch.cat((a, b)).contiguous()
...@@ -625,12 +625,16 @@ def deploy_awq(model_name: str, model_path: str, tokenizer_path: str, ...@@ -625,12 +625,16 @@ def deploy_awq(model_name: str, model_path: str, tokenizer_path: str,
assert qz.is_contiguous() assert qz.is_contiguous()
assert s.is_contiguous() assert s.is_contiguous()
_qw = torch.zeros_like(qw) _qw = torch.zeros_like(qw)
_sz = torch.zeros_like(s, dtype=torch.int32) _sz = torch.zeros_like(s, dtype=torch.int32) # half2
_ws = torch.zeros_like(s) _ws = torch.zeros_like(s)
_tm.convert_s4_k_m8(_qw, _sz, _ws, qw, s, qz, _tm.convert_s4_k_m8(_qw, _sz, _ws, qw, s, qz,
qw.size(-1) * 8, qw.size(0), group_size) qw.size(-1) * 8, qw.size(0), group_size)
return _qw, _sz return _qw, _sz
def tp_m_s4(x: torch.Tensor, tp: int):
return x.view(x.size(0) // 32, tp, -1, 128).permute(0, 2, 3,
1).contiguous()
attn_bias = False attn_bias = False
for i in range(num_layer): for i in range(num_layer):
...@@ -661,10 +665,10 @@ def deploy_awq(model_name: str, model_path: str, tokenizer_path: str, ...@@ -661,10 +665,10 @@ def deploy_awq(model_name: str, model_path: str, tokenizer_path: str,
except: # noqa: E722 except: # noqa: E722
pass pass
q_qw = transpose_qk(q_qw) q_qw = transpose_qk_s4(q_qw)
k_qw = transpose_qk(k_qw) k_qw = transpose_qk_s4(k_qw)
q_qz = transpose_qk(q_qz) q_qz = transpose_qk_s4(q_qz)
k_qz = transpose_qk(k_qz) k_qz = transpose_qk_s4(k_qz)
q_s = permute(q_s) q_s = permute(q_s)
k_s = permute(k_s) k_s = permute(k_s)
...@@ -674,6 +678,8 @@ def deploy_awq(model_name: str, model_path: str, tokenizer_path: str, ...@@ -674,6 +678,8 @@ def deploy_awq(model_name: str, model_path: str, tokenizer_path: str,
qkv_qw, qkv_sz = convert_s4(qkv_qw, qkv_qz, qkv_s, group_size) qkv_qw, qkv_sz = convert_s4(qkv_qw, qkv_qz, qkv_s, group_size)
qkv_qw = tp_m_s4(qkv_qw, tp)
model_params[f'layers.{i}.attention.w_qkv.qweight'] = qkv_qw model_params[f'layers.{i}.attention.w_qkv.qweight'] = qkv_qw
model_params[f'layers.{i}.attention.w_qkv.scales_zeros'] = qkv_sz model_params[f'layers.{i}.attention.w_qkv.scales_zeros'] = qkv_sz
...@@ -702,12 +708,14 @@ def deploy_awq(model_name: str, model_path: str, tokenizer_path: str, ...@@ -702,12 +708,14 @@ def deploy_awq(model_name: str, model_path: str, tokenizer_path: str,
w2_s = get_tensor(f'model.layers.{i}.mlp.down_proj.scales') w2_s = get_tensor(f'model.layers.{i}.mlp.down_proj.scales')
w3_s = get_tensor(f'model.layers.{i}.mlp.up_proj.scales') w3_s = get_tensor(f'model.layers.{i}.mlp.up_proj.scales')
w13_qw, w13_qz, w13_s = fuse_w1_w3(w1_qw, w1_qz, w1_s, w3_qw, w3_qz, w13_qw, w13_qz, w13_s = fuse_w1_w3_s4(w1_qw, w1_qz, w1_s, w3_qw, w3_qz,
w3_s) w3_s)
w13_qw, w13_sz = convert_s4(w13_qw, w13_qz, w13_s, group_size) w13_qw, w13_sz = convert_s4(w13_qw, w13_qz, w13_s, group_size)
w2_qw, w2_sz = convert_s4(w2_qw, w2_qz, w2_s, group_size) w2_qw, w2_sz = convert_s4(w2_qw, w2_qz, w2_s, group_size)
w13_qw = tp_m_s4(w13_qw, tp)
model_params[f'layers.{i}.feed_forward.w13.qweight'] = w13_qw model_params[f'layers.{i}.feed_forward.w13.qweight'] = w13_qw
model_params[f'layers.{i}.feed_forward.w13.scales_zeros'] = w13_sz model_params[f'layers.{i}.feed_forward.w13.scales_zeros'] = w13_sz
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment