Unverified Commit 8eb26eb2 authored by Casper's avatar Casper Committed by GitHub
Browse files

Merge pull request #69 from VikParuchuri/main

Use typing classes over base types
parents 386fede8 4abfefc9
......@@ -4,7 +4,7 @@ import json
import torch
import torch.nn as nn
from tqdm import tqdm
from typing import List, Union
from typing import List, Union, Dict
from safetensors.torch import save_file
from awq.modules.act import ScaledActivation
from huggingface_hub import snapshot_download
......@@ -23,7 +23,7 @@ class BaseAWQForCausalLM(nn.Module):
self.model_type:str = model_type
self.is_quantized:bool = is_quantized
self.search_result = None
self.quant_config:dict = quant_config
self.quant_config: Dict = quant_config
def to(self, device: str):
return self.model.to(device)
......
from .base import BaseAWQForCausalLM
from typing import Dict
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer as OldFalconDecoderLayer, FalconForCausalLM, FalconAttention
class FalconAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "FalconDecoderLayer"
@staticmethod
def fuse_layers(model: FalconForCausalLM, quant_config:dict):
def fuse_layers(model: FalconForCausalLM, quant_config: Dict):
fuser = FalconFuser(model)
# TODO: Implement correctly fused modules for Falcon 40B and Falcon 180B
......
from .base import BaseAWQForCausalLM
from typing import Dict
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM
class LlamaAWQForCausalLM(BaseAWQForCausalLM):
......@@ -6,7 +7,7 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
max_new_tokens_key = "max_position_embeddings"
@staticmethod
def fuse_layers(model: LlamaForCausalLM, quant_config: dict):
def fuse_layers(model: LlamaForCausalLM, quant_config: Dict):
fuser = LlamaFuser(model, quant_config)
fuser.fuse_attention()
fuser.fuse_rmsnorm()
......
from .base import BaseAWQForCausalLM
from typing import Dict
from transformers.models.mpt.modeling_mpt import MptBlock as OldMptBlock, MptForCausalLM
class MptAWQForCausalLM(BaseAWQForCausalLM):
......@@ -6,7 +7,7 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
max_new_tokens_key = "max_seq_len"
@staticmethod
def fuse_layers(model: MptForCausalLM, quant_config:dict):
def fuse_layers(model: MptForCausalLM, quant_config: Dict):
fuser = MptFuser(model)
fuser.fuse_transformer()
......
import torch
import torch.nn as nn
from typing import List
from awq.modules.fused.block import MPTBlock, FalconDecoderLayer
from transformers.modeling_outputs import BaseModelOutputWithPast
......@@ -8,7 +9,7 @@ class MPTModel(nn.Module):
super().__init__()
self.vocab_size = vocab_size
self.wte = wte
self.blocks: list[MPTBlock] = nn.ModuleList(blocks)
self.blocks: List[MPTBlock] = nn.ModuleList(blocks)
self.norm_f = norm_f
self.attn_uses_sequence_id = False
self.prefix_lm = False
......@@ -36,7 +37,7 @@ class FalconModel(nn.Module):
super().__init__()
self.vocab_size = vocab_size
self.word_embeddings = word_embeddings
self.blocks: list[FalconDecoderLayer] = nn.ModuleList(blocks)
self.blocks: List[FalconDecoderLayer] = nn.ModuleList(blocks)
self.ln_f = ln_f
self.attn_uses_sequence_id = False
self.prefix_lm = False
......
......@@ -3,6 +3,7 @@ import logging
import functools
import torch.nn as nn
from tqdm import tqdm
from typing import Dict, List
from collections import defaultdict
from awq.utils.utils import clear_memory
from awq.utils.calib_data import get_calib_dataset
......@@ -62,7 +63,7 @@ class AwqQuantizer:
clear_memory()
# [STEP 2]: Compute and apply scale list
module_config: list[dict] = self.awq_model.get_layers_for_scaling(
module_config: List[Dict] = self.awq_model.get_layers_for_scaling(
self.modules[i], input_feat, self.module_kwargs
)
scales_list = [self._search_best_scale(self.modules[i], **layer) for layer in module_config]
......@@ -78,7 +79,7 @@ class AwqQuantizer:
self._apply_quant(self.modules[i], named_linears)
clear_memory()
def _apply_quant(self, module, named_linears: dict[str, nn.Linear]):
def _apply_quant(self, module, named_linears: Dict[str, nn.Linear]):
for name, linear_layer in named_linears.items():
# NOTE: small regression in perplexity if linear layer uses .cpu().float()
linear_layer = linear_layer.cuda().half()
......@@ -111,7 +112,7 @@ class AwqQuantizer:
clear_memory()
@torch.no_grad()
def _search_best_scale(self, module, prev_op, layers: list[nn.Linear], inp: torch.Tensor, module2inspect=None, kwargs={}):
def _search_best_scale(self, module, prev_op, layers: List[nn.Linear], inp: torch.Tensor, module2inspect=None, kwargs={}):
if module2inspect is None:
assert len(layers) == 1
module2inspect = layers[0]
......@@ -148,7 +149,7 @@ class AwqQuantizer:
return (get_op_name(module, prev_op), tuple([get_op_name(module, m) for m in layers]), best_scales)
def _compute_best_scale(self, x, w_max, x_max, module2inspect, linears2scale: list[nn.Linear],
def _compute_best_scale(self, x, w_max, x_max, module2inspect, linears2scale: List[nn.Linear],
fp16_output, kwargs={}):
"""
Compute loss and select best scales
......
import torch
import torch.nn as nn
from typing import Tuple
from typing import Tuple, List
from awq.modules.act import ScaledActivation
from awq.utils.module import get_op_by_name, set_op_by_name
from transformers.models.bloom.modeling_bloom import BloomGelu
......@@ -62,7 +62,7 @@ def apply_scale(module, scales_list, input_feat_dict=None):
scales.cpu()
@torch.no_grad()
def scale_ln_fcs(ln: nn.Linear, fcs: list[nn.Linear], scales: torch.Tensor):
def scale_ln_fcs(ln: nn.Linear, fcs: List[nn.Linear], scales: torch.Tensor):
if not isinstance(fcs, list):
fcs = [fcs]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment