Unverified Commit 74d0fe44 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`core`] Add `is_hf_transformers` flag (#195)

parent 3b362c0d
...@@ -100,6 +100,7 @@ class QuantAttentionFused(nn.Module): ...@@ -100,6 +100,7 @@ class QuantAttentionFused(nn.Module):
self.use_alibi = use_alibi self.use_alibi = use_alibi
self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1")) self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.is_hf_transformers = False
# attention shapes for self attention # attention shapes for self attention
self.attention_shapes = get_attention_shapes( self.attention_shapes = get_attention_shapes(
...@@ -138,7 +139,8 @@ class QuantAttentionFused(nn.Module): ...@@ -138,7 +139,8 @@ class QuantAttentionFused(nn.Module):
# In case we re-generate, we need to refresh the starting position # In case we re-generate, we need to refresh the starting position
# to 0. We detect it by checking if `past_key_values` is set to None, # to 0. We detect it by checking if `past_key_values` is set to None,
# which indicates that we are on the first step of `generate()`. # which indicates that we are on the first step of `generate()`.
if"past_key_value" in kwargs and kwargs["past_key_value"] is None: # This is only applicable for `transformers` integration
if self.is_hf_transformers and "past_key_value" in kwargs and kwargs["past_key_value"] is None:
self.start_pos = 0 self.start_pos = 0
xqkv = self.qkv_proj(hidden_states) xqkv = self.qkv_proj(hidden_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment