Commit 7cf0d987 authored by Casper Hansen's avatar Casper Hansen
Browse files

Get correct devices

parent ab7d68e7
...@@ -102,7 +102,7 @@ class LlamaFuser: ...@@ -102,7 +102,7 @@ class LlamaFuser:
module.num_key_value_heads, module.num_key_value_heads,
qkv_layer, qkv_layer,
module.o_proj, module.o_proj,
qkv_layer.qweight.device, next(iter(qkv_layer.state_dict().values())).device,
self.model.config.max_new_tokens self.model.config.max_new_tokens
) )
set_module_name(self.model, name, attn) set_module_name(self.model, name, attn)
...@@ -111,7 +111,7 @@ class LlamaFuser: ...@@ -111,7 +111,7 @@ class LlamaFuser:
# get qkv and bias # get qkv and bias
q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
# create module # create module
qkv_layer = WQLinear( qkv_layer = WQLinear(
q_proj.w_bit, q_proj.w_bit,
...@@ -119,7 +119,7 @@ class LlamaFuser: ...@@ -119,7 +119,7 @@ class LlamaFuser:
q_proj.in_features, q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features, q_proj.out_features + k_proj.out_features + v_proj.out_features,
q_proj.bias is not None, q_proj.bias is not None,
q_proj.qweight.device next(iter(module.state_dict().values())).device
) )
# replace buffers with real weights # replace buffers with real weights
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment