Commit 7c19b3a8 authored by wangsen's avatar wangsen
Browse files

Initial commit

parents
Pipeline #1721 failed with stages
in 0 seconds
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import logging
from typing import List, Literal, Optional, Tuple
import torch
from torch import Tensor
from megatron.core import InferenceParams, parallel_state, tensor_parallel
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.transformer.enums import AttnMaskType, ModelType
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import make_tp_sharded_tensor_for_checkpoint
class T5LMHead(MegatronModule):
"""Masked LM head for T5
Args:
config (TransformerConfig): transformer config
parallel_output (bool): wether output logits being distributed or not.
vocab_size (int): vocabulary size
pre_process (bool): Include embedding layer
share_embeddings_and_output_weights (bool): When True, input embeddings and output logit weights are
shared.
"""
def __init__(
self,
config: TransformerConfig,
parallel_output: bool,
vocab_size: int,
pre_process: bool = True,
share_embeddings_and_output_weights: bool = False,
):
super(T5LMHead, self).__init__(config=config)
self.parallel_output = parallel_output
self.output_layer = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
vocab_size,
config=config,
init_method=config.init_method,
bias=share_embeddings_and_output_weights,
skip_bias_add=not share_embeddings_and_output_weights,
gather_output=not self.parallel_output,
skip_weight_param_allocation=pre_process and share_embeddings_and_output_weights,
)
def forward(self, hidden_states: Tensor, word_embeddings_weight: Tensor) -> Tensor:
"""Forward pass.
Args:
hidden_states (Tensor): output hidden states from decoder
word_embeddings_weight (Tensor): word embedding weight
Returns:
Tensor: logits tensor
"""
logits, _ = self.output_layer(hidden_states, weight=word_embeddings_weight)
return logits
class T5Model(LanguageModule):
"""T5 Language model.
Args:
config (TransformerConfig): transformer config
transformer_encoder_layer_spec (ModuleSpec): transformer layer customization specs for encoder
transformer_decoder_layer_spec (ModuleSpec): transformer layer customization specs for decoder
vocab_size (int): vocabulary size
max_sequence_length (int): maximum size of sequence. This is used for positional embedding
pre_process (bool): Include embedding layer (used with pipeline parallelism)
post_process (bool): Include an output layer (used with pipeline parallelism)
fp16_lm_cross_entropy (bool, optional): Defaults to False
parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel ranks
share_embeddings_and_output_weights (bool): When True, input embeddings and output logit weights are
shared. Defaults to False.
position_embedding_type (string): Position embedding type. Options ['learned_absolute', 'rope'].
Defaults is 'learned_absolute'.
rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings.
Defaults to 1.0 (100%). Ignored unless position_embedding_type is 'rope'.
seq_len_interpolation_factor (float): scale of linearly interpolating RoPE for longer sequences.
The value must be a float larger than 1.0. Defaults to None.
"""
def __init__(
self,
config: TransformerConfig,
transformer_encoder_layer_spec: ModuleSpec,
transformer_decoder_layer_spec: ModuleSpec,
vocab_size: int,
max_sequence_length: int,
pre_process: bool = True,
post_process: bool = True,
fp16_lm_cross_entropy: bool = False,
parallel_output: bool = True,
share_embeddings_and_output_weights: bool = False,
position_embedding_type: Literal['learned_absolute', 'rope'] = 'learned_absolute',
rotary_percent: float = 1.0,
seq_len_interpolation_factor: Optional[float] = None,
):
super(T5Model, self).__init__(config=config)
self.config: TransformerConfig = config
self.transformer_encoder_layer_spec: ModuleSpec = transformer_encoder_layer_spec
self.transformer_decoder_layer_spec: ModuleSpec = transformer_decoder_layer_spec
self.vocab_size = vocab_size
self.max_sequence_length = max_sequence_length
self.pre_process = pre_process
self.post_process = post_process
self.add_encoder = True
self.add_decoder = True
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
self.parallel_output = parallel_output
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
self.position_embedding_type = position_embedding_type
# megatron core pipelining currently depends on model type
self.model_type = ModelType.encoder_and_decoder
# Embeddings.
if self.pre_process:
self.embedding = LanguageModelEmbedding(
config=self.config,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
position_embedding_type=self.position_embedding_type,
)
# Rotary Position Embeddings
if self.position_embedding_type == 'rope':
self.rotary_pos_emb = RotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=rotary_percent,
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
)
# Transformer encoder
encoder_spec, decoder_spec = (
self.transformer_encoder_layer_spec,
self.transformer_decoder_layer_spec,
)
self.encoder = TransformerBlock(
config=self.config,
spec=encoder_spec,
pre_process=self.pre_process,
post_process=self.post_process,
)
# Transformer decoder
self.decoder = TransformerBlock(
config=self.config,
spec=decoder_spec,
pre_process=self.pre_process,
post_process=self.post_process,
)
# Output
if post_process:
self.lm_head = T5LMHead(
config,
parallel_output,
self.vocab_size,
self.pre_process,
self.share_embeddings_and_output_weights,
)
self.output_layer = self.lm_head.output_layer
if self.pre_process or self.post_process:
self.setup_embeddings_and_output_layer()
def forward(
self,
encoder_input_ids: Tensor,
decoder_input_ids: Tensor,
encoder_attn_mask: Tensor,
decoder_attn_mask: Tensor,
encoder_decoder_attn_mask: Tensor,
lm_labels: Tensor = None,
inference_params: InferenceParams = None,
) -> Tensor:
"""Forward pass.
Args:
encoder_input_ids (Tensor): input ids for encoder
decoder_input_ids (Tensor): input ids for decoder
encoder_attn_mask (Tensor): self-attention mask for encoder
decoder_attn_mask (Tensor): self-attention mask for decoder
encoder_decoder_attn_mask (Tensor): cross-attention mask between encoder and decoder
lm_labels (Tensor): labels for decoder output
inference_params (InferenceParams): relevant arguments for inferencing
Returns:
Tensor: loss tensor
"""
(
encoder_attn_mask,
decoder_attn_mask,
encoder_decoder_attn_mask,
) = t5_extended_attention_mask(
[encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask]
)
encoder_position_ids = t5_position_ids(encoder_input_ids)
decoder_position_ids = t5_position_ids(decoder_input_ids)
## Encoder forward
# Encoder embedding.
if self.pre_process:
encoder_input = self.embedding(
input_ids=encoder_input_ids, position_ids=encoder_position_ids
)
else:
# intermediate stage of pipeline
encoder_input = None
# Rotary positional embeddings
rotary_pos_emb = None
if self.position_embedding_type == 'rope':
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
inference_params, self.encoder, encoder_input, self.config
)
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
# Run encoder.
encoder_hidden_states = self.encoder(
hidden_states=encoder_input,
attention_mask=encoder_attn_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
)
## Decoder forward
# Decoder embedding.
if self.pre_process:
decoder_input = self.embedding(
input_ids=decoder_input_ids, position_ids=decoder_position_ids
)
else:
# intermediate stage of pipeline
decoder_input = None ### should it take encoder_hidden_states
# Rotary positional embeddings
rotary_pos_emb = None
if self.position_embedding_type == 'rope':
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
inference_params, self.decoder, decoder_input, self.config
)
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
# Run decoder.
decoder_hidden_states = self.decoder(
hidden_states=decoder_input,
attention_mask=decoder_attn_mask,
context=encoder_hidden_states,
context_mask=encoder_decoder_attn_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
)
# Return if not post_process
if not self.post_process:
return decoder_hidden_states
# logits and loss
output_weight = None
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()
logits = self.lm_head(decoder_hidden_states, word_embeddings_weight=output_weight)
if lm_labels is None:
# [s b h] => [b s h]
return logits.transpose(0, 1).contiguous()
loss = self.compute_language_model_loss(lm_labels, logits)
return loss
def set_input_tensor(self, input_tensor):
""" See megatron.model.transformer.set_input_tensor()"""
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
if self.add_encoder and self.add_decoder:
assert (
len(input_tensor) == 1
), 'input_tensor should only be length 1 for stage with both encoder and decoder'
self.encoder.set_input_tensor(input_tensor[0])
elif self.add_encoder:
assert (
len(input_tensor) == 1
), 'input_tensor should only be length 1 for stage with only encoder'
self.encoder.set_input_tensor(input_tensor[0])
elif self.add_decoder:
if len(input_tensor) == 2:
self.decoder.set_input_tensor(input_tensor[0])
self.encoder_hidden_state = input_tensor[1]
elif len(input_tensor) == 1:
self.decoder.set_input_tensor(None)
self.encoder_hidden_state = input_tensor[0]
else:
raise Exception('input_tensor must have either length 1 or 2')
else:
raise Exception('Stage must have at least either encoder or decoder')
def shared_embedding_or_output_weight(self) -> Tensor:
"""Function to share the input embeddings and output logit weights."""
if self.pre_process:
return self.embedding.word_embeddings.weight
elif self.post_process:
return self.lm_head.output_layer.weight
return None
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None
) -> ShardedStateDict:
assert not sharded_offsets, "Unexpected sharded offsets"
sharded_state_dict = {}
if self.pre_process:
embedding_prefix = f'{prefix}embedding.'
embedding_sharded_state_dict = self.embedding.sharded_state_dict(
prefix=embedding_prefix, metadata=metadata
)
sharded_state_dict.update(embedding_sharded_state_dict)
encoder_prefix = f'{prefix}encoder.'
encoder_sharded_state_dict = self.encoder.sharded_state_dict(
prefix=encoder_prefix, metadata=metadata
)
sharded_state_dict.update(encoder_sharded_state_dict)
decoder_prefix = f'{prefix}decoder.'
decoder_sharded_state_dict = self.decoder.sharded_state_dict(
prefix=decoder_prefix, metadata=metadata
)
sharded_state_dict.update(decoder_sharded_state_dict)
if self.post_process:
output_layer_prefix = f'{prefix}output_layer.'
output_layer_weight_key = f'{output_layer_prefix}weight'
output_layer_bias_key = f'{output_layer_prefix}bias'
if self.share_embeddings_and_output_weights:
if not self.pre_process:
# when sharing embeddings with last stage, we need to use the weights from the first stage
# on pipeline first rank, word embeddings are saved to {prefix}embedding.word_embeddings.weight
tensor = self.shared_embedding_or_output_weight()
first_stage_word_emb_key = f'{prefix}embedding.word_embeddings.weight'
dp_rank = parallel_state.get_data_parallel_rank()
dp_size = parallel_state.get_data_parallel_world_size()
last_stage_word_emb_replica_id = (
dp_rank + dp_size
) # copy of first stage embedding
sharded_output_layer_tensor = make_tp_sharded_tensor_for_checkpoint(
tensor=tensor,
key=first_stage_word_emb_key,
replica_id=last_stage_word_emb_replica_id,
allow_shape_mismatch=True,
)
sharded_state_dict[output_layer_weight_key] = sharded_output_layer_tensor
# output_layer.weight is shared, but we still need to process output_layer.bias
sharded_output_layer_tensor = make_tp_sharded_tensor_for_checkpoint(
tensor=self.lm_head.output_layer.bias,
key=output_layer_bias_key,
allow_shape_mismatch=True,
)
sharded_state_dict[output_layer_bias_key] = sharded_output_layer_tensor
else:
output_layer_state_dict = self.output_layer.state_dict(
prefix=output_layer_prefix, keep_vars=True
)
output_layer_tensor = output_layer_state_dict[output_layer_weight_key]
# independent output layer
sharded_output_layer_tensor = make_tp_sharded_tensor_for_checkpoint(
tensor=output_layer_tensor,
key=output_layer_weight_key,
replica_id=parallel_state.get_data_parallel_rank(),
allow_shape_mismatch=True,
)
sharded_state_dict[output_layer_weight_key] = sharded_output_layer_tensor
return sharded_state_dict
def t5_extended_attention_mask(attention_mask_list: List[Tensor]) -> List[Tensor]:
def attn_mask_postprocess(attn_mask):
# [b, 1, s, s]
extended_attention_mask = attn_mask.unsqueeze(1)
return extended_attention_mask
return [attn_mask_postprocess(attn_mask) for attn_mask in attention_mask_list]
def t5_position_ids(token_ids: Tensor) -> Tensor:
"""Calculate position ids from token ids
Args:
token_ids (Tensor): input tokens
Returns:
Tensor: position ids
"""
seq_length = token_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=token_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
return position_ids
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import (
CrossAttention,
CrossAttentionSubmodules,
SelfAttention,
SelfAttentionSubmodules,
)
from megatron.core.transformer.custom_layers.transformer_engine import (
TEColumnParallelLinear,
TEDotProductAttention,
TELayerNormColumnParallelLinear,
TENorm,
TERowParallelLinear,
)
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import (
TransformerBlockSubmodules,
get_num_layers_to_build,
)
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
def encoder_model_with_transformer_engine_default_spec() -> ModuleSpec:
"""T5 encoder TE spec (uses Transformer Engine components)."""
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.padding},
submodules=SelfAttentionSubmodules(
linear_qkv=TELayerNormColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
mlp=ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear,
),
),
mlp_bda=get_bias_dropout_add,
),
)
def decoder_model_with_transformer_engine_default_spec() -> ModuleSpec:
"""T5 decoder TE spec (uses Transformer Engine components)."""
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=TELayerNormColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_cross_attn_layernorm=TENorm,
cross_attention=ModuleSpec(
module=CrossAttention,
submodules=CrossAttentionSubmodules(
linear_q=TEColumnParallelLinear,
linear_kv=TEColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
),
),
cross_attn_bda=get_bias_dropout_add,
mlp=ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear,
),
),
mlp_bda=get_bias_dropout_add,
),
)
def encoder_model_with_local_spec() -> ModuleSpec:
"""T5 encoder local spec (uses Megatron-Core components)."""
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=FusedLayerNorm,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.padding},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=FusedLayerNorm,
mlp=ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear,
),
),
mlp_bda=get_bias_dropout_add,
sharded_state_dict_keys_map={
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_',
'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_',
},
),
)
def decoder_model_with_local_spec() -> ModuleSpec:
"""T5 decoder local spec (uses Megatron-Core components)."""
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=FusedLayerNorm,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_cross_attn_layernorm=FusedLayerNorm,
cross_attention=ModuleSpec(
module=CrossAttention,
submodules=CrossAttentionSubmodules(
linear_q=ColumnParallelLinear,
linear_kv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
),
),
cross_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=FusedLayerNorm,
mlp=ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear,
),
),
mlp_bda=get_bias_dropout_add,
sharded_state_dict_keys_map={
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_',
'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_',
},
),
)
def get_t5_encoder_with_transformer_engine_block_spec(
num_layers: int,
) -> TransformerBlockSubmodules:
"""T5 encoder block spec for Transformer Engine
Args:
config (TransformerConfig): config, containing number of layers for encoder
"""
layer_spec = encoder_model_with_transformer_engine_default_spec()
block_spec = TransformerBlockSubmodules([layer_spec] * num_layers)
return block_spec
def get_t5_decoder_with_transformer_engine_block_spec(
num_layers: int,
) -> TransformerBlockSubmodules:
"""T5 decoder block spec for Transformer Engine
Args:
config (TransformerConfig): config, containing number of layers for decoder
"""
layer_spec = decoder_model_with_transformer_engine_default_spec()
block_spec = TransformerBlockSubmodules([layer_spec] * num_layers)
return block_spec
def get_t5_encoder_with_local_block_spec(num_layers: int) -> TransformerBlockSubmodules:
"""T5 encoder block spec for local (uses Megatron-Core components)
Args:
num_layers (int): number of encoder layers
"""
layer_spec = encoder_model_with_local_spec()
block_spec = TransformerBlockSubmodules([layer_spec] * num_layers)
return block_spec
def get_t5_decoder_with_local_block_spec(num_layers: int) -> TransformerBlockSubmodules:
"""T5 decoder block spec for local (uses Megatron-Core components)
Args:
num_layers (int): number of decoder layers
"""
layer_spec = decoder_model_with_local_spec()
block_spec = TransformerBlockSubmodules([layer_spec] * num_layers)
return block_spec
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.custom_layers.transformer_engine import (
TEDotProductAttention,
TELayerNormColumnParallelLinear,
TERowParallelLinear,
)
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
# Use this spec to use lower level Transformer Engine modules (required for fp8 training)
bert_layer_with_transformer_engine_spec = ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.padding},
submodules=SelfAttentionSubmodules(
linear_qkv=TELayerNormColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
mlp=ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear,
),
),
mlp_bda=get_bias_dropout_add,
),
)
# Use this spec for an implementation using only modules in megatron core
bert_layer_local_spec = ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=FusedLayerNorm,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.padding},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=FusedLayerNorm,
mlp=ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear,
),
),
mlp_bda=get_bias_dropout_add,
sharded_state_dict_keys_map={
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_',
'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_',
},
),
)
import torch
from torch import Tensor
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import get_linear_layer
class BertLMHead(MegatronModule):
"""Masked LM head for Bert.
Args:
hidden_size: hidden size
config (TransformerConfig): TransformerConfig object
"""
def __init__(
self, hidden_size: int, config: TransformerConfig,
):
super().__init__(config=config)
# TODO: Should switch this to TE ?
self.dense = get_linear_layer(
hidden_size, hidden_size, config.init_method, config.perform_initialization
)
setattr(self.dense.weight, 'sequence_parallel', config.sequence_parallel)
setattr(self.dense.bias, 'sequence_parallel', config.sequence_parallel)
self.layer_norm = FusedLayerNorm(
config=config, hidden_size=hidden_size, eps=config.layernorm_epsilon,
)
self.gelu = torch.nn.functional.gelu
def forward(self, hidden_states: Tensor) -> Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.gelu(hidden_states)
hidden_states = self.layer_norm(hidden_states)
return hidden_states
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import os
from collections import OrderedDict
from typing import Dict, Literal, Optional
import torch
from torch import Tensor
from megatron.core import parallel_state, tensor_parallel
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.models.bert.bert_lm_head import BertLMHead
from megatron.core.models.bert.pooler import Pooler
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.transformer.enums import AttnMaskType, ModelType
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import get_linear_layer
from megatron.core.utils import make_tp_sharded_tensor_for_checkpoint
class BertModel(LanguageModule):
"""Transformer language model.
Args:
config (TransformerConfig): transformer config
num_tokentypes (int) : Set to 2 when args.bert_binary_head is True, and 0 otherwise. Defaults to 0.
transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers
vocab_size (int): vocabulary size
max_sequence_length (int): maximum size of sequence. This is used for positional embedding
pre_process (bool): Include embedding layer (used with pipeline parallelism)
post_process (bool): Include an output layer (used with pipeline parallelism)
parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel ranks
share_embeddings_and_output_weights (bool): When True, input embeddings and output logit weights are shared. Defaults to False.
position_embedding_type (string): Position embedding type. Options ['learned_absolute', 'rope'].
Defaults is 'learned_absolute'.
rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings.
Defaults to 1.0 (100%). Ignored unless position_embedding_type is 'rope'.
"""
def __init__(
self,
config: TransformerConfig,
num_tokentypes: int,
transformer_layer_spec: ModuleSpec,
vocab_size: int,
max_sequence_length: int,
pre_process: bool = True,
post_process: bool = True,
fp16_lm_cross_entropy: bool = False,
parallel_output: bool = True,
share_embeddings_and_output_weights: bool = False,
position_embedding_type: Literal['learned_absolute', 'rope'] = 'learned_absolute',
rotary_percent: float = 1.0,
seq_len_interpolation_factor: Optional[float] = None,
add_binary_head=True,
return_embeddings=False,
):
super(BertModel, self).__init__(config=config)
if return_embeddings:
assert self.post_process and self.add_binary_head
assert (
os.getenv('NVTE_ALLOW_NONDETERMINISTIC_ALGO') == '0'
or os.getenv('NVTE_FLASH_ATTN') == '0'
), "Bert currently does not support flash attention. Please set env variable NVTE_FLASH_ATTN=0 or set NVTE_ALLOW_NONDETERMINISTIC_ALGO=0"
self.config: TransformerConfig = config
self.transformer_layer_spec: ModuleSpec = transformer_layer_spec
self.vocab_size = vocab_size
self.max_sequence_length = max_sequence_length
self.pre_process = pre_process
self.post_process = post_process
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
self.parallel_output = parallel_output
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
self.position_embedding_type = position_embedding_type
self.add_binary_head = add_binary_head
self.return_embeddings = return_embeddings
# megatron core pipelining currently depends on model type
self.model_type = ModelType.encoder_or_decoder
# Embeddings.
if self.pre_process:
self.embedding = LanguageModelEmbedding(
config=self.config,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
position_embedding_type=position_embedding_type,
num_tokentypes=num_tokentypes,
)
if self.position_embedding_type == 'rope':
self.rotary_pos_emb = RotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=rotary_percent,
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
)
# Transformer.
self.encoder = TransformerBlock(
config=self.config,
spec=self.transformer_layer_spec,
pre_process=self.pre_process,
post_process=self.post_process,
)
# Output
if post_process:
# TODO: Make sure you are passing in the mpu_vocab_size properly
self.lm_head = BertLMHead(config.hidden_size, config,)
self.output_layer = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
self.vocab_size,
config=config,
init_method=config.init_method,
bias=True,
skip_bias_add=False,
gather_output=not self.parallel_output,
skip_weight_param_allocation=pre_process and share_embeddings_and_output_weights,
)
self.binary_head = None
if self.add_binary_head:
# TODO: Shoudl switch this to TE ?
self.binary_head = get_linear_layer(
config.hidden_size, 2, config.init_method, config.perform_initialization
)
self.pooler = Pooler(
config.hidden_size, config.init_method, config, config.sequence_parallel
)
if self.pre_process or self.post_process:
self.setup_embeddings_and_output_layer()
def bert_extended_attention_mask(self, attention_mask: Tensor) -> Tensor:
"""Creates the extended attention mask
Converts the attention mask of dimension [batch size, 1, seq len] to [batch size, 1, seq len, seq len] and makes it binary
Args:
attention_mask (Tensor): The input attention mask
Returns:
Tensor: The extended binary attention mask
"""
# We create a 3D attention mask from a 2D tensor mask.
# [b, 1, s]
attention_mask_b1s = attention_mask.unsqueeze(1)
# [b, s, 1]
attention_mask_bs1 = attention_mask.unsqueeze(2)
# [b, s, s]
attention_mask_bss = attention_mask_b1s * attention_mask_bs1
# [b, 1, s, s]
extended_attention_mask = attention_mask_bss.unsqueeze(1)
# Convert attention mask to binary:
extended_attention_mask = extended_attention_mask < 0.5
return extended_attention_mask
def bert_position_ids(self, token_ids):
# Create position ids
seq_length = token_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=token_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
return position_ids
def set_input_tensor(self, input_tensor: Tensor) -> None:
"""Sets input tensor to the model.
See megatron.model.transformer.set_input_tensor()
Args:
input_tensor (Tensor): Sets the input tensor for the model.
"""
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt/bert'
self.encoder.set_input_tensor(input_tensor[0])
def forward(
self,
input_ids: Tensor,
attention_mask: Tensor,
tokentype_ids: Tensor = None,
lm_labels: Tensor = None,
inference_params=None,
):
"""Forward function of BERT model
Forward function of the BERT Model This function passes the input tensors
through the embedding layer, and then the encoder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
"""
extended_attention_mask = self.bert_extended_attention_mask(attention_mask)
if parallel_state.is_pipeline_first_stage():
input_ids = input_ids
position_ids = self.bert_position_ids(input_ids)
else:
position_ids = None
input_ids = None
# Encoder embedding.
if self.pre_process:
encoder_input = self.embedding(
input_ids=input_ids, position_ids=position_ids, tokentype_ids=tokentype_ids
)
else:
# intermediate stage of pipeline
# encoder will get hidden_states from encoder.input_tensor
encoder_input = None
# Rotary positional embeddings (Why not move this into BERT/GPTEmberdding ?)
rotary_pos_emb = None
if self.position_embedding_type == 'rope':
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
inference_params, self.encoder, encoder_input, self.config
)
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
# Run encoder.
hidden_states = self.encoder(
hidden_states=encoder_input,
attention_mask=extended_attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
)
if not self.post_process:
return hidden_states
if self.add_binary_head:
pooled_output = self.pooler(hidden_states, 0)
if self.return_embeddings:
embeddings = torch.transpose(hidden_states, 0, 1)
masks = torch.sum(attention_mask, dim=1)
# Collect masked embeddings.
output = torch.zeros(
size=(embeddings.shape[0], embeddings.shape[2]),
dtype=torch.float32,
device=torch.cuda.current_device(),
)
for i, (embedding, mask) in enumerate(zip(embeddings, masks)):
output[i, :] = torch.mean(embedding[1 : mask - 1], dim=0)
return output
# logits and loss
output_weight = None
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()
hidden_states_after_lm_head = self.lm_head(hidden_states=hidden_states)
logits, _ = self.output_layer(hidden_states_after_lm_head, weight=output_weight)
binary_logits = None
if self.binary_head is not None:
binary_logits = self.binary_head(pooled_output)
if lm_labels is None:
# [s b h] => [b s h]
return logits.transpose(0, 1).contiguous(), binary_logits
loss = self.compute_language_model_loss(lm_labels, logits)
return loss, binary_logits
import torch
from torch import Tensor
from megatron.core import tensor_parallel
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import get_linear_layer
class Pooler(MegatronModule):
"""Pooler layer.
Pool hidden states of a specific token (for example start of the
sequence) and add a linear transformation followed by a tanh.
Args:
hidden_size (int): The hidden size_
init_method (callable): weight initialization method for the linear layer. bias is set to zero.
config (TransformerConfig): The transformer configuration
sequence_parallel (bool): Using squence parallel ? Defaults to False
"""
def __init__(
self,
hidden_size: int,
init_method: callable,
config: TransformerConfig,
sequence_parallel: bool = False,
):
super(Pooler, self).__init__(config)
# TODO: Shoudl switch this to TE ?
self.dense = get_linear_layer(
hidden_size, hidden_size, init_method, config.perform_initialization
)
self.sequence_parallel = sequence_parallel
def forward(self, hidden_states: Tensor, sequence_index=0):
# hidden_states: [s, b, h]
# sequence_index: index of the token to pool.
# gather data along sequence dimensions
# same pooler is run on all tensor parallel nodes
if self.sequence_parallel:
hidden_states = tensor_parallel.gather_from_sequence_parallel_region(
hidden_states, tensor_parallel_output_grad=False
)
pooled = hidden_states[sequence_index, :, :]
pooled = self.dense(pooled)
pooled = torch.tanh(pooled)
return pooled
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from typing import Literal
import torch
from torch import Tensor
from megatron.core import tensor_parallel
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
class LanguageModelEmbedding(MegatronModule):
"""Language model embeddings.
Args:
config (TransformerConfig): config object with all necessary configs for TransformerBlock
vocab_size (int): vocabulary size
max_sequence_length (int): maximum size of sequence. This
is used for positional embedding
add_position_embedding (bool): Add a position embedding.
embedding_dropout_prob (float): dropout probability for embeddings
num_tokentypes (int): Set to 0 without binary head, and 2 with a binary head . Defaults to 0.
"""
def __init__(
self,
config: TransformerConfig,
vocab_size: int,
max_sequence_length: int,
position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute',
num_tokentypes: int = 0,
):
super().__init__(config=config)
self.config: TransformerConfig = config
self.vocab_size: int = vocab_size
self.max_sequence_length: int = max_sequence_length
self.add_position_embedding: bool = position_embedding_type == 'learned_absolute'
self.num_tokentypes = num_tokentypes
self.reduce_scatter_embeddings = (
(not self.add_position_embedding)
and self.num_tokentypes <= 0
and self.config.sequence_parallel
)
# Word embeddings (parallel).
self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
num_embeddings=self.vocab_size,
embedding_dim=self.config.hidden_size,
init_method=self.config.init_method,
reduce_scatter_embeddings=self.reduce_scatter_embeddings,
config=self.config,
)
# Position embedding (serial).
if self.add_position_embedding:
self.position_embeddings = torch.nn.Embedding(
self.max_sequence_length, self.config.hidden_size
)
# Initialize the position embeddings.
if self.config.perform_initialization:
self.config.init_method(self.position_embeddings.weight)
if self.num_tokentypes > 0:
self.tokentype_embeddings = torch.nn.Embedding(
self.num_tokentypes, self.config.hidden_size
)
# Initialize the token-type embeddings.
if self.config.perform_initialization:
self.config.init_method(self.tokentype_embeddings.weight)
else:
self.tokentype_embeddings = None
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(self.config.hidden_dropout)
def zero_parameters(self):
"""Zero out all parameters in embedding."""
self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True
self.position_embeddings.weight.data.fill_(0)
self.position_embeddings.weight.shared = True
if self.num_tokentypes > 0:
self.tokentype_embeddings.weight.data.fill_(0)
self.tokentype_embeddings.weight.shared = True
def forward(self, input_ids: Tensor, position_ids: Tensor, tokentype_ids: int = None) -> Tensor:
"""Forward pass of the embedding module.
Args:
input_ids (Tensor): The input tokens
position_ids (Tensor): The position id's used to calculate position embeddings
tokentype_ids (int): The token type ids. Used when args.bert_binary_head is set to True. Defaults to None
Returns:
Tensor: The output embeddings
"""
word_embeddings = self.word_embeddings(input_ids)
if self.add_position_embedding:
position_embeddings = self.position_embeddings(position_ids)
embeddings = word_embeddings + position_embeddings
else:
embeddings = word_embeddings
if not self.reduce_scatter_embeddings:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
embeddings = embeddings.transpose(0, 1).contiguous()
if tokentype_ids is not None:
assert self.tokentype_embeddings is not None
# [b s h] -> [s b h] (So that it can be added with embeddings)
tokentype_embedding = self.tokentype_embeddings(tokentype_ids).permute(1, 0, 2)
embeddings = embeddings + tokentype_embedding
else:
assert self.tokentype_embeddings is None
# If the input flag for fp32 residual connection is set, convert for float.
if self.config.fp32_residual_connection:
embeddings = embeddings.float()
# Dropout.
if self.config.sequence_parallel:
if not self.reduce_scatter_embeddings:
embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings)
# `scatter_to_sequence_parallel_region` returns a view, which prevents
# the original tensor from being garbage collected. Clone to facilitate GC.
# Has a small runtime cost (~0.5%).
if self.config.clone_scatter_output_in_embedding:
embeddings = embeddings.clone()
with tensor_parallel.get_cuda_rng_tracker().fork():
embeddings = self.embedding_dropout(embeddings)
else:
embeddings = self.embedding_dropout(embeddings)
return embeddings
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from __future__ import annotations
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_block import TransformerBlock
import logging
import torch
from torch import Tensor, nn
from megatron.core import parallel_state
logger = logging.getLogger(__name__)
try:
from apex.transformer.functional import (
fused_apply_rotary_pos_emb,
fused_apply_rotary_pos_emb_thd,
)
HAVE_APPLY_ROPE_FUSION = True
except:
HAVE_APPLY_ROPE_FUSION = False
__all__ = ['RotaryEmbedding', 'apply_rotary_pos_emb']
def get_pos_emb_on_this_cp_rank(pos_emb, seq_dim):
cp_size = parallel_state.get_context_parallel_world_size()
cp_rank = parallel_state.get_context_parallel_rank()
cp_idx = torch.tensor(
[cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True
).cuda(non_blocking=True)
pos_emb = pos_emb.view(
*pos_emb.shape[:seq_dim], 2 * cp_size, -1, *pos_emb.shape[(seq_dim + 1) :]
)
pos_emb = pos_emb.index_select(seq_dim, cp_idx)
pos_emb = pos_emb.view(*pos_emb.shape[:seq_dim], -1, *pos_emb.shape[(seq_dim + 2) :])
return pos_emb
class RotaryEmbedding(nn.Module):
"""Rotary Embedding for language model.
Args:
kv_channels (int): Projection weights dimension in multi-head attention. Obtained from transformer config
rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings.
seq_len_interpolation_factor (float, optional): scale of linearly interpolating RoPE for longer sequences. The value must be a float larger than 1.0. Defaults to None
rotary_base (int, optional): Base period for rotary position embeddings. Defaults to 10000.
"""
def __init__(
self,
kv_channels: int,
rotary_percent: float,
rotary_interleaved: bool = False,
seq_len_interpolation_factor: float = None,
rotary_base: int = 10000,
) -> None:
super().__init__()
dim = kv_channels
if rotary_percent < 1.0:
dim = int(dim * rotary_percent)
self.rotary_interleaved = rotary_interleaved
self.seq_len_interpolation_factor = seq_len_interpolation_factor
self.inv_freq = 1.0 / (
rotary_base
** (
torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device())
/ dim
)
)
def forward(self, max_seq_len: int, offset: int = 0) -> Tensor:
"""Forward pass of RoPE embedding.
Args:
max_seq_len (int): Maximum size of sequence
offset (int, optional): _description_. Defaults to 0.
Returns:
Tensor: Embeddings after applying RoPE.
"""
seq = (
torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ offset
)
if self.seq_len_interpolation_factor is not None:
seq *= 1 / self.seq_len_interpolation_factor
freqs = torch.outer(seq, self.inv_freq)
# first part even vector components, second part odd vector components,
# 2 * dim in dimension size
if not self.rotary_interleaved:
emb = torch.cat((freqs, freqs), dim=-1)
else:
emb = torch.stack((freqs.view(-1, 1), freqs.view(-1, 1)), dim=-1).view(
freqs.shape[0], -1
)
# emb [seq_length, .., dim]
emb = emb[:, None, None, :]
if parallel_state.get_context_parallel_world_size() > 1:
# slice rotary_pos_emb along sequence dimension and select the parition of the current CP rank
emb = get_pos_emb_on_this_cp_rank(emb, 0)
return emb
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
state_dict.pop(f'{prefix}inv_freq', None)
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
def get_rotary_seq_len(
self,
inference_params,
transformer: TransformerBlock,
transformer_input: Tensor,
transformer_config: TransformerConfig,
) -> float:
"""Function to get the rotary sequence length.
Args:
inference_params : Used during Inference time
transformer (TransformerBlock): The transformer block (decoder/encoder) used by the model
transformer_input (Tensor): _description_
transformer_config (TransformerConfig): Transformer config used by the model
Returns:
float: The rotary sequence length
"""
if inference_params is not None:
rotary_seq_len = inference_params.max_sequence_length
else:
if transformer.input_tensor is not None:
rotary_seq_len = transformer.input_tensor.size(0)
else:
rotary_seq_len = transformer_input.size(0)
if transformer_config.sequence_parallel:
rotary_seq_len *= transformer_config.tensor_model_parallel_size
rotary_seq_len *= transformer_config.context_parallel_size
return rotary_seq_len
def _rotate_half(x: Tensor, rotary_interleaved: bool) -> Tensor:
"""Change sign so the last dimension becomes [-odd, +even]
Args:
x (Tensor): Input tensor
Returns:
Tensor: Tensor rotated half
"""
if not rotary_interleaved:
x1, x2 = torch.chunk(x, 2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x_new = torch.stack((-x2, x1), dim=-1)
return x_new.view(x_new.shape[0], x_new.shape[1], x_new.shape[2], -1)
def apply_rotary_pos_emb_bshd(t: Tensor, freqs: Tensor, rotary_interleaved: bool = False) -> Tensor:
"""Apply rotary positional embedding to input tensor T.
check https://kexue.fm/archives/8265 for detailed formulas
Args:
t (Tensor): Input tensor T is of shape [seq_length, ... , dim]
freqs (Tensor): Rotary Positional embedding tensor freq is of shape [seq_length, ..., dim]
Returns:
Tensor: The input tensor after applying RoPE
"""
rot_dim = freqs.shape[-1]
# ideally t_pass is empty so rotary pos embedding is applied to all tensor t
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
# first part is cosine component
# second part is sine component, need to change signs with _rotate_half method
cos_ = torch.cos(freqs).to(t.dtype)
sin_ = torch.sin(freqs).to(t.dtype)
t = (t * cos_) + (_rotate_half(t, rotary_interleaved) * sin_)
return torch.cat((t, t_pass), dim=-1)
def apply_rotary_pos_emb_thd(
t: Tensor, cu_seqlens: Tensor, freqs: Tensor, rotary_interleaved: bool = False
) -> Tensor:
"""A baseline implementation of applying RoPE for `thd` format.
Args:
t (Tensor): Input tensor T is of shape [t, h, d]
cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`,
with shape [b + 1] and dtype torch.int32.
freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d]
Returns:
Tensor: Shape [t, h, d]. The input tensor after applying RoPE.
"""
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return torch.cat(
[
apply_rotary_pos_emb_bshd(x.unsqueeze(1), freqs[: x.size(0)])
for x in torch.split(t, seqlens)
]
).squeeze(1)
def apply_rotary_pos_emb(
t: Tensor, freqs: Tensor, config: TransformerConfig, cu_seqlens: Optional[Tensor] = None,
):
"""
Reroute to the appropriate apply_rotary_pos_emb function depending on
fused/unfused kernels, or bshd (conventional) / thd (packed seq) format
"""
if config.apply_rope_fusion and not HAVE_APPLY_ROPE_FUSION:
# setting apply_rope_fusion in config to False so that subsequent queries to this config also return False
config.apply_rope_fusion = False
if not getattr(apply_rotary_pos_emb, "printed_fused_warning", False):
logger.warning(
"Setting apply_rope_fusion to false because its implementation"
" is not included in Apex. Try upgrading to the latest version"
)
apply_rotary_pos_emb.printed_fused_warning = True
if config.apply_rope_fusion:
if cu_seqlens is None:
return fused_apply_rotary_pos_emb(t, freqs, transpose_output_memory=True)
else:
return fused_apply_rotary_pos_emb_thd(t, cu_seqlens, freqs)
else:
if cu_seqlens is None:
return apply_rotary_pos_emb_bshd(t, freqs, rotary_interleaved=config.rotary_interleaved)
else:
return apply_rotary_pos_emb_thd(
t, cu_seqlens, freqs, rotary_interleaved=config.rotary_interleaved
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment