Unverified Commit ac638b37 authored by tpoisonooo's avatar tpoisonooo Committed by GitHub
Browse files

feat(deploy.py): support w pack qkv (#83)

* feat(deploy.py): support w pack qkv
parent e7d5e062
...@@ -125,7 +125,7 @@ def export(model_name: str, ...@@ -125,7 +125,7 @@ def export(model_name: str,
# export config and save it to {out_dir}/config.ini # export config and save it to {out_dir}/config.ini
vocab_size, bos_id, eos_id = tokenizer_info(tokenizer_path) vocab_size, bos_id, eos_id = tokenizer_info(tokenizer_path)
assert _vocab_size == vocab_size, \ assert _vocab_size >= vocab_size, \
f'different vocab size {_vocab_size} vs {vocab_size}' f'different vocab size {_vocab_size} vs {vocab_size}'
cfg = dict(llama=dict( cfg = dict(llama=dict(
model_name=model_name, model_name=model_name,
...@@ -323,14 +323,32 @@ def deploy_hf(model_name: str, model_path: str, tokenizer_path: str, ...@@ -323,14 +323,32 @@ def deploy_hf(model_name: str, model_path: str, tokenizer_path: str,
if name not in _params and name.find('bias'): if name not in _params and name.find('bias'):
return None return None
return _params[name].t() return _params[name].t()
w_pack = False
if 'model.layers.0.self_attn.W_pack.weight' in _params:
w_pack = True
for i in range(1000): for i in range(1000):
try: try:
# attention weights # attention weights
_qkvo = [f'model.layers.{i}.self_attn.{t}_proj' for t in 'qkvo']
for suffix in _suffixes: for suffix in _suffixes:
q, k, v, o = map(get_tensor_transposed, if w_pack:
map(('{}.' + suffix).format, _qkvo)) _qkvo = [f'model.layers.{i}.self_attn.{t}' for t in ['W_pack', 'o_proj']]
qkv, o = map(get_tensor_transposed,
map(('{}.' + suffix).format, _qkvo))
if qkv is None:
continue
_shape = qkv.shape[1] // 3
_qkv = torch.split(qkv, [_shape, _shape, _shape], dim=1)
q = _qkv[0]
k = _qkv[1]
v = _qkv[2]
else:
_qkvo = [f'model.layers.{i}.self_attn.{t}_proj' for t in 'qkvo']
q, k, v, o = map(get_tensor_transposed,
map(('{}.' + suffix).format, _qkvo))
if q is None: if q is None:
continue continue
# q, k has different layout for fb & hf, convert to fb's # q, k has different layout for fb & hf, convert to fb's
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment