Unverified Commit 97cd0cd5 authored by flybird11111's avatar flybird11111 Committed by GitHub
Browse files

[shardformer] fix llama error when transformers upgraded. (#5055)

* fix-llama

* Update llama.py
parent 3e021547
import warnings import warnings
from typing import List, Optional, Tuple from typing import List, Optional, Tuple, Union
import torch import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
...@@ -13,6 +13,11 @@ from transformers.utils import logging ...@@ -13,6 +13,11 @@ from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
try:
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
LATEST_VERSION = True
except ImportError:
LATEST_VERSION = False
class LlamaPipelineForwards: class LlamaPipelineForwards:
""" """
...@@ -97,9 +102,14 @@ class LlamaPipelineForwards: ...@@ -97,9 +102,14 @@ class LlamaPipelineForwards:
attention_mask = torch.ones( attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
) )
attention_mask = self._prepare_decoder_attention_mask( if LATEST_VERSION:
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length attention_mask = _prepare_4d_causal_attention_mask(
) attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)
else:
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment