"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "de52437cdb56243fdf44465d877a9dd2c548eb1a"
Unverified Commit ae24f424 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`core`] Make AutoAWQ fused modules compatible with HF transformers (#244)

parent 09db054f
...@@ -6,12 +6,22 @@ from torch.nn import functional as F ...@@ -6,12 +6,22 @@ from torch.nn import functional as F
from awq.modules.fused.cache import WindowedCache from awq.modules.fused.cache import WindowedCache
from awq.utils.fused_utils import get_attention_shapes from awq.utils.fused_utils import get_attention_shapes
try: try:
import ft_inference_engine import ft_inference_engine
FT_INSTALLED = True FT_INSTALLED = True
except: except:
FT_INSTALLED = False FT_INSTALLED = False
HF_NEW_CACHE_FORMAT = False
import transformers
# https://github.com/huggingface/transformers/pull/26681 introduced a new cache format
HF_NEW_CACHE_FORMAT = hasattr(transformers, "cache_utils")
if HF_NEW_CACHE_FORMAT:
from transformers.cache_utils import DynamicCache
class RoPE(nn.Module): class RoPE(nn.Module):
def __init__(self, hidden_size, n_heads, max_seq_len, device): def __init__(self, hidden_size, n_heads, max_seq_len, device):
super(RoPE, self).__init__() super(RoPE, self).__init__()
...@@ -223,4 +233,10 @@ class QuantAttentionFused(nn.Module): ...@@ -223,4 +233,10 @@ class QuantAttentionFused(nn.Module):
# we pass a dummy past kv cache for transformers to be able to retrieve the correct info # we pass a dummy past kv cache for transformers to be able to retrieve the correct info
# about past key length # about past key length
past_key_value = [torch.zeros(1, 1, self.start_pos, 1)] past_key_value = [torch.zeros(1, 1, self.start_pos, 1)]
if HF_NEW_CACHE_FORMAT and self.is_hf_transformers:
new_cache = DynamicCache()
new_cache.update(past_key_value[0], past_key_value[0], layer_idx=0)
past_key_value = new_cache
return attn_output, attention_weight, past_key_value return attn_output, attention_weight, past_key_value
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment