Commit 1aa8aebd authored by Casper Hansen's avatar Casper Hansen
Browse files

Remove xformers

parent dd41a223
......@@ -7,8 +7,6 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM):
@staticmethod
def fuse_layers(model: FalconForCausalLM, quant_config:dict):
fuser = FalconFuser(model)
# fuser.fuse_attention()
# fuser.fuse_layernorm()
@staticmethod
def get_model_layers(model: FalconForCausalLM):
......@@ -65,11 +63,9 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM):
return layers
import torch
import xformers
from torch.nn import LayerNorm
from typing import List, Tuple
from awq.utils.utils import set_module_name
from xformers.triton.layer_norm import FusedLayerNorm
from awq.modules.fused.attn import QuantAttentionFused
class FalconFuser:
......@@ -97,15 +93,3 @@ class FalconFuser:
self.model.config.max_new_tokens
)
set_module_name(self.model, name, attn)
\ No newline at end of file
def fuse_layernorm(self):
xformers.triton.k_layer_norm._triton_layernorm_fp16_enabled = True
for name, module in self.layernorm_modules:
norm = FusedLayerNorm(module.weight.shape, eps=module.eps).to(module.weight.device)
# copy weights and bias
with torch.no_grad():
norm.weight = module.weight
norm.bias = module.bias
set_module_name(self.model, name, norm)
\ No newline at end of file
......@@ -45,8 +45,7 @@ requirements = [
"attributedict",
"protobuf",
"torchvision",
"tabulate",
"xformers"
"tabulate"
]
def get_include_dirs():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment