Commit 7cf0d987 authored by Casper Hansen's avatar Casper Hansen
Browse files

Get correct devices

parent ab7d68e7
......@@ -102,7 +102,7 @@ class LlamaFuser:
module.num_key_value_heads,
qkv_layer,
module.o_proj,
qkv_layer.qweight.device,
next(iter(qkv_layer.state_dict().values())).device,
self.model.config.max_new_tokens
)
set_module_name(self.model, name, attn)
......@@ -119,7 +119,7 @@ class LlamaFuser:
q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features,
q_proj.bias is not None,
q_proj.qweight.device
next(iter(module.state_dict().values())).device
)
# replace buffers with real weights
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment