Unverified Commit 0d0a5820 authored by Wang Binluo's avatar Wang Binluo Committed by GitHub
Browse files

[shardformer] update transformers (#5583)

* flash_attention forward upgrade

* llama_model_forward

* remove useless comment

* update the requirements.txt

* add the transformers version requirements

* remove the LATEST VERSION try

* [shardformer] update bloom model (#5518)

* update bloom model

* remove the version restriction

* [shardformer] update_falcon (#5520)

* [shardformer] update mistral model (#5511)

* [shardformer] update gpt2 (#5502)

* [shardformer] update gptj model (#5503)

* [shardformer] update opt (#5522)

* [shardformer] update t5 model (#5524)

* [shardformer] update whisper model (#5529)

* [shardformer] update vit model (#5530)

* update vit model

* remove the output_hidden_states

* [shardformer] fix llama modeling

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [zero] support multiple (partial) backward passes (#5596)

* [zero] support multiple (partial) backward passes

* [misc] update requirements

* [zero] support multiple (partial) backward passes (#5596)

* [zero] support multiple (partial) backward passes

* [misc] update requirements

* fix conflicts

* [doc] fix ColossalMoE readme (#5599)

* fix readme

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* merge with main

* merge with main

* llama_model_forward

* remove useless comment

* remove the LATEST VERSION try

* [shardformer] update bloom model (#5518)

* update bloom model

* remove the version restriction

* [shardformer] update mistral model (#5511)

* [shardformer] update opt (#5522)

* [shardformer] update whisper model (#5529)

* [shardformer] fix llama modeling

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* [hotfix] Fix examples no pad token & auto parallel codegen bug; (#5606)

* fix no pad token bug

* fixed some auto parallel codegen bug, but might not run on torch 2.1

---------
Co-authored-by: default avatarEdenzzzz <wtan45@wisc.edu>

* [shardformer] fix pipeline grad ckpt (#5620)

* [shardformer] fix pipeline grad ckpt

* [shardformer] fix whisper (#5628)

* [test] fix llama model test

* fix the opt upgrade (#5634)

* [shardformer] fix attn replacement (#5636)

* [shardformer] update flashattention replacement (#5637)

* update transformers

update transformers

fix

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [test] fix llama test (#5638)

* [gemini] fix buffer cast (#5639)

* Fix shardformer upgrade (#5640)

* fix llama model

* fix the mistral

* fix the shardformer model

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [shardformer]support pipeline parallelism for mistral. (#5642)

* [shardformer] fix attn replacement (#5636)

* [shardformer] update flashattention replacement (#5637)

* update transformers

update transformers

fix

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] Support LLaMA-3 CPT and ST (#5619)

* support LLaMA-3

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Run pre-commit

---------
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [exampe] update llama example (#5626)

* [plugin] support dp inside for hybriad parallel

* [example] update llama benchmark

* [example] update llama benchmark

* [example] update llama readme

* [example] update llama readme

* [example] llama3 (#5631)

* release llama3

* [release] llama3

* [release] llama3

* [release] llama3

* [release] llama3

* [test] fix llama test (#5638)

* [gemini] fix buffer cast (#5639)

* support pp for mistral

* fix

* fix

fix

fix

* fix

---------
Co-authored-by: default avatarHongxin Liu <lhx0217@gmail.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTong Li <tong.li352711588@gmail.com>
Co-authored-by: default avatarbinmakeswell <binmakeswell@gmail.com>

---------
Co-authored-by: default avatarHongxin Liu <lhx0217@gmail.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarCamille Zhong <44392324+Camille7777@users.noreply.github.com>
Co-authored-by: default avatarEdenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: default avatarEdenzzzz <wtan45@wisc.edu>
Co-authored-by: default avatarflybird11111 <1829166702@qq.com>
Co-authored-by: default avatarTong Li <tong.li352711588@gmail.com>
Co-authored-by: default avatarbinmakeswell <binmakeswell@gmail.com>
parent f4c5aafe
...@@ -6,6 +6,7 @@ import torch.distributed as dist ...@@ -6,6 +6,7 @@ import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.nn import functional as F from torch.nn import functional as F
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import ( from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions, CausalLMOutputWithCrossAttentions,
...@@ -205,12 +206,13 @@ class BloomPipelineForwards: ...@@ -205,12 +206,13 @@ class BloomPipelineForwards:
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
# causal_mask is constructed every stage and its input is passed through different stages # causal_mask is constructed every stage and its input is passed through different stages
causal_mask = self._prepare_attn_mask( causal_mask = _prepare_4d_causal_attention_mask(
attention_mask, attention_mask,
input_shape=(batch_size, seq_length), input_shape=(batch_size, seq_length),
inputs_embeds=hidden_states,
past_key_values_length=past_key_values_length, past_key_values_length=past_key_values_length,
) )
causal_mask = causal_mask.bool()
# split the input tensor along sequence dimension # split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
if shard_config and shard_config.enable_sequence_parallelism: if shard_config and shard_config.enable_sequence_parallelism:
...@@ -227,21 +229,15 @@ class BloomPipelineForwards: ...@@ -227,21 +229,15 @@ class BloomPipelineForwards:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
outputs = self._gradient_checkpointing_func(
def create_custom_forward(module): block.__call__,
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, hidden_states,
alibi, alibi,
causal_mask, causal_mask,
layer_past, layer_past,
head_mask[i], head_mask[i],
use_cache,
output_attentions,
) )
else: else:
outputs = block( outputs = block(
...@@ -1002,11 +998,13 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): ...@@ -1002,11 +998,13 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
causal_mask = self._prepare_attn_mask( causal_mask = _prepare_4d_causal_attention_mask(
attention_mask, attention_mask,
input_shape=(batch_size, seq_length), input_shape=(batch_size, seq_length),
inputs_embeds=hidden_states,
past_key_values_length=past_key_values_length, past_key_values_length=past_key_values_length,
) )
causal_mask = causal_mask.bool()
# split the input tensor along sequence dimension # split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
hidden_states = split_forward_gather_backward( hidden_states = split_forward_gather_backward(
...@@ -1018,21 +1016,15 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): ...@@ -1018,21 +1016,15 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
outputs = self._gradient_checkpointing_func(
def create_custom_forward(module): block.__call__,
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, hidden_states,
alibi, alibi,
causal_mask, causal_mask,
layer_past, layer_past,
head_mask[i], head_mask[i],
use_cache,
output_attentions,
) )
else: else:
outputs = block( outputs = block(
......
import math
import warnings
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.modeling_outputs import ( from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions, CausalLMOutputWithCrossAttentions,
...@@ -99,11 +106,17 @@ def get_tp_falcon_decoder_layer_forward(): ...@@ -99,11 +106,17 @@ def get_tp_falcon_decoder_layer_forward():
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor], alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False, use_cache: bool = False,
output_attentions: bool = False, output_attentions: bool = False,
**kwargs,
): ):
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
residual = hidden_states residual = hidden_states
if self.config.new_decoder_architecture: if self.config.new_decoder_architecture:
...@@ -117,10 +130,12 @@ def get_tp_falcon_decoder_layer_forward(): ...@@ -117,10 +130,12 @@ def get_tp_falcon_decoder_layer_forward():
attention_layernorm_out, attention_layernorm_out,
layer_past=layer_past, layer_past=layer_past,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids,
alibi=alibi, alibi=alibi,
head_mask=head_mask, head_mask=head_mask,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
**kwargs,
) )
attention_output = attn_outputs[0] attention_output = attn_outputs[0]
...@@ -154,87 +169,6 @@ def get_tp_falcon_decoder_layer_forward(): ...@@ -154,87 +169,6 @@ def get_tp_falcon_decoder_layer_forward():
return forward return forward
def get_falcon_flash_attention_forward():
try:
from xformers.ops import memory_efficient_attention as me_attention
except:
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
from transformers.models.falcon.modeling_falcon import FalconAttention
def forward(
self: FalconAttention,
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
batch_size, query_length, _, _ = query_layer.shape
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim)
key_layer = key_layer.transpose(1, 2).reshape(
batch_size * num_kv_heads,
query_length,
self.head_dim,
)
value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size * self.num_heads, kv_length, head_dim]
# - value: [batch_size * self.num_heads, kv_length, head_dim]
key_layer = torch.cat((past_key, key_layer), dim=1)
value_layer = torch.cat((past_value, value_layer), dim=1)
_, kv_length, _ = key_layer.shape
if use_cache:
present = (key_layer, value_layer)
else:
present = None
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim).transpose(1, 2).contiguous()
key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim).transpose(1, 2).contiguous()
value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim).transpose(1, 2).contiguous()
if alibi is not None:
attention_mask_float = (
attention_mask_float + alibi.view(batch_size, self.num_heads, 1, kv_length) * self.beta
)
batch_size, src_len = query_layer_.size()[0], query_layer_.size()[1]
tgt_len = key_layer_.size()[1]
attention_mask_float = attention_mask_float.expand(batch_size, self.num_heads, src_len, tgt_len).contiguous()
context_layer = me_attention(
query_layer_,
key_layer_,
value_layer_,
attn_bias=attention_mask_float,
scale=self.inv_norm_factor,
p=self.attention_dropout.p,
)
batch_size, seq_length, _, _ = context_layer.shape
context_layer = context_layer.reshape(batch_size, seq_length, -1)
output_tensor = self.dense(context_layer)
return output_tensor, present
return forward
class FalconPipelineForwards: class FalconPipelineForwards:
""" """
This class serves as a micro library for falcon pipeline forwards. This class serves as a micro library for falcon pipeline forwards.
...@@ -246,6 +180,7 @@ class FalconPipelineForwards: ...@@ -246,6 +180,7 @@ class FalconPipelineForwards:
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
...@@ -274,17 +209,6 @@ class FalconPipelineForwards: ...@@ -274,17 +209,6 @@ class FalconPipelineForwards:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if past_key_values is None:
past_key_values = tuple([None] * len(self.h))
else:
past_key_values = self._convert_to_rw_cache(past_key_values)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
# case: First stage of training # case: First stage of training
if stage_manager.is_first_stage(): if stage_manager.is_first_stage():
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
...@@ -295,16 +219,22 @@ class FalconPipelineForwards: ...@@ -295,16 +219,22 @@ class FalconPipelineForwards:
batch_size, seq_length, _ = inputs_embeds.shape batch_size, seq_length, _ = inputs_embeds.shape
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids) inputs_embeds = self.word_embeddings(input_ids)
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
input_shape = hidden_states.shape[:-1] input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape batch_size, seq_length = input_shape
if past_key_values is None:
past_key_values = tuple([None] * len(self.h))
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
presents = () if use_cache else None presents = () if use_cache else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
...@@ -312,23 +242,81 @@ class FalconPipelineForwards: ...@@ -312,23 +242,81 @@ class FalconPipelineForwards:
# Compute alibi tensor: check build_alibi_tensor documentation # Compute alibi tensor: check build_alibi_tensor documentation
past_key_values_length = 0 past_key_values_length = 0
if past_key_values[0] is not None: if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[1] # 1 because RW-cache, not standard format past_key_values_length = past_key_values[0][0].shape[-2]
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=hidden_states.device)
else:
attention_mask = attention_mask.to(hidden_states.device)
if self.use_alibi: if self.use_alibi:
alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) mask = (
torch.ones(
(batch_size, seq_length + past_key_values_length), device=inputs_embeds.device, dtype=torch.long
)
if attention_mask is None
else attention_mask
)
alibi = build_alibi_tensor(mask, self.num_heads, dtype=hidden_states.dtype)
else: else:
alibi = None alibi = None
if position_ids is None:
causal_mask = self._prepare_attn_mask( device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0)
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
if alibi is None:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask, attention_mask,
input_shape=(batch_size, seq_length), (batch_size, seq_length),
past_key_values_length=past_key_values_length, inputs_embeds,
past_key_values_length,
)
elif head_mask is None:
alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
attention_mask_2d = attention_mask
# We don't call _prepare_4d_causal_attention_mask_for_sdpa as we need to mask alibi using the 4D attention_mask untouched.
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
) )
# We take care to integrate alibi bias in the attention_mask here.
if attention_mask_2d is None:
attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads)
else:
attention_mask = torch.masked_fill(
alibi / math.sqrt(self.config.hidden_size // self.num_heads),
attention_mask < -1,
torch.finfo(alibi.dtype).min,
)
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
if seq_length > 1:
attention_mask = AttentionMaskConverter._unmask_unattended(
attention_mask, attention_mask_2d, unmasked_value=0.0
)
else:
# PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case.
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
start_idx, end_idx = stage_index[0], stage_index[1] start_idx, end_idx = stage_index[0], stage_index[1]
for i, (block, layer_past) in enumerate( for i, (block, layer_past) in enumerate(
zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx]), start=start_idx zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx]), start=start_idx
...@@ -337,31 +325,23 @@ class FalconPipelineForwards: ...@@ -337,31 +325,23 @@ class FalconPipelineForwards:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: outputs = self._gradient_checkpointing_func(
logger.warning( block.__call__,
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, hidden_states,
alibi, alibi,
causal_mask, attention_mask,
position_ids,
head_mask[i], head_mask[i],
layer_past,
use_cache,
output_attentions,
) )
else: else:
outputs = block( outputs = block(
hidden_states, hidden_states,
layer_past=layer_past, layer_past=layer_past,
attention_mask=causal_mask, attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask[i], head_mask=head_mask[i],
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
...@@ -382,9 +362,6 @@ class FalconPipelineForwards: ...@@ -382,9 +362,6 @@ class FalconPipelineForwards:
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if presents is not None:
presents = self._convert_cache_to_standard_format(presents, batch_size)
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
if not return_dict: if not return_dict:
return tuple( return tuple(
......
...@@ -177,11 +177,9 @@ class GPT2PipelineForwards: ...@@ -177,11 +177,9 @@ class GPT2PipelineForwards:
head_mask = self.get_head_mask(head_mask, self.config.n_layer) head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if stage_manager.is_first_stage(): if stage_manager.is_first_stage():
if position_ids is not None: if position_ids is None:
position_ids = position_ids.view(-1, input_shape[-1])
else:
position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=device) position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) position_ids = position_ids.unsqueeze(0)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.wte(input_ids) inputs_embeds = self.wte(input_ids)
...@@ -239,22 +237,16 @@ class GPT2PipelineForwards: ...@@ -239,22 +237,16 @@ class GPT2PipelineForwards:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
outputs = self._gradient_checkpointing_func(
def create_custom_forward(module): block.__call__,
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, hidden_states,
None, None,
attention_mask, attention_mask,
head_mask[i], head_mask[i],
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
use_cache,
output_attentions,
) )
else: else:
outputs = block( outputs = block(
......
...@@ -148,11 +148,9 @@ class GPTJPipelineForwards: ...@@ -148,11 +148,9 @@ class GPTJPipelineForwards:
head_mask = self.get_head_mask(head_mask, self.config.n_layer) head_mask = self.get_head_mask(head_mask, self.config.n_layer)
# position id to be assigned not just for the first stage for attn input # position id to be assigned not just for the first stage for attn input
if position_ids is not None: if position_ids is None:
position_ids = position_ids.view(-1, seq_length)
else:
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device) position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) position_ids = position_ids.unsqueeze(0)
if stage_manager.is_first_stage(): if stage_manager.is_first_stage():
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.wte(input_ids) inputs_embeds = self.wte(input_ids)
...@@ -201,21 +199,15 @@ class GPTJPipelineForwards: ...@@ -201,21 +199,15 @@ class GPTJPipelineForwards:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
outputs = self._gradient_checkpointing_func(
def create_custom_forward(module): block.__call__,
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, hidden_states,
None, None,
attention_mask, attention_mask,
position_ids, position_ids,
head_mask[i], head_mask[i],
use_cache,
output_attentions,
) )
else: else:
outputs = block( outputs = block(
...@@ -627,7 +619,9 @@ def get_gptj_flash_attention_forward(): ...@@ -627,7 +619,9 @@ def get_gptj_flash_attention_forward():
value = torch.cat((past_value, value), dim=-2) value = torch.cat((past_value, value), dim=-2)
if use_cache is True: if use_cache is True:
present = (key, value) # Note that this cast is quite ugly, but is not implemented before ROPE as the original codebase keeps the key in float32 all along the computation.
# Reference: https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/layers.py#L128
present = (key.to(hidden_states.dtype), value)
else: else:
present = None present = None
......
...@@ -7,6 +7,7 @@ import torch.nn.functional as F ...@@ -7,6 +7,7 @@ import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.cache_utils import Cache
from transformers.modeling_outputs import ( from transformers.modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
CausalLMOutputWithPast, CausalLMOutputWithPast,
...@@ -16,6 +17,8 @@ from transformers.models.llama.modeling_llama import ( ...@@ -16,6 +17,8 @@ from transformers.models.llama.modeling_llama import (
LlamaForCausalLM, LlamaForCausalLM,
LlamaForSequenceClassification, LlamaForSequenceClassification,
LlamaModel, LlamaModel,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
apply_rotary_pos_emb, apply_rotary_pos_emb,
repeat_kv, repeat_kv,
) )
...@@ -31,13 +34,6 @@ from colossalai.shardformer.shard import ShardConfig ...@@ -31,13 +34,6 @@ from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention, cross_entropy_1d from ..layer import ColoAttention, cross_entropy_1d
try:
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
LATEST_VERSION = True
except ImportError:
LATEST_VERSION = False
class LlamaPipelineForwards: class LlamaPipelineForwards:
""" """
...@@ -75,13 +71,13 @@ class LlamaPipelineForwards: ...@@ -75,13 +71,13 @@ class LlamaPipelineForwards:
# retrieve input_ids and inputs_embeds # retrieve input_ids and inputs_embeds
if stage_manager.is_first_stage(): if stage_manager.is_first_stage():
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
batch_size, seq_length = input_ids.shape batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None: elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape batch_size, seq_length, _ = inputs_embeds.shape[:2]
else: else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device device = input_ids.device if input_ids is not None else inputs_embeds.device
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
...@@ -111,11 +107,12 @@ class LlamaPipelineForwards: ...@@ -111,11 +107,12 @@ class LlamaPipelineForwards:
if position_ids is None: if position_ids is None:
position_ids = torch.arange( position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
) )
position_ids = position_ids.unsqueeze(0).view(-1, seq_length) position_ids = position_ids.unsqueeze(0)
else:
position_ids = position_ids.view(-1, seq_length).long()
# embed positions, for the first stage, hidden_states is the input embeddings, # embed positions, for the first stage, hidden_states is the input embeddings,
# for the other stages, hidden_states is the output of the previous stage # for the other stages, hidden_states is the output of the previous stage
...@@ -123,20 +120,32 @@ class LlamaPipelineForwards: ...@@ -123,20 +120,32 @@ class LlamaPipelineForwards:
# in this case, attention_mask is a dict rather than a tensor # in this case, attention_mask is a dict rather than a tensor
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
attention_mask = ColoAttention.prepare_attn_kwargs( attention_mask = ColoAttention.prepare_attn_kwargs(
mask_shape, hidden_states.dtype, hidden_states.device, q_padding_mask=attention_mask, is_causal=True mask_shape,
hidden_states.dtype,
hidden_states.device,
q_padding_mask=attention_mask,
is_causal=True,
) )
else: else:
if attention_mask is None: if self._use_flash_attention_2:
attention_mask = torch.ones( # 2d mask is passed through the layers
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
) elif self._use_sdpa and not output_attentions:
if LATEST_VERSION: # output_attentions=True can not be supported when using SDPA, and we fall back on
attention_mask = _prepare_4d_causal_attention_mask( # the manual implementation that requires a 4D causal mask in all cases.
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
) )
else: else:
attention_mask = self._prepare_decoder_attention_mask( # 4d mask is passed through the layers
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
hidden_states,
past_key_values_length,
) )
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
...@@ -149,7 +158,7 @@ class LlamaPipelineForwards: ...@@ -149,7 +158,7 @@ class LlamaPipelineForwards:
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None next_decoder_cache = None
start_idx, end_idx = stage_index[0], stage_index[1] start_idx, end_idx = stage_index[0], stage_index[1]
num_ckpt_layers = 0 num_ckpt_layers = 0
...@@ -160,7 +169,7 @@ class LlamaPipelineForwards: ...@@ -160,7 +169,7 @@ class LlamaPipelineForwards:
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers( num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
stage=stage_manager.stage, stage=stage_manager.stage,
num_layers=end_idx - start_idx, num_layers=end_idx - start_idx,
model_chunk_id=stage_manager.model_chunk_id if stage_manager.is_interleave else 0, model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0),
) )
assert num_ckpt_layers <= end_idx - start_idx assert num_ckpt_layers <= end_idx - start_idx
...@@ -168,30 +177,22 @@ class LlamaPipelineForwards: ...@@ -168,30 +177,22 @@ class LlamaPipelineForwards:
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if idx - start_idx < num_ckpt_layers: if idx - start_idx < num_ckpt_layers:
layer_outputs = self._gradient_checkpointing_func(
def create_custom_forward(module): decoder_layer.__call__,
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
position_ids, position_ids,
None, past_key_values,
output_attentions,
use_cache,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
past_key_value=past_key_value, past_key_value=past_key_values,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
) )
...@@ -199,7 +200,7 @@ class LlamaPipelineForwards: ...@@ -199,7 +200,7 @@ class LlamaPipelineForwards:
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions: if output_attentions:
all_self_attns += (layer_outputs[1],) all_self_attns += (layer_outputs[1],)
...@@ -212,7 +213,16 @@ class LlamaPipelineForwards: ...@@ -212,7 +213,16 @@ class LlamaPipelineForwards:
next_cache = next_decoder_cache if use_cache else None next_cache = next_decoder_cache if use_cache else None
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return tuple(
v
for v in [
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
]
if v is not None
)
return BaseModelOutputWithPast( return BaseModelOutputWithPast(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=next_cache, past_key_values=next_cache,
...@@ -458,23 +468,25 @@ class LlamaPipelineForwards: ...@@ -458,23 +468,25 @@ class LlamaPipelineForwards:
def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size): def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size):
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
llama_version = 2
try: try:
from transformers.models.llama.modeling_llama import repeat_kv from transformers.models.llama.modeling_llama import repeat_kv
except: except:
warnings.warn("using llamav1, llamav1 hasn't repeat_kv function") warnings.warn("using llamav1, llamav1 hasn't repeat_kv function")
llama_version = 1
def forward( def forward(
self: LlamaAttention, self: LlamaAttention,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[dict] = None, attention_mask: Optional[dict] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Cache] = None,
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
if sp_mode in ["split_gather", "ring"]: if sp_mode in ["split_gather", "ring"]:
...@@ -498,19 +510,21 @@ def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size): ...@@ -498,19 +510,21 @@ def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size):
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2] if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None: if past_key_value is not None:
# reuse k, v, self_attention cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states = torch.cat([past_key_value[0], key_states], dim=2) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
if llama_version == 2:
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
...@@ -573,7 +587,10 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig): ...@@ -573,7 +587,10 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig):
if position_ids is None: if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange( position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
) )
position_ids = position_ids.unsqueeze(0).view(-1, seq_length) position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else: else:
...@@ -587,7 +604,11 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig): ...@@ -587,7 +604,11 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig):
# in this case, attention_mask is a dict rather than a tensor # in this case, attention_mask is a dict rather than a tensor
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
attention_mask = ColoAttention.prepare_attn_kwargs( attention_mask = ColoAttention.prepare_attn_kwargs(
mask_shape, hidden_states.dtype, hidden_states.device, q_padding_mask=attention_mask, is_causal=True mask_shape,
hidden_states.dtype,
hidden_states.device,
q_padding_mask=attention_mask,
is_causal=True,
) )
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
...@@ -918,7 +939,10 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group): ...@@ -918,7 +939,10 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
if position_ids is None: if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange( position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
) )
position_ids = position_ids.unsqueeze(0).view(-1, seq_length) position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else: else:
...@@ -934,10 +958,12 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group): ...@@ -934,10 +958,12 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones( attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device (batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device,
) )
attention_mask = self._prepare_decoder_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length
) )
......
This diff is collapsed.
...@@ -3,6 +3,7 @@ from typing import List, Optional, Tuple, Union ...@@ -3,6 +3,7 @@ from typing import List, Optional, Tuple, Union
import torch import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import ( from transformers.modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
CausalLMOutputWithPast, CausalLMOutputWithPast,
...@@ -42,7 +43,7 @@ def _get_attention_mask( ...@@ -42,7 +43,7 @@ def _get_attention_mask(
is_causal=True, is_causal=True,
) )
else: else:
attention_mask = self.decoder._prepare_decoder_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, attention_mask,
(batch_size, seq_length), (batch_size, seq_length),
hidden_states, hidden_states,
...@@ -57,6 +58,20 @@ class OPTPipelineForwards: ...@@ -57,6 +58,20 @@ class OPTPipelineForwards:
under pipeline setting. under pipeline setting.
""" """
@staticmethod
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
@staticmethod @staticmethod
def opt_model_forward( def opt_model_forward(
self: OPTModel, self: OPTModel,
...@@ -112,7 +127,7 @@ class OPTPipelineForwards: ...@@ -112,7 +127,7 @@ class OPTPipelineForwards:
inputs_embeds = decoder.project_in(inputs_embeds) inputs_embeds = decoder.project_in(inputs_embeds)
device = input_ids.device if input_ids is not None else inputs_embeds.device device = input_ids.device if input_ids is not None else inputs_embeds.device
inputs_embeds.dtype inputs_embeds.dtype
hidden_states = inputs_embeds
else: else:
if hidden_states is None: if hidden_states is None:
raise ValueError("hidden_states shouldn't be None for intermediate stages.") raise ValueError("hidden_states shouldn't be None for intermediate stages.")
...@@ -125,13 +140,26 @@ class OPTPipelineForwards: ...@@ -125,13 +140,26 @@ class OPTPipelineForwards:
# required mask seq length can be calculated via length of past # required mask seq length can be calculated via length of past
mask_seq_length = past_key_values_length + seq_length mask_seq_length = past_key_values_length + seq_length
# embed positions # embed positions
if self.decoder._use_flash_attention_2:
# 2d mask is passed through the layers
causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
attention_mask = (
torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
if attention_mask is None
else attention_mask
)
else:
# 4d mask is passed through the layers
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=device) attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
elif attention_mask.shape[1] != mask_seq_length: elif attention_mask.shape[1] != mask_seq_length:
raise ValueError( raise ValueError(
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
f"{mask_seq_length} (sum of the lengths of current and past inputs)" f"{mask_seq_length} (sum of the lengths of current and past inputs)"
) )
causal_attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, hidden_states, past_key_values_length
)
if stage_manager.is_first_stage(): if stage_manager.is_first_stage():
causal_attention_mask = _get_attention_mask( causal_attention_mask = _get_attention_mask(
...@@ -205,20 +233,14 @@ class OPTPipelineForwards: ...@@ -205,20 +233,14 @@ class OPTPipelineForwards:
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if decoder.gradient_checkpointing and decoder.training: if decoder.gradient_checkpointing and decoder.training:
layer_outputs = self._gradient_checkpointing_func(
def create_custom_forward(module): decoder_layer.__call__,
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states, hidden_states,
causal_attention_mask, causal_attention_mask,
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
None, None,
output_attentions,
use_cache,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
......
...@@ -3,7 +3,6 @@ from typing import Dict, List, Optional, Tuple, Union ...@@ -3,7 +3,6 @@ from typing import Dict, List, Optional, Tuple, Union
import torch import torch
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from torch.utils.checkpoint import checkpoint
from transformers.modeling_outputs import ( from transformers.modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
...@@ -118,16 +117,13 @@ class T5PipelineForwards: ...@@ -118,16 +117,13 @@ class T5PipelineForwards:
# required mask seq length can be calculated via length of past # required mask seq length can be calculated via length of past
mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=device)
if in_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
encoder_seq_length = encoder_hidden_states.shape[1]
encoder_attention_mask = torch.ones(batch_size, encoder_seq_length, device=device, dtype=torch.long)
# initialize past_key_values with `None` if past does not exist # initialize past_key_values with `None` if past does not exist
if past_key_values is None: if past_key_values is None:
past_key_values = [None] * len(self.block) past_key_values = [None] * len(self.block)
if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
...@@ -138,7 +134,7 @@ class T5PipelineForwards: ...@@ -138,7 +134,7 @@ class T5PipelineForwards:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None: if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else: else:
encoder_extended_attention_mask = None encoder_extended_attention_mask = None
...@@ -162,15 +158,8 @@ class T5PipelineForwards: ...@@ -162,15 +158,8 @@ class T5PipelineForwards:
torch.cuda.set_device(hidden_states.device) torch.cuda.set_device(hidden_states.device)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
def create_custom_forward(module): layer_module.forward,
def custom_forward(*inputs):
return tuple(module(*inputs, use_cache, output_attentions))
return custom_forward
layer_outputs = checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
extended_attention_mask, extended_attention_mask,
position_bias, position_bias,
...@@ -180,6 +169,8 @@ class T5PipelineForwards: ...@@ -180,6 +169,8 @@ class T5PipelineForwards:
layer_head_mask, layer_head_mask,
cross_attn_layer_head_mask, cross_attn_layer_head_mask,
None, # past_key_value is always None with gradient checkpointing None, # past_key_value is always None with gradient checkpointing
use_cache,
output_attentions,
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
......
...@@ -14,6 +14,8 @@ def _encoder_forward( ...@@ -14,6 +14,8 @@ def _encoder_forward(
end_idx: int, end_idx: int,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True, return_dict: bool = True,
stage_manager: PipelineStageManager = None, stage_manager: PipelineStageManager = None,
) -> Union[tuple, BaseModelOutput]: ) -> Union[tuple, BaseModelOutput]:
...@@ -23,20 +25,14 @@ def _encoder_forward( ...@@ -23,20 +25,14 @@ def _encoder_forward(
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if encoder.gradient_checkpointing and encoder.training: if encoder.gradient_checkpointing and encoder.training:
layer_outputs = encoder._gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, False)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
layer_head_mask, layer_head_mask,
output_attentions,
) )
else: else:
layer_outputs = layer_module(hidden_states, layer_head_mask, False) layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if not stage_manager.is_last_stage(): if not stage_manager.is_last_stage():
...@@ -114,6 +110,8 @@ def ViTModel_pipeline_forward(stage_manager: PipelineStageManager, stage_index: ...@@ -114,6 +110,8 @@ def ViTModel_pipeline_forward(stage_manager: PipelineStageManager, stage_index:
end_idx=stage_index[1], end_idx=stage_index[1],
hidden_states=hidden_states, hidden_states=hidden_states,
head_mask=head_mask, head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
stage_manager=stage_manager, stage_manager=stage_manager,
) )
......
...@@ -5,6 +5,10 @@ from typing import List, Optional, Tuple, Union ...@@ -5,6 +5,10 @@ from typing import List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.modeling_outputs import ( from transformers.modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
...@@ -35,6 +39,8 @@ def _get_attention_mask( ...@@ -35,6 +39,8 @@ def _get_attention_mask(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
past_key_values_length: int, past_key_values_length: int,
attention_mask: Optional[torch.FloatTensor], attention_mask: Optional[torch.FloatTensor],
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
): ):
batch_size, seq_length = hidden_states.shape[:2] batch_size, seq_length = hidden_states.shape[:2]
mask_seq_length = past_key_values_length + seq_length mask_seq_length = past_key_values_length + seq_length
...@@ -47,11 +53,19 @@ def _get_attention_mask( ...@@ -47,11 +53,19 @@ def _get_attention_mask(
is_causal=True, is_causal=True,
) )
else: else:
attention_mask = self._prepare_decoder_attention_mask( input_shape = (batch_size, seq_length)
attention_mask, if self._use_flash_attention_2:
(batch_size, seq_length), # 2d mask is passed through the layers
hidden_states, attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
past_key_values_length, elif self._use_sdpa and head_mask is None and not output_attentions:
# output_attentions=True & head_mask can not be supported when using SDPA.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask, input_shape, hidden_states, past_key_values_length
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, hidden_states, past_key_values_length
) )
return attention_mask return attention_mask
...@@ -539,18 +553,12 @@ class WhisperPipelineForwards: ...@@ -539,18 +553,12 @@ class WhisperPipelineForwards:
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
def create_custom_forward(module): encoder_layer.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states, hidden_states,
None, None,
(head_mask[idx] if head_mask is not None else None), (head_mask[idx] if head_mask is not None else None),
output_attentions,
) )
else: else:
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
...@@ -702,20 +710,16 @@ class WhisperPipelineForwards: ...@@ -702,20 +710,16 @@ class WhisperPipelineForwards:
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
attention_mask = _get_attention_mask(
self, shard_config, inputs_embeds, past_key_values_length, attention_mask
)
# embed positions # embed positions
if input_ids is not None: if input_ids is not None:
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
else: else:
positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length) positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length)
attention_mask = _get_attention_mask(
self,
shard_config,
inputs_embeds,
past_key_values_length,
attention_mask,
)
hidden_states = inputs_embeds + positions hidden_states = inputs_embeds + positions
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
...@@ -732,7 +736,6 @@ class WhisperPipelineForwards: ...@@ -732,7 +736,6 @@ class WhisperPipelineForwards:
"hidden_states shouldn't be None for stages other than the first stage of encoder/decoder." "hidden_states shouldn't be None for stages other than the first stage of encoder/decoder."
) )
input_shape = hidden_states.size()[:-1] input_shape = hidden_states.size()[:-1]
attention_mask = _get_attention_mask( attention_mask = _get_attention_mask(
self, self,
shard_config, shard_config,
...@@ -756,16 +759,8 @@ class WhisperPipelineForwards: ...@@ -756,16 +759,8 @@ class WhisperPipelineForwards:
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
def create_custom_forward(module): decoder_layer.__call__,
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
encoder_hidden_states, encoder_hidden_states,
...@@ -773,6 +768,8 @@ class WhisperPipelineForwards: ...@@ -773,6 +768,8 @@ class WhisperPipelineForwards:
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), (cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
None, # past_key_value None, # past_key_value
output_attentions,
use_cache,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
......
...@@ -24,12 +24,6 @@ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDe ...@@ -24,12 +24,6 @@ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDe
class BloomPolicy(Policy): class BloomPolicy(Policy):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
import transformers
from packaging.version import Version
assert Version(transformers.__version__) <= Version(
"4.33.0"
), "The Bloom model should run on a transformers version not greater than 4.33.0."
def config_sanity_check(self): def config_sanity_check(self):
pass pass
......
...@@ -7,12 +7,7 @@ from torch.nn import Module ...@@ -7,12 +7,7 @@ from torch.nn import Module
import colossalai.shardformer.layer as col_nn import colossalai.shardformer.layer as col_nn
from ..modeling.falcon import ( from ..modeling.falcon import FalconPipelineForwards, build_falcon_alibi_tensor_fn, get_tp_falcon_decoder_layer_forward
FalconPipelineForwards,
build_falcon_alibi_tensor_fn,
get_falcon_flash_attention_forward,
get_tp_falcon_decoder_layer_forward,
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["FalconPolicy"] __all__ = ["FalconPolicy"]
...@@ -21,12 +16,6 @@ __all__ = ["FalconPolicy"] ...@@ -21,12 +16,6 @@ __all__ = ["FalconPolicy"]
class FalconPolicy(Policy): class FalconPolicy(Policy):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
import transformers
from packaging.version import Version
assert Version(transformers.__version__) <= Version(
"4.33.0"
), "The Falcon model should run on a transformers version not greater than 4.33.0."
def config_sanity_check(self): def config_sanity_check(self):
pass pass
...@@ -36,7 +25,7 @@ class FalconPolicy(Policy): ...@@ -36,7 +25,7 @@ class FalconPolicy(Policy):
return self.model return self.model
def module_policy(self): def module_policy(self):
from transformers.models.falcon.modeling_falcon import FalconAttention, FalconDecoderLayer, FalconModel from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel
if not self.model.config.new_decoder_architecture and self.model.config.multi_query: if not self.model.config.new_decoder_architecture and self.model.config.multi_query:
warnings.warn( warnings.warn(
...@@ -147,11 +136,8 @@ class FalconPolicy(Policy): ...@@ -147,11 +136,8 @@ class FalconPolicy(Policy):
) )
if self.shard_config.enable_flash_attention: if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement( warnings.warn("Falcon doesn't support flash attention now, fallback to transformers attention.")
description={"forward": get_falcon_flash_attention_forward()},
policy=policy,
target_key=FalconAttention,
)
return policy return policy
def postprocess(self): def postprocess(self):
......
...@@ -35,13 +35,20 @@ class GPT2Policy(Policy): ...@@ -35,13 +35,20 @@ class GPT2Policy(Policy):
Reshape the Embedding layer to make the embedding dimension divisible by world_size Reshape the Embedding layer to make the embedding dimension divisible by world_size
""" """
self.tie_weight = self.tie_weight_check() self.tie_weight = self.tie_weight_check()
self.origin_attn_implement = self.model.config._attn_implementation
return self.model return self.model
def module_policy(self): def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
ATTN_IMPLEMENTATION = {
"eager": GPT2Attention,
}
policy = {} policy = {}
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
embedding_cls = None embedding_cls = None
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
embedding_cls = col_nn.VocabParallelEmbedding1D embedding_cls = col_nn.VocabParallelEmbedding1D
...@@ -186,7 +193,7 @@ class GPT2Policy(Policy): ...@@ -186,7 +193,7 @@ class GPT2Policy(Policy):
"forward": get_gpt2_flash_attention_forward(), "forward": get_gpt2_flash_attention_forward(),
}, },
policy=policy, policy=policy,
target_key=GPT2Attention, target_key=attn_cls,
) )
if not self.shard_config.pipeline_stage_manager: if not self.shard_config.pipeline_stage_manager:
policy[GPT2Model].method_replacement = { policy[GPT2Model].method_replacement = {
......
...@@ -30,13 +30,20 @@ class GPTJPolicy(Policy): ...@@ -30,13 +30,20 @@ class GPTJPolicy(Policy):
def preprocess(self): def preprocess(self):
self.tie_weight = self.tie_weight_check() self.tie_weight = self.tie_weight_check()
self.origin_attn_implement = self.model.config._attn_implementation
return self.model return self.model
def module_policy(self): def module_policy(self):
from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJModel from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJModel
ATTN_IMPLEMENTATION = {
"eager": GPTJAttention,
}
policy = {} policy = {}
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
embedding_cls = None embedding_cls = None
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
embedding_cls = col_nn.VocabParallelEmbedding1D embedding_cls = col_nn.VocabParallelEmbedding1D
...@@ -160,7 +167,7 @@ class GPTJPolicy(Policy): ...@@ -160,7 +167,7 @@ class GPTJPolicy(Policy):
"forward": get_gptj_flash_attention_forward(), "forward": get_gptj_flash_attention_forward(),
}, },
policy=policy, policy=policy,
target_key=GPTJAttention, target_key=attn_cls,
) )
if not self.shard_config.pipeline_stage_manager: if not self.shard_config.pipeline_stage_manager:
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
......
...@@ -36,13 +36,26 @@ class LlamaPolicy(Policy): ...@@ -36,13 +36,26 @@ class LlamaPolicy(Policy):
def preprocess(self): def preprocess(self):
self.tie_weight = self.tie_weight_check() self.tie_weight = self.tie_weight_check()
self.origin_attn_implement = self.model.config._attn_implementation
return self.model return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaFlashAttention2,
LlamaModel,
LlamaSdpaAttention,
)
ATTN_IMPLEMENTATION = {
"eager": LlamaAttention,
"flash_attention_2": LlamaFlashAttention2,
"sdpa": LlamaSdpaAttention,
}
policy = {} policy = {}
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
embedding_cls = None embedding_cls = None
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
embedding_cls = VocabParallelEmbedding1D embedding_cls = VocabParallelEmbedding1D
...@@ -93,7 +106,7 @@ class LlamaPolicy(Policy): ...@@ -93,7 +106,7 @@ class LlamaPolicy(Policy):
"forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group), "forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
}, },
policy=policy, policy=policy,
target_key=LlamaAttention, target_key=attn_cls,
) )
elif sp_mode == "all_to_all": elif sp_mode == "all_to_all":
decoder_attribute_replacement = { decoder_attribute_replacement = {
...@@ -102,7 +115,7 @@ class LlamaPolicy(Policy): ...@@ -102,7 +115,7 @@ class LlamaPolicy(Policy):
if getattr(self.model.config, "num_key_value_heads", False): if getattr(self.model.config, "num_key_value_heads", False):
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
policy[LlamaAttention] = ModulePolicyDescription( policy[attn_cls] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement, attribute_replacement=decoder_attribute_replacement,
) )
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
...@@ -110,7 +123,7 @@ class LlamaPolicy(Policy): ...@@ -110,7 +123,7 @@ class LlamaPolicy(Policy):
"forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group), "forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
}, },
policy=policy, policy=policy,
target_key=LlamaAttention, target_key=attn_cls,
) )
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
description={ description={
...@@ -221,7 +234,7 @@ class LlamaPolicy(Policy): ...@@ -221,7 +234,7 @@ class LlamaPolicy(Policy):
"forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_group, sp_size), "forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_group, sp_size),
}, },
policy=policy, policy=policy,
target_key=LlamaAttention, target_key=attn_cls,
) )
if self.pipeline_stage_manager is None: if self.pipeline_stage_manager is None:
# replace llama model forward method # replace llama model forward method
......
import warnings import warnings
from typing import Dict, Union from functools import partial
from typing import Callable, Dict, List, Union
import torch.nn as nn import torch.nn as nn
from torch import Tensor
from torch.nn import Module
from colossalai.shardformer.layer import ( from colossalai.shardformer.layer import (
FusedRMSNorm, FusedRMSNorm,
...@@ -13,7 +16,11 @@ from colossalai.shardformer.layer import ( ...@@ -13,7 +16,11 @@ from colossalai.shardformer.layer import (
VocabParallelLMHead1D, VocabParallelLMHead1D,
) )
from ..modeling.mistral import get_mistral_flash_attention_forward from ..modeling.mistral import (
MistralForwards,
get_mistral_flash_attention_forward,
get_mistral_model_forward_for_flash_attn,
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["MistralPolicy", "MistralModelPolicy", "MistralForCausalLMPolicy", "MistralForSequenceClassificationPolicy"] __all__ = ["MistralPolicy", "MistralModelPolicy", "MistralForCausalLMPolicy", "MistralForSequenceClassificationPolicy"]
...@@ -25,13 +32,26 @@ class MistralPolicy(Policy): ...@@ -25,13 +32,26 @@ class MistralPolicy(Policy):
def preprocess(self): def preprocess(self):
self.tie_weight = self.tie_weight_check() self.tie_weight = self.tie_weight_check()
self.origin_attn_implement = self.model.config._attn_implementation
return self.model return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.mistral.modeling_mistral import MistralAttention, MistralDecoderLayer, MistralModel from transformers.models.mistral.modeling_mistral import (
MistralAttention,
MistralDecoderLayer,
MistralFlashAttention2,
MistralModel,
)
ATTN_IMPLEMENTATION = {
"eager": MistralAttention,
"flash_attention_2": MistralFlashAttention2,
}
policy = {} policy = {}
attn_cls = ATTN_IMPLEMENTATION[self.model.config._attn_implementation]
embedding_cls = None embedding_cls = None
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
embedding_cls = VocabParallelEmbedding1D embedding_cls = VocabParallelEmbedding1D
...@@ -127,10 +147,19 @@ class MistralPolicy(Policy): ...@@ -127,10 +147,19 @@ class MistralPolicy(Policy):
if self.shard_config.enable_flash_attention: if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
description={ description={
"forward": get_mistral_flash_attention_forward(), "forward": get_mistral_flash_attention_forward(self.shard_config),
}, },
policy=policy, policy=policy,
target_key=MistralAttention, target_key=attn_cls,
)
if self.pipeline_stage_manager is None:
# replace llama model forward method
self.append_or_create_method_replacement(
description={
"forward": get_mistral_model_forward_for_flash_attn(self.shard_config),
},
policy=policy,
target_key=MistralModel,
) )
return policy return policy
...@@ -138,16 +167,92 @@ class MistralPolicy(Policy): ...@@ -138,16 +167,92 @@ class MistralPolicy(Policy):
def postprocess(self): def postprocess(self):
return self.model return self.model
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if self.pipeline_stage_manager is None:
return
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == "MistralModel":
module = self.model
else:
module = self.model.model
if stage_manager.is_interleave:
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {
"forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config)
}
else:
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {
"forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
)
}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
assert self.pipeline_stage_manager is not None
if self.model.__class__.__name__ == "MistralModel":
module = self.model
else:
module = self.model.model
stage_manager = self.pipeline_stage_manager
held_layers = []
if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
stage_indices = stage_manager.get_stage_index(layers_per_stage)
if stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.embed_tokens)
for start_idx, end_idx in stage_indices:
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(module.norm)
else:
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)
return held_layers
class MistralModelPolicy(MistralPolicy): class MistralModelPolicy(MistralPolicy):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
def module_policy(self): def module_policy(self):
policy = super().module_policy()
from transformers.models.mistral.modeling_mistral import MistralModel
if self.pipeline_stage_manager: if self.pipeline_stage_manager:
warnings.warn("Mistral doesn't support pipeline parallelism now.") self.set_pipeline_forward(
model_cls=MistralModel, new_forward=MistralForwards.mistral_model_forward, policy=policy
)
return policy
return super().module_policy() def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
held_layers = super().get_held_layers()
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in mistral model"""
return []
class MistralForCausalLMPolicy(MistralPolicy): class MistralForCausalLMPolicy(MistralPolicy):
...@@ -155,8 +260,6 @@ class MistralForCausalLMPolicy(MistralPolicy): ...@@ -155,8 +260,6 @@ class MistralForCausalLMPolicy(MistralPolicy):
from transformers import MistralForCausalLM from transformers import MistralForCausalLM
policy = super().module_policy() policy = super().module_policy()
if self.pipeline_stage_manager:
warnings.warn("Mistral doesn't support pipeline parallelism now.")
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm # add a new item for casual lm
...@@ -189,8 +292,38 @@ class MistralForCausalLMPolicy(MistralPolicy): ...@@ -189,8 +292,38 @@ class MistralForCausalLMPolicy(MistralPolicy):
policy.update(new_item) policy.update(new_item)
if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(
model_cls=MistralForCausalLM, new_forward=MistralForwards.mistral_for_causal_lm_forward, policy=policy
)
return policy return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
mistral_model = self.model.model
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
if (
id(mistral_model.embed_tokens.weight) == id(self.model.lm_head.weight)
and self.pipeline_stage_manager.num_stages > 1
):
# tie weights
return [
{
0: mistral_model.embed_tokens.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
}
]
return []
class MistralForSequenceClassificationPolicy(MistralPolicy): class MistralForSequenceClassificationPolicy(MistralPolicy):
def module_policy(self): def module_policy(self):
...@@ -209,9 +342,26 @@ class MistralForSequenceClassificationPolicy(MistralPolicy): ...@@ -209,9 +342,26 @@ class MistralForSequenceClassificationPolicy(MistralPolicy):
] ]
) )
} }
policy.update(new_item)
if self.pipeline_stage_manager: if self.pipeline_stage_manager:
warnings.warn("Mistral doesn't support pipeline parallelism now.") # set None as default
self.set_pipeline_forward(
model_cls=MistralForSequenceClassification,
new_forward=MistralForwards.mistral_for_sequence_classification_forward,
policy=policy,
)
policy.update(new_item)
return policy return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.score)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in llama for sequence classification model"""
return []
...@@ -38,26 +38,27 @@ __all__ = [ ...@@ -38,26 +38,27 @@ __all__ = [
class OPTPolicy(Policy): class OPTPolicy(Policy):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
import transformers
from packaging.version import Version
# TODO: remove this version check when transformers>=4.36.0
assert Version(transformers.__version__) <= Version(
"4.33.0"
), "The OPT model should run on a transformers version not greater than 4.33.0."
def config_sanity_check(self): def config_sanity_check(self):
pass pass
def preprocess(self): def preprocess(self):
self.tie_weight = self.tie_weight_check() self.tie_weight = self.tie_weight_check()
self.origin_attn_implement = self.model.config._attn_implementation
return self.model return self.model
def module_policy(self): def module_policy(self):
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer, OptFlashAttention2
ATTN_IMPLEMENTATION = {
"eager": OPTAttention,
"flash_attention_2": OptFlashAttention2,
}
policy = {} policy = {}
attn_cls = ATTN_IMPLEMENTATION[self.model.config._attn_implementation]
embedding_cls = None embedding_cls = None
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
embedding_cls = VocabParallelEmbedding1D embedding_cls = VocabParallelEmbedding1D
...@@ -88,7 +89,7 @@ class OPTPolicy(Policy): ...@@ -88,7 +89,7 @@ class OPTPolicy(Policy):
] ]
) )
policy[OPTAttention] = ModulePolicyDescription( policy[attn_cls] = ModulePolicyDescription(
attribute_replacement={ attribute_replacement={
"embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
...@@ -158,7 +159,7 @@ class OPTPolicy(Policy): ...@@ -158,7 +159,7 @@ class OPTPolicy(Policy):
"forward": get_opt_flash_attention_forward(self.shard_config), "forward": get_opt_flash_attention_forward(self.shard_config),
}, },
policy=policy, policy=policy,
target_key=OPTAttention, target_key=attn_cls,
) )
if not self.shard_config.pipeline_stage_manager: if not self.shard_config.pipeline_stage_manager:
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
......
import warnings
import colossalai.shardformer.layer as col_nn import colossalai.shardformer.layer as col_nn
from ..modeling.sam import forward_fn, get_sam_flash_attention_forward, get_sam_vision_flash_attention_forward from ..modeling.sam import forward_fn
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["SamPolicy", "SamModelPolicy"] __all__ = ["SamPolicy", "SamModelPolicy"]
...@@ -15,7 +17,6 @@ class SamPolicy(Policy): ...@@ -15,7 +17,6 @@ class SamPolicy(Policy):
def module_policy(self): def module_policy(self):
from transformers.models.sam.modeling_sam import ( from transformers.models.sam.modeling_sam import (
SamAttention,
SamTwoWayAttentionBlock, SamTwoWayAttentionBlock,
SamTwoWayTransformer, SamTwoWayTransformer,
SamVisionAttention, SamVisionAttention,
...@@ -210,20 +211,21 @@ class SamPolicy(Policy): ...@@ -210,20 +211,21 @@ class SamPolicy(Policy):
# use flash attention # use flash attention
if self.shard_config.enable_flash_attention: if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement( warnings.warn("Flash attention is not supported in SAM model. Fallback to normal attention.")
description={ # self.append_or_create_method_replacement(
"forward": get_sam_flash_attention_forward(), # description={
}, # "forward": get_sam_flash_attention_forward(),
policy=policy, # },
target_key=SamAttention, # policy=policy,
) # target_key=SamAttention,
self.append_or_create_method_replacement( # )
description={ # self.append_or_create_method_replacement(
"forward": get_sam_vision_flash_attention_forward(), # description={
}, # "forward": get_sam_vision_flash_attention_forward(),
policy=policy, # },
target_key=SamVisionAttention, # policy=policy,
) # target_key=SamVisionAttention,
# )
return policy return policy
......
...@@ -29,13 +29,6 @@ __all__ = [ ...@@ -29,13 +29,6 @@ __all__ = [
class WhisperPolicy(Policy): class WhisperPolicy(Policy):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
import transformers
from packaging.version import Version
# TODO: remove this version check when transformers>=4.36.0
assert Version(transformers.__version__) <= Version(
"4.33.0"
), "The Whisper model should run on a transformers version not greater than 4.33.0."
def config_sanity_check(self): def config_sanity_check(self):
pass pass
...@@ -55,6 +48,8 @@ class WhisperPolicy(Policy): ...@@ -55,6 +48,8 @@ class WhisperPolicy(Policy):
WhisperDecoderLayer, WhisperDecoderLayer,
WhisperEncoder, WhisperEncoder,
WhisperEncoderLayer, WhisperEncoderLayer,
WhisperFlashAttention2,
WhisperSdpaAttention,
) )
policy = {} policy = {}
...@@ -249,6 +244,20 @@ class WhisperPolicy(Policy): ...@@ -249,6 +244,20 @@ class WhisperPolicy(Policy):
policy=policy, policy=policy,
target_key=WhisperAttention, target_key=WhisperAttention,
) )
self.append_or_create_method_replacement(
description={
"forward": get_whisper_flash_attention_forward(),
},
policy=policy,
target_key=WhisperFlashAttention2,
)
self.append_or_create_method_replacement(
description={
"forward": get_whisper_flash_attention_forward(),
},
policy=policy,
target_key=WhisperSdpaAttention,
)
if not self.shard_config.pipeline_stage_manager: if not self.shard_config.pipeline_stage_manager:
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
description={ description={
......
...@@ -840,6 +840,7 @@ class GeminiDDP(ModelWrapper): ...@@ -840,6 +840,7 @@ class GeminiDDP(ModelWrapper):
for buffer in self.module.buffers(): for buffer in self.module.buffers():
if isinstance(buffer, LazyTensor): if isinstance(buffer, LazyTensor):
buffer.materialize() buffer.materialize()
for buffer in self.module.buffers():
buffer.data = buffer.to(get_accelerator().get_current_device()) buffer.data = buffer.to(get_accelerator().get_current_device())
if torch.is_floating_point(buffer): if torch.is_floating_point(buffer):
buffer.data = buffer.to(self.mixed_precision) buffer.data = buffer.to(self.mixed_precision)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment