Unverified Commit 4d806dba authored by jiqing-feng's avatar jiqing-feng Committed by GitHub
Browse files

Fix bug of _prepare_4d_attention_mask (#27847)

* use _prepare_4d_attention_mask

* fix comment
parent 75336c17
......@@ -29,7 +29,11 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask
from ...modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_attention_mask,
_prepare_4d_causal_attention_mask,
)
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
......@@ -78,9 +82,9 @@ def _get_unpad_data(attention_mask):
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
warnings.warn(
"Calling `transformers.models.llama.modeling_llama._prepare_4d_attention_mask` is deprecated and will be removed in v4.37. Use `transformers.modeling_attn_mask_utils.AttentionMaskConverter._prepare_4d_attention_mask"
"Calling `transformers.models.llama.modeling_llama._prepare_4d_attention_mask` is deprecated and will be removed in v4.37. Use `transformers.modeling_attn_mask_utils._prepare_4d_attention_mask"
)
return AttentionMaskConverter._prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
return _prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
def _make_causal_mask(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment