Commit 7cf0d987 authored by Casper Hansen's avatar Casper Hansen
Browse files

Get correct devices

parent ab7d68e7
......@@ -102,7 +102,7 @@ class LlamaFuser:
module.num_key_value_heads,
qkv_layer,
module.o_proj,
qkv_layer.qweight.device,
next(iter(qkv_layer.state_dict().values())).device,
self.model.config.max_new_tokens
)
set_module_name(self.model, name, attn)
......@@ -111,7 +111,7 @@ class LlamaFuser:
# get qkv and bias
q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
# create module
qkv_layer = WQLinear(
q_proj.w_bit,
......@@ -119,7 +119,7 @@ class LlamaFuser:
q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features,
q_proj.bias is not None,
q_proj.qweight.device
next(iter(module.state_dict().values())).device
)
# replace buffers with real weights
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment