Unverified Commit 97cd0cd5 authored by flybird11111's avatar flybird11111 Committed by GitHub
Browse files

[shardformer] fix llama error when transformers upgraded. (#5055)

* fix-llama

* Update llama.py
parent 3e021547
import warnings
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union
import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
......@@ -13,6 +13,11 @@ from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
try:
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
LATEST_VERSION = True
except ImportError:
LATEST_VERSION = False
class LlamaPipelineForwards:
"""
......@@ -97,9 +102,14 @@ class LlamaPipelineForwards:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)
if LATEST_VERSION:
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)
else:
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)
if self.gradient_checkpointing and self.training:
if use_cache:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment