Commit 1aa8aebd authored by Casper Hansen's avatar Casper Hansen
Browse files

Remove xformers

parent dd41a223
...@@ -7,8 +7,6 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM): ...@@ -7,8 +7,6 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM):
@staticmethod @staticmethod
def fuse_layers(model: FalconForCausalLM, quant_config:dict): def fuse_layers(model: FalconForCausalLM, quant_config:dict):
fuser = FalconFuser(model) fuser = FalconFuser(model)
# fuser.fuse_attention()
# fuser.fuse_layernorm()
@staticmethod @staticmethod
def get_model_layers(model: FalconForCausalLM): def get_model_layers(model: FalconForCausalLM):
...@@ -65,11 +63,9 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM): ...@@ -65,11 +63,9 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM):
return layers return layers
import torch import torch
import xformers
from torch.nn import LayerNorm from torch.nn import LayerNorm
from typing import List, Tuple from typing import List, Tuple
from awq.utils.utils import set_module_name from awq.utils.utils import set_module_name
from xformers.triton.layer_norm import FusedLayerNorm
from awq.modules.fused.attn import QuantAttentionFused from awq.modules.fused.attn import QuantAttentionFused
class FalconFuser: class FalconFuser:
...@@ -96,16 +92,4 @@ class FalconFuser: ...@@ -96,16 +92,4 @@ class FalconFuser:
next(iter(qkv_layer.state_dict().values())).device, next(iter(qkv_layer.state_dict().values())).device,
self.model.config.max_new_tokens self.model.config.max_new_tokens
) )
set_module_name(self.model, name, attn) set_module_name(self.model, name, attn)
\ No newline at end of file
def fuse_layernorm(self):
xformers.triton.k_layer_norm._triton_layernorm_fp16_enabled = True
for name, module in self.layernorm_modules:
norm = FusedLayerNorm(module.weight.shape, eps=module.eps).to(module.weight.device)
# copy weights and bias
with torch.no_grad():
norm.weight = module.weight
norm.bias = module.bias
set_module_name(self.model, name, norm)
\ No newline at end of file
...@@ -45,8 +45,7 @@ requirements = [ ...@@ -45,8 +45,7 @@ requirements = [
"attributedict", "attributedict",
"protobuf", "protobuf",
"torchvision", "torchvision",
"tabulate", "tabulate"
"xformers"
] ]
def get_include_dirs(): def get_include_dirs():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment