Commit 7cf0d987 authored by Casper Hansen's avatar Casper Hansen
Browse files

Get correct devices

parent ab7d68e7
...@@ -102,7 +102,7 @@ class LlamaFuser: ...@@ -102,7 +102,7 @@ class LlamaFuser:
module.num_key_value_heads, module.num_key_value_heads,
qkv_layer, qkv_layer,
module.o_proj, module.o_proj,
qkv_layer.qweight.device, next(iter(qkv_layer.state_dict().values())).device,
self.model.config.max_new_tokens self.model.config.max_new_tokens
) )
set_module_name(self.model, name, attn) set_module_name(self.model, name, attn)
...@@ -119,7 +119,7 @@ class LlamaFuser: ...@@ -119,7 +119,7 @@ class LlamaFuser:
q_proj.in_features, q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features, q_proj.out_features + k_proj.out_features + v_proj.out_features,
q_proj.bias is not None, q_proj.bias is not None,
q_proj.qweight.device next(iter(module.state_dict().values())).device
) )
# replace buffers with real weights # replace buffers with real weights
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment