Commit 66c3d7c9 authored by sdwldchl's avatar sdwldchl
Browse files

rewrite mtp

parent 1f7b14ab
from typing import Literal
import torch
from torch import Tensor
from megatron.core import tensor_parallel
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
def language_model_embedding_init_func(
self,
config: TransformerConfig,
vocab_size: int,
max_sequence_length: int,
position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute',
num_tokentypes: int = 0,
scatter_to_sequence_parallel: bool = True,
skip_weight_param_allocation: bool = False
):
"""Patch language model embeddings init."""
super(LanguageModelEmbedding, self).__init__(config=config)
self.config: TransformerConfig = config
self.vocab_size: int = vocab_size
self.max_sequence_length: int = max_sequence_length
self.add_position_embedding: bool = position_embedding_type == 'learned_absolute'
self.num_tokentypes = num_tokentypes
self.scatter_to_sequence_parallel = scatter_to_sequence_parallel
self.reduce_scatter_embeddings = (
(not self.add_position_embedding)
and self.num_tokentypes <= 0
and self.config.sequence_parallel
and self.scatter_to_sequence_parallel
)
# Word embeddings (parallel).
self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
num_embeddings=self.vocab_size,
embedding_dim=self.config.hidden_size,
init_method=self.config.init_method,
reduce_scatter_embeddings=self.reduce_scatter_embeddings,
config=self.config,
skip_weight_param_allocation=skip_weight_param_allocation
)
# Position embedding (serial).
if self.add_position_embedding:
self.position_embeddings = torch.nn.Embedding(
self.max_sequence_length, self.config.hidden_size
)
# Initialize the position embeddings.
if self.config.perform_initialization:
self.config.init_method(self.position_embeddings.weight)
if self.num_tokentypes > 0:
self.tokentype_embeddings = torch.nn.Embedding(
self.num_tokentypes, self.config.hidden_size
)
# Initialize the token-type embeddings.
if self.config.perform_initialization:
self.config.init_method(self.tokentype_embeddings.weight)
else:
self.tokentype_embeddings = None
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(self.config.hidden_dropout)
def language_model_embedding_forward(self,
input_ids: Tensor,
position_ids: Tensor,
tokentype_ids: int = None,
weight: Tensor = None) -> Tensor:
"""Pacth forward pass of the embedding module.
Args:
input_ids (Tensor): The input tokens
position_ids (Tensor): The position id's used to calculate position embeddings
tokentype_ids (int): The token type ids. Used when args.bert_binary_head is
set to True. Defaults to None
weight (Tensor): embedding weight
Returns:
Tensor: The output embeddings
"""
if weight is None:
if self.word_embeddings.weight is None:
raise RuntimeError(
"weight was not supplied to VocabParallelEmbedding forward pass "
"and skip_weight_param_allocation is True."
)
weight = self.word_embeddings.weight
word_embeddings = self.word_embeddings(input_ids, weight)
if self.add_position_embedding:
position_embeddings = self.position_embeddings(position_ids)
embeddings = word_embeddings + position_embeddings
else:
embeddings = word_embeddings
if not self.reduce_scatter_embeddings:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
embeddings = embeddings.transpose(0, 1).contiguous()
if tokentype_ids is not None:
assert self.tokentype_embeddings is not None
# [b s h] -> [s b h] (So that it can be added with embeddings)
tokentype_embedding = self.tokentype_embeddings(tokentype_ids).permute(1, 0, 2)
embeddings = embeddings + tokentype_embedding
else:
assert self.tokentype_embeddings is None
# If the input flag for fp32 residual connection is set, convert for float.
if self.config.fp32_residual_connection:
embeddings = embeddings.float()
# Dropout.
if self.config.sequence_parallel:
if not self.reduce_scatter_embeddings and self.scatter_to_sequence_parallel:
embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings)
# `scatter_to_sequence_parallel_region` returns a view, which prevents
# the original tensor from being garbage collected. Clone to facilitate GC.
# Has a small runtime cost (~0.5%).
if self.config.clone_scatter_output_in_embedding and self.scatter_to_sequence_parallel:
embeddings = embeddings.clone()
with tensor_parallel.get_cuda_rng_tracker().fork():
embeddings = self.embedding_dropout(embeddings)
else:
embeddings = self.embedding_dropout(embeddings)
return embeddings
import warnings
from megatron.core.tensor_parallel import ColumnParallelLinear
from megatron.core.transformer import ModuleSpec
from .multi_token_predictor import (
MultiTokenPredicationSubmodules,
MultiTokenPredictor
)
try:
from megatron.core.extensions.transformer_engine import (
TEColumnParallelLinear,
TENorm
)
HAVE_TE = True
except ImportError:
HAVE_TE = False
try:
import apex
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
LNImpl = FusedLayerNorm
except ImportError:
from megatron.core.transformer.torch_norm import WrappedTorchNorm
warnings.warn('Apex is not installed. Falling back to Torch Norm')
LNImpl = WrappedTorchNorm
def get_mtp_spec(transformer_layer, use_te=False):
"""
Multi Token Predication Layer Specification.
"""
use_te = use_te & HAVE_TE
mtp_spec = ModuleSpec(
module=MultiTokenPredictor,
submodules=MultiTokenPredicationSubmodules(
embedding=None,
enorm=TENorm if use_te else LNImpl,
hnorm=TENorm if use_te else LNImpl,
eh_proj=TEColumnParallelLinear if use_te else ColumnParallelLinear,
transformer_layer=transformer_layer,
final_layernorm=TENorm if use_te else LNImpl,
output_layer=None,
)
)
return mtp_spec
import os
import logging
from dataclasses import dataclass
from typing import Union, Optional, Literal
import torch
from torch import Tensor
from megatron.core import tensor_parallel, InferenceParams
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer.module import MegatronModule
from megatron.core.extensions.transformer_engine import TEColumnParallelLinear
from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy
from megatron.core.transformer import ModuleSpec, TransformerConfig, build_module
from ...tensor_parallel.random import CheckpointWithoutOutput
from ...tensor_parallel import FluxColumnParallelLinear
@dataclass
class MultiTokenPredicationSubmodules:
embedding: Union[ModuleSpec, type] = None
output_layer: Union[ModuleSpec, type] = None
eh_proj: Union[ModuleSpec, type] = None
enorm: Union[ModuleSpec, type] = None
hnorm: Union[ModuleSpec, type] = None
transformer_layer: Union[ModuleSpec, type] = None
final_layernorm: Union[ModuleSpec, type] = None
class MultiTokenPredictor(MegatronModule):
def __init__(
self,
config: TransformerConfig,
submodules: MultiTokenPredicationSubmodules,
vocab_size: int,
max_sequence_length: int,
layer_number: int = 1,
hidden_dropout: float = None,
pre_process: bool = True,
fp16_lm_cross_entropy: bool = False,
parallel_output: bool = True,
position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute',
rotary_percent: float = 1.0,
rotary_base: int = 10000,
seq_len_interpolation_factor: Optional[float] = None,
recompute_mtp_norm=False,
recompute_mtp_layer=False,
add_output_layer_bias=False
):
super().__init__(config=config)
self.config = config
self.submodules = submodules
self.layer_number = layer_number
self.hidden_dropout = hidden_dropout
self.hidden_size = self.config.hidden_size
self.vocab_size = vocab_size
self.max_sequence_length = max_sequence_length
self.pre_process = pre_process
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
self.parallel_output = parallel_output
self.position_embedding_type = position_embedding_type
self.recompute_layer_norm = recompute_mtp_norm
self.recompute_mtp_layer = recompute_mtp_layer
self.add_output_layer_bias = add_output_layer_bias
if self.position_embedding_type == 'rope':
self.rotary_pos_emb = RotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=rotary_percent,
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
use_cpu_initialization=self.config.use_cpu_initialization,
)
self.enorm = build_module(
self.submodules.enorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
self.hnorm = build_module(
self.submodules.hnorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
self.eh_proj = build_module(
self.submodules.eh_proj,
self.hidden_size + self.hidden_size,
self.hidden_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear,
skip_bias_add=True,
is_expert=False,
tp_comm_buffer_name='eh',
)
self.transformer_layer = build_module(
self.submodules.transformer_layer,
config=self.config,
)
if self.submodules.final_layernorm:
self.final_layernorm = build_module(
self.submodules.final_layernorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
else:
self.final_layernorm = None
if self.config.defer_embedding_wgrad_compute:
self.embedding_activation_buffer = []
self.grad_output_buffer = []
else:
self.embedding_activation_buffer = None
self.grad_output_buffer = None
def forward(
self,
hidden_input_ids: Tensor,
embed_input_ids: Tensor,
position_ids: Tensor,
attention_mask: Tensor,
labels: Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
embedding_layer=None,
output_layer=None,
output_weight=None
):
"""Forward function of the MTP module"""
# Decoder embedding.
decoder_input = embedding_layer(
input_ids=embed_input_ids,
position_ids=position_ids,
)
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb = None
if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
if inference_params is not None:
rotary_seq_len = inference_params.max_sequence_length
else:
rotary_seq_len = decoder_input.size(0)
if self.config.sequence_parallel:
rotary_seq_len *= self.config.tensor_model_parallel_size
rotary_seq_len *= self.config.context_parallel_size
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
if self.recompute_layer_norm:
self.enorm_ckpt = CheckpointWithoutOutput()
enorm_output = self.enorm_ckpt.checkpoint(self.enorm, False, decoder_input)
self.hnorm_ckpt = CheckpointWithoutOutput()
hnorm_output = self.hnorm_ckpt.checkpoint(self.hnorm, False, hidden_input_ids)
else:
enorm_output = self.enorm(decoder_input)
hnorm_output = self.hnorm(hidden_input_ids)
# [s, b, h] -> [s, b, 2h]
hidden_states = torch.concat(
[hnorm_output,
enorm_output],
dim=-1
)
if self.recompute_layer_norm:
self.enorm_ckpt.discard_output()
self.hnorm_ckpt.discard_output()
hidden_states.register_hook(self.enorm_ckpt.recompute)
hidden_states.register_hook(self.hnorm_ckpt.recompute)
# hidden_states -> [s, b, h]
hidden_states, _ = self.eh_proj(hidden_states)
if self.config.tensor_model_parallel_size > 1:
hidden_states = tensor_parallel.gather_from_tensor_model_parallel_region(hidden_states)
if self.config.sequence_parallel:
hidden_states = tensor_parallel.scatter_to_sequence_parallel_region(hidden_states)
if self.recompute_mtp_layer:
hidden_states, context = tensor_parallel.checkpoint(
self.transformer_layer,
self.config.distribute_saved_activations,
hidden_states,
attention_mask,
None,
None,
rotary_pos_emb,
inference_params,
packed_seq_params,
)
else:
hidden_states, _ = self.transformer_layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
rotary_pos_emb=rotary_pos_emb,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
**(extra_block_kwargs or {}),
)
# Final layer norm.
if self.final_layernorm is not None:
if self.recompute_layer_norm:
self.finalnorm_ckpt = CheckpointWithoutOutput()
finalnorm_output = self.finalnorm_ckpt.checkpoint(self.final_layernorm, False, hidden_states)
else:
finalnorm_output = self.final_layernorm(hidden_states)
else:
finalnorm_output = hidden_states
logits, _ = output_layer(finalnorm_output, weight=output_weight)
if self.recompute_layer_norm:
self.finalnorm_ckpt.discard_output()
logits.register_hook(self.finalnorm_ckpt.recompute)
if labels is None:
# [s b h] => [b s h]
return logits.transpose(0, 1).contiguous()
loss = self.compute_language_model_loss(labels, logits)
return hidden_states, loss
def compute_language_model_loss(self, labels: Tensor, logits: Tensor) -> Tensor:
"""Computes the language model loss (Cross entropy across vocabulary)
Args:
labels (Tensor): The labels of dimension [batch size, seq length]
logits (Tensor): The final logits returned by the output layer of the transformer model
Returns:
Tensor: Loss tensor of dimensions [batch size, sequence_length]
"""
# [b s] => [s b]
labels = labels.transpose(0, 1).contiguous()
if self.config.cross_entropy_loss_fusion:
loss = fused_vocab_parallel_cross_entropy(logits, labels)
else:
loss = tensor_parallel.vocab_parallel_cross_entropy(logits, labels)
# [s b] => [b, s]
loss = loss.transpose(0, 1).contiguous()
return loss
\ No newline at end of file
......@@ -30,33 +30,3 @@ def is_flux_min_version(version, check_equality=True):
if check_equality:
return get_flux_version() >= PkgVersion(version)
return get_flux_version() > PkgVersion(version)
def tensor_slide(
tensor: Optional[torch.Tensor],
num_slice: int,
dims: Union[int, List[int]] = -1,
step: int = 1,
return_first=False,
) -> List[Union[torch.Tensor, None]]:
"""通用滑动窗口函数,支持任意维度"""
if tensor is None:
# return `List[None]` to avoid NoneType Error
return [None] * (num_slice + 1)
if num_slice == 0:
return [tensor]
window_size = tensor.shape[-1] - num_slice
dims = [dims] if isinstance(dims, int) else sorted(dims, reverse=True)
# 连续多维度滑动
slices = []
for i in range(0, tensor.size(dims[-1]) - window_size + 1, step):
slice_obj = [slice(None)] * tensor.dim()
for dim in dims:
slice_obj[dim] = slice(i, i + window_size)
slices.append(tensor[tuple(slice_obj)])
if return_first:
return slices
return slices
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment