Commit 0aa4a596 authored by Casper Hansen's avatar Casper Hansen
Browse files

Fix model references

parent 1df0136e
...@@ -6,8 +6,8 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM): ...@@ -6,8 +6,8 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
max_new_tokens_key = "max_position_embeddings" max_new_tokens_key = "max_position_embeddings"
@staticmethod @staticmethod
def fuse_layers(awq_model: BaseAWQForCausalLM): def fuse_layers(model: LlamaForCausalLM):
fuser = LlamaFuser(awq_model) fuser = LlamaFuser(model)
fuser.fuse_attention() fuser.fuse_attention()
fuser.fuse_rmsnorm() fuser.fuse_rmsnorm()
fuser.fuse_mlp() fuser.fuse_mlp()
...@@ -75,9 +75,8 @@ from awq.modules.fused_attn import QuantLlamaAttention ...@@ -75,9 +75,8 @@ from awq.modules.fused_attn import QuantLlamaAttention
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm, LlamaMLP from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm, LlamaMLP
class LlamaFuser: class LlamaFuser:
def __init__(self, awq_model: BaseAWQForCausalLM): def __init__(self, model):
self.awq_model = awq_model self.model = model
self.model = awq_model.model
self.attention_modules: List[Tuple[str, LlamaAttention]] = [ self.attention_modules: List[Tuple[str, LlamaAttention]] = [
(name, module) for name, module in self.model.named_modules() (name, module) for name, module in self.model.named_modules()
...@@ -103,7 +102,7 @@ class LlamaFuser: ...@@ -103,7 +102,7 @@ class LlamaFuser:
qkv_layer, qkv_layer,
module.o_proj, module.o_proj,
qkv_layer.qweight.device, qkv_layer.qweight.device,
self.awq_model.model.config.max_new_tokens self.model.config.max_new_tokens
) )
set_module_name(self.model, name, attn) set_module_name(self.model, name, attn)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment