"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "b7042a141c74d9f0c0658f828acffc2b9196d0a7"
Unverified Commit 8eb26eb2 authored by Casper's avatar Casper Committed by GitHub
Browse files

Merge pull request #69 from VikParuchuri/main

Use typing classes over base types
parents 386fede8 4abfefc9
...@@ -4,7 +4,7 @@ import json ...@@ -4,7 +4,7 @@ import json
import torch import torch
import torch.nn as nn import torch.nn as nn
from tqdm import tqdm from tqdm import tqdm
from typing import List, Union from typing import List, Union, Dict
from safetensors.torch import save_file from safetensors.torch import save_file
from awq.modules.act import ScaledActivation from awq.modules.act import ScaledActivation
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
...@@ -23,7 +23,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -23,7 +23,7 @@ class BaseAWQForCausalLM(nn.Module):
self.model_type:str = model_type self.model_type:str = model_type
self.is_quantized:bool = is_quantized self.is_quantized:bool = is_quantized
self.search_result = None self.search_result = None
self.quant_config:dict = quant_config self.quant_config: Dict = quant_config
def to(self, device: str): def to(self, device: str):
return self.model.to(device) return self.model.to(device)
......
from .base import BaseAWQForCausalLM from .base import BaseAWQForCausalLM
from typing import Dict
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer as OldFalconDecoderLayer, FalconForCausalLM, FalconAttention from transformers.models.falcon.modeling_falcon import FalconDecoderLayer as OldFalconDecoderLayer, FalconForCausalLM, FalconAttention
class FalconAWQForCausalLM(BaseAWQForCausalLM): class FalconAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "FalconDecoderLayer" layer_type = "FalconDecoderLayer"
@staticmethod @staticmethod
def fuse_layers(model: FalconForCausalLM, quant_config:dict): def fuse_layers(model: FalconForCausalLM, quant_config: Dict):
fuser = FalconFuser(model) fuser = FalconFuser(model)
# TODO: Implement correctly fused modules for Falcon 40B and Falcon 180B # TODO: Implement correctly fused modules for Falcon 40B and Falcon 180B
......
from .base import BaseAWQForCausalLM from .base import BaseAWQForCausalLM
from typing import Dict
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM
class LlamaAWQForCausalLM(BaseAWQForCausalLM): class LlamaAWQForCausalLM(BaseAWQForCausalLM):
...@@ -6,7 +7,7 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM): ...@@ -6,7 +7,7 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
max_new_tokens_key = "max_position_embeddings" max_new_tokens_key = "max_position_embeddings"
@staticmethod @staticmethod
def fuse_layers(model: LlamaForCausalLM, quant_config: dict): def fuse_layers(model: LlamaForCausalLM, quant_config: Dict):
fuser = LlamaFuser(model, quant_config) fuser = LlamaFuser(model, quant_config)
fuser.fuse_attention() fuser.fuse_attention()
fuser.fuse_rmsnorm() fuser.fuse_rmsnorm()
......
from .base import BaseAWQForCausalLM from .base import BaseAWQForCausalLM
from typing import Dict
from transformers.models.mpt.modeling_mpt import MptBlock as OldMptBlock, MptForCausalLM from transformers.models.mpt.modeling_mpt import MptBlock as OldMptBlock, MptForCausalLM
class MptAWQForCausalLM(BaseAWQForCausalLM): class MptAWQForCausalLM(BaseAWQForCausalLM):
...@@ -6,7 +7,7 @@ class MptAWQForCausalLM(BaseAWQForCausalLM): ...@@ -6,7 +7,7 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
max_new_tokens_key = "max_seq_len" max_new_tokens_key = "max_seq_len"
@staticmethod @staticmethod
def fuse_layers(model: MptForCausalLM, quant_config:dict): def fuse_layers(model: MptForCausalLM, quant_config: Dict):
fuser = MptFuser(model) fuser = MptFuser(model)
fuser.fuse_transformer() fuser.fuse_transformer()
......
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import List
from awq.modules.fused.block import MPTBlock, FalconDecoderLayer from awq.modules.fused.block import MPTBlock, FalconDecoderLayer
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
...@@ -8,7 +9,7 @@ class MPTModel(nn.Module): ...@@ -8,7 +9,7 @@ class MPTModel(nn.Module):
super().__init__() super().__init__()
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.wte = wte self.wte = wte
self.blocks: list[MPTBlock] = nn.ModuleList(blocks) self.blocks: List[MPTBlock] = nn.ModuleList(blocks)
self.norm_f = norm_f self.norm_f = norm_f
self.attn_uses_sequence_id = False self.attn_uses_sequence_id = False
self.prefix_lm = False self.prefix_lm = False
...@@ -36,7 +37,7 @@ class FalconModel(nn.Module): ...@@ -36,7 +37,7 @@ class FalconModel(nn.Module):
super().__init__() super().__init__()
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.word_embeddings = word_embeddings self.word_embeddings = word_embeddings
self.blocks: list[FalconDecoderLayer] = nn.ModuleList(blocks) self.blocks: List[FalconDecoderLayer] = nn.ModuleList(blocks)
self.ln_f = ln_f self.ln_f = ln_f
self.attn_uses_sequence_id = False self.attn_uses_sequence_id = False
self.prefix_lm = False self.prefix_lm = False
......
...@@ -3,6 +3,7 @@ import logging ...@@ -3,6 +3,7 @@ import logging
import functools import functools
import torch.nn as nn import torch.nn as nn
from tqdm import tqdm from tqdm import tqdm
from typing import Dict, List
from collections import defaultdict from collections import defaultdict
from awq.utils.utils import clear_memory from awq.utils.utils import clear_memory
from awq.utils.calib_data import get_calib_dataset from awq.utils.calib_data import get_calib_dataset
...@@ -62,7 +63,7 @@ class AwqQuantizer: ...@@ -62,7 +63,7 @@ class AwqQuantizer:
clear_memory() clear_memory()
# [STEP 2]: Compute and apply scale list # [STEP 2]: Compute and apply scale list
module_config: list[dict] = self.awq_model.get_layers_for_scaling( module_config: List[Dict] = self.awq_model.get_layers_for_scaling(
self.modules[i], input_feat, self.module_kwargs self.modules[i], input_feat, self.module_kwargs
) )
scales_list = [self._search_best_scale(self.modules[i], **layer) for layer in module_config] scales_list = [self._search_best_scale(self.modules[i], **layer) for layer in module_config]
...@@ -78,7 +79,7 @@ class AwqQuantizer: ...@@ -78,7 +79,7 @@ class AwqQuantizer:
self._apply_quant(self.modules[i], named_linears) self._apply_quant(self.modules[i], named_linears)
clear_memory() clear_memory()
def _apply_quant(self, module, named_linears: dict[str, nn.Linear]): def _apply_quant(self, module, named_linears: Dict[str, nn.Linear]):
for name, linear_layer in named_linears.items(): for name, linear_layer in named_linears.items():
# NOTE: small regression in perplexity if linear layer uses .cpu().float() # NOTE: small regression in perplexity if linear layer uses .cpu().float()
linear_layer = linear_layer.cuda().half() linear_layer = linear_layer.cuda().half()
...@@ -111,7 +112,7 @@ class AwqQuantizer: ...@@ -111,7 +112,7 @@ class AwqQuantizer:
clear_memory() clear_memory()
@torch.no_grad() @torch.no_grad()
def _search_best_scale(self, module, prev_op, layers: list[nn.Linear], inp: torch.Tensor, module2inspect=None, kwargs={}): def _search_best_scale(self, module, prev_op, layers: List[nn.Linear], inp: torch.Tensor, module2inspect=None, kwargs={}):
if module2inspect is None: if module2inspect is None:
assert len(layers) == 1 assert len(layers) == 1
module2inspect = layers[0] module2inspect = layers[0]
...@@ -148,7 +149,7 @@ class AwqQuantizer: ...@@ -148,7 +149,7 @@ class AwqQuantizer:
return (get_op_name(module, prev_op), tuple([get_op_name(module, m) for m in layers]), best_scales) return (get_op_name(module, prev_op), tuple([get_op_name(module, m) for m in layers]), best_scales)
def _compute_best_scale(self, x, w_max, x_max, module2inspect, linears2scale: list[nn.Linear], def _compute_best_scale(self, x, w_max, x_max, module2inspect, linears2scale: List[nn.Linear],
fp16_output, kwargs={}): fp16_output, kwargs={}):
""" """
Compute loss and select best scales Compute loss and select best scales
......
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Tuple from typing import Tuple, List
from awq.modules.act import ScaledActivation from awq.modules.act import ScaledActivation
from awq.utils.module import get_op_by_name, set_op_by_name from awq.utils.module import get_op_by_name, set_op_by_name
from transformers.models.bloom.modeling_bloom import BloomGelu from transformers.models.bloom.modeling_bloom import BloomGelu
...@@ -62,7 +62,7 @@ def apply_scale(module, scales_list, input_feat_dict=None): ...@@ -62,7 +62,7 @@ def apply_scale(module, scales_list, input_feat_dict=None):
scales.cpu() scales.cpu()
@torch.no_grad() @torch.no_grad()
def scale_ln_fcs(ln: nn.Linear, fcs: list[nn.Linear], scales: torch.Tensor): def scale_ln_fcs(ln: nn.Linear, fcs: List[nn.Linear], scales: torch.Tensor):
if not isinstance(fcs, list): if not isinstance(fcs, list):
fcs = [fcs] fcs = [fcs]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment