Commit e80663bb authored by Casper Hansen's avatar Casper Hansen
Browse files

Initialize with device

parent ac3e86df
...@@ -85,7 +85,8 @@ class MptFuser: ...@@ -85,7 +85,8 @@ class MptFuser:
self.model.config.n_heads, self.model.config.n_heads,
module.attn.Wqkv, module.attn.Wqkv,
module.attn.out_proj, module.attn.out_proj,
module.ffn module.ffn,
next(iter(module.state_dict().values())).device
) )
set_module_name(self.model, name, block) set_module_name(self.model, name, block)
\ No newline at end of file
...@@ -2,14 +2,14 @@ import torch.nn as nn ...@@ -2,14 +2,14 @@ import torch.nn as nn
from awq.modules.fused.attn import QuantAttentionFused from awq.modules.fused.attn import QuantAttentionFused
class MptBlock(nn.Module): class MptBlock(nn.Module):
def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mpt_mlp): def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mpt_mlp, dev):
super().__init__() super().__init__()
self.n_heads = n_heads self.n_heads = n_heads
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.attn = QuantAttentionFused(hidden_size, self.n_heads, qkv_layer, o_proj, dev="cuda:0", max_seq_len=8096, use_alibi=True).to("cuda:0") self.attn = QuantAttentionFused(hidden_size, self.n_heads, qkv_layer, o_proj, dev="cuda:0", max_seq_len=8096, use_alibi=True)
self.ffn = mpt_mlp.to("cuda:0") self.ffn = mpt_mlp
self.norm_1 = nn.LayerNorm(hidden_size, eps=1e-6).half().to("cuda:0") self.norm_1 = nn.LayerNorm(hidden_size, eps=1e-6).half().to(dev)
self.norm_2 = nn.LayerNorm(hidden_size, eps=1e-6).half().to("cuda:0") self.norm_2 = nn.LayerNorm(hidden_size, eps=1e-6).half().to(dev)
def forward( def forward(
self, hidden_states, past_key_value, attn_bias, attention_mask, is_causal self, hidden_states, past_key_value, attn_bias, attention_mask, is_causal
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment