Unverified Commit 6e584070 authored by TechxGenus's avatar TechxGenus Committed by GitHub
Browse files

[`BC`] Fix BC for AWQ quant (#29965)

fix awq quant
parent 46d63681
......@@ -963,7 +963,7 @@ class CohereModel(CoherePreTrainedModel):
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache
if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache
target_length = self.config.max_position_embeddings
else: # dynamic cache
target_length = (
......
......@@ -971,7 +971,7 @@ class GemmaModel(GemmaPreTrainedModel):
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache
if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache
target_length = self.config.max_position_embeddings
else: # dynamic cache
target_length = (
......
......@@ -1064,7 +1064,7 @@ class LlamaModel(LlamaPreTrainedModel):
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache
if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache
target_length = self.config.max_position_embeddings
else: # dynamic cache
target_length = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment