Unverified Commit ff841900 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[`BC 4.37 -> 4.38`] for Llama family, memory and speed (#29753)

* attempt to fix

* the actual fix that works with compilation!

* this?

* temporary update

* nit?

* dispatcg to memory efficient?

* update both models that have static cache support

* fix copies fix compile

* make sure fix

* fix cohere and gemma

* fix beams?

* nit

* slipped through the cracks

* nit

* nits

* update

* fix-copies

* skip failing tests

* nits
parent 8dd4ce6f
...@@ -274,9 +274,7 @@ class CohereAttention(nn.Module): ...@@ -274,9 +274,7 @@ class CohereAttention(nn.Module):
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
if cache_position is not None:
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask attn_weights = attn_weights + causal_mask
# upcast attention to fp32 # upcast attention to fp32
...@@ -559,8 +557,9 @@ class CohereSdpaAttention(CohereAttention): ...@@ -559,8 +557,9 @@ class CohereSdpaAttention(CohereAttention):
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask causal_mask = attention_mask
if attention_mask is not None and cache_position is not None: # if attention_mask is not None and cache_position is not None:
causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]] if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577. # Reference: https://github.com/pytorch/pytorch/issues/112577.
...@@ -692,7 +691,7 @@ class CoherePreTrainedModel(PreTrainedModel): ...@@ -692,7 +691,7 @@ class CoherePreTrainedModel(PreTrainedModel):
base_model_prefix = "model" base_model_prefix = "model"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["CohereDecoderLayer"] _no_split_modules = ["CohereDecoderLayer"]
_skip_keys_device_placement = ["past_key_values", "causal_mask"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_cache_class = True _supports_cache_class = True
...@@ -715,12 +714,6 @@ class CoherePreTrainedModel(PreTrainedModel): ...@@ -715,12 +714,6 @@ class CoherePreTrainedModel(PreTrainedModel):
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
) )
if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device:
causal_mask = torch.full(
(max_cache_len, max_cache_len), fill_value=True, device=self.device, dtype=torch.bool
)
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
for layer in self.model.layers: for layer in self.model.layers:
device = layer.input_layernorm.weight.device device = layer.input_layernorm.weight.device
if hasattr(self.config, "_pre_quantization_dtype"): if hasattr(self.config, "_pre_quantization_dtype"):
...@@ -899,7 +892,7 @@ class CohereModel(CoherePreTrainedModel): ...@@ -899,7 +892,7 @@ class CohereModel(CoherePreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, past_seen_tokens) causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -967,25 +960,27 @@ class CohereModel(CoherePreTrainedModel): ...@@ -967,25 +960,27 @@ class CohereModel(CoherePreTrainedModel):
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
def _update_causal_mask(self, attention_mask, input_tensor, past_seen_tokens): def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
if self.config._attn_implementation == "flash_attention_2": if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask: if attention_mask is not None and 0.0 in attention_mask:
return attention_mask return attention_mask
return None return None
batch_size, seq_length = input_tensor.shape[:2] dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
device = input_tensor.device
# support going beyond cached `max_position_embedding`
if seq_length > self.causal_mask.shape[-1]:
causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1)
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
# We use the current dtype to avoid any overflows
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
causal_mask = self.causal_mask[None, None, :, :].to(dtype=dtype, device=device) * min_dtype sequence_length = input_tensor.shape[1]
causal_mask = causal_mask.expand(batch_size, 1, -1, -1) if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache
target_length = self.config.max_position_embeddings
else: # dynamic cache
target_length = (
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1
)
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None: if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2: if attention_mask.dim() == 2:
...@@ -995,8 +990,8 @@ class CohereModel(CoherePreTrainedModel): ...@@ -995,8 +990,8 @@ class CohereModel(CoherePreTrainedModel):
elif attention_mask.dim() == 4: elif attention_mask.dim() == 4:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only. # cache. In that case, the 4D attention mask attends to the newest tokens only.
if attention_mask.shape[-2] < past_seen_tokens + input_tensor.shape[1]: if attention_mask.shape[-2] < cache_position[0] + sequence_length:
offset = past_seen_tokens offset = cache_position[0]
else: else:
offset = 0 offset = 0
mask_shape = attention_mask.shape mask_shape = attention_mask.shape
......
...@@ -279,10 +279,7 @@ class GemmaAttention(nn.Module): ...@@ -279,10 +279,7 @@ class GemmaAttention(nn.Module):
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it if attention_mask is not None: # no matter the length, we just slice it
if cache_position is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
else:
causal_mask = attention_mask
attn_weights = attn_weights + causal_mask attn_weights = attn_weights + causal_mask
# upcast attention to fp32 # upcast attention to fp32
...@@ -563,8 +560,8 @@ class GemmaSdpaAttention(GemmaAttention): ...@@ -563,8 +560,8 @@ class GemmaSdpaAttention(GemmaAttention):
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask causal_mask = attention_mask
if attention_mask is not None and cache_position is not None: if attention_mask is not None:
causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]] causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577. # Reference: https://github.com/pytorch/pytorch/issues/112577.
...@@ -836,12 +833,6 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -836,12 +833,6 @@ class GemmaModel(GemmaPreTrainedModel):
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False self.gradient_checkpointing = False
# Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
# NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`.
causal_mask = torch.full(
(config.max_position_embeddings, config.max_position_embeddings), fill_value=True, dtype=torch.bool
)
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
...@@ -901,7 +892,7 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -901,7 +892,7 @@ class GemmaModel(GemmaPreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, past_seen_tokens) causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -975,26 +966,27 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -975,26 +966,27 @@ class GemmaModel(GemmaPreTrainedModel):
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
def _update_causal_mask(self, attention_mask, input_tensor, past_seen_tokens): def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
if self.config._attn_implementation == "flash_attention_2": if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask: if attention_mask is not None and 0.0 in attention_mask:
return attention_mask return attention_mask
return None return None
batch_size, seq_length = input_tensor.shape[:2] dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
device = input_tensor.device
# support going beyond cached `max_position_embedding`
if seq_length > self.causal_mask.shape[-1]:
causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1)
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
# We use the current dtype to avoid any overflows
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache
target_length = self.config.max_position_embeddings
else: # dynamic cache
target_length = (
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1
)
causal_mask = self.causal_mask[None, None, :, :].to(dtype=dtype, device=device) * min_dtype causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
causal_mask = causal_mask.expand(batch_size, 1, -1, -1) if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None: if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2: if attention_mask.dim() == 2:
...@@ -1004,8 +996,8 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -1004,8 +996,8 @@ class GemmaModel(GemmaPreTrainedModel):
elif attention_mask.dim() == 4: elif attention_mask.dim() == 4:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only. # cache. In that case, the 4D attention mask attends to the newest tokens only.
if attention_mask.shape[-2] < past_seen_tokens + input_tensor.shape[1]: if attention_mask.shape[-2] < cache_position[0] + sequence_length:
offset = past_seen_tokens offset = cache_position[0]
else: else:
offset = 0 offset = 0
mask_shape = attention_mask.shape mask_shape = attention_mask.shape
......
...@@ -371,9 +371,7 @@ class LlamaAttention(nn.Module): ...@@ -371,9 +371,7 @@ class LlamaAttention(nn.Module):
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
if cache_position is not None:
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask attn_weights = attn_weights + causal_mask
# upcast attention to fp32 # upcast attention to fp32
...@@ -658,8 +656,9 @@ class LlamaSdpaAttention(LlamaAttention): ...@@ -658,8 +656,9 @@ class LlamaSdpaAttention(LlamaAttention):
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask causal_mask = attention_mask
if attention_mask is not None and cache_position is not None: # if attention_mask is not None and cache_position is not None:
causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]] if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577. # Reference: https://github.com/pytorch/pytorch/issues/112577.
...@@ -792,7 +791,7 @@ class LlamaPreTrainedModel(PreTrainedModel): ...@@ -792,7 +791,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
base_model_prefix = "model" base_model_prefix = "model"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["LlamaDecoderLayer"] _no_split_modules = ["LlamaDecoderLayer"]
_skip_keys_device_placement = ["past_key_values", "causal_mask"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_cache_class = True _supports_cache_class = True
...@@ -815,12 +814,6 @@ class LlamaPreTrainedModel(PreTrainedModel): ...@@ -815,12 +814,6 @@ class LlamaPreTrainedModel(PreTrainedModel):
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
) )
if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device:
causal_mask = torch.full(
(max_cache_len, max_cache_len), fill_value=True, device=self.device, dtype=torch.bool
)
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
for layer in self.model.layers: for layer in self.model.layers:
device = layer.input_layernorm.weight.device device = layer.input_layernorm.weight.device
if hasattr(self.config, "_pre_quantization_dtype"): if hasattr(self.config, "_pre_quantization_dtype"):
...@@ -934,12 +927,6 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -934,12 +927,6 @@ class LlamaModel(LlamaPreTrainedModel):
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False self.gradient_checkpointing = False
# Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
# NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`.
causal_mask = torch.full(
(config.max_position_embeddings, config.max_position_embeddings), fill_value=True, dtype=torch.bool
)
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
...@@ -1000,7 +987,7 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -1000,7 +987,7 @@ class LlamaModel(LlamaPreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, past_seen_tokens) causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -1068,25 +1055,27 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -1068,25 +1055,27 @@ class LlamaModel(LlamaPreTrainedModel):
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
def _update_causal_mask(self, attention_mask, input_tensor, past_seen_tokens): def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
if self.config._attn_implementation == "flash_attention_2": if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask: if attention_mask is not None and 0.0 in attention_mask:
return attention_mask return attention_mask
return None return None
batch_size, seq_length = input_tensor.shape[:2] dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
device = input_tensor.device
# support going beyond cached `max_position_embedding`
if seq_length > self.causal_mask.shape[-1]:
causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1)
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
# We use the current dtype to avoid any overflows
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
causal_mask = self.causal_mask[None, None, :, :].to(dtype=dtype, device=device) * min_dtype sequence_length = input_tensor.shape[1]
causal_mask = causal_mask.expand(batch_size, 1, -1, -1) if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache
target_length = self.config.max_position_embeddings
else: # dynamic cache
target_length = (
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1
)
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None: if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2: if attention_mask.dim() == 2:
...@@ -1096,8 +1085,8 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -1096,8 +1085,8 @@ class LlamaModel(LlamaPreTrainedModel):
elif attention_mask.dim() == 4: elif attention_mask.dim() == 4:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only. # cache. In that case, the 4D attention mask attends to the newest tokens only.
if attention_mask.shape[-2] < past_seen_tokens + input_tensor.shape[1]: if attention_mask.shape[-2] < cache_position[0] + sequence_length:
offset = past_seen_tokens offset = cache_position[0]
else: else:
offset = 0 offset = 0
mask_shape = attention_mask.shape mask_shape = attention_mask.shape
......
...@@ -283,7 +283,9 @@ class CohereModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix ...@@ -283,7 +283,9 @@ class CohereModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
) )
test_headmasking = False test_headmasking = False
test_pruning = False test_pruning = False
fx_compatible = True fx_compatible = (
False # FIXME @michaelbenayoun or @fxmarty from https://github.com/huggingface/transformers/pull/29753
)
# Need to use `0.8` instead of `0.9` for `test_cpu_offload` # Need to use `0.8` instead of `0.9` for `test_cpu_offload`
# This is because we are hitting edge cases with the causal_mask buffer # This is because we are hitting edge cases with the causal_mask buffer
......
...@@ -300,7 +300,9 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -300,7 +300,9 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
) )
test_headmasking = False test_headmasking = False
test_pruning = False test_pruning = False
fx_compatible = True fx_compatible = (
False # FIXME @michaelbenayoun or @fxmarty from https://github.com/huggingface/transformers/pull/29753
)
# Need to use `0.8` instead of `0.9` for `test_cpu_offload` # Need to use `0.8` instead of `0.9` for `test_cpu_offload`
# This is because we are hitting edge cases with the causal_mask buffer # This is because we are hitting edge cases with the causal_mask buffer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment